Skip to content

Commit 0cf4fc7

Browse files
committed
Support decorating async coroutines with retry
1 parent 0aa8c9b commit 0cf4fc7

File tree

4 files changed

+130
-44
lines changed

4 files changed

+130
-44
lines changed

indico_toolkit/retry.py

Lines changed: 67 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1+
import asyncio
12
import time
23
from functools import wraps
4+
from inspect import iscoroutinefunction
35
from random import random
4-
from typing import TYPE_CHECKING
6+
from typing import TYPE_CHECKING, overload
57

68
if TYPE_CHECKING:
7-
from collections.abc import Callable
8-
from typing import TypeVar
9+
from collections.abc import Awaitable, Callable
10+
from typing import ParamSpec, TypeVar
911

10-
ReturnType = TypeVar("ReturnType")
12+
ArgumentsType = ParamSpec("ArgumentsType")
13+
OuterReturnType = TypeVar("OuterReturnType")
14+
InnerReturnType = TypeVar("InnerReturnType")
1115

1216

1317
class MaxRetriesExceeded(Exception):
@@ -22,14 +26,14 @@ def retry(
2226
wait: float = 1,
2327
backoff: float = 4,
2428
jitter: float = 0.5,
25-
) -> "Callable[[Callable[..., ReturnType]], Callable[..., ReturnType]]":
29+
) -> "Callable[[Callable[ArgumentsType, OuterReturnType]], Callable[ArgumentsType, OuterReturnType]]": # noqa: E501
2630
"""
27-
Decorate a function to automatically retry when it raises specific errors,
31+
Decorate a function or coroutine to retry when it raises specified errors,
2832
apply exponential backoff and jitter to the wait time,
2933
and raise `MaxRetriesExceeded` after it retries too many times.
3034
31-
By default, the decorated method will be retried up to 4 times over the course of
32-
~2 minutes (waiting 1, 4, 16, and 64 seconds; plus up to 50% jitter)
35+
By default, the decorated function or coroutine will be retried up to 4 times over
36+
the course of ~2 minutes (waiting 1, 4, 16, and 64 seconds; plus up to 50% jitter)
3337
before raising `MaxRetriesExceeded` from the last error.
3438
3539
Arguments:
@@ -41,22 +45,61 @@ def retry(
4145
to the wait time to prevent simultaneous retries.
4246
"""
4347

48+
def wait_time(times_retried: int) -> float:
49+
"""
50+
Calculate the sleep time based on number of times retried.
51+
"""
52+
return wait * backoff**times_retried * (1 + jitter * random())
53+
54+
@overload
55+
def retry_decorator(
56+
decorated: "Callable[ArgumentsType, Awaitable[InnerReturnType]]",
57+
) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]]": ...
58+
@overload
59+
def retry_decorator(
60+
decorated: "Callable[ArgumentsType, InnerReturnType]",
61+
) -> "Callable[ArgumentsType, InnerReturnType]": ...
4462
def retry_decorator(
45-
function: "Callable[..., ReturnType]",
46-
) -> "Callable[..., ReturnType]":
47-
@wraps(function)
48-
def retrying_function(*args: object, **kwargs: object) -> "ReturnType":
49-
for times_retried in range(count + 1):
50-
try:
51-
return function(*args, **kwargs)
52-
except errors as error:
53-
last_error = error
54-
55-
if times_retried >= count:
56-
raise MaxRetriesExceeded() from last_error
57-
58-
time.sleep(wait * backoff**times_retried * (1 + jitter * random()))
59-
60-
return retrying_function
63+
decorated: "Callable[ArgumentsType, InnerReturnType]",
64+
) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]] | Callable[ArgumentsType, InnerReturnType]": # noqa: E501
65+
"""
66+
Decorate either a function or coroutine as appropriate.
67+
"""
68+
if iscoroutinefunction(decorated):
69+
70+
@wraps(decorated)
71+
async def retrying_coroutine( # type: ignore[return]
72+
*args: "ArgumentsType.args", **kwargs: "ArgumentsType.kwargs"
73+
) -> "InnerReturnType":
74+
for times_retried in range(count + 1):
75+
try:
76+
return await decorated(*args, **kwargs) # type: ignore[no-any-return]
77+
except errors as error:
78+
last_error = error
79+
80+
if times_retried >= count:
81+
raise MaxRetriesExceeded() from last_error
82+
83+
await asyncio.sleep(wait_time(times_retried))
84+
85+
return retrying_coroutine
86+
else:
87+
88+
@wraps(decorated)
89+
def retrying_function( # type: ignore[return]
90+
*args: "ArgumentsType.args", **kwargs: "ArgumentsType.kwargs"
91+
) -> "InnerReturnType":
92+
for times_retried in range(count + 1):
93+
try:
94+
return decorated(*args, **kwargs)
95+
except errors as error:
96+
last_error = error
97+
98+
if times_retried >= count:
99+
raise MaxRetriesExceeded() from last_error
100+
101+
time.sleep(wait_time(times_retried))
102+
103+
return retrying_function
61104

62105
return retry_decorator

pyproject.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@ requires = [
2222

2323
[tool.flit.metadata.requires-extra]
2424
test = [
25-
"pytest>=5.2.1",
26-
"requests-mock>=1.7.0-7",
27-
"pytest-dependency==0.5.1"
25+
"pytest==8.3.4",
26+
"pytest-asyncio==0.25.2",
27+
"pytest-dependency==0.6.0",
28+
"requests-mock>=1.7.0-7"
2829
]
2930
full = [
3031
"PyMuPDF==1.19.6", "spacy>=3.1.4,<4"

requirements.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@ indico-client>=5.1.4
22
python-dateutil==2.8.1
33
PyMuPDF==1.19.6
44
pytz==2021.1
5-
pytest==6.2.2
6-
pytest-dependency==0.5.1
5+
pytest==8.3.4
6+
pytest-asyncio==0.25.2
7+
pytest-dependency==0.6.0
78
black==22.3
89
plotly==5.2.1
910
tqdm==4.50.0

tests/test_retry.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,71 @@
33
from indico_toolkit.retry import retry, MaxRetriesExceeded
44

55

6-
@retry(Exception)
7-
def no_exceptions():
8-
return True
6+
def test_no_errors() -> None:
7+
@retry(Exception)
8+
def no_errors() -> bool:
9+
return True
910

11+
assert no_errors()
1012

11-
def test_retry_decorator_returns() -> None:
12-
assert no_exceptions() is True
1313

14+
def test_raises_errors() -> None:
15+
calls = 0
16+
17+
@retry(RuntimeError, count=4, wait=0)
18+
def raises_errors() -> None:
19+
nonlocal calls
20+
calls += 1
21+
raise RuntimeError()
22+
23+
with pytest.raises(MaxRetriesExceeded):
24+
raises_errors()
25+
26+
assert calls == 5
27+
28+
29+
def test_raises_other_errors() -> None:
30+
calls = 0
1431

15-
calls = 0
32+
@retry(RuntimeError, count=4, wait=0)
33+
def raises_errors() -> None:
34+
nonlocal calls
35+
calls += 1
36+
raise ValueError()
1637

38+
with pytest.raises(ValueError):
39+
raises_errors()
1740

18-
@retry(RuntimeError, count=5, wait=0)
19-
def raises_exceptions():
20-
global calls
21-
calls += 1
22-
raise RuntimeError()
41+
assert calls == 1
2342

2443

25-
def test_retry_max_exceeded() -> None:
26-
global calls
44+
@pytest.mark.asyncio
45+
async def test_raises_errors_async() -> None:
2746
calls = 0
2847

48+
@retry(RuntimeError, count=4, wait=0)
49+
async def raises_errors() -> None:
50+
nonlocal calls
51+
calls += 1
52+
raise RuntimeError()
53+
2954
with pytest.raises(MaxRetriesExceeded):
30-
raises_exceptions()
55+
await raises_errors()
56+
57+
assert calls == 5
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_raises_other_errors_async() -> None:
62+
calls = 0
63+
64+
@retry(RuntimeError, count=4, wait=0)
65+
async def raises_errors() -> None:
66+
nonlocal calls
67+
calls += 1
68+
raise ValueError()
69+
70+
with pytest.raises(ValueError):
71+
await raises_errors()
3172

32-
assert calls == 6
73+
assert calls == 1

0 commit comments

Comments
 (0)