26
26
SERVICE_FABRIC ,
27
27
DEFAULT_TO_VM ,
28
28
)
29
+ from msal .token_cache import is_subdict_of
29
30
30
31
31
32
class ManagedIdentityTestCase (unittest .TestCase ):
@@ -60,7 +61,7 @@ def setUp(self):
60
61
http_client = requests .Session (),
61
62
)
62
63
63
- def _test_token_cache (self , app ):
64
+ def assertCacheStatus (self , app ):
64
65
cache = app ._token_cache ._cache
65
66
self .assertEqual (1 , len (cache .get ("AccessToken" , [])), "Should have 1 AT" )
66
67
at = list (cache ["AccessToken" ].values ())[0 ]
@@ -70,30 +71,55 @@ def _test_token_cache(self, app):
70
71
"Should have expected client_id" )
71
72
self .assertEqual ("managed_identity" , at ["realm" ], "Should have expected realm" )
72
73
73
- def _test_happy_path (self , app , mocked_http ):
74
- result = app .acquire_token_for_client (resource = "R" )
74
+ def _test_happy_path (self , app , mocked_http , expires_in , resource = "R" ):
75
+ result = app .acquire_token_for_client (resource = resource )
75
76
mocked_http .assert_called ()
76
- self .assertEqual ({
77
+ call_count = mocked_http .call_count
78
+ expected_result = {
77
79
"access_token" : "AT" ,
78
- "expires_in" : 1234 ,
79
- "resource" : "R" ,
80
80
"token_type" : "Bearer" ,
81
- }, result , "Should obtain a token response" )
81
+ }
82
+ self .assertTrue (
83
+ is_subdict_of (expected_result , result ), # We will test refresh_on later
84
+ "Should obtain a token response" )
85
+ self .assertEqual (expires_in , result ["expires_in" ], "Should have expected expires_in" )
86
+ if expires_in >= 7200 :
87
+ expected_refresh_on = int (time .time () + expires_in / 2 )
88
+ self .assertTrue (
89
+ expected_refresh_on - 1 <= result ["refresh_on" ] <= expected_refresh_on + 1 ,
90
+ "Should have a refresh_on time around the middle of the token's life" )
82
91
self .assertEqual (
83
92
result ["access_token" ],
84
- app .acquire_token_for_client (resource = "R" ).get ("access_token" ),
93
+ app .acquire_token_for_client (resource = resource ).get ("access_token" ),
85
94
"Should hit the same token from cache" )
86
- self ._test_token_cache (app )
95
+
96
+ self .assertCacheStatus (app )
97
+
98
+ result = app .acquire_token_for_client (resource = resource )
99
+ self .assertEqual (
100
+ call_count , mocked_http .call_count ,
101
+ "No new call to the mocked http should be made for a cache hit" )
102
+ self .assertTrue (
103
+ is_subdict_of (expected_result , result ), # We will test refresh_on later
104
+ "Should obtain a token response" )
105
+ self .assertTrue (
106
+ expires_in - 5 < result ["expires_in" ] <= expires_in ,
107
+ "Should have similar expires_in" )
108
+ if expires_in >= 7200 :
109
+ self .assertTrue (
110
+ expected_refresh_on - 5 < result ["refresh_on" ] <= expected_refresh_on ,
111
+ "Should have a refresh_on time around the middle of the token's life" )
87
112
88
113
89
114
class VmTestCase (ClientTestCase ):
90
115
91
116
def test_happy_path (self ):
117
+ expires_in = 7890 # We test a bigger than 7200 value here
92
118
with patch .object (self .app ._http_client , "get" , return_value = MinimalResponse (
93
119
status_code = 200 ,
94
- text = '{"access_token": "AT", "expires_in": "1234 ", "resource": "R"}' ,
120
+ text = '{"access_token": "AT", "expires_in": "%s ", "resource": "R"}' % expires_in ,
95
121
)) as mocked_method :
96
- self ._test_happy_path (self .app , mocked_method )
122
+ self ._test_happy_path (self .app , mocked_method , expires_in )
97
123
98
124
def test_vm_error_should_be_returned_as_is (self ):
99
125
raw_error = '{"raw": "error format is undefined"}'
@@ -110,12 +136,13 @@ def test_vm_error_should_be_returned_as_is(self):
110
136
class AppServiceTestCase (ClientTestCase ):
111
137
112
138
def test_happy_path (self ):
139
+ expires_in = 1234
113
140
with patch .object (self .app ._http_client , "get" , return_value = MinimalResponse (
114
141
status_code = 200 ,
115
142
text = '{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
116
- int (time .time ()) + 1234 ),
143
+ int (time .time ()) + expires_in ),
117
144
)) as mocked_method :
118
- self ._test_happy_path (self .app , mocked_method )
145
+ self ._test_happy_path (self .app , mocked_method , expires_in )
119
146
120
147
def test_app_service_error_should_be_normalized (self ):
121
148
raw_error = '{"statusCode": 500, "message": "error content is undefined"}'
@@ -134,12 +161,13 @@ def test_app_service_error_should_be_normalized(self):
134
161
class MachineLearningTestCase (ClientTestCase ):
135
162
136
163
def test_happy_path (self ):
164
+ expires_in = 1234
137
165
with patch .object (self .app ._http_client , "get" , return_value = MinimalResponse (
138
166
status_code = 200 ,
139
167
text = '{"access_token": "AT", "expires_on": "%s", "resource": "R"}' % (
140
- int (time .time ()) + 1234 ),
168
+ int (time .time ()) + expires_in ),
141
169
)) as mocked_method :
142
- self ._test_happy_path (self .app , mocked_method )
170
+ self ._test_happy_path (self .app , mocked_method , expires_in )
143
171
144
172
def test_machine_learning_error_should_be_normalized (self ):
145
173
raw_error = '{"error": "placeholder", "message": "placeholder"}'
@@ -162,12 +190,14 @@ def test_machine_learning_error_should_be_normalized(self):
162
190
class ServiceFabricTestCase (ClientTestCase ):
163
191
164
192
def _test_happy_path (self , app ):
193
+ expires_in = 1234
165
194
with patch .object (app ._http_client , "get" , return_value = MinimalResponse (
166
195
status_code = 200 ,
167
196
text = '{"access_token": "AT", "expires_on": %s, "resource": "R", "token_type": "Bearer"}' % (
168
- int (time .time ()) + 1234 ),
197
+ int (time .time ()) + expires_in ),
169
198
)) as mocked_method :
170
- super (ServiceFabricTestCase , self )._test_happy_path (app , mocked_method )
199
+ super (ServiceFabricTestCase , self )._test_happy_path (
200
+ app , mocked_method , expires_in )
171
201
172
202
def test_happy_path (self ):
173
203
self ._test_happy_path (self .app )
@@ -212,15 +242,16 @@ class ArcTestCase(ClientTestCase):
212
242
})
213
243
214
244
def test_happy_path (self , mocked_stat ):
245
+ expires_in = 1234
215
246
with patch .object (self .app ._http_client , "get" , side_effect = [
216
247
self .challenge ,
217
248
MinimalResponse (
218
249
status_code = 200 ,
219
- text = '{"access_token": "AT", "expires_in": "1234 ", "resource": "R"}' ,
250
+ text = '{"access_token": "AT", "expires_in": "%s ", "resource": "R"}' % expires_in ,
220
251
),
221
252
]) as mocked_method :
222
253
try :
223
- super ( ArcTestCase , self ) ._test_happy_path (self .app , mocked_method )
254
+ self ._test_happy_path (self .app , mocked_method , expires_in )
224
255
mocked_stat .assert_called_with (os .path .join (
225
256
_supported_arc_platforms_and_their_prefixes [sys .platform ],
226
257
"foo.key" ))
0 commit comments