Skip to content

Commit 0a8a199

Browse files
committed
Fixed #313, remove stack when empty
1 parent 64b28c7 commit 0a8a199

File tree

2 files changed

+70
-22
lines changed

2 files changed

+70
-22
lines changed

gino/engine.py

+43-20
Original file line numberDiff line numberDiff line change
@@ -270,14 +270,9 @@ async def release(self, *, permanent=True):
270270
271271
"""
272272
if permanent and self._stack is not None:
273-
for i in range(len(self._stack)):
274-
if self._stack[-1].gino_conn is self:
275-
dbapi_conn = self._stack.pop()
276-
self._stack.rotate(-i)
277-
await dbapi_conn.release(True)
278-
break
279-
else:
280-
self._stack.rotate()
273+
dbapi_conn = self._stack.remove(lambda x: x.gino_conn is self)
274+
if dbapi_conn:
275+
await dbapi_conn.release(True)
281276
else:
282277
raise ValueError('This connection is already released.')
283278
else:
@@ -493,6 +488,39 @@ async def prepare(self, clause):
493488
clause, (_bypass_no_param,), {}).prepare(clause)
494489

495490

491+
class _ContextualStack:
492+
__slots__ = ('_ctx', '_stack')
493+
494+
def __init__(self, ctx):
495+
self._ctx = ctx
496+
self._stack = ctx.get()
497+
if self._stack is None:
498+
self._stack = collections.deque()
499+
ctx.set(self._stack)
500+
501+
def __bool__(self):
502+
return bool(self._stack)
503+
504+
@property
505+
def top(self):
506+
return self._stack[-1]
507+
508+
def push(self, value):
509+
self._stack.append(value)
510+
511+
def remove(self, checker):
512+
for i in range(len(self._stack)):
513+
if checker(self._stack[-1]):
514+
rv = self._stack.pop()
515+
if self._stack:
516+
self._stack.rotate(-i)
517+
else:
518+
self._ctx.set(None)
519+
return rv
520+
else:
521+
self._stack.rotate(1)
522+
523+
496524
class GinoEngine:
497525
"""
498526
Connects a :class:`~.dialects.base.Pool` and
@@ -522,7 +550,7 @@ def __init__(self, dialect, pool, loop,
522550
self._dialect = dialect
523551
self._pool = pool
524552
self._loop = loop
525-
self._ctx = ContextVar('gino')
553+
self._ctx = ContextVar('gino', default=None)
526554

527555
@property
528556
def dialect(self):
@@ -608,14 +636,10 @@ def acquire(self, *, timeout=None, reuse=False, lazy=False, reusable=True):
608636
self._acquire, timeout, reuse, lazy, reusable))
609637

610638
async def _acquire(self, timeout, reuse, lazy, reusable):
611-
try:
612-
stack = self._ctx.get()
613-
except LookupError:
614-
stack = collections.deque()
615-
self._ctx.set(stack)
639+
stack = _ContextualStack(self._ctx)
616640
if reuse and stack:
617641
dbapi_conn = _ReusingDBAPIConnection(self._dialect.cursor_cls,
618-
stack[-1])
642+
stack.top)
619643
reusable = False
620644
else:
621645
dbapi_conn = _DBAPIConnection(self._dialect.cursor_cls, self._pool)
@@ -626,7 +650,7 @@ async def _acquire(self, timeout, reuse, lazy, reusable):
626650
if not lazy:
627651
await dbapi_conn.acquire(timeout=timeout)
628652
if reusable:
629-
stack.append(dbapi_conn)
653+
stack.push(dbapi_conn)
630654
return rv
631655

632656
@property
@@ -638,10 +662,9 @@ def current_connection(self):
638662
:return: :class:`.GinoConnection`
639663
640664
"""
641-
try:
642-
return self._ctx.get()[-1].gino_conn
643-
except (LookupError, IndexError):
644-
pass
665+
stack = self._ctx.get()
666+
if stack:
667+
return stack[-1].gino_conn
645668

646669
async def close(self):
647670
"""

tests/test_engine.py

+27-2
Original file line numberDiff line numberDiff line change
@@ -231,15 +231,15 @@ async def test_lazy(mocker):
231231
async with engine.acquire(lazy=True):
232232
assert qsize(engine) == init_size
233233
assert len(engine._ctx.get()) == 1
234-
assert len(engine._ctx.get()) == 0
234+
assert engine._ctx.get() is None
235235
assert qsize(engine) == init_size
236236
async with engine.acquire(lazy=True):
237237
assert qsize(engine) == init_size
238238
assert len(engine._ctx.get()) == 1
239239
assert await engine.scalar('select 1')
240240
assert qsize(engine) == init_size - 1
241241
assert len(engine._ctx.get()) == 1
242-
assert len(engine._ctx.get()) == 0
242+
assert engine._ctx.get() is None
243243
assert qsize(engine) == init_size
244244

245245
loop = asyncio.get_event_loop()
@@ -367,3 +367,28 @@ async def test_ssl():
367367

368368
e = await gino.create_engine(PG_URL, ssl=ctx)
369369
await e.close()
370+
371+
372+
async def test_issue_313(bind):
373+
assert bind._ctx.get() is None
374+
375+
async with db.acquire():
376+
pass
377+
378+
assert bind._ctx.get() is None
379+
380+
async def task():
381+
async with db.acquire(reuse=True):
382+
await db.scalar('SELECT now()')
383+
384+
await asyncio.gather(*[task() for _ in range(5)])
385+
386+
assert bind._ctx.get() is None
387+
388+
async def task():
389+
async with db.transaction():
390+
await db.scalar('SELECT now()')
391+
392+
await asyncio.gather(*[task() for _ in range(5)])
393+
394+
assert bind._ctx.get() is None

0 commit comments

Comments
 (0)