Skip to content

Commit

Permalink
Support decorating async coroutines with retry
Browse files Browse the repository at this point in the history
  • Loading branch information
mawelborn committed Jan 13, 2025
1 parent 0aa8c9b commit 0cf4fc7
Show file tree
Hide file tree
Showing 4 changed files with 130 additions and 44 deletions.
91 changes: 67 additions & 24 deletions indico_toolkit/retry.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import asyncio
import time
from functools import wraps
from inspect import iscoroutinefunction
from random import random
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, overload

if TYPE_CHECKING:
from collections.abc import Callable
from typing import TypeVar
from collections.abc import Awaitable, Callable
from typing import ParamSpec, TypeVar

ReturnType = TypeVar("ReturnType")
ArgumentsType = ParamSpec("ArgumentsType")
OuterReturnType = TypeVar("OuterReturnType")
InnerReturnType = TypeVar("InnerReturnType")


class MaxRetriesExceeded(Exception):
Expand All @@ -22,14 +26,14 @@ def retry(
wait: float = 1,
backoff: float = 4,
jitter: float = 0.5,
) -> "Callable[[Callable[..., ReturnType]], Callable[..., ReturnType]]":
) -> "Callable[[Callable[ArgumentsType, OuterReturnType]], Callable[ArgumentsType, OuterReturnType]]": # noqa: E501
"""
Decorate a function to automatically retry when it raises specific errors,
Decorate a function or coroutine to retry when it raises specified errors,
apply exponential backoff and jitter to the wait time,
and raise `MaxRetriesExceeded` after it retries too many times.
By default, the decorated method will be retried up to 4 times over the course of
~2 minutes (waiting 1, 4, 16, and 64 seconds; plus up to 50% jitter)
By default, the decorated function or coroutine will be retried up to 4 times over
the course of ~2 minutes (waiting 1, 4, 16, and 64 seconds; plus up to 50% jitter)
before raising `MaxRetriesExceeded` from the last error.
Arguments:
Expand All @@ -41,22 +45,61 @@ def retry(
to the wait time to prevent simultaneous retries.
"""

def wait_time(times_retried: int) -> float:
"""
Calculate the sleep time based on number of times retried.
"""
return wait * backoff**times_retried * (1 + jitter * random())

@overload
def retry_decorator(
decorated: "Callable[ArgumentsType, Awaitable[InnerReturnType]]",
) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]]": ...
@overload
def retry_decorator(
decorated: "Callable[ArgumentsType, InnerReturnType]",
) -> "Callable[ArgumentsType, InnerReturnType]": ...
def retry_decorator(
function: "Callable[..., ReturnType]",
) -> "Callable[..., ReturnType]":
@wraps(function)
def retrying_function(*args: object, **kwargs: object) -> "ReturnType":
for times_retried in range(count + 1):
try:
return function(*args, **kwargs)
except errors as error:
last_error = error

if times_retried >= count:
raise MaxRetriesExceeded() from last_error

time.sleep(wait * backoff**times_retried * (1 + jitter * random()))

return retrying_function
decorated: "Callable[ArgumentsType, InnerReturnType]",
) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]] | Callable[ArgumentsType, InnerReturnType]": # noqa: E501
"""
Decorate either a function or coroutine as appropriate.
"""
if iscoroutinefunction(decorated):

@wraps(decorated)
async def retrying_coroutine( # type: ignore[return]
*args: "ArgumentsType.args", **kwargs: "ArgumentsType.kwargs"
) -> "InnerReturnType":
for times_retried in range(count + 1):
try:
return await decorated(*args, **kwargs) # type: ignore[no-any-return]
except errors as error:
last_error = error

if times_retried >= count:
raise MaxRetriesExceeded() from last_error

await asyncio.sleep(wait_time(times_retried))

return retrying_coroutine
else:

@wraps(decorated)
def retrying_function( # type: ignore[return]
*args: "ArgumentsType.args", **kwargs: "ArgumentsType.kwargs"
) -> "InnerReturnType":
for times_retried in range(count + 1):
try:
return decorated(*args, **kwargs)
except errors as error:
last_error = error

if times_retried >= count:
raise MaxRetriesExceeded() from last_error

time.sleep(wait_time(times_retried))

return retrying_function

return retry_decorator
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ requires = [

[tool.flit.metadata.requires-extra]
test = [
"pytest>=5.2.1",
"requests-mock>=1.7.0-7",
"pytest-dependency==0.5.1"
"pytest==8.3.4",
"pytest-asyncio==0.25.2",
"pytest-dependency==0.6.0",
"requests-mock>=1.7.0-7"
]
full = [
"PyMuPDF==1.19.6", "spacy>=3.1.4,<4"
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ indico-client>=5.1.4
python-dateutil==2.8.1
PyMuPDF==1.19.6
pytz==2021.1
pytest==6.2.2
pytest-dependency==0.5.1
pytest==8.3.4
pytest-asyncio==0.25.2
pytest-dependency==0.6.0
black==22.3
plotly==5.2.1
tqdm==4.50.0
Expand Down
71 changes: 56 additions & 15 deletions tests/test_retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,71 @@
from indico_toolkit.retry import retry, MaxRetriesExceeded


@retry(Exception)
def no_exceptions():
return True
def test_no_errors() -> None:
@retry(Exception)
def no_errors() -> bool:
return True

assert no_errors()

def test_retry_decorator_returns() -> None:
assert no_exceptions() is True

def test_raises_errors() -> None:
calls = 0

@retry(RuntimeError, count=4, wait=0)
def raises_errors() -> None:
nonlocal calls
calls += 1
raise RuntimeError()

with pytest.raises(MaxRetriesExceeded):
raises_errors()

assert calls == 5


def test_raises_other_errors() -> None:
calls = 0

calls = 0
@retry(RuntimeError, count=4, wait=0)
def raises_errors() -> None:
nonlocal calls
calls += 1
raise ValueError()

with pytest.raises(ValueError):
raises_errors()

@retry(RuntimeError, count=5, wait=0)
def raises_exceptions():
global calls
calls += 1
raise RuntimeError()
assert calls == 1


def test_retry_max_exceeded() -> None:
global calls
@pytest.mark.asyncio
async def test_raises_errors_async() -> None:
calls = 0

@retry(RuntimeError, count=4, wait=0)
async def raises_errors() -> None:
nonlocal calls
calls += 1
raise RuntimeError()

with pytest.raises(MaxRetriesExceeded):
raises_exceptions()
await raises_errors()

assert calls == 5


@pytest.mark.asyncio
async def test_raises_other_errors_async() -> None:
calls = 0

@retry(RuntimeError, count=4, wait=0)
async def raises_errors() -> None:
nonlocal calls
calls += 1
raise ValueError()

with pytest.raises(ValueError):
await raises_errors()

assert calls == 6
assert calls == 1

0 comments on commit 0cf4fc7

Please sign in to comment.