diff --git a/asgiref/_asyncio.py b/asgiref/_asyncio.py new file mode 100644 index 00000000..2b450bee --- /dev/null +++ b/asgiref/_asyncio.py @@ -0,0 +1,114 @@ +__all__ = [ + "get_running_loop", + "create_task_threadsafe", + "wrap_task_context", + "run_in_executor", +] + +import asyncio +import concurrent.futures +import contextvars +import functools +import sys +import types +from asyncio import get_running_loop +from collections.abc import Awaitable, Callable, Coroutine +from typing import Any, Generic, Protocol, TypeVar, Union + +from ._context import restore_context as _restore_context + +_R = TypeVar("_R") + +Coro = Coroutine[Any, Any, _R] + + +def create_task_threadsafe( + loop: asyncio.AbstractEventLoop, awaitable: Coro[object] +) -> None: + loop.call_soon_threadsafe(loop.create_task, awaitable) + + +async def wrap_task_context( + loop: asyncio.AbstractEventLoop, + task_context: list[asyncio.Task[Any]], + awaitable: Awaitable[_R], +) -> _R: + if task_context is None: + return await awaitable + + current_task = asyncio.current_task(loop) + if current_task is None: + return await awaitable + + task_context.append(current_task) + try: + return await awaitable + finally: + task_context.remove(current_task) + + +ExcInfo = Union[ + tuple[type[BaseException], BaseException, types.TracebackType], + tuple[None, None, None], +] + + +class ThreadHandlerType(Protocol, Generic[_R]): + def __call__( + self, + loop: asyncio.AbstractEventLoop, + exc_info: ExcInfo, + task_context: list[asyncio.Task[Any]], + func: Callable[[Callable[[], _R]], _R], + child: Callable[[], _R], + ) -> _R: + ... + + +async def run_in_executor( + *, + loop: asyncio.AbstractEventLoop, + executor: concurrent.futures.ThreadPoolExecutor, + thread_handler: ThreadHandlerType[_R], + child: Callable[[], _R], +) -> _R: + context = contextvars.copy_context() + func = context.run + task_context: list[asyncio.Task[Any]] = [] + + # Run the code in the right thread + exec_coro = loop.run_in_executor( + executor, + functools.partial( + thread_handler, + loop, + sys.exc_info(), + task_context, + func, + child, + ), + ) + ret: _R + try: + ret = await asyncio.shield(exec_coro) + except asyncio.CancelledError: + cancel_parent = True + try: + task = task_context[0] + task.cancel() + try: + await task + cancel_parent = False + except asyncio.CancelledError: + pass + except IndexError: + pass + if exec_coro.done(): + raise + if cancel_parent: + exec_coro.cancel() + ret = await exec_coro + finally: + _restore_context(context) + + return ret diff --git a/asgiref/_context.py b/asgiref/_context.py new file mode 100644 index 00000000..08af5153 --- /dev/null +++ b/asgiref/_context.py @@ -0,0 +1,13 @@ +import contextvars + + +def restore_context(context: contextvars.Context) -> None: + # Check for changes in contextvars, and set them to the current + # context for downstream consumers + for cvar in context: + cvalue = context.get(cvar) + try: + if cvar.get() != cvalue: + cvar.set(cvalue) + except LookupError: + cvar.set(cvalue) diff --git a/asgiref/_trio.py b/asgiref/_trio.py new file mode 100644 index 00000000..2c6803e0 --- /dev/null +++ b/asgiref/_trio.py @@ -0,0 +1,176 @@ +import asyncio +import concurrent.futures +import contextvars +import functools +import sys +import types +from collections.abc import Awaitable, Callable, Coroutine +from typing import Any, Generic, Protocol, TypeVar, Union + +import sniffio +import trio.lowlevel +import trio.to_thread + +from . import _asyncio +from ._context import restore_context as _restore_context + +_R = TypeVar("_R") + +Coro = Coroutine[Any, Any, _R] + +Loop = Union[asyncio.AbstractEventLoop, trio.lowlevel.TrioToken] +TaskContext = list[Any] + + +class TrioThreadCancelled(BaseException): + pass + + +def get_running_loop() -> Loop: + + try: + asynclib = sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + return asyncio.get_running_loop() + + if asynclib == "asyncio": + return asyncio.get_running_loop() + if asynclib == "trio": + return trio.lowlevel.current_trio_token() + raise RuntimeError(f"unsupported library {asynclib}") + + +@trio.lowlevel.disable_ki_protection +async def wrap_awaitable(awaitable: Awaitable[_R]) -> _R: + return await awaitable + + +def create_task_threadsafe(loop: Loop, awaitable: Coro[_R]) -> None: + if isinstance(loop, trio.lowlevel.TrioToken): + try: + loop.run_sync_soon( + trio.lowlevel.spawn_system_task, + wrap_awaitable, + awaitable, + ) + except trio.RunFinishedError: + raise RuntimeError("trio loop no-longer running") + return + + _asyncio.create_task_threadsafe(loop, awaitable) + + +ExcInfo = Union[ + tuple[type[BaseException], BaseException, types.TracebackType], + tuple[None, None, None], +] + + +class ThreadHandlerType(Protocol, Generic[_R]): + def __call__( + self, + loop: Loop, + exc_info: ExcInfo, + task_context: TaskContext, + func: Callable[[Callable[[], _R]], _R], + child: Callable[[], _R], + ) -> _R: + ... + + +async def run_in_executor( + *, + loop: Loop, + executor: concurrent.futures.ThreadPoolExecutor, + thread_handler: ThreadHandlerType[_R], + child: Callable[[], _R], +) -> _R: + if isinstance(loop, trio.lowlevel.TrioToken): + context = contextvars.copy_context() + func = context.run + task_context: TaskContext = [] + + # Run the code in the right thread + full_func = functools.partial( + thread_handler, + loop, + sys.exc_info(), + task_context, + func, + child, + ) + try: + if executor is None: + + async def handle_cancel() -> None: + try: + await trio.sleep_forever() + except trio.Cancelled: + if task_context: + task_context[0].cancel() + raise + + async with trio.open_nursery() as nursery: + nursery.start_soon(handle_cancel) + try: + return await trio.to_thread.run_sync( + full_func, abandon_on_cancel=False + ) + except TrioThreadCancelled: + pass + finally: + nursery.cancel_scope.cancel() + assert False + else: + event = trio.Event() + + def callback(fut: object) -> None: + loop.run_sync_soon(event.set) + + fut = executor.submit(full_func) + fut.add_done_callback(callback) + + async def handle_cancel_fut() -> None: + try: + await trio.sleep_forever() + except trio.Cancelled: + fut.cancel() + if task_context: + task_context[0].cancel() + raise + + async with trio.open_nursery() as nursery: + nursery.start_soon(handle_cancel_fut) + with trio.CancelScope(shield=True): + await event.wait() + nursery.cancel_scope.cancel() + try: + return fut.result() + except TrioThreadCancelled: + pass + assert False + finally: + _restore_context(context) + + else: + return await _asyncio.run_in_executor( + loop=loop, executor=executor, thread_handler=thread_handler, child=child + ) + + +async def wrap_task_context( + loop: Loop, task_context: Union[TaskContext, None], awaitable: Awaitable[_R] +) -> _R: + if task_context is None: + return await awaitable + + if isinstance(loop, trio.lowlevel.TrioToken): + with trio.CancelScope() as scope: + task_context.append(scope) + try: + return await awaitable + finally: + task_context.remove(scope) + raise TrioThreadCancelled + + return await _asyncio.wrap_task_context(loop, task_context, awaitable) diff --git a/asgiref/local.py b/asgiref/local.py index 7d228aeb..11721c9f 100644 --- a/asgiref/local.py +++ b/asgiref/local.py @@ -5,6 +5,32 @@ from typing import Any, Union +def _is_asyncio_running(): + try: + asyncio.get_running_loop() + except RuntimeError: + return False + else: + return True + + +try: + import sniffio +except ModuleNotFoundError: + _is_async = _is_asyncio_running +else: + + def _is_async(): + try: + sniffio.current_async_library() + except sniffio.AsyncLibraryNotFoundError: + pass + else: + return True + + return _is_asyncio_running() + + class _CVar: """Storage utility for Local.""" @@ -83,18 +109,9 @@ def __init__(self, thread_critical: bool = False) -> None: def _lock_storage(self): # Thread safe access to storage if self._thread_critical: - try: - # this is a test for are we in a async or sync - # thread - will raise RuntimeError if there is - # no current loop - asyncio.get_running_loop() - except RuntimeError: - # We are in a sync thread, the storage is - # just the plain thread local (i.e, "global within - # this thread" - it doesn't matter where you are - # in a call stack you see the same storage) - yield self._storage - else: + # this is a test for are we in a async or sync + # thread + if _is_async(): # We are in an async thread - storage is still # local to this thread, but additionally should # behave like a context var (is only visible with @@ -108,6 +125,12 @@ def _lock_storage(self): # can't be accessed in another thread (we don't # need any locks) yield self._storage.cvar + else: + # We are in a sync thread, the storage is + # just the plain thread local (i.e, "global within + # this thread" - it doesn't matter where you are + # in a call stack you see the same storage) + yield self._storage else: # Lock for thread_critical=False as other threads # can access the exact same storage object diff --git a/asgiref/sync.py b/asgiref/sync.py index 0c6ea98e..5de00983 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -1,6 +1,6 @@ -import asyncio import asyncio.coroutines import contextvars +import enum import functools import inspect import os @@ -24,6 +24,7 @@ overload, ) +from ._context import restore_context as _restore_context from .current_thread_executor import CurrentThreadExecutor from .local import Local @@ -36,23 +37,35 @@ # This is not available to import at runtime from _typeshed import OptExcInfo + from ._trio import ( + create_task_threadsafe, + get_running_loop, + run_in_executor, + wrap_task_context, + ) +else: + try: + __import__("trio") + except ModuleNotFoundError: + from ._asyncio import ( + create_task_threadsafe, + get_running_loop, + run_in_executor, + wrap_task_context, + ) + else: + from ._trio import ( + create_task_threadsafe, + get_running_loop, + run_in_executor, + wrap_task_context, + ) + _F = TypeVar("_F", bound=Callable[..., Any]) _P = ParamSpec("_P") _R = TypeVar("_R") -def _restore_context(context: contextvars.Context) -> None: - # Check for changes in contextvars, and set them to the current - # context for downstream consumers - for cvar in context: - cvalue = context.get(cvar) - try: - if cvar.get() != cvalue: - cvar.set(cvalue) - except LookupError: - cvar.set(cvalue) - - # Python 3.12 deprecates asyncio.iscoroutinefunction() as an alias for # inspect.iscoroutinefunction(), whilst also removing the _is_coroutine marker. # The latter is replaced with the inspect.markcoroutinefunction decorator. @@ -110,6 +123,19 @@ async def __aexit__(self, exc, value, tb): SyncToAsync.thread_sensitive_context.reset(self.token) +class LoopType(enum.Enum): + ASYNCIO = enum.auto() + TRIO = enum.auto() + + +def run(async_backend, callable, /, *args): + if async_backend is LoopType.TRIO: + import trio + + return trio.run(callable, *args) + return asyncio.run(callable(*args)) + + class AsyncToSync(Generic[_P, _R]): """ Utility class which turns an awaitable that only works on the thread with @@ -129,7 +155,7 @@ class AsyncToSync(Generic[_P, _R]): # When we can't find a CurrentThreadExecutor from the context, such as # inside create_task, we'll look it up here from the running event loop. - loop_thread_executors: "Dict[asyncio.AbstractEventLoop, CurrentThreadExecutor]" = {} + loop_thread_executors: "Dict[object, CurrentThreadExecutor]" = {} def __init__( self, @@ -137,8 +163,11 @@ def __init__( Callable[_P, Coroutine[Any, Any, _R]], Callable[_P, Awaitable[_R]], ], - force_new_loop: bool = False, + force_new_loop: Union[LoopType, bool] = False, ): + if force_new_loop and not isinstance(force_new_loop, LoopType): + force_new_loop = LoopType.ASYNCIO + if not callable(awaitable) or ( not iscoroutinefunction(awaitable) and not iscoroutinefunction(getattr(awaitable, "__call__", awaitable)) @@ -156,7 +185,7 @@ def __init__( self.force_new_loop = force_new_loop self.main_event_loop = None try: - self.main_event_loop = asyncio.get_running_loop() + self.main_event_loop = get_running_loop() except RuntimeError: # There's no event loop in this thread. pass @@ -179,7 +208,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: # You can't call AsyncToSync from a thread with a running event loop try: - asyncio.get_running_loop() + get_running_loop() except RuntimeError: pass else: @@ -224,7 +253,7 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: ) async def new_loop_wrap() -> None: - loop = asyncio.get_running_loop() + loop = get_running_loop() self.loop_thread_executors[loop] = current_executor try: await awaitable @@ -233,8 +262,9 @@ async def new_loop_wrap() -> None: if self.main_event_loop is not None: try: - self.main_event_loop.call_soon_threadsafe( - self.main_event_loop.create_task, awaitable + create_task_threadsafe( + self.main_event_loop, + awaitable, ) except RuntimeError: running_in_main_event_loop = False @@ -248,7 +278,9 @@ async def new_loop_wrap() -> None: if not running_in_main_event_loop: # Make our own event loop - in a new thread - and run inside that. loop_executor = ThreadPoolExecutor(max_workers=1) - loop_future = loop_executor.submit(asyncio.run, new_loop_wrap()) + loop_future = loop_executor.submit( + run, self.force_new_loop, new_loop_wrap + ) # Run the CurrentThreadExecutor until the future is done. current_executor.run_until_future(loop_future) # Wait for future and/or allow for exception propagation @@ -283,13 +315,11 @@ async def main_wrap( __traceback_hide__ = True # noqa: F841 + loop = get_running_loop() if context is not None: _restore_context(context[0]) - current_task = asyncio.current_task() - if current_task is not None and task_context is not None: - task_context.append(current_task) - + result: _R try: # If we have an exception, run the function inside the except block # after raising it so exc_info is correctly populated. @@ -297,16 +327,14 @@ async def main_wrap( try: raise exc_info[1] except BaseException: - result = await awaitable + result = await wrap_task_context(loop, task_context, awaitable) else: - result = await awaitable + result = await wrap_task_context(loop, task_context, awaitable) except BaseException as e: call_result.set_exception(e) else: call_result.set_result(result) finally: - if current_task is not None and task_context is not None: - task_context.remove(current_task) context[0] = contextvars.copy_context() @@ -382,7 +410,7 @@ def __init__( async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: __traceback_hide__ = True # noqa: F841 - loop = asyncio.get_running_loop() + loop = get_running_loop() # Work out what thread to run the code in if self._thread_sensitive: @@ -417,49 +445,16 @@ async def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: # Use the passed in executor, or the loop's default if it is None executor = self._executor - context = contextvars.copy_context() - child = functools.partial(self.func, *args, **kwargs) - func = context.run - task_context: List[asyncio.Task[Any]] = [] - - # Run the code in the right thread - exec_coro = loop.run_in_executor( - executor, - functools.partial( - self.thread_handler, - loop, - sys.exc_info(), - task_context, - func, - child, - ), - ) - ret: _R try: - ret = await asyncio.shield(exec_coro) - except asyncio.CancelledError: - cancel_parent = True - try: - task = task_context[0] - task.cancel() - try: - await task - cancel_parent = False - except asyncio.CancelledError: - pass - except IndexError: - pass - if exec_coro.done(): - raise - if cancel_parent: - exec_coro.cancel() - ret = await exec_coro + return await run_in_executor( + loop=loop, + executor=executor, + thread_handler=self.thread_handler, + child=functools.partial(self.func, *args, **kwargs), + ) finally: - _restore_context(context) self.deadlock_context.set(False) - return ret - def __get__( self, parent: Any, objtype: Any ) -> Callable[_P, Coroutine[Any, Any, _R]]: @@ -496,7 +491,7 @@ def thread_handler(self, loop, exc_info, task_context, func, *args, **kwargs): @overload def async_to_sync( *, - force_new_loop: bool = False, + force_new_loop: Union[LoopType, bool] = False, ) -> Callable[ [Union[Callable[_P, Coroutine[Any, Any, _R]], Callable[_P, Awaitable[_R]]]], Callable[_P, _R], @@ -511,7 +506,7 @@ def async_to_sync( Callable[_P, Awaitable[_R]], ], *, - force_new_loop: bool = False, + force_new_loop: Union[LoopType, bool] = False, ) -> Callable[_P, _R]: ... @@ -524,7 +519,7 @@ def async_to_sync( ] ] = None, *, - force_new_loop: bool = False, + force_new_loop: Union[LoopType, bool] = False, ) -> Union[ Callable[ [Union[Callable[_P, Coroutine[Any, Any, _R]], Callable[_P, Awaitable[_R]]]], diff --git a/setup.cfg b/setup.cfg index ef0a4314..8d10df10 100644 --- a/setup.cfg +++ b/setup.cfg @@ -39,6 +39,7 @@ zip_safe = false tests = pytest pytest-asyncio + anyio[trio] mypy>=1.14.0 [tool:pytest] diff --git a/tests/test_sync.py b/tests/test_sync.py index 0c67308c..974e1a80 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1,7 +1,6 @@ import asyncio import functools import multiprocessing -import sys import threading import time import warnings @@ -10,7 +9,9 @@ from typing import Any from unittest import TestCase +import anyio import pytest +import trio.to_thread from asgiref.sync import ( ThreadSensitiveContext, @@ -21,7 +22,7 @@ from asgiref.timeout import timeout -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async(): """ Tests we can call sync functions from an async thread @@ -41,6 +42,16 @@ def sync_function(): end = time.monotonic() assert result == 42 assert end - start >= 1 + + +@pytest.mark.asyncio +async def test_sync_to_async_one_worker(): + # Define sync function + @sync_to_async + def async_function(): + time.sleep(1) + return 42 + # Set workers to 1, call it twice and make sure that works right loop = asyncio.get_running_loop() old_executor = loop._default_executor or ThreadPoolExecutor() @@ -72,7 +83,7 @@ def test_sync_to_async_fail_non_function(): ) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_fail_async(): """ sync_to_async raises a TypeError when applied to a sync function. @@ -88,7 +99,7 @@ async def test_function(): ) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_to_sync_fail_partial(): """ sync_to_async raises a TypeError when applied to a sync partial. @@ -106,7 +117,7 @@ async def test_function(*args): ) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_raises_typeerror_for_async_callable_instance(): class CallableClass: async def __call__(self): @@ -118,7 +129,7 @@ async def __call__(self): sync_to_async(CallableClass()) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_decorator(): """ Tests sync_to_async as a decorator @@ -134,7 +145,7 @@ def test_function(): assert result == 43 -@pytest.mark.asyncio +@pytest.mark.anyio async def test_nested_sync_to_async_retains_wrapped_function_attributes(): """ Tests that attributes of functions wrapped by sync_to_async are retained @@ -157,7 +168,7 @@ def test_function(): assert test_function.__name__ == "test_function" -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_method_decorator(): """ Tests sync_to_async as a method decorator @@ -175,7 +186,7 @@ def test_method(self): assert result == 44 -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_method_self_attribute(): """ Tests sync_to_async on a method copies __self__ @@ -197,7 +208,7 @@ def test_method(self): assert method.__self__ == instance -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_to_sync_to_async(): """ Tests we can call async functions from a sync thread created by async_to_sync @@ -225,7 +236,7 @@ def sync_function(): assert result["thread"] == threading.current_thread() -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_to_sync_to_async_decorator(): """ Test async_to_sync as a function decorator uses the outer thread @@ -253,9 +264,8 @@ def sync_function(): assert result["thread"] == threading.current_thread() -@pytest.mark.asyncio -@pytest.mark.skipif(sys.version_info < (3, 9), reason="requires python3.9") -async def test_async_to_sync_to_thread_decorator(): +@pytest.mark.anyio +async def test_async_to_sync_to_thread_decorator(anyio_backend_name): """ Test async_to_sync as a function decorator uses the outer thread when used inside another sync thread. @@ -270,7 +280,10 @@ async def inner_async_function(): return 42 # Check it works right - number = await asyncio.to_thread(inner_async_function) + if anyio_backend_name == "trio": + number = await trio.to_thread.run_sync(inner_async_function) + else: + number = await asyncio.to_thread(inner_async_function) assert number == 42 assert result["worked"] # Make sure that it didn't needlessly make a new async loop @@ -363,7 +376,7 @@ async def test_function(self): assert result["worked"] -@pytest.mark.asyncio +@pytest.mark.anyio async def test_async_to_sync_in_async(): """ Makes sure async_to_sync bails if you try to call it from an async loop @@ -509,7 +522,7 @@ def inner_task(): assert result["thread2"] == threading.current_thread() -@pytest.mark.asyncio +@pytest.mark.anyio async def test_thread_sensitive_outside_async(): """ Tests that thread_sensitive SyncToAsync where the outside is async code runs @@ -535,16 +548,16 @@ def inner(result): result["thread"] = threading.current_thread() # Run it (in supposed parallel!) - await asyncio.wait( - [asyncio.create_task(outer(result_1)), asyncio.create_task(inner(result_2))] - ) + async with anyio.create_task_group() as tg: + tg.start_soon(outer, result_1) + await inner(result_2) # They should not have run in the main thread, but in the same thread assert result_1["thread"] != threading.current_thread() assert result_1["thread"] == result_2["thread"] -@pytest.mark.asyncio +@pytest.mark.anyio async def test_thread_sensitive_with_context_matches(): result_1 = {} result_2 = {} @@ -557,12 +570,9 @@ def store_thread(result): async def fn(): async with ThreadSensitiveContext(): # Run it (in supposed parallel!) - await asyncio.wait( - [ - asyncio.create_task(store_thread_async(result_1)), - asyncio.create_task(store_thread_async(result_2)), - ] - ) + async with anyio.create_task_group() as tg: + tg.start_soon(store_thread_async, result_1) + await store_thread_async(result_2) await fn() @@ -571,7 +581,7 @@ async def fn(): assert result_1["thread"] == result_2["thread"] -@pytest.mark.asyncio +@pytest.mark.anyio async def test_thread_sensitive_nested_context(): result_1 = {} result_2 = {} @@ -590,7 +600,7 @@ def store_thread(result): assert result_1["thread"] == result_2["thread"] -@pytest.mark.asyncio +@pytest.mark.anyio async def test_thread_sensitive_context_without_sync_work(): async with ThreadSensitiveContext(): pass @@ -629,7 +639,7 @@ def level4(): assert result["thread"] == threading.current_thread() -@pytest.mark.asyncio +@pytest.mark.anyio async def test_thread_sensitive_double_nested_async(): """ Tests that thread_sensitive SyncToAsync nests inside itself where the @@ -729,7 +739,7 @@ def fork_first(): return queue.get(True, 1) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_multiprocessing(): """ Tests that a forked process can use async_to_sync without it looking for @@ -738,7 +748,7 @@ async def test_multiprocessing(): assert await sync_to_async(fork_first)() == 42 -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_uses_executor(): """ Tests that SyncToAsync uses the passed in executor correctly. @@ -834,7 +844,7 @@ async def async_process_that_triggers_event(): await trigger_task -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_with_blocker_non_thread_sensitive(): """ Tests sync_to_async running on a long-time blocker in a non_thread_sensitive context. @@ -850,23 +860,20 @@ async def async_process_waiting_on_event(): async def async_process_that_triggers_event(): """Sleep, then set the event.""" - await asyncio.sleep(1) + await anyio.sleep(1) await sync_to_async(event.set)() - # Run the event setter as a task. - trigger_task = asyncio.ensure_future(async_process_that_triggers_event()) + async with anyio.create_task_group() as tg: + # Run the event setter as a task. + tg.start_soon(async_process_that_triggers_event) - try: - # wait on the event waiter, which is now blocking the event setter. - async with timeout(delay + 1): - assert await async_process_waiting_on_event() == 42 - except asyncio.TimeoutError: - # In case of timeout, set the event to unblock things, else - # downstream tests will get fouled up. - event.set() - raise - finally: - await trigger_task + try: + with anyio.fail_after(delay + 1): + assert await async_process_waiting_on_event() == 42 + except TimeoutError: + # In case of timeout, set the event to unblock things, else + # downstream tests will get fouled up. + event.set() @pytest.mark.asyncio @@ -1194,7 +1201,7 @@ async def test_function(**kwargs: Any) -> None: test_function(context=1) -@pytest.mark.asyncio +@pytest.mark.anyio async def test_sync_to_async_overlapping_kwargs() -> None: """ Tests that SyncToAsync correctly passes through kwargs to the wrapped function, diff --git a/tox.ini b/tox.ini index 49c49bce..1eb1cc7c 100644 --- a/tox.ini +++ b/tox.ini @@ -1,6 +1,6 @@ [tox] envlist = - py{38,39,310,311,312,313}-{test,mypy} + py{38,39,310,311,312,313}-{test,mypy,trio} qa [testenv]