Skip to content

added AsyncSingleThreadContext #511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 65 additions & 2 deletions asgiref/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,45 @@ def markcoroutinefunction(func: _F) -> _F:
return func


class AsyncSingleThreadContext:
"""Context manager to run async code inside the same thread.

Normally, AsyncToSync functions run either inside a separate ThreadPoolExecutor or
the main event loop if it exists. This context manager ensures that all AsyncToSync
functions execute within the same thread.

This context manager is re-entrant, so only the outer-most call to
AsyncSingleThreadContext will set the context.

Usage:

>>> import asyncio
>>> with AsyncSingleThreadContext():
... async_to_sync(asyncio.sleep(1))()
"""

def __init__(self):
self.token = None

def __enter__(self):
try:
AsyncToSync.async_single_thread_context.get()
except LookupError:
self.token = AsyncToSync.async_single_thread_context.set(self)

return self

def __exit__(self, exc, value, tb):
if not self.token:
return

executor = AsyncToSync.context_to_thread_executor.pop(self, None)
if executor:
executor.shutdown()

AsyncToSync.async_single_thread_context.reset(self.token)


class ThreadSensitiveContext:
"""Async context manager to manage context for thread sensitive mode

Expand Down Expand Up @@ -131,6 +170,14 @@ class AsyncToSync(Generic[_P, _R]):
# inside create_task, we'll look it up here from the running event loop.
loop_thread_executors: "Dict[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = {}

async_single_thread_context: "contextvars.ContextVar[AsyncSingleThreadContext]" = (
contextvars.ContextVar("async_single_thread_context")
)

context_to_thread_executor: "weakref.WeakKeyDictionary[AsyncSingleThreadContext, ThreadPoolExecutor]" = (
weakref.WeakKeyDictionary()
)

def __init__(
self,
awaitable: Union[
Expand Down Expand Up @@ -246,8 +293,24 @@ async def new_loop_wrap() -> None:
running_in_main_event_loop = False

if not running_in_main_event_loop:
# Make our own event loop - in a new thread - and run inside that.
loop_executor = ThreadPoolExecutor(max_workers=1)
loop_executor = None

if self.async_single_thread_context.get(None):
single_thread_context = self.async_single_thread_context.get()

if single_thread_context in self.context_to_thread_executor:
loop_executor = self.context_to_thread_executor[
single_thread_context
]
else:
loop_executor = ThreadPoolExecutor(max_workers=1)
self.context_to_thread_executor[
single_thread_context
] = loop_executor
else:
# Make our own event loop - in a new thread - and run inside that.
loop_executor = ThreadPoolExecutor(max_workers=1)

loop_future = loop_executor.submit(asyncio.run, new_loop_wrap())
# Run the CurrentThreadExecutor until the future is done.
current_executor.run_until_future(loop_future)
Expand Down
94 changes: 94 additions & 0 deletions tests/test_sync.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import contextvars
import functools
import multiprocessing
import sys
Expand All @@ -13,6 +14,7 @@
import pytest

from asgiref.sync import (
AsyncSingleThreadContext,
ThreadSensitiveContext,
async_to_sync,
iscoroutinefunction,
Expand Down Expand Up @@ -544,6 +546,98 @@ def inner(result):
assert result_1["thread"] == result_2["thread"]


def test_async_single_thread_context_matches():
"""
Tests that functions wrapped with async_to_sync and executed within an
AsyncSingleThreadContext run on the same thread, even without a main_event_loop.
"""
result_1 = {}
result_2 = {}

async def store_thread_async(result):
result["thread"] = threading.current_thread()

with AsyncSingleThreadContext():
async_to_sync(store_thread_async)(result_1)
async_to_sync(store_thread_async)(result_2)

# They should not have run in the main thread, and on the same threads
assert result_1["thread"] != threading.current_thread()
assert result_1["thread"] == result_2["thread"]


def test_async_single_thread_nested_context():
"""
Tests that behavior remains the same when using nested context managers.
"""
result_1 = {}
result_2 = {}

@async_to_sync
async def store_thread(result):
result["thread"] = threading.current_thread()

with AsyncSingleThreadContext():
store_thread(result_1)

with AsyncSingleThreadContext():
store_thread(result_2)

# They should not have run in the main thread, and on the same threads
assert result_1["thread"] != threading.current_thread()
assert result_1["thread"] == result_2["thread"]


def test_async_single_thread_context_without_async_work():
"""
Tests everything works correctly without any async_to_sync calls.
"""
with AsyncSingleThreadContext():
pass


def test_async_single_thread_context_success_share_context():
"""
Tests that we share context between different async_to_sync functions.
"""
connection = contextvars.ContextVar("connection")
connection.set(0)

async def handler():
connection.set(connection.get(0) + 1)

with AsyncSingleThreadContext():
async_to_sync(handler)()
async_to_sync(handler)()

assert connection.get() == 2


@pytest.mark.asyncio
async def test_async_single_thread_context_matches_from_async_thread():
"""
Tests that we use main_event_loop for running async_to_sync functions executed
within an AsyncSingleThreadContext.
"""
result_1 = {}
result_2 = {}

@async_to_sync
async def store_thread_async(result):
result["thread"] = threading.current_thread()

def inner():
with AsyncSingleThreadContext():
store_thread_async(result_1)
store_thread_async(result_2)

await sync_to_async(inner)()

# They should not have run in the main thread, and on the same threads
assert result_1["thread"] == threading.current_thread()
assert result_1["thread"] == result_2["thread"]


@pytest.mark.asyncio
async def test_thread_sensitive_with_context_matches():
result_1 = {}
Expand Down