diff --git a/asgiref/sync.py b/asgiref/sync.py index 0c6ea98e..7813c26f 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -69,6 +69,45 @@ def markcoroutinefunction(func: _F) -> _F: return func +class AsyncSingleThreadContext: + """Context manager to run async code inside the same thread. + + Normally, AsyncToSync functions run either inside a separate ThreadPoolExecutor or + the main event loop if it exists. This context manager ensures that all AsyncToSync + functions execute within the same thread. + + This context manager is re-entrant, so only the outer-most call to + AsyncSingleThreadContext will set the context. + + Usage: + + >>> import asyncio + >>> with AsyncSingleThreadContext(): + ... async_to_sync(asyncio.sleep(1))() + """ + + def __init__(self): + self.token = None + + def __enter__(self): + try: + AsyncToSync.async_single_thread_context.get() + except LookupError: + self.token = AsyncToSync.async_single_thread_context.set(self) + + return self + + def __exit__(self, exc, value, tb): + if not self.token: + return + + executor = AsyncToSync.context_to_thread_executor.pop(self, None) + if executor: + executor.shutdown() + + AsyncToSync.async_single_thread_context.reset(self.token) + + class ThreadSensitiveContext: """Async context manager to manage context for thread sensitive mode @@ -131,6 +170,14 @@ class AsyncToSync(Generic[_P, _R]): # inside create_task, we'll look it up here from the running event loop. loop_thread_executors: "Dict[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = {} + async_single_thread_context: "contextvars.ContextVar[AsyncSingleThreadContext]" = ( + contextvars.ContextVar("async_single_thread_context") + ) + + context_to_thread_executor: "weakref.WeakKeyDictionary[AsyncSingleThreadContext, ThreadPoolExecutor]" = ( + weakref.WeakKeyDictionary() + ) + def __init__( self, awaitable: Union[ @@ -246,8 +293,24 @@ async def new_loop_wrap() -> None: running_in_main_event_loop = False if not running_in_main_event_loop: - # Make our own event loop - in a new thread - and run inside that. - loop_executor = ThreadPoolExecutor(max_workers=1) + loop_executor = None + + if self.async_single_thread_context.get(None): + single_thread_context = self.async_single_thread_context.get() + + if single_thread_context in self.context_to_thread_executor: + loop_executor = self.context_to_thread_executor[ + single_thread_context + ] + else: + loop_executor = ThreadPoolExecutor(max_workers=1) + self.context_to_thread_executor[ + single_thread_context + ] = loop_executor + else: + # Make our own event loop - in a new thread - and run inside that. + loop_executor = ThreadPoolExecutor(max_workers=1) + loop_future = loop_executor.submit(asyncio.run, new_loop_wrap()) # Run the CurrentThreadExecutor until the future is done. current_executor.run_until_future(loop_future) diff --git a/tests/test_sync.py b/tests/test_sync.py index 0c67308c..592e8681 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1,4 +1,5 @@ import asyncio +import contextvars import functools import multiprocessing import sys @@ -13,6 +14,7 @@ import pytest from asgiref.sync import ( + AsyncSingleThreadContext, ThreadSensitiveContext, async_to_sync, iscoroutinefunction, @@ -544,6 +546,98 @@ def inner(result): assert result_1["thread"] == result_2["thread"] +def test_async_single_thread_context_matches(): + """ + Tests that functions wrapped with async_to_sync and executed within an + AsyncSingleThreadContext run on the same thread, even without a main_event_loop. + """ + result_1 = {} + result_2 = {} + + async def store_thread_async(result): + result["thread"] = threading.current_thread() + + with AsyncSingleThreadContext(): + async_to_sync(store_thread_async)(result_1) + async_to_sync(store_thread_async)(result_2) + + # They should not have run in the main thread, and on the same threads + assert result_1["thread"] != threading.current_thread() + assert result_1["thread"] == result_2["thread"] + + +def test_async_single_thread_nested_context(): + """ + Tests that behavior remains the same when using nested context managers. + """ + result_1 = {} + result_2 = {} + + @async_to_sync + async def store_thread(result): + result["thread"] = threading.current_thread() + + with AsyncSingleThreadContext(): + store_thread(result_1) + + with AsyncSingleThreadContext(): + store_thread(result_2) + + # They should not have run in the main thread, and on the same threads + assert result_1["thread"] != threading.current_thread() + assert result_1["thread"] == result_2["thread"] + + +def test_async_single_thread_context_without_async_work(): + """ + Tests everything works correctly without any async_to_sync calls. + """ + with AsyncSingleThreadContext(): + pass + + +def test_async_single_thread_context_success_share_context(): + """ + Tests that we share context between different async_to_sync functions. + """ + connection = contextvars.ContextVar("connection") + connection.set(0) + + async def handler(): + connection.set(connection.get(0) + 1) + + with AsyncSingleThreadContext(): + async_to_sync(handler)() + async_to_sync(handler)() + + assert connection.get() == 2 + + +@pytest.mark.asyncio +async def test_async_single_thread_context_matches_from_async_thread(): + """ + Tests that we use main_event_loop for running async_to_sync functions executed + within an AsyncSingleThreadContext. + """ + result_1 = {} + result_2 = {} + + @async_to_sync + async def store_thread_async(result): + result["thread"] = threading.current_thread() + + def inner(): + with AsyncSingleThreadContext(): + store_thread_async(result_1) + store_thread_async(result_2) + + await sync_to_async(inner)() + + # They should not have run in the main thread, and on the same threads + assert result_1["thread"] == threading.current_thread() + assert result_1["thread"] == result_2["thread"] + + @pytest.mark.asyncio async def test_thread_sensitive_with_context_matches(): result_1 = {}