Skip to content

asgiref.sync trio support #509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions asgiref/_asyncio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
__all__ = [
"get_running_loop",
"create_task_threadsafe",
"wrap_task_context",
"run_in_executor",
]

import asyncio
import concurrent.futures
import contextvars
import functools
import sys
import types
from asyncio import get_running_loop
from collections.abc import Awaitable, Callable, Coroutine
from typing import Any, Generic, Protocol, TypeVar, Union

from ._context import restore_context as _restore_context

_R = TypeVar("_R")

Coro = Coroutine[Any, Any, _R]


def create_task_threadsafe(
loop: asyncio.AbstractEventLoop, awaitable: Coro[object]
) -> None:
loop.call_soon_threadsafe(loop.create_task, awaitable)


async def wrap_task_context(
loop: asyncio.AbstractEventLoop,
task_context: list[asyncio.Task[Any]],
awaitable: Awaitable[_R],
) -> _R:
if task_context is None:
return await awaitable

current_task = asyncio.current_task(loop)
if current_task is None:
return await awaitable

task_context.append(current_task)
try:
return await awaitable
finally:
task_context.remove(current_task)


ExcInfo = Union[
tuple[type[BaseException], BaseException, types.TracebackType],
tuple[None, None, None],
]


class ThreadHandlerType(Protocol, Generic[_R]):
def __call__(
self,
loop: asyncio.AbstractEventLoop,
exc_info: ExcInfo,
task_context: list[asyncio.Task[Any]],
func: Callable[[Callable[[], _R]], _R],
child: Callable[[], _R],
) -> _R:
...


async def run_in_executor(
*,
loop: asyncio.AbstractEventLoop,
executor: concurrent.futures.ThreadPoolExecutor,
thread_handler: ThreadHandlerType[_R],
child: Callable[[], _R],
) -> _R:
context = contextvars.copy_context()
func = context.run
task_context: list[asyncio.Task[Any]] = []

# Run the code in the right thread
exec_coro = loop.run_in_executor(
executor,
functools.partial(
thread_handler,
loop,
sys.exc_info(),
task_context,
func,
child,
),
)
ret: _R
try:
ret = await asyncio.shield(exec_coro)
except asyncio.CancelledError:
cancel_parent = True
try:
task = task_context[0]
task.cancel()
try:
await task
cancel_parent = False
except asyncio.CancelledError:
pass
except IndexError:
pass
if exec_coro.done():
raise
if cancel_parent:
exec_coro.cancel()
ret = await exec_coro
finally:
_restore_context(context)

return ret
13 changes: 13 additions & 0 deletions asgiref/_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import contextvars


def restore_context(context: contextvars.Context) -> None:
# Check for changes in contextvars, and set them to the current
# context for downstream consumers
for cvar in context:
cvalue = context.get(cvar)
try:
if cvar.get() != cvalue:
cvar.set(cvalue)
except LookupError:
cvar.set(cvalue)
176 changes: 176 additions & 0 deletions asgiref/_trio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import asyncio
import concurrent.futures
import contextvars
import functools
import sys
import types
from collections.abc import Awaitable, Callable, Coroutine
from typing import Any, Generic, Protocol, TypeVar, Union

import sniffio
import trio.lowlevel
import trio.to_thread

from . import _asyncio
from ._context import restore_context as _restore_context

_R = TypeVar("_R")

Coro = Coroutine[Any, Any, _R]

Loop = Union[asyncio.AbstractEventLoop, trio.lowlevel.TrioToken]
TaskContext = list[Any]


class TrioThreadCancelled(BaseException):
pass


def get_running_loop() -> Loop:

try:
asynclib = sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError:
return asyncio.get_running_loop()

if asynclib == "asyncio":
return asyncio.get_running_loop()
if asynclib == "trio":
return trio.lowlevel.current_trio_token()
raise RuntimeError(f"unsupported library {asynclib}")


@trio.lowlevel.disable_ki_protection
async def wrap_awaitable(awaitable: Awaitable[_R]) -> _R:
return await awaitable


def create_task_threadsafe(loop: Loop, awaitable: Coro[_R]) -> None:
if isinstance(loop, trio.lowlevel.TrioToken):
try:
loop.run_sync_soon(
trio.lowlevel.spawn_system_task,
wrap_awaitable,
awaitable,
)
except trio.RunFinishedError:
raise RuntimeError("trio loop no-longer running")
return

_asyncio.create_task_threadsafe(loop, awaitable)


ExcInfo = Union[
tuple[type[BaseException], BaseException, types.TracebackType],
tuple[None, None, None],
]


class ThreadHandlerType(Protocol, Generic[_R]):
def __call__(
self,
loop: Loop,
exc_info: ExcInfo,
task_context: TaskContext,
func: Callable[[Callable[[], _R]], _R],
child: Callable[[], _R],
) -> _R:
...


async def run_in_executor(
*,
loop: Loop,
executor: concurrent.futures.ThreadPoolExecutor,
thread_handler: ThreadHandlerType[_R],
child: Callable[[], _R],
) -> _R:
if isinstance(loop, trio.lowlevel.TrioToken):
context = contextvars.copy_context()
func = context.run
task_context: TaskContext = []

# Run the code in the right thread
full_func = functools.partial(
thread_handler,
loop,
sys.exc_info(),
task_context,
func,
child,
)
try:
if executor is None:

async def handle_cancel() -> None:
try:
await trio.sleep_forever()
except trio.Cancelled:
if task_context:
task_context[0].cancel()
raise

async with trio.open_nursery() as nursery:
nursery.start_soon(handle_cancel)
try:
return await trio.to_thread.run_sync(
full_func, abandon_on_cancel=False
)
except TrioThreadCancelled:
pass
finally:
nursery.cancel_scope.cancel()
assert False
else:
event = trio.Event()

def callback(fut: object) -> None:
loop.run_sync_soon(event.set)

fut = executor.submit(full_func)
fut.add_done_callback(callback)

async def handle_cancel_fut() -> None:
try:
await trio.sleep_forever()
except trio.Cancelled:
fut.cancel()
if task_context:
task_context[0].cancel()
raise

async with trio.open_nursery() as nursery:
nursery.start_soon(handle_cancel_fut)
with trio.CancelScope(shield=True):
await event.wait()
nursery.cancel_scope.cancel()
try:
return fut.result()
except TrioThreadCancelled:
pass
assert False
finally:
_restore_context(context)

else:
return await _asyncio.run_in_executor(
loop=loop, executor=executor, thread_handler=thread_handler, child=child
)


async def wrap_task_context(
loop: Loop, task_context: Union[TaskContext, None], awaitable: Awaitable[_R]
) -> _R:
if task_context is None:
return await awaitable

if isinstance(loop, trio.lowlevel.TrioToken):
with trio.CancelScope() as scope:
task_context.append(scope)
try:
return await awaitable
finally:
task_context.remove(scope)
raise TrioThreadCancelled

return await _asyncio.wrap_task_context(loop, task_context, awaitable)
47 changes: 35 additions & 12 deletions asgiref/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,32 @@
from typing import Any, Union


def _is_asyncio_running():
try:
asyncio.get_running_loop()
except RuntimeError:
return False
else:
return True


try:
import sniffio
except ModuleNotFoundError:
_is_async = _is_asyncio_running
else:

def _is_async():
try:
sniffio.current_async_library()
except sniffio.AsyncLibraryNotFoundError:
pass
else:
return True

return _is_asyncio_running()


class _CVar:
"""Storage utility for Local."""

Expand Down Expand Up @@ -83,18 +109,9 @@ def __init__(self, thread_critical: bool = False) -> None:
def _lock_storage(self):
# Thread safe access to storage
if self._thread_critical:
try:
# this is a test for are we in a async or sync
# thread - will raise RuntimeError if there is
# no current loop
asyncio.get_running_loop()
except RuntimeError:
# We are in a sync thread, the storage is
# just the plain thread local (i.e, "global within
# this thread" - it doesn't matter where you are
# in a call stack you see the same storage)
yield self._storage
else:
# this is a test for are we in a async or sync
# thread
if _is_async():
# We are in an async thread - storage is still
# local to this thread, but additionally should
# behave like a context var (is only visible with
Expand All @@ -108,6 +125,12 @@ def _lock_storage(self):
# can't be accessed in another thread (we don't
# need any locks)
yield self._storage.cvar
else:
# We are in a sync thread, the storage is
# just the plain thread local (i.e, "global within
# this thread" - it doesn't matter where you are
# in a call stack you see the same storage)
yield self._storage
else:
# Lock for thread_critical=False as other threads
# can access the exact same storage object
Expand Down
Loading