Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moved after_execute after async_execute, added buggy rowcount #413

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions gino/dialects/asyncpg.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ def __init__(self, dbapi_conn):
self._conn = dbapi_conn
self._attributes = None
self._status = None
self._rowcount = 0

async def prepare(self, context, clause=None):
timeout = context.timeout
Expand Down Expand Up @@ -183,12 +184,31 @@ def executor(state, timeout_):
self._attributes = []
if not many:
result, self._status = result[:2]
try:
# Refs https://git.io/fphKg
parts = self._status.split()
if parts[0] in (b'INSERT', b'SELECT', b'UPDATE', b'DELETE',
b'MOVE', b'FETCH', b'COPY'):
self._rowcount += int(parts[-1])
except (AttributeError, IndexError, ValueError):
pass
return result

@property
def description(self):
return [((a[0], a[1][0]) + (None,) * 5) for a in self._attributes]

@property
def rowcount(self):
"""Simulate DB-API rowcount.

This has several known issues:
* Execute with limit (e.g. first() or scalar()) may not trigger a
CommandComplete, thus rowcount is always 0
* It's always 0 for executemany(), iterate() and prepare()
"""
return self._rowcount

def get_statusmsg(self):
return self._status.decode()

Expand Down
22 changes: 18 additions & 4 deletions gino/dialects/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ def executemany(self, statement, parameters):
def description(self):
raise NotImplementedError

@property
def rowcount(self):
raise NotImplementedError

async def prepare(self, context, clause=None):
raise NotImplementedError

Expand Down Expand Up @@ -198,7 +202,7 @@ async def execute(self, one=False, return_model=True, status=False):

cursor = context.cursor
if context.executemany:
return await cursor.async_execute(
rv = await cursor.async_execute(
context.statement, context.timeout, param_groups,
many=True)
else:
Expand All @@ -213,19 +217,29 @@ async def execute(self, one=False, return_model=True, status=False):
item = None
if status:
item = cursor.get_statusmsg(), item
return item
rv = item
context.root_connection.after_execute()
return rv

def iterate(self):
if self._context.executemany:
raise ValueError('too many multiparams')
return _IterableCursor(self._context)
rv = _IterableCursor(self._context)
self._context.root_connection.after_execute()
return rv

async def prepare(self, clause):
return await self._context.cursor.prepare(self._context, clause)
rv = await self._context.cursor.prepare(self._context, clause)
self._context.root_connection.after_execute()
return rv

def _soft_close(self):
pass

@property
def rowcount(self):
return self.context.rowcount


class Cursor:
async def many(self, n, *, timeout=DEFAULT):
Expand Down
27 changes: 27 additions & 0 deletions gino/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,35 @@ def keys(self):
_bypass_no_param = _bypass_no_param()


class _InterceptedListener:
__slots__ = '_orig', '_args'

def __init__(self, orig):
self._orig = orig
self._args = None

def __getattr__(self, item):
return getattr(self._orig, item)

def __call__(self, *args, **kwargs):
self._args = args, kwargs

def call(self):
if self._args:
return self._orig(*self._args[0], **self._args[1])


# noinspection PyAbstractClass
class _SAConnection(Connection):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dispatch.after_execute = _InterceptedListener(
self.dispatch.after_execute)

def after_execute(self):
if self._has_events or self.engine._has_events:
self.dispatch.after_execute.call()

def _execute_context(self, dialect, constructor,
statement, parameters,
*args):
Expand Down
35 changes: 35 additions & 0 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,3 +392,38 @@ async def task():
await asyncio.gather(*[task() for _ in range(5)])

assert bind._ctx.get() is None


async def test_issue_412(bind):
sql = 'SELECT now()'

@sa.event.listens_for(bind._sa_engine, 'after_execute')
def after_exec(conn, clauseelement, multiparams, params, result):
nonlocal rowcount
rowcount = result.rowcount

for i in range(4):
rowcount = None
await bind.all(sql)
assert rowcount == 1

rowcount = None
await bind.first(sql)
assert rowcount == 0

rowcount = None
await bind.all(sql, [(), ()])
assert rowcount == 0

async with bind.transaction() as tx:
rowcount = None
async for _ in bind.iterate(sql):
assert rowcount == 0

rowcount = None
stmt = await tx.connection.prepare(sql)
assert rowcount == 0

rowcount = None
await stmt.all()
assert rowcount is None