用mock测试service

requirements.txt

sqlalchemy
mysql-connector

index.py

from lessweb import Application
from lessweb.plugin import database
from controller import list_reply

database.init(dburi='mysql+mysqlconnector://root:pwd@localhost/db')
app = Application()
app.add_get_mapping('/reply/list', list_reply)

if __name__ == '__main__':
    app.run()

model.py

from datetime import datetime
from sqlalchemy import Column, Integer, String, Text, DateTime
from lessweb import Model
from lessweb.plugin.database import DbModel

class Pager(Model):
    pageNo: int = 1
    pageSize: int = 10
    total: int = 0
    totalPage: int = 1

    def slice(self, ordered_query):
        self.total = ordered_query.count()
        self.pageNo = max(self.pageNo, 1)
        self.pageSize = min(max(self.pageSize, 1), 200)
        self.totalPage = (self.total + self.pageSize - 1) // self.pageSize
        return ordered_query.offset((self.pageNo - 1) * self.pageSize).limit(self.pageSize).all()

class Reply(Model):
    id: int
    nickname: str
    age: int
    message: str
    create_at: datetime

class TblReply(DbModel):
    __tablename__ = 'reply'
    id = Column(Integer, primary_key=True)
    nickname = Column(String(200))
    age = Column(Integer)
    gender = Column(Integer)
    message = Column(Text)
    create_at = Column(DateTime)

controller.py

from lessweb.plugin.database import DbServ

import service
from model import Reply, Pager

def list_reply(serv:DbServ, reply:Reply, pager:Pager):
    replys = service.list_reply(serv, reply, pager)
    return {'code': 0, 'list': replys, 'page': pager}

service.py

from sqlalchemy import desc
from lessweb.plugin.database import DbServ, cast_models
from model import Reply, TblReply, Pager

def list_reply(serv:DbServ, reply:Reply, pager: Pager):
    def get_filters():
        if reply.nickname: yield TblReply.nickname == reply.nickname
        if reply.age: yield TblReply.age == reply.age

    query = serv.db.query(TblReply).filter(*get_filters()).order_by(desc(TblReply.id))
    rows = pager.slice(ordered_query=query)
    replys = cast_models(Reply, rows)
    return replys


test.py

from typing import cast
from unittest import TestCase
from unittest.mock import ANY, DEFAULT, patch
from lessweb import Storage, ChainMock
from lessweb.plugin.database import DbServ

from controller import list_reply
from service import list_reply as service_list_reply
from model import Reply, Pager, TblReply, datetime

class TestListReplyController(TestCase):
    @patch('controller.service.list_reply')
    def test_list_reply(self, list_reply_mock):
        list_reply_mock.side_effect = (
            lambda _, a, b: (
                self.assertEqual((a.nickname, a.age), ('nn', 33)),
                self.assertEqual((b.pageNo, b.pageSize), (3, 4)),
            ) and []
        )
        pager = Pager()
        pager.pageNo = 3; pager.pageSize = 4
        reply = Reply()
        reply.nickname = 'nn'; reply.age = 33
        # CALL controller.list_reply
        serv = cast(DbServ, 'serv')
        ret = list_reply(serv, reply, pager)
        self.assertEqual(ret, {'code': 0, 'list': [], 'page': pager})
        list_reply_mock.assert_any_call(serv, ANY, ANY)

class TestListReplyService(TestCase):
    def setUp(self):
        self.serv = cast(DbServ, Storage())

    def test_list_reply_success(self):
        tbl_reply = TblReply(id=1, nickname='qq', age=25, message='cc', create_at=datetime(2015, 1, 31, 0, 0))

        mock = ChainMock('query.filter.order_by.offset.limit.all', [tbl_reply]).join('query.filter.order_by.count', 3)
        mock('query.filter.order_by').side_effect = lambda a: self.assertEqual(str(a), 'reply.id DESC') or DEFAULT
        mock('query.filter').side_effect = (
            lambda a, b: (
                self.assertEqual(str(a), 'reply.nickname = :nickname_1'),
                self.assertEqual(a.compile().params, {'nickname_1': 'aa'}),
                self.assertEqual(str(b), 'reply.age = :age_1'),
                self.assertEqual(b.compile().params, {'age_1': 25}),
            ) and DEFAULT
        )

        self.serv.db = Storage(query=mock('query'))
        reply = Reply()
        reply.nickname = 'aa'; reply.age = 25
        reply_expect = Reply()
        reply_expect.id = 1; reply_expect.nickname = 'qq'; reply_expect.age = 25; reply_expect.message = 'cc'
        reply_expect.create_at = datetime(2015, 1, 31, 0, 0)
        pager = Pager()
        pager.pageNo = 2; pager.pageSize = 10
        # CALL service.list_reply
        replys = service_list_reply(self.serv, reply, pager)
        self.assertEqual(replys, [reply_expect])
        self.assertEqual(pager.total, 3)
        mock('query.filter.order_by.count').assert_called_once_with()
        mock('query.filter.order_by.offset').assert_called_once_with(10)
        mock('query.filter.order_by.offset.limit').assert_called_once_with(10)
        mock('query.filter.order_by.offset.limit.all').assert_called_once_with()

测试命令: nosetests test.py
测试结果:

..
----------------------------------------------------------------------
Ran 2 tests in 0.004s

OK

参考文档