Skip to content

Commit

Permalink
Expose refresh_on (if any) to fresh or cached response
Browse files Browse the repository at this point in the history
  • Loading branch information
rayluo committed Jul 17, 2024
1 parent 57dce47 commit 3a4f44f
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 29 deletions.
13 changes: 9 additions & 4 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,14 @@ def _clean_up(result):
"msalruntime_telemetry": result.get("_msalruntime_telemetry"),
"msal_python_telemetry": result.get("_msal_python_telemetry"),
}, separators=(",", ":"))
return {
return_value = {
k: result[k] for k in result
if k != "refresh_in" # MSAL handled refresh_in, customers need not
and not k.startswith('_') # Skim internal properties
}
if "refresh_in" in result: # To encourage proactive refresh
return_value["refresh_on"] = int(time.time() + result["refresh_in"])
return return_value
return result # It could be None


Expand Down Expand Up @@ -1507,9 +1510,11 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
"expires_in": int(expires_in), # OAuth2 specs defines it as int
self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE,
}
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
refresh_reason = msal.telemetry.AT_AGING
break # With a fallback in hand, we break here to go refresh
if "refresh_on" in entry:
access_token_from_cache["refresh_on"] = int(entry["refresh_on"])
if int(entry["refresh_on"]) < now: # aging
refresh_reason = msal.telemetry.AT_AGING
break # With a fallback in hand, we break here to go refresh
self._build_telemetry_context(-1).hit_an_access_token()
return access_token_from_cache # It is still good as new
else:
Expand Down
8 changes: 6 additions & 2 deletions msal/managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,10 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the
"token_type": entry.get("token_type", "Bearer"),
"expires_in": int(expires_in), # OAuth2 specs defines it as int
}
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
break # With a fallback in hand, we break here to go refresh
if "refresh_on" in entry:
access_token_from_cache["refresh_on"] = int(entry["refresh_on"])
if int(entry["refresh_on"]) < now: # aging
break # With a fallback in hand, we break here to go refresh
return access_token_from_cache # It is still good as new
try:
result = _obtain_token(self._http_client, self._managed_identity, resource)
Expand All @@ -290,6 +292,8 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the
params={},
data={},
))
if "refresh_in" in result:
result["refresh_on"] = int(now + result["refresh_in"])
if (result and "error" not in result) or (not access_token_from_cache):
return result
except: # The exact HTTP exception is transportation-layer dependent
Expand Down
25 changes: 21 additions & 4 deletions tests/test_application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Note: Since Aug 2019 we move all e2e tests into test_e2e.py,
# so this test_application file contains only unit tests without dependency.
import sys
import time
from msal.application import *
from msal.application import _str2bytes
import msal
Expand Down Expand Up @@ -353,10 +354,18 @@ def populate_cache(self, access_token="at", expires_in=86400, refresh_in=43200):
uid=self.uid, utid=self.utid, refresh_token=self.rt),
})

def assertRefreshOn(self, result, refresh_in):
refresh_on = int(time.time() + refresh_in)
self.assertTrue(
refresh_on - 1 < result.get("refresh_on", 0) < refresh_on + 1,
"refresh_on should be set properly")

def test_fresh_token_should_be_returned_from_cache(self):
# a.k.a. Return unexpired token that is not above token refresh expiration threshold
refresh_in = 450
access_token = "An access token prepopulated into cache"
self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450)
self.populate_cache(
access_token=access_token, expires_in=900, refresh_in=refresh_in)
result = self.app.acquire_token_silent(
['s1'], self.account,
post=lambda url, *args, **kwargs: # Utilize the undocumented test feature
Expand All @@ -365,32 +374,38 @@ def test_fresh_token_should_be_returned_from_cache(self):
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
self.assertEqual(access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
self.assertRefreshOn(result, refresh_in)

def test_aging_token_and_available_aad_should_return_new_token(self):
# a.k.a. Attempt to refresh unexpired token when AAD available
self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1)
new_access_token = "new AT"
new_refresh_in = 123
def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|84,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=200, text=json.dumps({
"access_token": new_access_token,
"refresh_in": 123,
"refresh_in": new_refresh_in,
}))
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(new_access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
self.assertRefreshOn(result, new_refresh_in)

def test_aging_token_and_unavailable_aad_should_return_old_token(self):
# a.k.a. Attempt refresh unexpired token when AAD unavailable
refresh_in = -1
old_at = "old AT"
self.populate_cache(access_token=old_at, expires_in=3599, refresh_in=-1)
self.populate_cache(
access_token=old_at, expires_in=3599, refresh_in=refresh_in)
def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|84,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=400, text=json.dumps({"error": "foo"}))
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
self.assertEqual(old_at, result.get("access_token"))
self.assertRefreshOn(result, refresh_in)

def test_expired_token_and_unavailable_aad_should_return_error(self):
# a.k.a. Attempt refresh expired token when AAD unavailable
Expand All @@ -407,16 +422,18 @@ def test_expired_token_and_available_aad_should_return_new_token(self):
# a.k.a. Attempt refresh expired token when AAD available
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
new_access_token = "new AT"
new_refresh_in = 123
def mock_post(url, headers=None, *args, **kwargs):
self.assertEqual("4|84,3|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
return MinimalResponse(status_code=200, text=json.dumps({
"access_token": new_access_token,
"refresh_in": 123,
"refresh_in": new_refresh_in,
}))
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
self.assertEqual(new_access_token, result.get("access_token"))
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
self.assertRefreshOn(result, new_refresh_in)


class TestTelemetryMaintainingOfflineState(unittest.TestCase):
Expand Down
69 changes: 50 additions & 19 deletions tests/test_mi.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
SERVICE_FABRIC,
DEFAULT_TO_VM,
)
from msal.token_cache import is_subdict_of


class ManagedIdentityTestCase(unittest.TestCase):
Expand Down Expand Up @@ -60,7 +61,7 @@ def setUp(self):
http_client=requests.Session(),
)

def _test_token_cache(self, app):
def assertCacheStatus(self, app):
cache = app._token_cache._cache
self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT")
at = list(cache["AccessToken"].values())[0]
Expand All @@ -70,30 +71,55 @@ def _test_token_cache(self, app):
"Should have expected client_id")
self.assertEqual("managed_identity", at["realm"], "Should have expected realm")

def _test_happy_path(self, app, mocked_http):
result = app.acquire_token_for_client(resource="R")
def _test_happy_path(self, app, mocked_http, expires_in, resource="R"):
result = app.acquire_token_for_client(resource=resource)
mocked_http.assert_called()
self.assertEqual({
call_count = mocked_http.call_count
expected_result = {
"access_token": "AT",
"expires_in": 1234,
"resource": "R",
"token_type": "Bearer",
}, result, "Should obtain a token response")
}
self.assertTrue(
is_subdict_of(expected_result, result), # We will test refresh_on later
"Should obtain a token response")
self.assertEqual(expires_in, result["expires_in"], "Should have expected expires_in")
if expires_in >= 7200:
expected_refresh_on = int(time.time() + expires_in / 2)
self.assertTrue(
expected_refresh_on - 1 <= result["refresh_on"] <= expected_refresh_on + 1,
"Should have a refresh_on time around the middle of the token's life")
self.assertEqual(
result["access_token"],
app.acquire_token_for_client(resource="R").get("access_token"),
app.acquire_token_for_client(resource=resource).get("access_token"),
"Should hit the same token from cache")
self._test_token_cache(app)

self.assertCacheStatus(app)

result = app.acquire_token_for_client(resource=resource)
self.assertEqual(
call_count, mocked_http.call_count,
"No new call to the mocked http should be made for a cache hit")
self.assertTrue(
is_subdict_of(expected_result, result), # We will test refresh_on later
"Should obtain a token response")
self.assertTrue(
expires_in - 5 < result["expires_in"] <= expires_in,
"Should have similar expires_in")
if expires_in >= 7200:
self.assertTrue(
expected_refresh_on - 5 < result["refresh_on"] <= expected_refresh_on,
"Should have a refresh_on time around the middle of the token's life")


class VmTestCase(ClientTestCase):

def test_happy_path(self):
expires_in = 7890 # We test a bigger than 7200 value here
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
text='{"access_token": "AT", "expires_in": "%s", "resource": "R"}' % expires_in,
)) as mocked_method:
self._test_happy_path(self.app, mocked_method)
self._test_happy_path(self.app, mocked_method, expires_in)

def test_vm_error_should_be_returned_as_is(self):
raw_error = '{"raw": "error format is undefined"}'
Expand All @@ -110,12 +136,13 @@ def test_vm_error_should_be_returned_as_is(self):
class AppServiceTestCase(ClientTestCase):

def test_happy_path(self):
expires_in = 1234
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
int(time.time()) + 1234),
int(time.time()) + expires_in),
)) as mocked_method:
self._test_happy_path(self.app, mocked_method)
self._test_happy_path(self.app, mocked_method, expires_in)

def test_app_service_error_should_be_normalized(self):
raw_error = '{"statusCode": 500, "message": "error content is undefined"}'
Expand All @@ -134,12 +161,13 @@ def test_app_service_error_should_be_normalized(self):
class MachineLearningTestCase(ClientTestCase):

def test_happy_path(self):
expires_in = 1234
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
int(time.time()) + 1234),
int(time.time()) + expires_in),
)) as mocked_method:
self._test_happy_path(self.app, mocked_method)
self._test_happy_path(self.app, mocked_method, expires_in)

def test_machine_learning_error_should_be_normalized(self):
raw_error = '{"error": "placeholder", "message": "placeholder"}'
Expand All @@ -162,12 +190,14 @@ def test_machine_learning_error_should_be_normalized(self):
class ServiceFabricTestCase(ClientTestCase):

def _test_happy_path(self, app):
expires_in = 1234
with patch.object(app._http_client, "get", return_value=MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % (
int(time.time()) + 1234),
int(time.time()) + expires_in),
)) as mocked_method:
super(ServiceFabricTestCase, self)._test_happy_path(app, mocked_method)
super(ServiceFabricTestCase, self)._test_happy_path(
app, mocked_method, expires_in)

def test_happy_path(self):
self._test_happy_path(self.app)
Expand Down Expand Up @@ -212,15 +242,16 @@ class ArcTestCase(ClientTestCase):
})

def test_happy_path(self, mocked_stat):
expires_in = 1234
with patch.object(self.app._http_client, "get", side_effect=[
self.challenge,
MinimalResponse(
status_code=200,
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
text='{"access_token": "AT", "expires_in": "%s", "resource": "R"}' % expires_in,
),
]) as mocked_method:
try:
super(ArcTestCase, self)._test_happy_path(self.app, mocked_method)
self._test_happy_path(self.app, mocked_method, expires_in)
mocked_stat.assert_called_with(os.path.join(
_supported_arc_platforms_and_their_prefixes[sys.platform],
"foo.key"))
Expand Down

0 comments on commit 3a4f44f

Please sign in to comment.