Skip to content

Adding cache level refreshNow functionality #47

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion src/aws_secretsmanager_caching/cache/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# pylint: disable=super-with-arguments

import threading
import time
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from datetime import datetime, timedelta, timezone
Expand All @@ -24,7 +25,8 @@

class SecretCacheObject: # pylint: disable=too-many-instance-attributes
"""Secret cache object that handles the common refresh logic."""

# Jitter max for refresh now
FORCE_REFRESH_JITTER_SLEEP = 5000
__metaclass__ = ABCMeta

def __init__(self, config, client, secret_id):
Expand Down Expand Up @@ -121,6 +123,26 @@ def get_secret_value(self, version_stage=None):
if not value and self._exception:
raise self._exception
return deepcopy(value)

def refresh_secret_now(self):
"""Force a refresh of the cached secret.
:rtype: None
:return: None
"""
self._refresh_needed = True

# Generate a random number to have a sleep jitter to not get stuck in a retry loop
sleep = randint(int(self.FORCE_REFRESH_JITTER_SLEEP / 2), self.FORCE_REFRESH_JITTER_SLEEP + 1)

if self._exception is not None:
current_time_millis = int(datetime.now(timezone.utc).timestamp() * 1000)
exception_sleep = self._next_retry_time - current_time_millis
sleep = max(exception_sleep, sleep)

# Divide by 1000 for millis
time.sleep(sleep / 1000)

self._execute_refresh()

def _get_result(self):
"""Get the stored result using a hook if present"""
Expand Down
8 changes: 8 additions & 0 deletions src/aws_secretsmanager_caching/secret_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,11 @@ def get_secret_binary(self, secret_id, version_stage=None):
if secret is None:
return secret
return secret.get("SecretBinary")

def refresh_secret_now(self, secret_id):
"""Immediately refresh the secret in the cache.

:type secret_id: str
:param secret_id: The secret identifier
"""
self._get_cached_secret(secret_id).refresh_secret_now()
25 changes: 25 additions & 0 deletions test/unit/test_aws_secretsmanager_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,31 @@ def test_get_secret_binary_no_versions(self):
cache = SecretCache(client=self.get_client())
self.assertIsNone(cache.get_secret_binary('test'))

def test_refresh_secret_now(self):
secret = 'mysecret'
response = {}
versions = {
'01234567890123456789012345678901': ['AWSCURRENT']
}
version_response = {'SecretString': secret}
cache = SecretCache(client=self.get_client(response,
versions,
version_response))
secret = cache._get_cached_secret('test')
self.assertIsNotNone(secret)

old_refresh_time = secret._next_refresh_time

secret = cache._get_cached_secret('test')
self.assertTrue(old_refresh_time == secret._next_refresh_time)

cache.refresh_secret_now('test')

secret = cache._get_cached_secret('test')

new_refresh_time = secret._next_refresh_time
self.assertTrue(new_refresh_time > old_refresh_time)

def test_get_secret_string_exception(self):
client = botocore.session.get_session().create_client(
'secretsmanager', region_name='us-west-2')
Expand Down
24 changes: 24 additions & 0 deletions test/unit/test_items.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,30 @@ def test_simple_2(self):
sco._exception = Exception("test")
self.assertRaises(Exception, sco.get_secret_value)

def test_refresh_now(self):
config = SecretCacheConfig()

client_mock = Mock()
client_mock.describe_secret = Mock()
client_mock.describe_secret.return_value = "test"
secret_cache_item = SecretCacheItem(config, client_mock, None)
secret_cache_item._next_refresh_time = datetime.now(timezone.utc) + timedelta(days=30)
secret_cache_item._refresh_needed = False
self.assertFalse(secret_cache_item._is_refresh_needed())

old_refresh_time = secret_cache_item._next_refresh_time
self.assertTrue(old_refresh_time > datetime.now(timezone.utc) + timedelta(days=29))

secret_cache_item.refresh_secret_now()
new_refresh_time = secret_cache_item._next_refresh_time

ttl = config.secret_refresh_interval

# New refresh time will use the ttl and will be less than the old refresh time that was artificially set a month ahead
# The new refresh time will be between now + ttl and now + (ttl / 2) if the secret was immediately refreshed
self.assertTrue(new_refresh_time < old_refresh_time and new_refresh_time < datetime.now(timezone.utc) + timedelta(ttl))


def test_datetime_fix_is_refresh_needed(self):
secret_cached_object = TestSecretCacheObject.TestObject(SecretCacheConfig(), None, None)

Expand Down
Loading