1
+ import asyncio
1
2
import time
2
3
from functools import wraps
4
+ from inspect import iscoroutinefunction
3
5
from random import random
4
- from typing import TYPE_CHECKING
6
+ from typing import TYPE_CHECKING , overload
5
7
6
8
if TYPE_CHECKING :
7
- from collections .abc import Callable
8
- from typing import TypeVar
9
+ from collections .abc import Awaitable , Callable
10
+ from typing import ParamSpec , TypeVar
9
11
10
- ReturnType = TypeVar ("ReturnType" )
12
+ ArgumentsType = ParamSpec ("ArgumentsType" )
13
+ OuterReturnType = TypeVar ("OuterReturnType" )
14
+ InnerReturnType = TypeVar ("InnerReturnType" )
11
15
12
16
13
17
class MaxRetriesExceeded (Exception ):
@@ -22,14 +26,14 @@ def retry(
22
26
wait : float = 1 ,
23
27
backoff : float = 4 ,
24
28
jitter : float = 0.5 ,
25
- ) -> "Callable[[Callable[..., ReturnType ]], Callable[..., ReturnType ]]" :
29
+ ) -> "Callable[[Callable[ArgumentsType, OuterReturnType ]], Callable[ArgumentsType, OuterReturnType ]]" : # noqa: E501
26
30
"""
27
- Decorate a function to automatically retry when it raises specific errors,
31
+ Decorate a function or coroutine to retry when it raises specified errors,
28
32
apply exponential backoff and jitter to the wait time,
29
33
and raise `MaxRetriesExceeded` after it retries too many times.
30
34
31
- By default, the decorated method will be retried up to 4 times over the course of
32
- ~2 minutes (waiting 1, 4, 16, and 64 seconds; plus up to 50% jitter)
35
+ By default, the decorated function or coroutine will be retried up to 4 times over
36
+ the course of ~2 minutes (waiting 1, 4, 16, and 64 seconds; plus up to 50% jitter)
33
37
before raising `MaxRetriesExceeded` from the last error.
34
38
35
39
Arguments:
@@ -41,22 +45,61 @@ def retry(
41
45
to the wait time to prevent simultaneous retries.
42
46
"""
43
47
48
+ def wait_time (times_retried : int ) -> float :
49
+ """
50
+ Calculate the sleep time based on number of times retried.
51
+ """
52
+ return wait * backoff ** times_retried * (1 + jitter * random ())
53
+
54
+ @overload
55
+ def retry_decorator (
56
+ decorated : "Callable[ArgumentsType, Awaitable[InnerReturnType]]" ,
57
+ ) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]]" : ...
58
+ @overload
59
+ def retry_decorator (
60
+ decorated : "Callable[ArgumentsType, InnerReturnType]" ,
61
+ ) -> "Callable[ArgumentsType, InnerReturnType]" : ...
44
62
def retry_decorator (
45
- function : "Callable[..., ReturnType]" ,
46
- ) -> "Callable[..., ReturnType]" :
47
- @wraps (function )
48
- def retrying_function (* args : object , ** kwargs : object ) -> "ReturnType" :
49
- for times_retried in range (count + 1 ):
50
- try :
51
- return function (* args , ** kwargs )
52
- except errors as error :
53
- last_error = error
54
-
55
- if times_retried >= count :
56
- raise MaxRetriesExceeded () from last_error
57
-
58
- time .sleep (wait * backoff ** times_retried * (1 + jitter * random ()))
59
-
60
- return retrying_function
63
+ decorated : "Callable[ArgumentsType, InnerReturnType]" ,
64
+ ) -> "Callable[ArgumentsType, Awaitable[InnerReturnType]] | Callable[ArgumentsType, InnerReturnType]" : # noqa: E501
65
+ """
66
+ Decorate either a function or coroutine as appropriate.
67
+ """
68
+ if iscoroutinefunction (decorated ):
69
+
70
+ @wraps (decorated )
71
+ async def retrying_coroutine ( # type: ignore[return]
72
+ * args : "ArgumentsType.args" , ** kwargs : "ArgumentsType.kwargs"
73
+ ) -> "InnerReturnType" :
74
+ for times_retried in range (count + 1 ):
75
+ try :
76
+ return await decorated (* args , ** kwargs ) # type: ignore[no-any-return]
77
+ except errors as error :
78
+ last_error = error
79
+
80
+ if times_retried >= count :
81
+ raise MaxRetriesExceeded () from last_error
82
+
83
+ await asyncio .sleep (wait_time (times_retried ))
84
+
85
+ return retrying_coroutine
86
+ else :
87
+
88
+ @wraps (decorated )
89
+ def retrying_function ( # type: ignore[return]
90
+ * args : "ArgumentsType.args" , ** kwargs : "ArgumentsType.kwargs"
91
+ ) -> "InnerReturnType" :
92
+ for times_retried in range (count + 1 ):
93
+ try :
94
+ return decorated (* args , ** kwargs )
95
+ except errors as error :
96
+ last_error = error
97
+
98
+ if times_retried >= count :
99
+ raise MaxRetriesExceeded () from last_error
100
+
101
+ time .sleep (wait_time (times_retried ))
102
+
103
+ return retrying_function
61
104
62
105
return retry_decorator
0 commit comments