From 0cf4fc7be850ed3c5e79729e38e6a45b8b994402 Mon Sep 17 00:00:00 2001 From: Michael Welborn Date: Mon, 13 Jan 2025 15:00:14 -0600 Subject: [PATCH] Support decorating async coroutines with retry --- indico_toolkit/retry.py | 91 ++++++++++++++++++++++++++++++----------- pyproject.toml | 7 ++-- requirements.txt | 5 ++- tests/test_retry.py | 71 +++++++++++++++++++++++++------- 4 files changed, 130 insertions(+), 44 deletions(-) diff --git a/indico_toolkit/retry.py b/indico_toolkit/retry.py index 4d90811a..24362344 100644 --- a/indico_toolkit/retry.py +++ b/indico_toolkit/retry.py @@ -1,13 +1,17 @@ +import asyncio import time from functools import wraps +from inspect import iscoroutinefunction from random import random -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, overload if TYPE_CHECKING: - from collections.abc import Callable - from typing import TypeVar + from collections.abc import Awaitable, Callable + from typing import ParamSpec, TypeVar - ReturnType = TypeVar("ReturnType") + ArgumentsType = ParamSpec("ArgumentsType") + OuterReturnType = TypeVar("OuterReturnType") + InnerReturnType = TypeVar("InnerReturnType") class MaxRetriesExceeded(Exception): @@ -22,14 +26,14 @@ def retry( wait: float = 1, backoff: float = 4, jitter: float = 0.5, -) -> "Callable[[Callable[..., ReturnType]], Callable[..., ReturnType]]": +) -> "Callable[[Callable[ArgumentsType, OuterReturnType]], Callable[ArgumentsType, OuterReturnType]]": # noqa: E501 """ - Decorate a function to automatically retry when it raises specific errors, + Decorate a function or coroutine to retry when it raises specified errors, apply exponential backoff and jitter to the wait time, and raise `MaxRetriesExceeded` after it retries too many times. - By default, the decorated method will be retried up to 4 times over the course of - ~2 minutes (waiting 1, 4, 16, and 64 seconds; plus up to 50% jitter) + By default, the decorated function or coroutine will be retried up to 4 times over + the course of ~2 minutes (waiting 1, 4, 16, and 64 seconds; plus up to 50% jitter) before raising `MaxRetriesExceeded` from the last error. Arguments: @@ -41,22 +45,61 @@ def retry( to the wait time to prevent simultaneous retries. """ + def wait_time(times_retried: int) -> float: + """ + Calculate the sleep time based on number of times retried. + """ + return wait * backoff**times_retried * (1 + jitter * random()) + + @overload + def retry_decorator( + decorated: "Callable[ArgumentsType, Awaitable[InnerReturnType]]", + ) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]]": ... + @overload + def retry_decorator( + decorated: "Callable[ArgumentsType, InnerReturnType]", + ) -> "Callable[ArgumentsType, InnerReturnType]": ... def retry_decorator( - function: "Callable[..., ReturnType]", - ) -> "Callable[..., ReturnType]": - @wraps(function) - def retrying_function(*args: object, **kwargs: object) -> "ReturnType": - for times_retried in range(count + 1): - try: - return function(*args, **kwargs) - except errors as error: - last_error = error - - if times_retried >= count: - raise MaxRetriesExceeded() from last_error - - time.sleep(wait * backoff**times_retried * (1 + jitter * random())) - - return retrying_function + decorated: "Callable[ArgumentsType, InnerReturnType]", + ) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]] | Callable[ArgumentsType, InnerReturnType]": # noqa: E501 + """ + Decorate either a function or coroutine as appropriate. + """ + if iscoroutinefunction(decorated): + + @wraps(decorated) + async def retrying_coroutine( # type: ignore[return] + *args: "ArgumentsType.args", **kwargs: "ArgumentsType.kwargs" + ) -> "InnerReturnType": + for times_retried in range(count + 1): + try: + return await decorated(*args, **kwargs) # type: ignore[no-any-return] + except errors as error: + last_error = error + + if times_retried >= count: + raise MaxRetriesExceeded() from last_error + + await asyncio.sleep(wait_time(times_retried)) + + return retrying_coroutine + else: + + @wraps(decorated) + def retrying_function( # type: ignore[return] + *args: "ArgumentsType.args", **kwargs: "ArgumentsType.kwargs" + ) -> "InnerReturnType": + for times_retried in range(count + 1): + try: + return decorated(*args, **kwargs) + except errors as error: + last_error = error + + if times_retried >= count: + raise MaxRetriesExceeded() from last_error + + time.sleep(wait_time(times_retried)) + + return retrying_function return retry_decorator diff --git a/pyproject.toml b/pyproject.toml index e535cc50..536abcfd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,9 +22,10 @@ requires = [ [tool.flit.metadata.requires-extra] test = [ - "pytest>=5.2.1", - "requests-mock>=1.7.0-7", - "pytest-dependency==0.5.1" + "pytest==8.3.4", + "pytest-asyncio==0.25.2", + "pytest-dependency==0.6.0", + "requests-mock>=1.7.0-7" ] full = [ "PyMuPDF==1.19.6", "spacy>=3.1.4,<4" diff --git a/requirements.txt b/requirements.txt index 8043b984..a19b51b6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,8 +2,9 @@ indico-client>=5.1.4 python-dateutil==2.8.1 PyMuPDF==1.19.6 pytz==2021.1 -pytest==6.2.2 -pytest-dependency==0.5.1 +pytest==8.3.4 +pytest-asyncio==0.25.2 +pytest-dependency==0.6.0 black==22.3 plotly==5.2.1 tqdm==4.50.0 diff --git a/tests/test_retry.py b/tests/test_retry.py index ba9f3105..f78c10c7 100644 --- a/tests/test_retry.py +++ b/tests/test_retry.py @@ -3,30 +3,71 @@ from indico_toolkit.retry import retry, MaxRetriesExceeded -@retry(Exception) -def no_exceptions(): - return True +def test_no_errors() -> None: + @retry(Exception) + def no_errors() -> bool: + return True + assert no_errors() -def test_retry_decorator_returns() -> None: - assert no_exceptions() is True +def test_raises_errors() -> None: + calls = 0 + + @retry(RuntimeError, count=4, wait=0) + def raises_errors() -> None: + nonlocal calls + calls += 1 + raise RuntimeError() + + with pytest.raises(MaxRetriesExceeded): + raises_errors() + + assert calls == 5 + + +def test_raises_other_errors() -> None: + calls = 0 -calls = 0 + @retry(RuntimeError, count=4, wait=0) + def raises_errors() -> None: + nonlocal calls + calls += 1 + raise ValueError() + with pytest.raises(ValueError): + raises_errors() -@retry(RuntimeError, count=5, wait=0) -def raises_exceptions(): - global calls - calls += 1 - raise RuntimeError() + assert calls == 1 -def test_retry_max_exceeded() -> None: - global calls +@pytest.mark.asyncio +async def test_raises_errors_async() -> None: calls = 0 + @retry(RuntimeError, count=4, wait=0) + async def raises_errors() -> None: + nonlocal calls + calls += 1 + raise RuntimeError() + with pytest.raises(MaxRetriesExceeded): - raises_exceptions() + await raises_errors() + + assert calls == 5 + + +@pytest.mark.asyncio +async def test_raises_other_errors_async() -> None: + calls = 0 + + @retry(RuntimeError, count=4, wait=0) + async def raises_errors() -> None: + nonlocal calls + calls += 1 + raise ValueError() + + with pytest.raises(ValueError): + await raises_errors() - assert calls == 6 + assert calls == 1