From 3a4f44fbc599e653cfc45114e9a14b977dfe18c1 Mon Sep 17 00:00:00 2001 From: Ray Luo Date: Fri, 12 Jul 2024 12:34:56 -0700 Subject: [PATCH] Expose refresh_on (if any) to fresh or cached response --- msal/application.py | 13 +++++--- msal/managed_identity.py | 8 +++-- tests/test_application.py | 25 +++++++++++--- tests/test_mi.py | 69 ++++++++++++++++++++++++++++----------- 4 files changed, 86 insertions(+), 29 deletions(-) diff --git a/msal/application.py b/msal/application.py index ba10cd39..8f30eb1c 100644 --- a/msal/application.py +++ b/msal/application.py @@ -104,11 +104,14 @@ def _clean_up(result): "msalruntime_telemetry": result.get("_msalruntime_telemetry"), "msal_python_telemetry": result.get("_msal_python_telemetry"), }, separators=(",", ":")) - return { + return_value = { k: result[k] for k in result if k != "refresh_in" # MSAL handled refresh_in, customers need not and not k.startswith('_') # Skim internal properties } + if "refresh_in" in result: # To encourage proactive refresh + return_value["refresh_on"] = int(time.time() + result["refresh_in"]) + return return_value return result # It could be None @@ -1507,9 +1510,11 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it( "expires_in": int(expires_in), # OAuth2 specs defines it as int self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE, } - if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging - refresh_reason = msal.telemetry.AT_AGING - break # With a fallback in hand, we break here to go refresh + if "refresh_on" in entry: + access_token_from_cache["refresh_on"] = int(entry["refresh_on"]) + if int(entry["refresh_on"]) < now: # aging + refresh_reason = msal.telemetry.AT_AGING + break # With a fallback in hand, we break here to go refresh self._build_telemetry_context(-1).hit_an_access_token() return access_token_from_cache # It is still good as new else: diff --git a/msal/managed_identity.py b/msal/managed_identity.py index 354fee52..aee57ca3 100644 --- a/msal/managed_identity.py +++ b/msal/managed_identity.py @@ -273,8 +273,10 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the "token_type": entry.get("token_type", "Bearer"), "expires_in": int(expires_in), # OAuth2 specs defines it as int } - if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging - break # With a fallback in hand, we break here to go refresh + if "refresh_on" in entry: + access_token_from_cache["refresh_on"] = int(entry["refresh_on"]) + if int(entry["refresh_on"]) < now: # aging + break # With a fallback in hand, we break here to go refresh return access_token_from_cache # It is still good as new try: result = _obtain_token(self._http_client, self._managed_identity, resource) @@ -290,6 +292,8 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the params={}, data={}, )) + if "refresh_in" in result: + result["refresh_on"] = int(now + result["refresh_in"]) if (result and "error" not in result) or (not access_token_from_cache): return result except: # The exact HTTP exception is transportation-layer dependent diff --git a/tests/test_application.py b/tests/test_application.py index cebc7225..71dc16ea 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,6 +1,7 @@ # Note: Since Aug 2019 we move all e2e tests into test_e2e.py, # so this test_application file contains only unit tests without dependency. import sys +import time from msal.application import * from msal.application import _str2bytes import msal @@ -353,10 +354,18 @@ def populate_cache(self, access_token="at", expires_in=86400, refresh_in=43200): uid=self.uid, utid=self.utid, refresh_token=self.rt), }) + def assertRefreshOn(self, result, refresh_in): + refresh_on = int(time.time() + refresh_in) + self.assertTrue( + refresh_on - 1 < result.get("refresh_on", 0) < refresh_on + 1, + "refresh_on should be set properly") + def test_fresh_token_should_be_returned_from_cache(self): # a.k.a. Return unexpired token that is not above token refresh expiration threshold + refresh_in = 450 access_token = "An access token prepopulated into cache" - self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450) + self.populate_cache( + access_token=access_token, expires_in=900, refresh_in=refresh_in) result = self.app.acquire_token_silent( ['s1'], self.account, post=lambda url, *args, **kwargs: # Utilize the undocumented test feature @@ -365,32 +374,38 @@ def test_fresh_token_should_be_returned_from_cache(self): self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE) self.assertEqual(access_token, result.get("access_token")) self.assertNotIn("refresh_in", result, "Customers need not know refresh_in") + self.assertRefreshOn(result, refresh_in) def test_aging_token_and_available_aad_should_return_new_token(self): # a.k.a. Attempt to refresh unexpired token when AAD available self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1) new_access_token = "new AT" + new_refresh_in = 123 def mock_post(url, headers=None, *args, **kwargs): self.assertEqual("4|84,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) return MinimalResponse(status_code=200, text=json.dumps({ "access_token": new_access_token, - "refresh_in": 123, + "refresh_in": new_refresh_in, })) result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post) self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP) self.assertEqual(new_access_token, result.get("access_token")) self.assertNotIn("refresh_in", result, "Customers need not know refresh_in") + self.assertRefreshOn(result, new_refresh_in) def test_aging_token_and_unavailable_aad_should_return_old_token(self): # a.k.a. Attempt refresh unexpired token when AAD unavailable + refresh_in = -1 old_at = "old AT" - self.populate_cache(access_token=old_at, expires_in=3599, refresh_in=-1) + self.populate_cache( + access_token=old_at, expires_in=3599, refresh_in=refresh_in) def mock_post(url, headers=None, *args, **kwargs): self.assertEqual("4|84,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) return MinimalResponse(status_code=400, text=json.dumps({"error": "foo"})) result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post) self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE) self.assertEqual(old_at, result.get("access_token")) + self.assertRefreshOn(result, refresh_in) def test_expired_token_and_unavailable_aad_should_return_error(self): # a.k.a. Attempt refresh expired token when AAD unavailable @@ -407,16 +422,18 @@ def test_expired_token_and_available_aad_should_return_new_token(self): # a.k.a. Attempt refresh expired token when AAD available self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900) new_access_token = "new AT" + new_refresh_in = 123 def mock_post(url, headers=None, *args, **kwargs): self.assertEqual("4|84,3|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY)) return MinimalResponse(status_code=200, text=json.dumps({ "access_token": new_access_token, - "refresh_in": 123, + "refresh_in": new_refresh_in, })) result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post) self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP) self.assertEqual(new_access_token, result.get("access_token")) self.assertNotIn("refresh_in", result, "Customers need not know refresh_in") + self.assertRefreshOn(result, new_refresh_in) class TestTelemetryMaintainingOfflineState(unittest.TestCase): diff --git a/tests/test_mi.py b/tests/test_mi.py index d6dcc159..f3182c7b 100644 --- a/tests/test_mi.py +++ b/tests/test_mi.py @@ -26,6 +26,7 @@ SERVICE_FABRIC, DEFAULT_TO_VM, ) +from msal.token_cache import is_subdict_of class ManagedIdentityTestCase(unittest.TestCase): @@ -60,7 +61,7 @@ def setUp(self): http_client=requests.Session(), ) - def _test_token_cache(self, app): + def assertCacheStatus(self, app): cache = app._token_cache._cache self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT") at = list(cache["AccessToken"].values())[0] @@ -70,30 +71,55 @@ def _test_token_cache(self, app): "Should have expected client_id") self.assertEqual("managed_identity", at["realm"], "Should have expected realm") - def _test_happy_path(self, app, mocked_http): - result = app.acquire_token_for_client(resource="R") + def _test_happy_path(self, app, mocked_http, expires_in, resource="R"): + result = app.acquire_token_for_client(resource=resource) mocked_http.assert_called() - self.assertEqual({ + call_count = mocked_http.call_count + expected_result = { "access_token": "AT", - "expires_in": 1234, - "resource": "R", "token_type": "Bearer", - }, result, "Should obtain a token response") + } + self.assertTrue( + is_subdict_of(expected_result, result), # We will test refresh_on later + "Should obtain a token response") + self.assertEqual(expires_in, result["expires_in"], "Should have expected expires_in") + if expires_in >= 7200: + expected_refresh_on = int(time.time() + expires_in / 2) + self.assertTrue( + expected_refresh_on - 1 <= result["refresh_on"] <= expected_refresh_on + 1, + "Should have a refresh_on time around the middle of the token's life") self.assertEqual( result["access_token"], - app.acquire_token_for_client(resource="R").get("access_token"), + app.acquire_token_for_client(resource=resource).get("access_token"), "Should hit the same token from cache") - self._test_token_cache(app) + + self.assertCacheStatus(app) + + result = app.acquire_token_for_client(resource=resource) + self.assertEqual( + call_count, mocked_http.call_count, + "No new call to the mocked http should be made for a cache hit") + self.assertTrue( + is_subdict_of(expected_result, result), # We will test refresh_on later + "Should obtain a token response") + self.assertTrue( + expires_in - 5 < result["expires_in"] <= expires_in, + "Should have similar expires_in") + if expires_in >= 7200: + self.assertTrue( + expected_refresh_on - 5 < result["refresh_on"] <= expected_refresh_on, + "Should have a refresh_on time around the middle of the token's life") class VmTestCase(ClientTestCase): def test_happy_path(self): + expires_in = 7890 # We test a bigger than 7200 value here with patch.object(self.app._http_client, "get", return_value=MinimalResponse( status_code=200, - text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}', + text='{"access_token": "AT", "expires_in": "%s", "resource": "R"}' % expires_in, )) as mocked_method: - self._test_happy_path(self.app, mocked_method) + self._test_happy_path(self.app, mocked_method, expires_in) def test_vm_error_should_be_returned_as_is(self): raw_error = '{"raw": "error format is undefined"}' @@ -110,12 +136,13 @@ def test_vm_error_should_be_returned_as_is(self): class AppServiceTestCase(ClientTestCase): def test_happy_path(self): + expires_in = 1234 with patch.object(self.app._http_client, "get", return_value=MinimalResponse( status_code=200, text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % ( - int(time.time()) + 1234), + int(time.time()) + expires_in), )) as mocked_method: - self._test_happy_path(self.app, mocked_method) + self._test_happy_path(self.app, mocked_method, expires_in) def test_app_service_error_should_be_normalized(self): raw_error = '{"statusCode": 500, "message": "error content is undefined"}' @@ -134,12 +161,13 @@ def test_app_service_error_should_be_normalized(self): class MachineLearningTestCase(ClientTestCase): def test_happy_path(self): + expires_in = 1234 with patch.object(self.app._http_client, "get", return_value=MinimalResponse( status_code=200, text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % ( - int(time.time()) + 1234), + int(time.time()) + expires_in), )) as mocked_method: - self._test_happy_path(self.app, mocked_method) + self._test_happy_path(self.app, mocked_method, expires_in) def test_machine_learning_error_should_be_normalized(self): raw_error = '{"error": "placeholder", "message": "placeholder"}' @@ -162,12 +190,14 @@ def test_machine_learning_error_should_be_normalized(self): class ServiceFabricTestCase(ClientTestCase): def _test_happy_path(self, app): + expires_in = 1234 with patch.object(app._http_client, "get", return_value=MinimalResponse( status_code=200, text='{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % ( - int(time.time()) + 1234), + int(time.time()) + expires_in), )) as mocked_method: - super(ServiceFabricTestCase, self)._test_happy_path(app, mocked_method) + super(ServiceFabricTestCase, self)._test_happy_path( + app, mocked_method, expires_in) def test_happy_path(self): self._test_happy_path(self.app) @@ -212,15 +242,16 @@ class ArcTestCase(ClientTestCase): }) def test_happy_path(self, mocked_stat): + expires_in = 1234 with patch.object(self.app._http_client, "get", side_effect=[ self.challenge, MinimalResponse( status_code=200, - text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}', + text='{"access_token": "AT", "expires_in": "%s", "resource": "R"}' % expires_in, ), ]) as mocked_method: try: - super(ArcTestCase, self)._test_happy_path(self.app, mocked_method) + self._test_happy_path(self.app, mocked_method, expires_in) mocked_stat.assert_called_with(os.path.join( _supported_arc_platforms_and_their_prefixes[sys.platform], "foo.key"))