From 5d3dc047e95213ee773416782d100932256d7440 Mon Sep 17 00:00:00 2001 From: Fantix King Date: Wed, 19 Dec 2018 18:40:43 +0800 Subject: [PATCH] Fixed #412, moved `after_execute` after `async_execute`, added buggy `rowcount`. --- gino/dialects/asyncpg.py | 20 ++++++++++++++++++++ gino/dialects/base.py | 22 ++++++++++++++++++---- gino/engine.py | 27 +++++++++++++++++++++++++++ tests/test_engine.py | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 4 deletions(-) diff --git a/gino/dialects/asyncpg.py b/gino/dialects/asyncpg.py index 97b53fac..f946d3b6 100644 --- a/gino/dialects/asyncpg.py +++ b/gino/dialects/asyncpg.py @@ -137,6 +137,7 @@ def __init__(self, dbapi_conn): self._conn = dbapi_conn self._attributes = None self._status = None + self._rowcount = 0 async def prepare(self, context, clause=None): timeout = context.timeout @@ -183,12 +184,31 @@ def executor(state, timeout_): self._attributes = [] if not many: result, self._status = result[:2] + try: + # Refs https://git.io/fphKg + parts = self._status.split() + if parts[0] in (b'INSERT', b'SELECT', b'UPDATE', b'DELETE', + b'MOVE', b'FETCH', b'COPY'): + self._rowcount += int(parts[-1]) + except (AttributeError, IndexError, ValueError): + pass return result @property def description(self): return [((a[0], a[1][0]) + (None,) * 5) for a in self._attributes] + @property + def rowcount(self): + """Simulate DB-API rowcount. + + This has several known issues: + * Execute with limit (e.g. first() or scalar()) may not trigger a + CommandComplete, thus rowcount is always 0 + * It's always 0 for executemany(), iterate() and prepare() + """ + return self._rowcount + def get_statusmsg(self): return self._status.decode() diff --git a/gino/dialects/base.py b/gino/dialects/base.py index 18ae9e67..73161b56 100644 --- a/gino/dialects/base.py +++ b/gino/dialects/base.py @@ -31,6 +31,10 @@ def executemany(self, statement, parameters): def description(self): raise NotImplementedError + @property + def rowcount(self): + raise NotImplementedError + async def prepare(self, context, clause=None): raise NotImplementedError @@ -198,7 +202,7 @@ async def execute(self, one=False, return_model=True, status=False): cursor = context.cursor if context.executemany: - return await cursor.async_execute( + rv = await cursor.async_execute( context.statement, context.timeout, param_groups, many=True) else: @@ -213,19 +217,29 @@ async def execute(self, one=False, return_model=True, status=False): item = None if status: item = cursor.get_statusmsg(), item - return item + rv = item + context.root_connection.after_execute() + return rv def iterate(self): if self._context.executemany: raise ValueError('too many multiparams') - return _IterableCursor(self._context) + rv = _IterableCursor(self._context) + self._context.root_connection.after_execute() + return rv async def prepare(self, clause): - return await self._context.cursor.prepare(self._context, clause) + rv = await self._context.cursor.prepare(self._context, clause) + self._context.root_connection.after_execute() + return rv def _soft_close(self): pass + @property + def rowcount(self): + return self.context.rowcount + class Cursor: async def many(self, n, *, timeout=DEFAULT): diff --git a/gino/engine.py b/gino/engine.py index 1ff1f1bd..ecbf1e82 100644 --- a/gino/engine.py +++ b/gino/engine.py @@ -106,8 +106,35 @@ def keys(self): _bypass_no_param = _bypass_no_param() +class _InterceptedListener: + __slots__ = '_orig', '_args' + + def __init__(self, orig): + self._orig = orig + self._args = None + + def __getattr__(self, item): + return getattr(self._orig, item) + + def __call__(self, *args, **kwargs): + self._args = args, kwargs + + def call(self): + if self._args: + return self._orig(*self._args[0], **self._args[1]) + + # noinspection PyAbstractClass class _SAConnection(Connection): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.dispatch.after_execute = _InterceptedListener( + self.dispatch.after_execute) + + def after_execute(self): + if self._has_events or self.engine._has_events: + self.dispatch.after_execute.call() + def _execute_context(self, dialect, constructor, statement, parameters, *args): diff --git a/tests/test_engine.py b/tests/test_engine.py index 7c656648..f7d6a402 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -392,3 +392,38 @@ async def task(): await asyncio.gather(*[task() for _ in range(5)]) assert bind._ctx.get() is None + + +async def test_issue_412(bind): + sql = 'SELECT now()' + + @sa.event.listens_for(bind._sa_engine, 'after_execute') + def after_exec(conn, clauseelement, multiparams, params, result): + nonlocal rowcount + rowcount = result.rowcount + + for i in range(4): + rowcount = None + await bind.all(sql) + assert rowcount == 1 + + rowcount = None + await bind.first(sql) + assert rowcount == 0 + + rowcount = None + await bind.all(sql, [(), ()]) + assert rowcount == 0 + + async with bind.transaction() as tx: + rowcount = None + async for _ in bind.iterate(sql): + assert rowcount == 0 + + rowcount = None + stmt = await tx.connection.prepare(sql) + assert rowcount == 0 + + rowcount = None + await stmt.all() + assert rowcount is None