Skip to content

Commit 3a4f44f

Browse files
committed
Expose refresh_on (if any) to fresh or cached response
1 parent 57dce47 commit 3a4f44f

File tree

4 files changed

+86
-29
lines changed

4 files changed

+86
-29
lines changed

msal/application.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,11 +104,14 @@ def _clean_up(result):
104104
"msalruntime_telemetry": result.get("_msalruntime_telemetry"),
105105
"msal_python_telemetry": result.get("_msal_python_telemetry"),
106106
}, separators=(",", ":"))
107-
return {
107+
return_value = {
108108
k: result[k] for k in result
109109
if k != "refresh_in" # MSAL handled refresh_in, customers need not
110110
and not k.startswith('_') # Skim internal properties
111111
}
112+
if "refresh_in" in result: # To encourage proactive refresh
113+
return_value["refresh_on"] = int(time.time() + result["refresh_in"])
114+
return return_value
112115
return result # It could be None
113116

114117

@@ -1507,9 +1510,11 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
15071510
"expires_in": int(expires_in), # OAuth2 specs defines it as int
15081511
self._TOKEN_SOURCE: self._TOKEN_SOURCE_CACHE,
15091512
}
1510-
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
1511-
refresh_reason = msal.telemetry.AT_AGING
1512-
break # With a fallback in hand, we break here to go refresh
1513+
if "refresh_on" in entry:
1514+
access_token_from_cache["refresh_on"] = int(entry["refresh_on"])
1515+
if int(entry["refresh_on"]) < now: # aging
1516+
refresh_reason = msal.telemetry.AT_AGING
1517+
break # With a fallback in hand, we break here to go refresh
15131518
self._build_telemetry_context(-1).hit_an_access_token()
15141519
return access_token_from_cache # It is still good as new
15151520
else:

msal/managed_identity.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -273,8 +273,10 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the
273273
"token_type": entry.get("token_type", "Bearer"),
274274
"expires_in": int(expires_in), # OAuth2 specs defines it as int
275275
}
276-
if "refresh_on" in entry and int(entry["refresh_on"]) < now: # aging
277-
break # With a fallback in hand, we break here to go refresh
276+
if "refresh_on" in entry:
277+
access_token_from_cache["refresh_on"] = int(entry["refresh_on"])
278+
if int(entry["refresh_on"]) < now: # aging
279+
break # With a fallback in hand, we break here to go refresh
278280
return access_token_from_cache # It is still good as new
279281
try:
280282
result = _obtain_token(self._http_client, self._managed_identity, resource)
@@ -290,6 +292,8 @@ def acquire_token_for_client(self, *, resource): # We may support scope in the
290292
params={},
291293
data={},
292294
))
295+
if "refresh_in" in result:
296+
result["refresh_on"] = int(now + result["refresh_in"])
293297
if (result and "error" not in result) or (not access_token_from_cache):
294298
return result
295299
except: # The exact HTTP exception is transportation-layer dependent

tests/test_application.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# Note: Since Aug 2019 we move all e2e tests into test_e2e.py,
22
# so this test_application file contains only unit tests without dependency.
33
import sys
4+
import time
45
from msal.application import *
56
from msal.application import _str2bytes
67
import msal
@@ -353,10 +354,18 @@ def populate_cache(self, access_token="at", expires_in=86400, refresh_in=43200):
353354
uid=self.uid, utid=self.utid, refresh_token=self.rt),
354355
})
355356

357+
def assertRefreshOn(self, result, refresh_in):
358+
refresh_on = int(time.time() + refresh_in)
359+
self.assertTrue(
360+
refresh_on - 1 < result.get("refresh_on", 0) < refresh_on + 1,
361+
"refresh_on should be set properly")
362+
356363
def test_fresh_token_should_be_returned_from_cache(self):
357364
# a.k.a. Return unexpired token that is not above token refresh expiration threshold
365+
refresh_in = 450
358366
access_token = "An access token prepopulated into cache"
359-
self.populate_cache(access_token=access_token, expires_in=900, refresh_in=450)
367+
self.populate_cache(
368+
access_token=access_token, expires_in=900, refresh_in=refresh_in)
360369
result = self.app.acquire_token_silent(
361370
['s1'], self.account,
362371
post=lambda url, *args, **kwargs: # Utilize the undocumented test feature
@@ -365,32 +374,38 @@ def test_fresh_token_should_be_returned_from_cache(self):
365374
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
366375
self.assertEqual(access_token, result.get("access_token"))
367376
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
377+
self.assertRefreshOn(result, refresh_in)
368378

369379
def test_aging_token_and_available_aad_should_return_new_token(self):
370380
# a.k.a. Attempt to refresh unexpired token when AAD available
371381
self.populate_cache(access_token="old AT", expires_in=3599, refresh_in=-1)
372382
new_access_token = "new AT"
383+
new_refresh_in = 123
373384
def mock_post(url, headers=None, *args, **kwargs):
374385
self.assertEqual("4|84,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
375386
return MinimalResponse(status_code=200, text=json.dumps({
376387
"access_token": new_access_token,
377-
"refresh_in": 123,
388+
"refresh_in": new_refresh_in,
378389
}))
379390
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
380391
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
381392
self.assertEqual(new_access_token, result.get("access_token"))
382393
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
394+
self.assertRefreshOn(result, new_refresh_in)
383395

384396
def test_aging_token_and_unavailable_aad_should_return_old_token(self):
385397
# a.k.a. Attempt refresh unexpired token when AAD unavailable
398+
refresh_in = -1
386399
old_at = "old AT"
387-
self.populate_cache(access_token=old_at, expires_in=3599, refresh_in=-1)
400+
self.populate_cache(
401+
access_token=old_at, expires_in=3599, refresh_in=refresh_in)
388402
def mock_post(url, headers=None, *args, **kwargs):
389403
self.assertEqual("4|84,4|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
390404
return MinimalResponse(status_code=400, text=json.dumps({"error": "foo"}))
391405
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
392406
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_CACHE)
393407
self.assertEqual(old_at, result.get("access_token"))
408+
self.assertRefreshOn(result, refresh_in)
394409

395410
def test_expired_token_and_unavailable_aad_should_return_error(self):
396411
# a.k.a. Attempt refresh expired token when AAD unavailable
@@ -407,16 +422,18 @@ def test_expired_token_and_available_aad_should_return_new_token(self):
407422
# a.k.a. Attempt refresh expired token when AAD available
408423
self.populate_cache(access_token="expired at", expires_in=-1, refresh_in=-900)
409424
new_access_token = "new AT"
425+
new_refresh_in = 123
410426
def mock_post(url, headers=None, *args, **kwargs):
411427
self.assertEqual("4|84,3|", (headers or {}).get(CLIENT_CURRENT_TELEMETRY))
412428
return MinimalResponse(status_code=200, text=json.dumps({
413429
"access_token": new_access_token,
414-
"refresh_in": 123,
430+
"refresh_in": new_refresh_in,
415431
}))
416432
result = self.app.acquire_token_silent(['s1'], self.account, post=mock_post)
417433
self.assertEqual(result[self.app._TOKEN_SOURCE], self.app._TOKEN_SOURCE_IDP)
418434
self.assertEqual(new_access_token, result.get("access_token"))
419435
self.assertNotIn("refresh_in", result, "Customers need not know refresh_in")
436+
self.assertRefreshOn(result, new_refresh_in)
420437

421438

422439
class TestTelemetryMaintainingOfflineState(unittest.TestCase):

tests/test_mi.py

Lines changed: 50 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
SERVICE_FABRIC,
2727
DEFAULT_TO_VM,
2828
)
29+
from msal.token_cache import is_subdict_of
2930

3031

3132
class ManagedIdentityTestCase(unittest.TestCase):
@@ -60,7 +61,7 @@ def setUp(self):
6061
http_client=requests.Session(),
6162
)
6263

63-
def _test_token_cache(self, app):
64+
def assertCacheStatus(self, app):
6465
cache = app._token_cache._cache
6566
self.assertEqual(1, len(cache.get("AccessToken", [])), "Should have 1 AT")
6667
at = list(cache["AccessToken"].values())[0]
@@ -70,30 +71,55 @@ def _test_token_cache(self, app):
7071
"Should have expected client_id")
7172
self.assertEqual("managed_identity", at["realm"], "Should have expected realm")
7273

73-
def _test_happy_path(self, app, mocked_http):
74-
result = app.acquire_token_for_client(resource="R")
74+
def _test_happy_path(self, app, mocked_http, expires_in, resource="R"):
75+
result = app.acquire_token_for_client(resource=resource)
7576
mocked_http.assert_called()
76-
self.assertEqual({
77+
call_count = mocked_http.call_count
78+
expected_result = {
7779
"access_token": "AT",
78-
"expires_in": 1234,
79-
"resource": "R",
8080
"token_type": "Bearer",
81-
}, result, "Should obtain a token response")
81+
}
82+
self.assertTrue(
83+
is_subdict_of(expected_result, result), # We will test refresh_on later
84+
"Should obtain a token response")
85+
self.assertEqual(expires_in, result["expires_in"], "Should have expected expires_in")
86+
if expires_in >= 7200:
87+
expected_refresh_on = int(time.time() + expires_in / 2)
88+
self.assertTrue(
89+
expected_refresh_on - 1 <= result["refresh_on"] <= expected_refresh_on + 1,
90+
"Should have a refresh_on time around the middle of the token's life")
8291
self.assertEqual(
8392
result["access_token"],
84-
app.acquire_token_for_client(resource="R").get("access_token"),
93+
app.acquire_token_for_client(resource=resource).get("access_token"),
8594
"Should hit the same token from cache")
86-
self._test_token_cache(app)
95+
96+
self.assertCacheStatus(app)
97+
98+
result = app.acquire_token_for_client(resource=resource)
99+
self.assertEqual(
100+
call_count, mocked_http.call_count,
101+
"No new call to the mocked http should be made for a cache hit")
102+
self.assertTrue(
103+
is_subdict_of(expected_result, result), # We will test refresh_on later
104+
"Should obtain a token response")
105+
self.assertTrue(
106+
expires_in - 5 < result["expires_in"] <= expires_in,
107+
"Should have similar expires_in")
108+
if expires_in >= 7200:
109+
self.assertTrue(
110+
expected_refresh_on - 5 < result["refresh_on"] <= expected_refresh_on,
111+
"Should have a refresh_on time around the middle of the token's life")
87112

88113

89114
class VmTestCase(ClientTestCase):
90115

91116
def test_happy_path(self):
117+
expires_in = 7890 # We test a bigger than 7200 value here
92118
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
93119
status_code=200,
94-
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
120+
text='{"access_token": "AT", "expires_in": "%s", "resource": "R"}' % expires_in,
95121
)) as mocked_method:
96-
self._test_happy_path(self.app, mocked_method)
122+
self._test_happy_path(self.app, mocked_method, expires_in)
97123

98124
def test_vm_error_should_be_returned_as_is(self):
99125
raw_error = '{"raw": "error format is undefined"}'
@@ -110,12 +136,13 @@ def test_vm_error_should_be_returned_as_is(self):
110136
class AppServiceTestCase(ClientTestCase):
111137

112138
def test_happy_path(self):
139+
expires_in = 1234
113140
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
114141
status_code=200,
115142
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
116-
int(time.time()) + 1234),
143+
int(time.time()) + expires_in),
117144
)) as mocked_method:
118-
self._test_happy_path(self.app, mocked_method)
145+
self._test_happy_path(self.app, mocked_method, expires_in)
119146

120147
def test_app_service_error_should_be_normalized(self):
121148
raw_error = '{"statusCode": 500, "message": "error content is undefined"}'
@@ -134,12 +161,13 @@ def test_app_service_error_should_be_normalized(self):
134161
class MachineLearningTestCase(ClientTestCase):
135162

136163
def test_happy_path(self):
164+
expires_in = 1234
137165
with patch.object(self.app._http_client, "get", return_value=MinimalResponse(
138166
status_code=200,
139167
text='{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
140-
int(time.time()) + 1234),
168+
int(time.time()) + expires_in),
141169
)) as mocked_method:
142-
self._test_happy_path(self.app, mocked_method)
170+
self._test_happy_path(self.app, mocked_method, expires_in)
143171

144172
def test_machine_learning_error_should_be_normalized(self):
145173
raw_error = '{"error": "placeholder", "message": "placeholder"}'
@@ -162,12 +190,14 @@ def test_machine_learning_error_should_be_normalized(self):
162190
class ServiceFabricTestCase(ClientTestCase):
163191

164192
def _test_happy_path(self, app):
193+
expires_in = 1234
165194
with patch.object(app._http_client, "get", return_value=MinimalResponse(
166195
status_code=200,
167196
text='{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % (
168-
int(time.time()) + 1234),
197+
int(time.time()) + expires_in),
169198
)) as mocked_method:
170-
super(ServiceFabricTestCase, self)._test_happy_path(app, mocked_method)
199+
super(ServiceFabricTestCase, self)._test_happy_path(
200+
app, mocked_method, expires_in)
171201

172202
def test_happy_path(self):
173203
self._test_happy_path(self.app)
@@ -212,15 +242,16 @@ class ArcTestCase(ClientTestCase):
212242
})
213243

214244
def test_happy_path(self, mocked_stat):
245+
expires_in = 1234
215246
with patch.object(self.app._http_client, "get", side_effect=[
216247
self.challenge,
217248
MinimalResponse(
218249
status_code=200,
219-
text='{"access_token": "AT", "expires_in": "1234", "resource": "R"}',
250+
text='{"access_token": "AT", "expires_in": "%s", "resource": "R"}' % expires_in,
220251
),
221252
]) as mocked_method:
222253
try:
223-
super(ArcTestCase, self)._test_happy_path(self.app, mocked_method)
254+
self._test_happy_path(self.app, mocked_method, expires_in)
224255
mocked_stat.assert_called_with(os.path.join(
225256
_supported_arc_platforms_and_their_prefixes[sys.platform],
226257
"foo.key"))

0 commit comments

Comments
 (0)