diff --git a/asgiref/current_thread_executor.py b/asgiref/current_thread_executor.py index 67a7926f..4ecc4585 100644 --- a/asgiref/current_thread_executor.py +++ b/asgiref/current_thread_executor.py @@ -55,7 +55,8 @@ class CurrentThreadExecutor(Executor): def __init__(self) -> None: self._work_thread = threading.current_thread() - self._work_queue: queue.Queue[Union[_WorkItem, "Future[Any]"]] = queue.Queue() + self._work_queue: queue.Queue[Union[_WorkItem, None]] = queue.Queue() + self._level = 0 self._broken = False def run_until_future(self, future: "Future[Any]") -> None: @@ -68,20 +69,21 @@ def run_until_future(self, future: "Future[Any]") -> None: raise RuntimeError( "You cannot run CurrentThreadExecutor from a different thread" ) - future.add_done_callback(self._work_queue.put) - # Keep getting and running work items until we get the future we're waiting for - # back via the future's done callback. + future.add_done_callback(lambda future: self._work_queue.put(None)) + self._level += 1 + # Keep getting and running work items until the future we're waiting for is + # done. try: - while True: + while not future.done(): # Get a work item and run it work_item = self._work_queue.get() - if work_item is future: - return - assert isinstance(work_item, _WorkItem) + if work_item is None: + continue work_item.run() del work_item finally: - self._broken = True + self._level -= 1 + self._broken = self._level == 0 def _submit( self, diff --git a/asgiref/sync.py b/asgiref/sync.py index 377075d1..9bfaf182 100644 --- a/asgiref/sync.py +++ b/asgiref/sync.py @@ -192,12 +192,14 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _R: # Make a future for the return information call_result: "Future[_R]" = Future() - # Make a CurrentThreadExecutor we'll use to idle in this thread - we - # need one for every sync frame, even if there's one above us in the - # same thread. + # Make a CurrentThreadExecutor we'll use to idle in this thread, unless + # we can reuse one above us in the same thread. old_executor = getattr(self.executors, "current", None) - current_executor = CurrentThreadExecutor() - self.executors.current = current_executor + if old_executor is None: + current_executor = CurrentThreadExecutor() + self.executors.current = current_executor + else: + current_executor = old_executor # Wrapping context in list so it can be reassigned from within # `main_wrap`. diff --git a/tests/test_sync.py b/tests/test_sync.py index 0c67308c..11667413 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1208,3 +1208,33 @@ def test_function(**kwargs: Any) -> None: # SyncToAsync.__call__.loop.run_in_executor has a param named `task_context`. await test_function(task_context=1) + + +def test_nested_task() -> None: + async def inner() -> asyncio.Task[None]: + return asyncio.create_task(sync_to_async(print)("inner")) + + async def main() -> None: + task = await sync_to_async(async_to_sync(inner))() + await task + + async_to_sync(main)() + + +def test_nested_task_later() -> None: + def later(fut: asyncio.Future[asyncio.Task[None]]) -> None: + task = asyncio.create_task(sync_to_async(print)("later")) + fut.set_result(task) + + async def inner() -> asyncio.Future[asyncio.Task[None]]: + loop = asyncio.get_running_loop() + fut = loop.create_future() + loop.call_later(0.1, later, fut) + return fut + + async def main() -> None: + fut = await sync_to_async(async_to_sync(inner))() + task = await fut + await task + + async_to_sync(main)()