From 1a982d6f33a335c5ba94672fbe0c45e5add0a530 Mon Sep 17 00:00:00 2001 From: William Lewis Date: Tue, 30 Apr 2024 20:57:32 +0000 Subject: [PATCH] Adding cache level refreshNow functionality --- src/aws_secretsmanager_caching/cache/items.py | 24 +++++++++++++++++- .../secret_cache.py | 8 ++++++ test/unit/test_aws_secretsmanager_caching.py | 25 +++++++++++++++++++ test/unit/test_items.py | 24 ++++++++++++++++++ 4 files changed, 80 insertions(+), 1 deletion(-) diff --git a/src/aws_secretsmanager_caching/cache/items.py b/src/aws_secretsmanager_caching/cache/items.py index 70d9a3e..b4d5cc4 100644 --- a/src/aws_secretsmanager_caching/cache/items.py +++ b/src/aws_secretsmanager_caching/cache/items.py @@ -14,6 +14,7 @@ # pylint: disable=super-with-arguments import threading +import time from abc import ABCMeta, abstractmethod from copy import deepcopy from datetime import datetime, timedelta, timezone @@ -24,7 +25,8 @@ class SecretCacheObject: # pylint: disable=too-many-instance-attributes """Secret cache object that handles the common refresh logic.""" - + # Jitter max for refresh now + FORCE_REFRESH_JITTER_SLEEP = 5000 __metaclass__ = ABCMeta def __init__(self, config, client, secret_id): @@ -121,6 +123,26 @@ def get_secret_value(self, version_stage=None): if not value and self._exception: raise self._exception return deepcopy(value) + + def refresh_secret_now(self): + """Force a refresh of the cached secret. + :rtype: None + :return: None + """ + self._refresh_needed = True + + # Generate a random number to have a sleep jitter to not get stuck in a retry loop + sleep = randint(int(self.FORCE_REFRESH_JITTER_SLEEP / 2), self.FORCE_REFRESH_JITTER_SLEEP + 1) + + if self._exception is not None: + current_time_millis = int(datetime.now(timezone.utc).timestamp() * 1000) + exception_sleep = self._next_retry_time - current_time_millis + sleep = max(exception_sleep, sleep) + + # Divide by 1000 for millis + time.sleep(sleep / 1000) + + self._execute_refresh() def _get_result(self): """Get the stored result using a hook if present""" diff --git a/src/aws_secretsmanager_caching/secret_cache.py b/src/aws_secretsmanager_caching/secret_cache.py index b90f22c..ca72996 100644 --- a/src/aws_secretsmanager_caching/secret_cache.py +++ b/src/aws_secretsmanager_caching/secret_cache.py @@ -99,3 +99,11 @@ def get_secret_binary(self, secret_id, version_stage=None): if secret is None: return secret return secret.get("SecretBinary") + + def refresh_secret_now(self, secret_id): + """Immediately refresh the secret in the cache. + + :type secret_id: str + :param secret_id: The secret identifier + """ + self._get_cached_secret(secret_id).refresh_secret_now() diff --git a/test/unit/test_aws_secretsmanager_caching.py b/test/unit/test_aws_secretsmanager_caching.py index a111241..f00c560 100644 --- a/test/unit/test_aws_secretsmanager_caching.py +++ b/test/unit/test_aws_secretsmanager_caching.py @@ -169,6 +169,31 @@ def test_get_secret_binary_no_versions(self): cache = SecretCache(client=self.get_client()) self.assertIsNone(cache.get_secret_binary('test')) + def test_refresh_secret_now(self): + secret = 'mysecret' + response = {} + versions = { + '01234567890123456789012345678901': ['AWSCURRENT'] + } + version_response = {'SecretString': secret} + cache = SecretCache(client=self.get_client(response, + versions, + version_response)) + secret = cache._get_cached_secret('test') + self.assertIsNotNone(secret) + + old_refresh_time = secret._next_refresh_time + + secret = cache._get_cached_secret('test') + self.assertTrue(old_refresh_time == secret._next_refresh_time) + + cache.refresh_secret_now('test') + + secret = cache._get_cached_secret('test') + + new_refresh_time = secret._next_refresh_time + self.assertTrue(new_refresh_time > old_refresh_time) + def test_get_secret_string_exception(self): client = botocore.session.get_session().create_client( 'secretsmanager', region_name='us-west-2') diff --git a/test/unit/test_items.py b/test/unit/test_items.py index 8d9a8fb..4f50c3e 100644 --- a/test/unit/test_items.py +++ b/test/unit/test_items.py @@ -50,6 +50,30 @@ def test_simple_2(self): sco._exception = Exception("test") self.assertRaises(Exception, sco.get_secret_value) + def test_refresh_now(self): + config = SecretCacheConfig() + + client_mock = Mock() + client_mock.describe_secret = Mock() + client_mock.describe_secret.return_value = "test" + secret_cache_item = SecretCacheItem(config, client_mock, None) + secret_cache_item._next_refresh_time = datetime.now(timezone.utc) + timedelta(days=30) + secret_cache_item._refresh_needed = False + self.assertFalse(secret_cache_item._is_refresh_needed()) + + old_refresh_time = secret_cache_item._next_refresh_time + self.assertTrue(old_refresh_time > datetime.now(timezone.utc) + timedelta(days=29)) + + secret_cache_item.refresh_secret_now() + new_refresh_time = secret_cache_item._next_refresh_time + + ttl = config.secret_refresh_interval + + # New refresh time will use the ttl and will be less than the old refresh time that was artificially set a month ahead + # The new refresh time will be between now + ttl and now + (ttl / 2) if the secret was immediately refreshed + self.assertTrue(new_refresh_time < old_refresh_time and new_refresh_time < datetime.now(timezone.utc) + timedelta(ttl)) + + def test_datetime_fix_is_refresh_needed(self): secret_cached_object = TestSecretCacheObject.TestObject(SecretCacheConfig(), None, None)