Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MHPY-23 Retry a 429 response #17

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 26 additions & 2 deletions mediahaven/mediahaven.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
OAuth2Grant,
RefreshTokenError,
)
from mediahaven.retry import RetryException, retry_exponential, TooManyRetriesException

API_PATH = "/mediahaven-rest-api/v2/"

Expand Down Expand Up @@ -45,12 +46,29 @@ class ContentType(Enum):


class MediaHavenClient:
"""The MediaHaven client class to communicate with MediaHaven."""
"""The MediaHaven client class to communicate with MediaHaven.

def __init__(self, mh_base_url: str, grant: OAuth2Grant):
Attributes:
- grant: The OAuth2.0 grant.
- mh_base_url: The MediaHaven base URL (https://{host}:{port}).
- mh_api_url: The mh_base_url concatenated with the API path including the version.
- retry_rate_limit: Indicates if when hitting a rate limit, the request should be retried.
"""

def __init__(
self, mh_base_url: str, grant: OAuth2Grant, retry_rate_limit: bool = False
):
"""Initialize a MediaHaven client.

Args:
- mh_base_url: The MediaHaven base URL (https://{host}:{port}).
- grant: The OAuth2.0 grant.
- retry_rate_limit: Indicates if, when hitting a rate limit, the request should be retried.
"""
self.grant = grant
self.mh_base_url = mh_base_url
self.mh_api_url = urljoin(self.mh_base_url, API_PATH)
self.retry_rate_limit = retry_rate_limit

def _raise_mediahaven_exception_if_needed(self, response):
"""Raise a MediaHaven exception if the response status >= 400.
Expand All @@ -69,6 +87,7 @@ def _raise_mediahaven_exception_if_needed(self, response):
error_message = {"response": response.text}
raise MediaHavenException(error_message, status_code=response.status_code)

@retry_exponential((RetryException), 1, 2, 10)
def _execute_request(self, **kwargs):
"""Execute an authorized request.

Expand All @@ -86,6 +105,7 @@ def _execute_request(self, **kwargs):
NoTokenError: If a token has not yet been requested.
RefreshTokenError: If an error occurred when refreshing the token.
requests.RequestException: Reraise if a RequestException happen.
TooManyRetriesException: If all the (re)tries have been exhausted.
"""
# Get a session with a valid auth
try:
Expand All @@ -110,7 +130,11 @@ def _execute_request(self, **kwargs):
return response
except RequestException:
raise
except TooManyRetriesException:
raise
else:
if self.retry_rate_limit and response.status_code == 429:
raise RetryException
return response

def _build_headers(self, accept_format: AcceptFormat = None) -> dict:
Expand Down
52 changes: 52 additions & 0 deletions mediahaven/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import functools
import time


class RetryException(Exception):
"""Exception raised when an action needs to be retried
in combination with retry_exponential decorator"""

pass


class TooManyRetriesException(Exception):
"""Exception raised when all the tries are exhausted"""

pass


DELAY = 1
BACKOFF = 2
NUMBER_OF_TRIES = 10


def retry_exponential(
exceptions: list[BaseException],
delay_seconds: int = DELAY,
backoff: int = BACKOFF,
number_of_tries: int = NUMBER_OF_TRIES,
):
"""A decorator allowing for a function to be retried via an exponential backoff

Raises:
- TooManyRetriesException: When all the (re)tries are exhausted.
"""

def decorator_retry(func):
@functools.wraps(func)
def wrapper(self, *args, **kwargs):
number_of_tries_over = number_of_tries
delay = delay_seconds
while number_of_tries_over:
number_of_tries_over -= 1
try:
return func(self, *args, **kwargs)
except exceptions as error:
# Todo: log
time.sleep(delay)
delay *= backoff
raise TooManyRetriesException(f"Too many retries: {number_of_tries}")

return wrapper

return decorator_retry
62 changes: 62 additions & 0 deletions tests/test_retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from unittest.mock import MagicMock, patch
from mediahaven.retry import (
retry_exponential,
RetryException,
TooManyRetriesException,
NUMBER_OF_TRIES,
DELAY,
BACKOFF,
)
import pytest


@patch("time.sleep")
def test_retry_defaults(time_sleep_mock):
function_mock = MagicMock()
function_mock.side_effect = RetryException

@retry_exponential((RetryException))
def func(self):
function_mock()

# Execute the decorated method
with pytest.raises(TooManyRetriesException):
func(MagicMock())

# Test if function was executed multiple times
assert function_mock.call_count == NUMBER_OF_TRIES

# Test if time.sleep was executed multiple times
assert time_sleep_mock.call_count == NUMBER_OF_TRIES

# Test exponential backoff
assert time_sleep_mock.call_args_list[0][0][0] == DELAY
for i in range(1, NUMBER_OF_TRIES):
prev_val = time_sleep_mock.call_args_list[i - 1][0][0]
assert time_sleep_mock.call_args_list[i][0][0] == prev_val * BACKOFF


@patch("time.sleep")
def test_retry(time_sleep_mock):
function_mock = MagicMock()
function_mock.side_effect = RetryException

@retry_exponential((RetryException), 2, 4, 5)
def func(self):
function_mock()

# Execute the decorated method
with pytest.raises(TooManyRetriesException):
func(MagicMock())

# Test if function was executed multiple times
assert function_mock.call_count == 5

# Test if time.sleep was executed multiple times
assert time_sleep_mock.call_count == 5

# Test exponential backoff
assert time_sleep_mock.call_args_list[0][0][0] == 2
for i in range(1, 5):
prev_val = time_sleep_mock.call_args_list[i - 1][0][0]
assert time_sleep_mock.call_args_list[i][0][0] == prev_val * 4