diff --git a/later/unittest/case.py b/later/unittest/case.py index 3d2b771..e4c55e1 100644 --- a/later/unittest/case.py +++ b/later/unittest/case.py @@ -25,6 +25,7 @@ import asyncio.tasks import sys import unittest.mock as mock +import weakref from functools import wraps from typing import Any, Callable, TypeVar @@ -39,6 +40,7 @@ _IGNORE_TASK_LEAKS_ATTR = "__later_testcase_ignore_tasks__" _IGNORE_AIO_ERRS_ATTR = "__later_testcase_ignore_asyncio__" atleastpy38: bool = sys.version_info[:2] >= (3, 8) +_unmanaged_tasks: weakref.WeakSet[asyncio.Task] = weakref.WeakSet() class TestTask(asyncio.Task): @@ -48,6 +50,7 @@ class TestTask(asyncio.Task): def __init__(self, coro, *args, **kws) -> None: # pyre-fixme[16]: Module `coroutines` has no attribute `_format_coroutine`. self._coro_repr = asyncio.coroutines._format_coroutine(coro) + _unmanaged_tasks.add(self) super().__init__(coro, *args, **kws) def __repr__(self) -> str: @@ -60,24 +63,29 @@ def __repr__(self) -> str: repr_info[1] = coro # pragma: nocover return f"<{self.__class__.__name__} {' '.join(repr_info)}>" + def _mark_managed(self): + if not self._managed: + self._managed = True + _unmanaged_tasks.remove(self) + def __await__(self): - self._managed = True + self._mark_managed() return super().__await__() def result(self): if self.done(): - self._managed = True + self._mark_managed() return super().result() def exception(self): if self.done(): - self._managed = True + self._mark_managed() return super().exception() def add_done_callback(self, fn, *, context=None) -> None: @wraps(fn) def mark_managed(fut): - self._managed = True + self._mark_managed() return fn(fut) super().add_done_callback(mark_managed, context=context) @@ -135,7 +143,7 @@ def all_tasks(loop): i = 0 while True: try: - tasks = list(_all_tasks) + tasks = list(_all_tasks) + list(_unmanaged_tasks) except RuntimeError: # pragma: nocover i += 1 if i >= 1000: