From c150a41125a2d4d10e9eb43272fe8e467da96273 Mon Sep 17 00:00:00 2001 From: generatedunixname89002005287564 Date: Sun, 20 Oct 2024 14:32:25 -0700 Subject: [PATCH] later Reviewed By: itamaro Differential Revision: D64526256 fbshipit-source-id: 39354717bb2e2ad6ff403a487d4e84506cbc5341 --- later/task.py | 59 ++++++++++++++----------------- later/tests/unittest/test_case.py | 2 +- later/unittest/case.py | 11 ++---- 3 files changed, 30 insertions(+), 42 deletions(-) diff --git a/later/task.py b/later/task.py index 0745d53..368dc61 100644 --- a/later/task.py +++ b/later/task.py @@ -20,6 +20,7 @@ import functools import logging import threading +from collections.abc import Awaitable, Callable, Coroutine, Hashable, Mapping, Sequence from functools import partial, wraps from inspect import isawaitable @@ -27,20 +28,14 @@ from typing import ( AbstractSet, Any, - Awaitable, - Callable, cast, - Coroutine, Dict, - Hashable, List, - Mapping, NewType, Optional, overload, ParamSpec, Protocol, - Sequence, Tuple, Type, TypeVar, @@ -91,7 +86,7 @@ async def cancel(fut: asyncio.Future) -> None: if fut.done(): return # nothing to do fut.cancel() - exc: Optional[asyncio.CancelledError] = None + exc: asyncio.CancelledError | None = None while not fut.done(): shielded = asyncio.shield(fut) try: @@ -156,13 +151,13 @@ class WatcherError(RuntimeError): class Watcher: - _tasks: Dict[asyncio.Future, Optional[FixerType]] - _scheduled: List[FixerType] + _tasks: dict[asyncio.Future, FixerType | None] + _scheduled: list[FixerType] _tasks_changed: BiDirectionalEvent _cancelled: asyncio.Event _cancel_timeout: float - _preexit_callbacks: List[Callable[[], None]] - _shielded_tasks: Dict[asyncio.Task, asyncio.Future] + _preexit_callbacks: list[Callable[[], None]] + _shielded_tasks: dict[asyncio.Task, asyncio.Future] # pyre-ignore[13]: loop is initialized in __aenter__ loop: asyncio.AbstractEventLoop running: bool @@ -188,12 +183,12 @@ def __init__( if context: WATCHER_CONTEXT.set(self) self._cancel_timeout = cancel_timeout - self._tasks: Dict[asyncio.Future, Optional[FixerType]] = {} - self._scheduled: List[FixerType] = [] + self._tasks: dict[asyncio.Future, FixerType | None] = {} + self._scheduled: list[FixerType] = [] self._tasks_changed = BiDirectionalEvent() self._cancelled = asyncio.Event() self._preexit_callbacks = [] - self._shielded_tasks: Dict[asyncio.Task, asyncio.Future] = {} + self._shielded_tasks: dict[asyncio.Task, asyncio.Future] = {} self.running = False self.done_ok = done_ok @@ -213,7 +208,7 @@ async def _run_scheduled(self) -> None: async def unwatch( self, task: asyncio.Task = START_TASK, - fixer: Optional[FixerType] = None, + fixer: FixerType | None = None, *, shield: bool = False, ) -> bool: @@ -258,7 +253,7 @@ async def tasks_changed() -> None: def watch( self, task: asyncio.Task = START_TASK, - fixer: Optional[FixerType] = None, + fixer: FixerType | None = None, *, shield: bool = False, ) -> None: @@ -325,16 +320,16 @@ def _run_preexit_callbacks(self) -> None: f"ignoring exception from pre-exit callback {callback}: {e}" ) - async def __aenter__(self) -> "Watcher": + async def __aenter__(self) -> Watcher: WATCHER_CONTEXT.set(self) self.loop = asyncio.get_running_loop() return self async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc: Optional[BaseException], - tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc: BaseException | None, + tb: TracebackType | None, ) -> bool: cancel_task: asyncio.Task = self.loop.create_task(self._cancelled.wait()) changed_task: asyncio.Task = START_TASK @@ -411,7 +406,7 @@ async def _handle_cancel(self) -> None: task.cancel() done, pending = await asyncio.wait(tasks, timeout=self._cancel_timeout) - bad_tasks: List[asyncio.Future] = [] + bad_tasks: list[asyncio.Future] = [] for task in done: if task.cancelled(): continue @@ -433,26 +428,26 @@ async def _handle_cancel(self) -> None: class _CountTask: """So herd can track herd size and task together for cancellation""" - task: Optional[asyncio.Task] = None + task: asyncio.Task | None = None count: int = 0 -def _get_local(local: threading.local, field: str) -> Dict[CacheKey, object]: +def _get_local(local: threading.local, field: str) -> dict[CacheKey, object]: """ helper for attempting to fetch a named attr from a threading.local """ try: - return cast(Dict[CacheKey, object], getattr(local, field)) + return cast(dict[CacheKey, object], getattr(local, field)) except AttributeError: - container: Dict[CacheKey, object] = {} + container: dict[CacheKey, object] = {} setattr(local, field, container) return container def _build_key( - args: Tuple[object, ...], + args: tuple[object, ...], kwargs: Mapping[str, object], - ignored_args: Optional[AbstractSet[ArgID]] = None, + ignored_args: AbstractSet[ArgID] | None = None, ) -> CacheKey: """ Build a key for caching Hashable args and kwargs. @@ -466,7 +461,7 @@ def _build_key( ( tuple((value for idx, value in enumerate(args) if idx not in ignored_args)), tuple( - (item for item in sorted(kwargs.items()) if item[0] not in ignored_args) + item for item in sorted(kwargs.items()) if item[0] not in ignored_args ), ) ) @@ -483,7 +478,7 @@ def __call__( def herd( fn: Callable[TParams, Coroutine[object, object, T]], *, - ignored_args: Optional[AbstractSet[ArgID]] = None, + ignored_args: AbstractSet[ArgID] | None = None, ) -> Callable[TParams, Coroutine[object, object, T]]: # pragma: nocover ... @@ -492,7 +487,7 @@ def herd( def herd( fn: None = None, *, - ignored_args: Optional[AbstractSet[ArgID]] = None, + ignored_args: AbstractSet[ArgID] | None = None, ) -> AsyncCallable: # pragma: nocover ... @@ -500,7 +495,7 @@ def herd( def herd( fn: Callable[TParams, Coroutine[object, object, T]] | None = None, *, - ignored_args: Optional[AbstractSet[ArgID]] = None, + ignored_args: AbstractSet[ArgID] | None = None, ) -> ( Callable[TParams, Coroutine[object, object, T]] | Callable[ @@ -528,7 +523,7 @@ def decorator( @functools.wraps(fn) async def wrapped(*args: TParams.args, **kwargs: TParams.kwargs) -> T: - pending = cast(Dict[CacheKey, _CountTask], _get_local(local, "pending")) + pending = cast(dict[CacheKey, _CountTask], _get_local(local, "pending")) request = _build_key(tuple(args), kwargs, ignored_args) count_task = pending.setdefault(request, _CountTask()) count_task.count += 1 diff --git a/later/tests/unittest/test_case.py b/later/tests/unittest/test_case.py index 577c758..5af4b10 100644 --- a/later/tests/unittest/test_case.py +++ b/later/tests/unittest/test_case.py @@ -23,7 +23,7 @@ # This is a place to purposefully produce leaked tasks -saved_tasks: List[asyncio.Task] = [] +saved_tasks: list[asyncio.Task] = [] class TestTestCase(TestCase): diff --git a/later/unittest/case.py b/later/unittest/case.py index b306243..8abfadf 100644 --- a/later/unittest/case.py +++ b/later/unittest/case.py @@ -30,17 +30,10 @@ import sys import unittest.mock as mock import weakref +from collections.abc import Callable, Coroutine, Generator from contextvars import Context from functools import wraps -from typing import ( - AbstractSet, - Callable, - Coroutine, - Generator, - Generic, - TYPE_CHECKING, - TypeVar, -) +from typing import AbstractSet, Generic, TYPE_CHECKING, TypeVar from unittest import IsolatedAsyncioTestCase as AsyncioTestCase