Skip to content

Commit

Permalink
cinder compatability for differences in all_tasks
Browse files Browse the repository at this point in the history
Summary:
cinder runtime's tasks.all_tasks has some mechanism where done() tasks don't show up in all_tasks but we need them to check managed tasks so instead we track our own tasks in our task class for tests.

This seems to make all the runtimes happy.

Reviewed By: aleivag

Differential Revision: D51681518

fbshipit-source-id: 455484fc9a52f7f07135ff2c52d569979ecd15f5
  • Loading branch information
fried authored and facebook-github-bot committed Nov 29, 2023
1 parent ec92482 commit c4e6cd9
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions later/unittest/case.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import asyncio.tasks
import sys
import unittest.mock as mock
import weakref
from functools import wraps
from typing import Any, Callable, TypeVar

Expand All @@ -39,6 +40,7 @@
_IGNORE_TASK_LEAKS_ATTR = "__later_testcase_ignore_tasks__"
_IGNORE_AIO_ERRS_ATTR = "__later_testcase_ignore_asyncio__"
atleastpy38: bool = sys.version_info[:2] >= (3, 8)
_unmanaged_tasks: weakref.WeakSet[asyncio.Task] = weakref.WeakSet()


class TestTask(asyncio.Task):
Expand All @@ -48,6 +50,7 @@ class TestTask(asyncio.Task):
def __init__(self, coro, *args, **kws) -> None:
# pyre-fixme[16]: Module `coroutines` has no attribute `_format_coroutine`.
self._coro_repr = asyncio.coroutines._format_coroutine(coro)
_unmanaged_tasks.add(self)
super().__init__(coro, *args, **kws)

def __repr__(self) -> str:
Expand All @@ -60,24 +63,29 @@ def __repr__(self) -> str:
repr_info[1] = coro # pragma: nocover
return f"<{self.__class__.__name__} {' '.join(repr_info)}>"

def _mark_managed(self):
if not self._managed:
self._managed = True
_unmanaged_tasks.remove(self)

def __await__(self):
self._managed = True
self._mark_managed()
return super().__await__()

def result(self):
if self.done():
self._managed = True
self._mark_managed()
return super().result()

def exception(self):
if self.done():
self._managed = True
self._mark_managed()
return super().exception()

def add_done_callback(self, fn, *, context=None) -> None:
@wraps(fn)
def mark_managed(fut):
self._managed = True
self._mark_managed()
return fn(fut)

super().add_done_callback(mark_managed, context=context)
Expand Down Expand Up @@ -135,7 +143,7 @@ def all_tasks(loop):
i = 0
while True:
try:
tasks = list(_all_tasks)
tasks = list(_all_tasks) + list(_unmanaged_tasks)
except RuntimeError: # pragma: nocover
i += 1
if i >= 1000:
Expand Down

0 comments on commit c4e6cd9

Please sign in to comment.