From dea01ccdaeb85a053021cfc00dee2bae55081d56 Mon Sep 17 00:00:00 2001 From: Evan Sims Date: Sun, 31 Mar 2024 23:44:24 -0500 Subject: [PATCH] feat: retry 5xx auth network requests --- openfga_sdk/api/open_fga_api.py | 1 - openfga_sdk/oauth2.py | 85 ++++++++++++++---- openfga_sdk/sync/oauth2.py | 85 ++++++++++++++---- test-requirements.txt | 2 +- test/test_credentials.py | 8 +- test/test_oauth2.py | 148 +++++++++++++++++++++++++++++++- test/test_oauth2_sync.py | 148 +++++++++++++++++++++++++++++++- test/test_open_fga_api.py | 4 +- test/test_open_fga_api_sync.py | 5 +- 9 files changed, 439 insertions(+), 47 deletions(-) diff --git a/openfga_sdk/api/open_fga_api.py b/openfga_sdk/api/open_fga_api.py index f1af609..2b88230 100644 --- a/openfga_sdk/api/open_fga_api.py +++ b/openfga_sdk/api/open_fga_api.py @@ -10,7 +10,6 @@ NOTE: This file was auto generated by OpenAPI Generator (https://openapi-generator.tech). DO NOT EDIT. """ - from openfga_sdk.api_client import ApiClient from openfga_sdk.exceptions import ApiValueError, FgaValidationException from openfga_sdk.oauth2 import OAuth2Client diff --git a/openfga_sdk/oauth2.py b/openfga_sdk/oauth2.py index 8609aab..4a2c87a 100644 --- a/openfga_sdk/oauth2.py +++ b/openfga_sdk/oauth2.py @@ -10,22 +10,47 @@ NOTE: This file was auto generated by OpenAPI Generator (https://openapi-generator.tech). DO NOT EDIT. """ +import asyncio import json +import math +import random +import sys from datetime import datetime, timedelta import urllib3 +from openfga_sdk.configuration import Configuration from openfga_sdk.credentials import Credentials from openfga_sdk.exceptions import AuthenticationError +def jitter(loop_count, min_wait_in_ms): + """ + Generate a random jitter value for exponential backoff + """ + minimum = math.ceil(2**loop_count * min_wait_in_ms) + maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms) + jitter = random.randrange(minimum, maximum) / 1000 + + # If running in pytest, set jitter to 0 to speed up tests + if "pytest" in sys.modules: + jitter = 0 + + return jitter + + class OAuth2Client: - def __init__(self, credentials: Credentials): + def __init__(self, credentials: Credentials, configuration=None): self._credentials = credentials self._access_token = None self._access_expiry_time = None + if configuration is None: + configuration = Configuration.get_default_copy() + + self.configuration = configuration + def _token_valid(self): """ Return whether token is valid @@ -41,13 +66,16 @@ async def _obtain_token(self, client): Perform OAuth2 and obtain token """ configuration = self._credentials.configuration + token_url = f"https://{configuration.api_issuer}/oauth/token" + post_params = { "client_id": configuration.client_id, "client_secret": configuration.client_secret, "audience": configuration.api_audience, "grant_type": "client_credentials", } + headers = urllib3.response.HTTPHeaderDict( { "Accept": "application/json", @@ -55,23 +83,48 @@ async def _obtain_token(self, client): "User-Agent": "openfga-sdk (python) 0.4.1", } ) - raw_response = await client.POST( - token_url, headers=headers, post_params=post_params + + max_retry = ( + self.configuration.retry_params.max_retry + if ( + self.configuration.retry_params is not None + and self.configuration.retry_params.max_retry is not None + ) + else 0 ) - if 200 <= raw_response.status <= 299: - try: - api_response = json.loads(raw_response.data) - except: - raise AuthenticationError(http_resp=raw_response) - if not api_response.get("expires_in") or not api_response.get( - "access_token" - ): - raise AuthenticationError(http_resp=raw_response) - self._access_expiry_time = datetime.now() + timedelta( - seconds=int(api_response.get("expires_in")) + + min_wait_in_ms = ( + self.configuration.retry_params.min_wait_in_ms + if ( + self.configuration.retry_params is not None + and self.configuration.retry_params.min_wait_in_ms is not None + ) + else 0 + ) + + for attempt in range(max_retry + 1): + raw_response = await client.POST( + token_url, headers=headers, post_params=post_params ) - self._access_token = api_response.get("access_token") - else: + + if 500 <= raw_response.status <= 599 or raw_response.status == 429: + if attempt < max_retry and raw_response.status != 501: + await asyncio.sleep(jitter(attempt, min_wait_in_ms)) + continue + + if 200 <= raw_response.status <= 299: + try: + api_response = json.loads(raw_response.data) + except: + raise AuthenticationError(http_resp=raw_response) + + if api_response.get("expires_in") and api_response.get("access_token"): + self._access_expiry_time = datetime.now() + timedelta( + seconds=int(api_response.get("expires_in")) + ) + self._access_token = api_response.get("access_token") + break + raise AuthenticationError(http_resp=raw_response) async def get_authentication_header(self, client): diff --git a/openfga_sdk/sync/oauth2.py b/openfga_sdk/sync/oauth2.py index 4abbae0..523fa16 100644 --- a/openfga_sdk/sync/oauth2.py +++ b/openfga_sdk/sync/oauth2.py @@ -11,21 +11,46 @@ """ import json +import math +import random +import sys +import time from datetime import datetime, timedelta import urllib3 +from openfga_sdk.configuration import Configuration from openfga_sdk.credentials import Credentials from openfga_sdk.exceptions import AuthenticationError +def jitter(loop_count, min_wait_in_ms): + """ + Generate a random jitter value for exponential backoff + """ + minimum = math.ceil(2**loop_count * min_wait_in_ms) + maximum = math.ceil(2 ** (loop_count + 1) * min_wait_in_ms) + jitter = random.randrange(minimum, maximum) / 1000 + + # If running in pytest, set jitter to 0 to speed up tests + if "pytest" in sys.modules: + jitter = 0 + + return jitter + + class OAuth2Client: - def __init__(self, credentials: Credentials): + def __init__(self, credentials: Credentials, configuration=None): self._credentials = credentials self._access_token = None self._access_expiry_time = None + if configuration is None: + configuration = Configuration.get_default_copy() + + self.configuration = configuration + def _token_valid(self): """ Return whether token is valid @@ -41,13 +66,16 @@ def _obtain_token(self, client): Perform OAuth2 and obtain token """ configuration = self._credentials.configuration + token_url = f"https://{configuration.api_issuer}/oauth/token" + post_params = { "client_id": configuration.client_id, "client_secret": configuration.client_secret, "audience": configuration.api_audience, "grant_type": "client_credentials", } + headers = urllib3.response.HTTPHeaderDict( { "Accept": "application/json", @@ -55,21 +83,48 @@ def _obtain_token(self, client): "User-Agent": "openfga-sdk (python) 0.4.1", } ) - raw_response = client.POST(token_url, headers=headers, post_params=post_params) - if 200 <= raw_response.status <= 299: - try: - api_response = json.loads(raw_response.data) - except: - raise AuthenticationError(http_resp=raw_response) - if not api_response.get("expires_in") or not api_response.get( - "access_token" - ): - raise AuthenticationError(http_resp=raw_response) - self._access_expiry_time = datetime.now() + timedelta( - seconds=int(api_response.get("expires_in")) + + max_retry = ( + self.configuration.retry_params.max_retry + if ( + self.configuration.retry_params is not None + and self.configuration.retry_params.max_retry is not None + ) + else 0 + ) + + min_wait_in_ms = ( + self.configuration.retry_params.min_wait_in_ms + if ( + self.configuration.retry_params is not None + and self.configuration.retry_params.min_wait_in_ms is not None + ) + else 0 + ) + + for attempt in range(max_retry + 1): + raw_response = client.POST( + token_url, headers=headers, post_params=post_params ) - self._access_token = api_response.get("access_token") - else: + + if 500 <= raw_response.status <= 599 or raw_response.status == 429: + if attempt < max_retry and raw_response.status != 501: + time.sleep(jitter(attempt, min_wait_in_ms)) + continue + + if 200 <= raw_response.status <= 299: + try: + api_response = json.loads(raw_response.data) + except: + raise AuthenticationError(http_resp=raw_response) + + if api_response.get("expires_in") and api_response.get("access_token"): + self._access_expiry_time = datetime.now() + timedelta( + seconds=int(api_response.get("expires_in")) + ) + self._access_token = api_response.get("access_token") + break + raise AuthenticationError(http_resp=raw_response) def get_authentication_header(self, client): diff --git a/test-requirements.txt b/test-requirements.txt index b3567ca..37be0b6 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -2,5 +2,5 @@ mock >= 5.1.0, < 6 flake8 >= 7.0.0, < 8 -pytest-cov >= 4.1.0, < 5 +pytest-cov >= 5, < 6 griffe >= 0.41.2, < 1 diff --git a/test/test_credentials.py b/test/test_credentials.py index 933c170..57467fd 100644 --- a/test/test_credentials.py +++ b/test/test_credentials.py @@ -98,7 +98,7 @@ def test_configuration_client_credentials(self): configuration=CredentialConfiguration( client_id="myclientid", client_secret="mysecret", - api_issuer="www.testme.com", + api_issuer="issuer.fga.example", api_audience="myaudience", ), ) @@ -121,7 +121,7 @@ def test_configuration_client_credentials_missing_client_id(self): method="client_credentials", configuration=CredentialConfiguration( client_secret="mysecret", - api_issuer="www.testme.com", + api_issuer="issuer.fga.example", api_audience="myaudience", ), ) @@ -136,7 +136,7 @@ def test_configuration_client_credentials_missing_client_secret(self): method="client_credentials", configuration=CredentialConfiguration( client_id="myclientid", - api_issuer="www.testme.com", + api_issuer="issuer.fga.example", api_audience="myaudience", ), ) @@ -167,7 +167,7 @@ def test_configuration_client_credentials_missing_api_audience(self): configuration=CredentialConfiguration( client_id="myclientid", client_secret="mysecret", - api_issuer="www.testme.com", + api_issuer="issuer.fga.example", ), ) with self.assertRaises(openfga_sdk.ApiValueError): diff --git a/test/test_oauth2.py b/test/test_oauth2.py index e9f2ce0..7795dd5 100644 --- a/test/test_oauth2.py +++ b/test/test_oauth2.py @@ -67,7 +67,7 @@ async def test_get_authentication_obtain_client_credentials(self, mock_request): configuration=CredentialConfiguration( client_id="myclientid", client_secret="mysecret", - api_issuer="www.testme.com", + api_issuer="issuer.fga.example", api_audience="myaudience", ), ) @@ -89,7 +89,7 @@ async def test_get_authentication_obtain_client_credentials(self, mock_request): ) mock_request.assert_called_once_with( "POST", - "https://www.testme.com/oauth/token", + "https://issuer.fga.example/oauth/token", headers=expected_header, query_params=None, body=None, @@ -123,7 +123,7 @@ async def test_get_authentication_obtain_client_credentials_failed( configuration=CredentialConfiguration( client_id="myclientid", client_secret="mysecret", - api_issuer="www.testme.com", + api_issuer="issuer.fga.example", api_audience="myaudience", ), ) @@ -132,3 +132,145 @@ async def test_get_authentication_obtain_client_credentials_failed( with self.assertRaises(AuthenticationError): await client.get_authentication_header(rest_client) await rest_client.close() + + @patch.object(rest.RESTClientObject, "request") + async def test_get_authentication_obtain_with_expired_client_credentials_failed( + self, mock_request + ): + """ + Expired token should trigger a new token request + """ + + response_body = """ +{ + "reason": "Unauthorized" +} + """ + mock_request.return_value = mock_response(response_body, 403) + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + rest_client = rest.RESTClientObject(Configuration()) + client = OAuth2Client(credentials) + + client._access_token = "XYZ123" + client._access_expiry_time = datetime.now() - timedelta(seconds=240) + + with self.assertRaises(AuthenticationError): + await client.get_authentication_header(rest_client) + await rest_client.close() + + @patch.object(rest.RESTClientObject, "request") + async def test_get_authentication_unexpected_response_fails(self, mock_request): + """ + Receiving an unexpected response from the server should raise an exception + """ + + response_body = """ +This is not a JSON response + """ + mock_request.return_value = mock_response(response_body, 200) + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + rest_client = rest.RESTClientObject(Configuration()) + client = OAuth2Client(credentials) + + with self.assertRaises(AuthenticationError): + await client.get_authentication_header(rest_client) + await rest_client.close() + + @patch.object(rest.RESTClientObject, "request") + async def test_get_authentication_erroneous_response_fails(self, mock_request): + """ + Receiving an erroneous response from the server that's missing properties should raise an exception + """ + + response_body = """ +{ + "access_token": "AABBCCDD" +} + """ + mock_request.return_value = mock_response(response_body, 200) + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + rest_client = rest.RESTClientObject(Configuration()) + client = OAuth2Client(credentials) + + with self.assertRaises(AuthenticationError): + await client.get_authentication_header(rest_client) + await rest_client.close() + + @patch.object(rest.RESTClientObject, "request") + async def test_get_authentication_retries_5xx_responses(self, mock_request): + """ + Receiving a 5xx response from the server should be retried + """ + + error_response_body = """ +{ + "code": "rate_limit_exceeded", + "message": "Rate Limit exceeded" +} + """ + + response_body = """ +{ + "expires_in": 120, + "access_token": "AABBCCDD" +} + """ + + mock_request.side_effect = [ + mock_response(error_response_body, 429), + mock_response(error_response_body, 429), + mock_response(error_response_body, 429), + mock_response(response_body, 200), + ] + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + + configuration = Configuration() + configuration.retry_params.max_retry = 5 + configuration.retry_params.retry_interval = 0 + + rest_client = rest.RESTClientObject(configuration) + client = OAuth2Client(credentials, configuration) + + auth_header = await client.get_authentication_header(rest_client) + + mock_request.assert_called() + self.assertEqual(mock_request.call_count, 4) # 3 retries, 1 success + self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) + + await rest_client.close() diff --git a/test/test_oauth2_sync.py b/test/test_oauth2_sync.py index 5ebcc18..df87ed5 100644 --- a/test/test_oauth2_sync.py +++ b/test/test_oauth2_sync.py @@ -67,7 +67,7 @@ def test_get_authentication_obtain_client_credentials(self, mock_request): configuration=CredentialConfiguration( client_id="myclientid", client_secret="mysecret", - api_issuer="www.testme.com", + api_issuer="issuer.fga.example", api_audience="myaudience", ), ) @@ -89,7 +89,7 @@ def test_get_authentication_obtain_client_credentials(self, mock_request): ) mock_request.assert_called_once_with( "POST", - "https://www.testme.com/oauth/token", + "https://issuer.fga.example/oauth/token", headers=expected_header, query_params=None, body=None, @@ -121,7 +121,7 @@ def test_get_authentication_obtain_client_credentials_failed(self, mock_request) configuration=CredentialConfiguration( client_id="myclientid", client_secret="mysecret", - api_issuer="www.testme.com", + api_issuer="issuer.fga.example", api_audience="myaudience", ), ) @@ -130,3 +130,145 @@ def test_get_authentication_obtain_client_credentials_failed(self, mock_request) with self.assertRaises(AuthenticationError): client.get_authentication_header(rest_client) rest_client.close() + + @patch.object(rest.RESTClientObject, "request") + async def test_get_authentication_obtain_with_expired_client_credentials_failed( + self, mock_request + ): + """ + Expired token should trigger a new token request + """ + + response_body = """ +{ + "reason": "Unauthorized" +} + """ + mock_request.return_value = mock_response(response_body, 403) + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + rest_client = rest.RESTClientObject(Configuration()) + client = OAuth2Client(credentials) + + client._access_token = "XYZ123" + client._access_expiry_time = datetime.now() - timedelta(seconds=240) + + with self.assertRaises(AuthenticationError): + client.get_authentication_header(rest_client) + rest_client.close() + + @patch.object(rest.RESTClientObject, "request") + async def test_get_authentication_unexpected_response_fails(self, mock_request): + """ + Receiving an unexpected response from the server should raise an exception + """ + + response_body = """ +This is not a JSON response + """ + mock_request.return_value = mock_response(response_body, 200) + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + rest_client = rest.RESTClientObject(Configuration()) + client = OAuth2Client(credentials) + + with self.assertRaises(AuthenticationError): + client.get_authentication_header(rest_client) + rest_client.close() + + @patch.object(rest.RESTClientObject, "request") + async def test_get_authentication_erroneous_response_fails(self, mock_request): + """ + Receiving an erroneous response from the server that's missing properties should raise an exception + """ + + response_body = """ +{ + "access_token": "AABBCCDD" +} + """ + mock_request.return_value = mock_response(response_body, 200) + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + rest_client = rest.RESTClientObject(Configuration()) + client = OAuth2Client(credentials) + + with self.assertRaises(AuthenticationError): + client.get_authentication_header(rest_client) + rest_client.close() + + @patch.object(rest.RESTClientObject, "request") + async def test_get_authentication_retries_5xx_responses(self, mock_request): + """ + Receiving a 5xx response from the server should be retried + """ + + error_response_body = """ +{ + "code": "rate_limit_exceeded", + "message": "Rate Limit exceeded" +} + """ + + response_body = """ +{ + "expires_in": 120, + "access_token": "AABBCCDD" +} + """ + + mock_request.side_effect = [ + mock_response(error_response_body, 429), + mock_response(error_response_body, 429), + mock_response(error_response_body, 429), + mock_response(response_body, 200), + ] + + credentials = Credentials( + method="client_credentials", + configuration=CredentialConfiguration( + client_id="myclientid", + client_secret="mysecret", + api_issuer="issuer.fga.example", + api_audience="myaudience", + ), + ) + + configuration = Configuration() + configuration.retry_params.max_retry = 5 + configuration.retry_params.retry_interval = 0 + + rest_client = rest.RESTClientObject(configuration) + client = OAuth2Client(credentials, configuration) + + auth_header = client.get_authentication_header(rest_client) + + mock_request.assert_called() + self.assertEqual(mock_request.call_count, 4) # 3 retries, 1 success + self.assertEqual(auth_header, {"Authorization": "Bearer AABBCCDD"}) + + rest_client.close() diff --git a/test/test_open_fga_api.py b/test/test_open_fga_api.py index 0717faf..3e88433 100644 --- a/test/test_open_fga_api.py +++ b/test/test_open_fga_api.py @@ -1250,10 +1250,10 @@ async def test_500_error(self, mock_request): http_resp=http_mock_response(response_body, 500) ) - retry = openfga_sdk.configuration.RetryParams(0, 10) configuration = self.configuration configuration.store_id = store_id - configuration.retry_params = retry + configuration.retry_params.max_retry = 0 + async with openfga_sdk.ApiClient(configuration) as api_client: api_instance = open_fga_api.OpenFgaApi(api_client) body = CheckRequest( diff --git a/test/test_open_fga_api_sync.py b/test/test_open_fga_api_sync.py index 5ed38a9..629f4dd 100644 --- a/test/test_open_fga_api_sync.py +++ b/test/test_open_fga_api_sync.py @@ -1247,10 +1247,9 @@ async def test_500_error(self, mock_request): http_resp=http_mock_response(response_body, 500) ) - retry = openfga_sdk.configuration.RetryParams(0, 10) configuration = self.configuration configuration.store_id = store_id - configuration.retry_params = retry + configuration.retry_params.max_retry = 0 with ApiClient(configuration) as api_client: api_instance = open_fga_api.OpenFgaApi(api_client) @@ -1276,6 +1275,8 @@ async def test_500_error(self, mock_request): api_exception.exception.parsed_exception.message, "Internal Server Error", ) + mock_request.assert_called() + self.assertEqual(mock_request.call_count, 1) @patch.object(rest.RESTClientObject, "request") async def test_500_error(self, mock_request):