Skip to content

Commit

Permalink
Use custom implementation of ContextVar that is not copied automatica…
Browse files Browse the repository at this point in the history
…lly to new tasks (#1)
  • Loading branch information
masipcat committed Nov 11, 2021
1 parent 9e1fdb5 commit e1306d9
Show file tree
Hide file tree
Showing 13 changed files with 218 additions and 20 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ CHANGELOG

6.3.16 (unreleased)
-------------------
- Use custom implementation of ContextVar that is not copied automatically to new tasks
[masipcat]

- Use custom implementation of Transaction/TransacionManager to support real SQL transactions
[masipcat]
Expand All @@ -16,7 +18,7 @@ CHANGELOG
6.3.15 (2021-08-05)
-------------------

- fix: Add MIMEMultipart('alternative') to attach message in parent MIMEMultipart to render only html body.
- fix: Add MIMEMultipart('alternative') to attach message in parent MIMEMultipart to render only html body.
[rboixaderg]

6.3.14 (2021-08-04)
Expand Down
3 changes: 2 additions & 1 deletion guillotina/contrib/cache/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from guillotina.exceptions import NoPubSubUtility
from guillotina.interfaces import ICacheUtility
from guillotina.profile import profilable
from guillotina.task_vars import copy_context
from typing import Any
from typing import Dict
from typing import List
Expand Down Expand Up @@ -129,7 +130,7 @@ async def close(self, invalidate=True, publish=True):
await self.fill_cache()
if len(self._keys_to_publish) > 0 and self._utility._subscriber is not None:
keys = self._keys_to_publish
asyncio.ensure_future(self.synchronize(keys))
asyncio.ensure_future(copy_context(self.synchronize(keys)))
else:
self._stored_objects.clear()
else:
Expand Down
3 changes: 2 additions & 1 deletion guillotina/contrib/mailer/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from guillotina.contrib.mailer.exceptions import NoEndpointDefinedException
from guillotina.interfaces import IMailEndpoint
from guillotina.interfaces import IMailer
from guillotina.task_vars import copy_context
from guillotina.utils import get_random_string
from zope.interface import implementer

Expand Down Expand Up @@ -135,7 +136,7 @@ def get_endpoint(self, endpoint_name):
else:
raise NoEndpointDefinedException("{} mail endpoint not defined".format(endpoint_name))
utility.from_settings(settings)
asyncio.ensure_future(utility.initialize())
asyncio.ensure_future(copy_context(utility.initialize()))
self._endpoints[endpoint_name] = utility
return self._endpoints[endpoint_name]

Expand Down
1 change: 1 addition & 0 deletions guillotina/contrib/pubsub/utility.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from guillotina.contrib.pubsub.exceptions import NoPubSubDriver
from guillotina.profile import profilable
from guillotina.task_vars import copy_context
from guillotina.utils import resolve_dotted_name
from typing import Any
from typing import Callable
Expand Down
5 changes: 3 additions & 2 deletions guillotina/db/storages/pg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from guillotina.exceptions import ConflictIdOnContainer
from guillotina.exceptions import TIDConflictError
from guillotina.profile import profilable
from guillotina.task_vars import copy_context
from zope.interface import implementer

import asyncio
Expand Down Expand Up @@ -403,7 +404,7 @@ async def _initialize(self):
try:
oid, table_name = await self._queue.get()
self._active = True
await shield(self.vacuum(oid, table_name))
await shield(copy_context(self.vacuum(oid, table_name)))
except (concurrent.futures.CancelledError, RuntimeError):
raise
except Exception:
Expand Down Expand Up @@ -555,7 +556,7 @@ async def initialize(self, loop=None, **kw):

if self._autovacuum:
self._vacuum = self._vacuum_class(self, loop)
self._vacuum_task = asyncio.Task(self._vacuum.initialize(), loop=loop)
self._vacuum_task = asyncio.Task(copy_context(self._vacuum.initialize()), loop=loop)

async def restart(self, timeout=2):
# needs to be used with lock
Expand Down
5 changes: 3 additions & 2 deletions guillotina/db/transaction_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from guillotina.exceptions import TIDConflictError
from guillotina.exceptions import TransactionNotFound
from guillotina.profile import profilable
from guillotina.task_vars import copy_context
from guillotina.transactions import transaction
from guillotina.utils import get_authenticated_user_id
from zope.interface import implementer
Expand Down Expand Up @@ -102,7 +103,7 @@ async def begin(self, read_only: bool = False, lazy: bool = False) -> ITransacti
return txn

async def commit(self, *, txn: typing.Optional[ITransaction] = None) -> None:
return await shield(self._commit(txn=txn))
return await shield(copy_context(self._commit(txn=txn)))

async def _commit(self, *, txn: typing.Optional[ITransaction] = None) -> None:
""" Commit the last transaction
Expand Down Expand Up @@ -157,7 +158,7 @@ async def _close_txn(self, txn: typing.Optional[ITransaction]):

async def abort(self, *, txn: typing.Optional[ITransaction] = None) -> None:
try:
return await shield(self._abort(txn=txn))
return await shield(copy_context(self._abort(txn=txn)))
except asyncio.CancelledError:
pass

Expand Down
3 changes: 2 additions & 1 deletion guillotina/factory/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from guillotina.db.transaction_manager import TransactionManager
from guillotina.interfaces import IApplication
from guillotina.interfaces import IDatabase
from guillotina.task_vars import copy_context
from guillotina.transactions import get_transaction
from guillotina.utils import apply_coroutine
from guillotina.utils import import_class
Expand Down Expand Up @@ -69,7 +70,7 @@ def add_async_utility(
kw["name"] = config["name"]
provide_utility(utility_object, interface, **kw)
if hasattr(utility_object, "initialize"):
func = lazy_apply(utility_object.initialize, app=self.app)
func = copy_context(lazy_apply(utility_object.initialize, app=self.app))

task = asyncio.ensure_future(notice_on_error(key, func), loop=loop or self._loop)
self.add_async_task(key, task, config)
Expand Down
158 changes: 148 additions & 10 deletions guillotina/task_vars.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from contextvars import ContextVar
from guillotina.db.interfaces import ITransaction
from guillotina.db.interfaces import ITransactionManager
from guillotina.interfaces import IContainer
Expand All @@ -8,17 +7,156 @@
from guillotina.interfaces import IRequest
from guillotina.interfaces import ISecurityPolicy
from typing import Dict
from typing import Generic
from typing import Optional
from typing import TypeVar

import asyncio
import contextvars
import weakref

request: ContextVar[Optional[IRequest]] = ContextVar("g_request", default=None)
txn: ContextVar[Optional[ITransaction]] = ContextVar("g_txn", default=None)
tm: ContextVar[Optional[ITransactionManager]] = ContextVar("g_tm", default=None)
futures: ContextVar[Optional[Dict]] = ContextVar("g_futures", default=None)
authenticated_user: ContextVar[Optional[IPrincipal]] = ContextVar("g_authenticated_user", default=None)
security_policies: ContextVar[Optional[Dict[str, ISecurityPolicy]]] = ContextVar(

# This global dictionary keeps all the contextvars for each task.
# When a task finishes and is destroyed, the context is destroyed as well
_context = weakref.WeakKeyDictionary() # type: ignore


class FakeTask:
"""
This class is necessary because we need an object to use as a key in the WeakKeyDictionary _context.
We can't use built-in objects because they don't have the `__weakref__` in the `__slots__` and without
this attribute, weakreaf.ref doesn't work
"""


_no_task_fallback = FakeTask()


def copy_context(coro):
"""
This function it's similar to contextvars.copy_context() but has a slightly different
signature and it's not called by default when a new task/future is created.
To copy the context from the current task to a new one you need to call this
funcion explicitly, like this:
async def worker():
...
asyncio.create_task(copy_context(worker()))
"""
try:
from_task = asyncio.current_task()
except RuntimeError:
assert _no_task_fallback is not None
from_task = _no_task_fallback

if from_task in _context:
# The _context value type is a dict so we need to copy the dict to avoid
# sharing the same context value in different tasks
new_context = _context[from_task].copy()
else:
new_context = {}

return _run_coro_with_ctx(coro, new_context)


async def _run_coro_with_ctx(coro, new_context):
task = asyncio.current_task()
_context[task] = new_context
return await coro


_T = TypeVar("_T")
_NO_DEFAULT = object()


class Token:
"""
Reimplementation of contextvars.Token
"""

MISSING = contextvars.Token.MISSING

def __init__(self, var, old_value) -> None:
self._var = var
self._old_value = old_value

@property
def var(self):
return self._var

@property
def old_value(self):
return self._old_value


class ShyContextVar(Generic[_T]):
"""
Reimplementation of contextvars.ContextVar but stores the values to the global `_context`
instead of storing it to the PyContext
"""

def __init__(self, name: str, default=_NO_DEFAULT):
self._name = name
self._default = default

@property
def name(self):
return self._name

def get(self, default=_NO_DEFAULT):
ctx = self._get_ctx_data()
if self._name in ctx:
return ctx[self._name]
elif default != _NO_DEFAULT:
return default
elif self._default != _NO_DEFAULT:
return self._default
else:
raise LookupError(self)

def set(self, value) -> Token:
data = self._get_ctx_data()
name = self._name
if name in data:
t = Token(self, data[name])
else:
t = Token(self, Token.MISSING)
data[self._name] = value
return t

def reset(self, token):
if token.old_value == Token.MISSING:
ctx = self._get_ctx_data()
if ctx and self._name in ctx:
del ctx[self._name]
else:
self.set(token.old_value)

def _get_ctx_data(self):
try:
task = asyncio.current_task()
except RuntimeError:
task = _no_task_fallback
try:
data = _context[task]
except KeyError:
# Initialize _context value for this task
data = {}
_context[task] = data
return data


request: ShyContextVar[Optional[IRequest]] = ShyContextVar("g_request", default=None)
txn: ShyContextVar[Optional[ITransaction]] = ShyContextVar("g_txn", default=None)
tm: ShyContextVar[Optional[ITransactionManager]] = ShyContextVar("g_tm", default=None)
futures: ShyContextVar[Optional[Dict]] = ShyContextVar("g_futures", default=None)
authenticated_user: ShyContextVar[Optional[IPrincipal]] = ShyContextVar("g_authenticated_user", default=None)
security_policies: ShyContextVar[Optional[Dict[str, ISecurityPolicy]]] = ShyContextVar(
"g_security_policy", default=None
)
container: ContextVar[Optional[IContainer]] = ContextVar("g_container", default=None)
registry: ContextVar[Optional[IRegistry]] = ContextVar("g_container", default=None)
db: ContextVar[Optional[IDatabase]] = ContextVar("g_database", default=None)
container: ShyContextVar[Optional[IContainer]] = ShyContextVar("g_container", default=None)
registry: ShyContextVar[Optional[IRegistry]] = ShyContextVar("g_registry", default=None)
db: ShyContextVar[Optional[IDatabase]] = ShyContextVar("g_database", default=None)
8 changes: 8 additions & 0 deletions guillotina/tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from guillotina.interfaces import IDatabase
from guillotina.tests import mocks
from guillotina.tests.utils import ContainerRequesterAsyncContextManager
from guillotina.tests.utils import copy_global_ctx
from guillotina.tests.utils import get_mocked_request
from guillotina.tests.utils import login
from guillotina.tests.utils import logout
Expand Down Expand Up @@ -386,6 +387,7 @@ def clear_task_vars():
@pytest.fixture(scope="function")
async def dummy_guillotina(event_loop, request):
globalregistry.reset()
task_vars._no_task_fallback = task_vars.FakeTask()
app = make_app(settings=get_dummy_settings(request.node), loop=event_loop)
async with TestClient(app):
yield app
Expand Down Expand Up @@ -429,6 +431,10 @@ def __init__(self, request):
self.txn = None

async def __aenter__(self):
# This is a hack to copy contextvars defined in fixture dummy_request
# (oustide event loop) to this asyncio task
copy_global_ctx()

tm = get_tm()
self.txn = await tm.begin()
self.root = await tm.get_root()
Expand Down Expand Up @@ -484,6 +490,7 @@ async def _clear_dbs(root):
@pytest.fixture(scope="function")
async def app(event_loop, db, request):
globalregistry.reset()
task_vars._no_task_fallback = task_vars.FakeTask()
settings = get_db_settings(request.node)
app = make_app(settings=settings, loop=event_loop)

Expand Down Expand Up @@ -516,6 +523,7 @@ async def app(event_loop, db, request):
@pytest.fixture(scope="function")
async def app_client(event_loop, db, request):
globalregistry.reset()
task_vars._no_task_fallback = task_vars.FakeTask()
app = make_app(settings=get_db_settings(request.node), loop=event_loop)
async with TestClient(app, timeout=30) as client:
await _clear_dbs(app.app.root)
Expand Down
Loading

0 comments on commit e1306d9

Please sign in to comment.