diff --git a/backoff/_decorator.py b/backoff/_decorator.py index 77ed8c2..c281798 100644 --- a/backoff/_decorator.py +++ b/backoff/_decorator.py @@ -18,7 +18,7 @@ _Jitterer, _MaybeCallable, _MaybeLogger, - _MaybeSequence, + _MaybeTuple, _Predicate, _WaitGenerator, ) @@ -120,8 +120,23 @@ def decorate(target): return decorate +def _check_exception_type(exception: _MaybeTuple[Type[Exception]]) -> None: + """ + Raise TypeError if exception is not a valid type, otherwise return None + """ + if isinstance(exception, Type) and issubclass(exception, Exception): + return + if isinstance(exception, tuple) and all(isinstance(e, Type) + and issubclass(e, Exception) for e in exception): + return + raise TypeError( + f"exception '{exception}' of type {type(exception)} is not" + " an Exception type or a tuple of Exception types" + ) + + def on_exception(wait_gen: _WaitGenerator, - exception: _MaybeSequence[Type[Exception]], + exception: _MaybeTuple[Type[Exception]], *, max_tries: Optional[_MaybeCallable[int]] = None, max_time: Optional[_MaybeCallable[float]] = None, @@ -180,6 +195,7 @@ def on_exception(wait_gen: _WaitGenerator, args will first be evaluated and their return values passed. This is useful for runtime configuration. """ + _check_exception_type(exception) def decorate(target): nonlocal logger, on_success, on_backoff, on_giveup diff --git a/backoff/_typing.py b/backoff/_typing.py index 20446d4..cce37d1 100644 --- a/backoff/_typing.py +++ b/backoff/_typing.py @@ -1,7 +1,7 @@ # coding:utf-8 import logging import sys -from typing import (Any, Callable, Coroutine, Dict, Generator, Sequence, Tuple, +from typing import (Any, Callable, Coroutine, Dict, Generator, Tuple, TypeVar, Union) if sys.version_info >= (3, 8): # pragma: no cover @@ -39,6 +39,6 @@ class Details(_Details, total=False): _Jitterer = Callable[[float], float] _MaybeCallable = Union[T, Callable[[], T]] _MaybeLogger = Union[str, logging.Logger, None] -_MaybeSequence = Union[T, Sequence[T]] +_MaybeTuple = Union[T, Tuple[T, ...]] _Predicate = Callable[[T], bool] _WaitGenerator = Callable[..., Generator[float, None, None]] diff --git a/tests/test_typing.py b/tests/test_typing.py index 7f53459..2a58722 100644 --- a/tests/test_typing.py +++ b/tests/test_typing.py @@ -32,3 +32,17 @@ def bar(): ) def baz(): pass + + +# Type Successes +for exception in OSError, tuple([OSError]), (OSError, ValueError): + backoff.on_exception(backoff.expo, exception) + + +# Type Failures +for exception in OSError(), [OSError], (OSError, ValueError()), "hi", (2, 3): + try: + backoff.on_exception(backoff.expo, exception) + raise AssertionError(f"Expected TypeError for {exception}") + except TypeError: + pass