diff --git a/mediahaven/mediahaven.py b/mediahaven/mediahaven.py index 0033a45..7e9888f 100644 --- a/mediahaven/mediahaven.py +++ b/mediahaven/mediahaven.py @@ -18,6 +18,7 @@ OAuth2Grant, RefreshTokenError, ) +from mediahaven.retry import RetryException, retry_exponential, TooManyRetriesException API_PATH = "/mediahaven-rest-api/v2/" @@ -45,12 +46,29 @@ class ContentType(Enum): class MediaHavenClient: - """The MediaHaven client class to communicate with MediaHaven.""" + """The MediaHaven client class to communicate with MediaHaven. - def __init__(self, mh_base_url: str, grant: OAuth2Grant): + Attributes: + - grant: The OAuth2.0 grant. + - mh_base_url: The MediaHaven base URL (https://{host}:{port}). + - mh_api_url: The mh_base_url concatenated with the API path including the version. + - retry_rate_limit: Indicates if when hitting a rate limit, the request should be retried. + """ + + def __init__( + self, mh_base_url: str, grant: OAuth2Grant, retry_rate_limit: bool = False + ): + """Initialize a MediaHaven client. + + Args: + - mh_base_url: The MediaHaven base URL (https://{host}:{port}). + - grant: The OAuth2.0 grant. + - retry_rate_limit: Indicates if, when hitting a rate limit, the request should be retried. + """ self.grant = grant self.mh_base_url = mh_base_url self.mh_api_url = urljoin(self.mh_base_url, API_PATH) + self.retry_rate_limit = retry_rate_limit def _raise_mediahaven_exception_if_needed(self, response): """Raise a MediaHaven exception if the response status >= 400. @@ -69,6 +87,7 @@ def _raise_mediahaven_exception_if_needed(self, response): error_message = {"response": response.text} raise MediaHavenException(error_message, status_code=response.status_code) + @retry_exponential((RetryException), 1, 2, 10) def _execute_request(self, **kwargs): """Execute an authorized request. @@ -86,6 +105,7 @@ def _execute_request(self, **kwargs): NoTokenError: If a token has not yet been requested. RefreshTokenError: If an error occurred when refreshing the token. requests.RequestException: Reraise if a RequestException happen. + TooManyRetriesException: If all the (re)tries have been exhausted. """ # Get a session with a valid auth try: @@ -110,7 +130,11 @@ def _execute_request(self, **kwargs): return response except RequestException: raise + except TooManyRetriesException: + raise else: + if self.retry_rate_limit and response.status_code == 429: + raise RetryException return response def _build_headers(self, accept_format: AcceptFormat = None) -> dict: diff --git a/mediahaven/retry.py b/mediahaven/retry.py new file mode 100644 index 0000000..287ded3 --- /dev/null +++ b/mediahaven/retry.py @@ -0,0 +1,52 @@ +import functools +import time + + +class RetryException(Exception): + """Exception raised when an action needs to be retried + in combination with retry_exponential decorator""" + + pass + + +class TooManyRetriesException(Exception): + """Exception raised when all the tries are exhausted""" + + pass + + +DELAY = 1 +BACKOFF = 2 +NUMBER_OF_TRIES = 10 + + +def retry_exponential( + exceptions: list[BaseException], + delay_seconds: int = DELAY, + backoff: int = BACKOFF, + number_of_tries: int = NUMBER_OF_TRIES, +): + """A decorator allowing for a function to be retried via an exponential backoff + + Raises: + - TooManyRetriesException: When all the (re)tries are exhausted. + """ + + def decorator_retry(func): + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + number_of_tries_over = number_of_tries + delay = delay_seconds + while number_of_tries_over: + number_of_tries_over -= 1 + try: + return func(self, *args, **kwargs) + except exceptions as error: + # Todo: log + time.sleep(delay) + delay *= backoff + raise TooManyRetriesException(f"Too many retries: {number_of_tries}") + + return wrapper + + return decorator_retry diff --git a/tests/test_retry.py b/tests/test_retry.py new file mode 100644 index 0000000..dfea0cf --- /dev/null +++ b/tests/test_retry.py @@ -0,0 +1,62 @@ +from unittest.mock import MagicMock, patch +from mediahaven.retry import ( + retry_exponential, + RetryException, + TooManyRetriesException, + NUMBER_OF_TRIES, + DELAY, + BACKOFF, +) +import pytest + + +@patch("time.sleep") +def test_retry_defaults(time_sleep_mock): + function_mock = MagicMock() + function_mock.side_effect = RetryException + + @retry_exponential((RetryException)) + def func(self): + function_mock() + + # Execute the decorated method + with pytest.raises(TooManyRetriesException): + func(MagicMock()) + + # Test if function was executed multiple times + assert function_mock.call_count == NUMBER_OF_TRIES + + # Test if time.sleep was executed multiple times + assert time_sleep_mock.call_count == NUMBER_OF_TRIES + + # Test exponential backoff + assert time_sleep_mock.call_args_list[0][0][0] == DELAY + for i in range(1, NUMBER_OF_TRIES): + prev_val = time_sleep_mock.call_args_list[i - 1][0][0] + assert time_sleep_mock.call_args_list[i][0][0] == prev_val * BACKOFF + + +@patch("time.sleep") +def test_retry(time_sleep_mock): + function_mock = MagicMock() + function_mock.side_effect = RetryException + + @retry_exponential((RetryException), 2, 4, 5) + def func(self): + function_mock() + + # Execute the decorated method + with pytest.raises(TooManyRetriesException): + func(MagicMock()) + + # Test if function was executed multiple times + assert function_mock.call_count == 5 + + # Test if time.sleep was executed multiple times + assert time_sleep_mock.call_count == 5 + + # Test exponential backoff + assert time_sleep_mock.call_args_list[0][0][0] == 2 + for i in range(1, 5): + prev_val = time_sleep_mock.call_args_list[i - 1][0][0] + assert time_sleep_mock.call_args_list[i][0][0] == prev_val * 4