diff --git a/msal/application.py b/msal/application.py index 87d8347..a17d359 100644 --- a/msal/application.py +++ b/msal/application.py @@ -1,3 +1,5 @@ +import functools +import json import time try: # Python 2 from urlparse import urljoin @@ -54,11 +56,11 @@ def decorate_scope( CLIENT_CURRENT_TELEMETRY = 'x-client-current-telemetry' def _get_new_correlation_id(): - return str(uuid.uuid4()) + return str(uuid.uuid4()) def _build_current_telemetry_request_header(public_api_id, force_refresh=False): - return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0") + return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0") def extract_certs(public_cert_content): @@ -92,6 +94,7 @@ def __init__( self, client_id, client_credential=None, authority=None, validate_authority=True, token_cache=None, + http_client=None, verify=True, proxies=None, timeout=None, client_claims=None, app_name=None, app_version=None): """Create an instance of application. @@ -151,18 +154,24 @@ def __init__( :param TokenCache cache: Sets the token cache used by this ClientApplication instance. By default, an in-memory cache will be created and used. + :param http_client: (optional) + Your implementation of abstract class HttpClient + Defaults to a requests session instance :param verify: (optional) It will be passed to the `verify parameter in the underlying requests library `_ + This does not apply if you have chosen to pass your own Http client :param proxies: (optional) It will be passed to the `proxies parameter in the underlying requests library `_ + This does not apply if you have chosen to pass your own Http client :param timeout: (optional) It will be passed to the `timeout parameter in the underlying requests library `_ + This does not apply if you have chosen to pass your own Http client :param app_name: (optional) You can provide your application name for Microsoft telemetry purposes. Default value is None, means it will not be passed to Microsoft. @@ -173,14 +182,21 @@ def __init__( self.client_id = client_id self.client_credential = client_credential self.client_claims = client_claims - self.verify = verify - self.proxies = proxies - self.timeout = timeout + if http_client: + self.http_client = http_client + else: + self.http_client = requests.Session() + self.http_client.verify = verify + self.http_client.proxies = proxies + # Requests, does not support session - wide timeout + # But you can patch that (https://github.com/psf/requests/issues/3341): + self.http_client.request = functools.partial( + self.http_client.request, timeout=timeout) self.app_name = app_name self.app_version = app_version self.authority = Authority( authority or "https://login.microsoftonline.com/common/", - validate_authority, verify=verify, proxies=proxies, timeout=timeout) + self.http_client, validate_authority=validate_authority) # Here the self.authority is not the same type as authority in input self.token_cache = token_cache or TokenCache() self.client = self._build_client(client_credential, self.authority) @@ -223,14 +239,14 @@ def _build_client(self, client_credential, authority): return Client( server_configuration, self.client_id, + http_client=self.http_client, default_headers=default_headers, default_body=default_body, client_assertion=client_assertion, client_assertion_type=client_assertion_type, on_obtaining_tokens=self.token_cache.add, on_removing_rt=self.token_cache.remove_rt, - on_updating_rt=self.token_cache.update_rt, - verify=self.verify, proxies=self.proxies, timeout=self.timeout) + on_updating_rt=self.token_cache.update_rt) def get_authorization_request_url( self, @@ -288,12 +304,13 @@ def get_authorization_request_url( # Multi-tenant app can use new authority on demand the_authority = Authority( authority, - verify=self.verify, proxies=self.proxies, timeout=self.timeout, + self.http_client ) if authority else self.authority client = Client( {"authorization_endpoint": the_authority.authorization_endpoint}, - self.client_id) + self.client_id, + http_client=self.http_client) return client.build_auth_request_uri( response_type=response_type, redirect_uri=redirect_uri, state=state, login_hint=login_hint, @@ -399,13 +416,12 @@ def _find_msal_accounts(self, environment): def _get_authority_aliases(self, instance): if not self.authority_groups: - resp = requests.get( + resp = self.http_client.get( "https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize", - headers={'Accept': 'application/json'}, - verify=self.verify, proxies=self.proxies, timeout=self.timeout) + headers={'Accept': 'application/json'}) resp.raise_for_status() self.authority_groups = [ - set(group['aliases']) for group in resp.json()['metadata']] + set(group['aliases']) for group in json.loads(resp.text)['metadata']] for group in self.authority_groups: if instance in group: return [alias for alias in group if alias != instance] @@ -524,7 +540,7 @@ def acquire_token_silent_with_error( warnings.warn("We haven't decided how/if this method will accept authority parameter") # the_authority = Authority( # authority, - # verify=self.verify, proxies=self.proxies, timeout=self.timeout, + # self.http_client, # ) if authority else self.authority result = self._acquire_token_silent_from_cache_and_possibly_refresh_it( scopes, account, self.authority, force_refresh=force_refresh, @@ -536,8 +552,8 @@ def acquire_token_silent_with_error( for alias in self._get_authority_aliases(self.authority.instance): the_authority = Authority( "https://" + alias + "/" + self.authority.tenant, - validate_authority=False, - verify=self.verify, proxies=self.proxies, timeout=self.timeout) + self.http_client, + validate_authority=False) result = self._acquire_token_silent_from_cache_and_possibly_refresh_it( scopes, account, the_authority, force_refresh=force_refresh, correlation_id=correlation_id, @@ -780,13 +796,11 @@ def acquire_token_by_username_password( def _acquire_token_by_username_password_federated( self, user_realm_result, username, password, scopes=None, **kwargs): - verify = kwargs.pop("verify", self.verify) - proxies = kwargs.pop("proxies", self.proxies) wstrust_endpoint = {} if user_realm_result.get("federation_metadata_url"): wstrust_endpoint = mex_send_request( user_realm_result["federation_metadata_url"], - verify=verify, proxies=proxies) + self.http_client) if wstrust_endpoint is None: raise ValueError("Unable to find wstrust endpoint from MEX. " "This typically happens when attempting MSA accounts. " @@ -798,7 +812,7 @@ def _acquire_token_by_username_password_federated( wstrust_endpoint.get("address", # Fallback to an AAD supplied endpoint user_realm_result.get("federation_active_auth_url")), - wstrust_endpoint.get("action"), verify=verify, proxies=proxies) + wstrust_endpoint.get("action"), self.http_client) if not ("token" in wstrust_result and "type" in wstrust_result): raise RuntimeError("Unsuccessful RSTR. %s" % wstrust_result) GRANT_TYPE_SAML1_1 = 'urn:ietf:params:oauth:grant-type:saml1_1-bearer' diff --git a/msal/authority.py b/msal/authority.py index d8221ec..94caaab 100644 --- a/msal/authority.py +++ b/msal/authority.py @@ -1,11 +1,10 @@ +import json try: from urllib.parse import urlparse except ImportError: # Fall back to Python 2 from urlparse import urlparse import logging -import requests - from .exceptions import MsalServiceError @@ -25,6 +24,7 @@ "b2clogin.de", ] + class Authority(object): """This class represents an (already-validated) authority. @@ -33,9 +33,7 @@ class Authority(object): """ _domains_without_user_realm_discovery = set([]) - def __init__(self, authority_url, validate_authority=True, - verify=True, proxies=None, timeout=None, - ): + def __init__(self, authority_url, http_client, validate_authority=True): """Creates an authority instance, and also validates it. :param validate_authority: @@ -44,9 +42,7 @@ def __init__(self, authority_url, validate_authority=True, This parameter only controls whether an instance discovery will be performed. """ - self.verify = verify - self.proxies = proxies - self.timeout = timeout + self.http_client = http_client authority, self.instance, tenant = canonicalize(authority_url) parts = authority.path.split('/') is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or ( @@ -56,7 +52,7 @@ def __init__(self, authority_url, validate_authority=True, payload = instance_discovery( "https://{}{}/oauth2/v2.0/authorize".format( self.instance, authority.path), - verify=verify, proxies=proxies, timeout=timeout) + self.http_client) if payload.get("error") == "invalid_instance": raise ValueError( "invalid_instance: " @@ -75,7 +71,7 @@ def __init__(self, authority_url, validate_authority=True, )) openid_config = tenant_discovery( tenant_discovery_endpoint, - verify=verify, proxies=proxies, timeout=timeout) + self.http_client) logger.debug("openid_config = %s", openid_config) self.authorization_endpoint = openid_config['authorization_endpoint'] self.token_endpoint = openid_config['token_endpoint'] @@ -87,15 +83,14 @@ def user_realm_discovery(self, username, correlation_id=None, response=None): # "federation_protocol", "cloud_audience_urn", # "federation_metadata_url", "federation_active_auth_url", etc. if self.instance not in self.__class__._domains_without_user_realm_discovery: - resp = response or requests.get( + resp = response or self.http_client.get( "https://{netloc}/common/userrealm/{username}?api-version=1.0".format( netloc=self.instance, username=username), - headers={'Accept':'application/json', - 'client-request-id': correlation_id}, - verify=self.verify, proxies=self.proxies, timeout=self.timeout) + headers={'Accept': 'application/json', + 'client-request-id': correlation_id},) if resp.status_code != 404: resp.raise_for_status() - return resp.json() + return json.loads(resp.text) self.__class__._domains_without_user_realm_discovery.add(self.instance) return {} # This can guide the caller to fall back normal ROPC flow @@ -113,20 +108,21 @@ def canonicalize(authority_url): % authority_url) return authority, authority.hostname, parts[1] -def instance_discovery(url, **kwargs): - return requests.get( # Note: This URL seemingly returns V1 endpoint only +def instance_discovery(url, http_client, **kwargs): + resp = http_client.get( # Note: This URL seemingly returns V1 endpoint only 'https://{}/common/discovery/instance'.format( WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too # See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103 # and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33 ), params={'authorization_endpoint': url, 'api-version': '1.0'}, - **kwargs).json() + **kwargs) + return json.loads(resp.text) -def tenant_discovery(tenant_discovery_endpoint, **kwargs): +def tenant_discovery(tenant_discovery_endpoint, http_client, **kwargs): # Returns Openid Configuration - resp = requests.get(tenant_discovery_endpoint, **kwargs) - payload = resp.json() + resp = http_client.get(tenant_discovery_endpoint, **kwargs) + payload = json.loads(resp.text) if 'authorization_endpoint' in payload and 'token_endpoint' in payload: return payload raise MsalServiceError(status_code=resp.status_code, **payload) diff --git a/msal/mex.py b/msal/mex.py index caf5e3e..684d50e 100644 --- a/msal/mex.py +++ b/msal/mex.py @@ -34,15 +34,14 @@ except ImportError: from xml.etree import ElementTree as ET -import requests - def _xpath_of_root(route_to_leaf): # Construct an xpath suitable to find a root node which has a specified leaf return '/'.join(route_to_leaf + ['..'] * (len(route_to_leaf)-1)) -def send_request(mex_endpoint, **kwargs): - mex_document = requests.get( + +def send_request(mex_endpoint, http_client, **kwargs): + mex_document = http_client.get( mex_endpoint, headers={'Content-Type': 'application/soap+xml'}, **kwargs).text return Mex(mex_document).get_wstrust_username_password_endpoint() diff --git a/msal/oauth2cli/oauth2.py b/msal/oauth2cli/oauth2.py index 9a94739..fac35f1 100644 --- a/msal/oauth2cli/oauth2.py +++ b/msal/oauth2cli/oauth2.py @@ -1,6 +1,7 @@ """This OAuth2 client implementation aims to be spec-compliant, and generic.""" # OAuth2 spec https://tools.ietf.org/html/rfc6749 +import json try: from urllib.parse import urlencode, parse_qs except ImportError: @@ -11,6 +12,7 @@ import time import base64 import sys +import functools import requests @@ -35,6 +37,7 @@ def __init__( self, server_configuration, # type: dict client_id, # type: str + http_client=None, # We insert it here to match the upcoming async API client_secret=None, # type: Optional[str] client_assertion=None, # type: Union[bytes, callable, None] client_assertion_type=None, # type: Optional[str] @@ -57,6 +60,9 @@ def __init__( or https://example.com/.../.well-known/openid-configuration client_id (str): The client's id, issued by the authorization server + http_client (http.HttpClient): + Your implementation of abstract class :class:`http.HttpClient`. + Defaults to a requests session instance. client_secret (str): Triggers HTTP AUTH for Confidential Client client_assertion (bytes, callable): The client assertion to authenticate this client, per RFC 7521. @@ -76,20 +82,51 @@ def __init__( you could choose to set this as {"client_secret": "your secret"} if your authorization server wants it to be in the request body (rather than in the request header). + + verify (boolean): + It will be passed to the + `verify parameter in the underlying requests library + `_ + This does not apply if you have chosen to pass your own Http client. + proxies (dict): + It will be passed to the + `proxies parameter in the underlying requests library + `_ + This does not apply if you have chosen to pass your own Http client. + timeout (object): + It will be passed to the + `timeout parameter in the underlying requests library + `_ + This does not apply if you have chosen to pass your own Http client. + + There is no session-wide `timeout` parameter defined here. + The timeout behavior is determined by the actual http client you use. + If you happen to use Requests, it chose to not support session-wide timeout + (https://github.com/psf/requests/issues/3341), but you can patch that by: + + s = requests.Session() + s.request = functools.partial(s.request, timeout=3) + + and then feed that patched session instance to this class. """ self.configuration = server_configuration self.client_id = client_id self.client_secret = client_secret self.client_assertion = client_assertion + self.default_headers = default_headers or {} self.default_body = default_body or {} if client_assertion_type is not None: self.default_body["client_assertion_type"] = client_assertion_type self.logger = logging.getLogger(__name__) - self.session = s = requests.Session() - s.headers.update(default_headers or {}) - s.verify = verify - s.proxies = proxies or {} - self.timeout = timeout + if http_client: + self.http_client = http_client + else: + self.http_client = requests.Session() + self.http_client.verify = verify + self.http_client.proxies = proxies + self.http_client.request = functools.partial( + # A workaround for requests not supporting session-wide timeout + self.http_client.request, timeout=timeout) def _build_auth_request_params(self, response_type, **kwargs): # response_type is a string defined in @@ -110,7 +147,6 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 params=None, # a dict to be sent as query string to the endpoint data=None, # All relevant data, which will go into the http body headers=None, # a dict to be sent as request headers - timeout=None, post=None, # A callable to replace requests.post(), for testing. # Such as: lambda url, **kwargs: # Mock(status_code=200, json=Mock(return_value={})) @@ -128,11 +164,15 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 _data.update(self.default_body) # It may contain authen parameters _data.update(data or {}) # So the content in data param prevails - # We don't have to clean up None values here, because requests lib will. + _data = {k: v for k, v in _data.items() if v} # Clean up None values if _data.get('scope'): _data['scope'] = self._stringify(_data['scope']) + _headers = {'Accept': 'application/json'} + _headers.update(self.default_headers) + _headers.update(headers or {}) + # Quoted from https://tools.ietf.org/html/rfc6749#section-2.3.1 # Clients in possession of a client password MAY use the HTTP Basic # authentication. @@ -140,18 +180,16 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 # the authorization server MAY support including the # client credentials in the request-body using the following # parameters: client_id, client_secret. - auth = None if self.client_secret and self.client_id: - auth = (self.client_id, self.client_secret) # for HTTP Basic Auth + _headers["Authorization"] = "Basic " + base64.b64encode( + "{}:{}".format(self.client_id, self.client_secret) + .encode("ascii")).decode("ascii") if "token_endpoint" not in self.configuration: raise ValueError("token_endpoint not found in configuration") - _headers = {'Accept': 'application/json'} - _headers.update(headers or {}) - resp = (post or self.session.post)( + resp = (post or self.http_client.post)( self.configuration["token_endpoint"], - headers=_headers, params=params, data=_data, auth=auth, - timeout=timeout or self.timeout, + headers=_headers, params=params, data=_data, **kwargs) if resp.status_code >= 500: resp.raise_for_status() # TODO: Will probably retry here @@ -159,7 +197,7 @@ def _obtain_token( # The verb "obtain" is influenced by OAUTH2 RFC 6749 # The spec (https://tools.ietf.org/html/rfc6749#section-5.2) says # even an error response will be a valid json structure, # so we simply return it here, without needing to invent an exception. - return resp.json() + return json.loads(resp.text) except ValueError: self.logger.exception( "Token response is not in json format: %s", resp.text) @@ -200,7 +238,7 @@ class Client(BaseClient): # We choose to implement all 4 grants in 1 class grant_assertion_encoders = {GRANT_TYPE_SAML2: BaseClient.encode_saml_assertion} - def initiate_device_flow(self, scope=None, timeout=None, **kwargs): + def initiate_device_flow(self, scope=None, **kwargs): # type: (list, **dict) -> dict # The naming of this method is following the wording of this specs # https://tools.ietf.org/html/draft-ietf-oauth-device-flow-12#section-3.1 @@ -218,10 +256,11 @@ def initiate_device_flow(self, scope=None, timeout=None, **kwargs): DAE = "device_authorization_endpoint" if not self.configuration.get(DAE): raise ValueError("You need to provide device authorization endpoint") - flow = self.session.post(self.configuration[DAE], + resp = self.http_client.post(self.configuration[DAE], data={"client_id": self.client_id, "scope": self._stringify(scope or [])}, - timeout=timeout or self.timeout, - **kwargs).json() + headers=dict(self.default_headers, **kwargs.pop("headers", {})), + **kwargs) + flow = json.loads(resp.text) flow["interval"] = int(flow.get("interval", 5)) # Some IdP returns string flow["expires_in"] = int(flow.get("expires_in", 1800)) flow["expires_at"] = time.time() + flow["expires_in"] # We invent this diff --git a/msal/wstrust_request.py b/msal/wstrust_request.py index 84c0384..b2898f7 100644 --- a/msal/wstrust_request.py +++ b/msal/wstrust_request.py @@ -29,16 +29,13 @@ from datetime import datetime, timedelta import logging -import requests - from .mex import Mex from .wstrust_response import parse_response - logger = logging.getLogger(__name__) def send_request( - username, password, cloud_audience_urn, endpoint_address, soap_action, + username, password, cloud_audience_urn, endpoint_address, soap_action, http_client, **kwargs): if not endpoint_address: raise ValueError("WsTrust endpoint address can not be empty") @@ -51,7 +48,7 @@ def send_request( "Unsupported soap action: %s" % soap_action) data = _build_rst( username, password, cloud_audience_urn, endpoint_address, soap_action) - resp = requests.post(endpoint_address, data=data, headers={ + resp = http_client.post(endpoint_address, data=data, headers={ 'Content-type':'application/soap+xml; charset=utf-8', 'SOAPAction': soap_action, }, **kwargs) @@ -61,11 +58,13 @@ def send_request( # resp.raise_for_status() return parse_response(resp.text) + def escape_password(password): return (password.replace('&', '&').replace('"', '"') .replace("'", ''') # the only one not provided by cgi.escape(s, True) .replace('<', '<').replace('>', '>')) + def wsu_time_format(datetime_obj): # WsTrust (http://docs.oasis-open.org/ws-sx/ws-trust/v1.4/ws-trust.html) # does not seem to define timestamp format, but we see YYYY-mm-ddTHH:MM:SSZ @@ -74,6 +73,7 @@ def wsu_time_format(datetime_obj): # https://docs.python.org/2/library/datetime.html#datetime.datetime.isoformat return datetime_obj.strftime('%Y-%m-%dT%H:%M:%SZ') + def _build_rst(username, password, cloud_audience_urn, endpoint_address, soap_action): now = datetime.utcnow() return """ diff --git a/requirements.txt b/requirements.txt index 61a6510..9c558e3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ . -mock; python_version < '3.3' diff --git a/tests/http_client.py b/tests/http_client.py new file mode 100644 index 0000000..4bff9b4 --- /dev/null +++ b/tests/http_client.py @@ -0,0 +1,30 @@ +import requests + + +class MinimalHttpClient: + + def __init__(self, verify=True, proxies=None, timeout=None): + self.session = requests.Session() + self.session.verify = verify + self.session.proxies = proxies + self.timeout = timeout + + def post(self, url, params=None, data=None, headers=None, **kwargs): + return MinimalResponse(requests_resp=self.session.post( + url, params=params, data=data, headers=headers, + timeout=self.timeout)) + + def get(self, url, params=None, headers=None, **kwargs): + return MinimalResponse(requests_resp=self.session.get( + url, params=params, headers=headers, timeout=self.timeout)) + + +class MinimalResponse(object): # Not for production use + def __init__(self, requests_resp=None, status_code=None, text=None): + self.status_code = status_code or requests_resp.status_code + self.text = text or requests_resp.text + self._raw_resp = requests_resp + + def raise_for_status(self): + if self._raw_resp: + self._raw_resp.raise_for_status() diff --git a/tests/test_application.py b/tests/test_application.py index 4d7c288..39becd5 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -1,18 +1,10 @@ # Note: Since Aug 2019 we move all e2e tests into test_e2e.py, # so this test_application file contains only unit tests without dependency. -import os -import json -import logging - -try: - from unittest.mock import * # Python 3 -except: - from mock import * # Need an external mock package - from msal.application import * import msal from tests import unittest from tests.test_token_cache import TokenCacheTestCase +from tests.http_client import MinimalHttpClient, MinimalResponse logger = logging.getLogger(__name__) @@ -50,7 +42,8 @@ class TestClientApplicationAcquireTokenSilentErrorBehaviors(unittest.TestCase): def setUp(self): self.authority_url = "https://login.microsoftonline.com/common" - self.authority = msal.authority.Authority(self.authority_url) + self.authority = msal.authority.Authority( + self.authority_url, MinimalHttpClient()) self.scopes = ["s1", "s2"] self.uid = "my_uid" self.utid = "my_utid" @@ -76,31 +69,31 @@ def test_cache_empty_will_be_returned_as_None(self): None, self.app.acquire_token_silent_with_error(['cache_miss'], self.account)) def test_acquire_token_silent_will_suppress_error(self): - error_response = {"error": "invalid_grant", "suberror": "xyz"} + error_response = '{"error": "invalid_grant", "suberror": "xyz"}' def tester(url, **kwargs): - return Mock(status_code=400, json=Mock(return_value=error_response)) + return MinimalResponse(status_code=400, text=error_response) self.assertEqual(None, self.app.acquire_token_silent( self.scopes, self.account, post=tester)) def test_acquire_token_silent_with_error_will_return_error(self): - error_response = {"error": "invalid_grant", "error_description": "xyz"} + error_response = '{"error": "invalid_grant", "error_description": "xyz"}' def tester(url, **kwargs): - return Mock(status_code=400, json=Mock(return_value=error_response)) - self.assertEqual(error_response, self.app.acquire_token_silent_with_error( + return MinimalResponse(status_code=400, text=error_response) + self.assertEqual(json.loads(error_response), self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester)) def test_atswe_will_map_some_suberror_to_classification_as_is(self): - error_response = {"error": "invalid_grant", "suberror": "basic_action"} + error_response = '{"error": "invalid_grant", "suberror": "basic_action"}' def tester(url, **kwargs): - return Mock(status_code=400, json=Mock(return_value=error_response)) + return MinimalResponse(status_code=400, text=error_response) result = self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester) self.assertEqual("basic_action", result.get("classification")) def test_atswe_will_map_some_suberror_to_classification_to_empty_string(self): - error_response = {"error": "invalid_grant", "suberror": "client_mismatch"} + error_response = '{"error": "invalid_grant", "suberror": "client_mismatch"}' def tester(url, **kwargs): - return Mock(status_code=400, json=Mock(return_value=error_response)) + return MinimalResponse(status_code=400, text=error_response) result = self.app.acquire_token_silent_with_error( self.scopes, self.account, post=tester) self.assertEqual("", result.get("classification")) @@ -109,7 +102,8 @@ class TestClientApplicationAcquireTokenSilentFociBehaviors(unittest.TestCase): def setUp(self): self.authority_url = "https://login.microsoftonline.com/common" - self.authority = msal.authority.Authority(self.authority_url) + self.authority = msal.authority.Authority( + self.authority_url, MinimalHttpClient()) self.scopes = ["s1", "s2"] self.uid = "my_uid" self.utid = "my_utid" @@ -131,11 +125,10 @@ def test_unknown_orphan_app_will_attempt_frt_and_not_remove_it(self): app = ClientApplication( "unknown_orphan", authority=self.authority_url, token_cache=self.cache) logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) + error_response = '{"error": "invalid_grant","error_description": "Was issued to another client"}' def tester(url, data=None, **kwargs): self.assertEqual(self.frt, data.get("refresh_token"), "Should attempt the FRT") - return Mock(status_code=400, json=Mock(return_value={ - "error": "invalid_grant", - "error_description": "Was issued to another client"})) + return MinimalResponse(status_code=400, text=error_response) app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self.authority, self.scopes, self.account, post=tester) self.assertNotEqual([], app.token_cache.find( @@ -156,7 +149,7 @@ def test_known_orphan_app_will_skip_frt_and_only_use_its_own_rt(self): logger.debug("%s.cache = %s", self.id(), self.cache.serialize()) def tester(url, data=None, **kwargs): self.assertEqual(rt, data.get("refresh_token"), "Should attempt the RT") - return Mock(status_code=200, json=Mock(return_value={})) + return MinimalResponse(status_code=200, text='{}') app._acquire_token_silent_by_finding_rt_belongs_to_me_or_my_family( self.authority, self.scopes, self.account, post=tester) @@ -164,9 +157,8 @@ def test_unknown_family_app_will_attempt_frt_and_join_family(self): def tester(url, data=None, **kwargs): self.assertEqual( self.frt, data.get("refresh_token"), "Should attempt the FRT") - return Mock( - status_code=200, - json=Mock(return_value=TokenCacheTestCase.build_response( + return MinimalResponse( + status_code=200, text=json.dumps(TokenCacheTestCase.build_response( uid=self.uid, utid=self.utid, foci="1", access_token="at"))) app = ClientApplication( "unknown_family_app", authority=self.authority_url, token_cache=self.cache) diff --git a/tests/test_authority.py b/tests/test_authority.py index d1e75ef..15a0eb5 100644 --- a/tests/test_authority.py +++ b/tests/test_authority.py @@ -1,8 +1,8 @@ import os from msal.authority import * -from msal.exceptions import MsalServiceError from tests import unittest +from tests.http_client import MinimalHttpClient @unittest.skipIf(os.getenv("TRAVIS_TAG"), "Skip network io during tagged release") @@ -11,7 +11,8 @@ class TestAuthority(unittest.TestCase): def test_wellknown_host_and_tenant(self): # Assert all well known authority hosts are using their own "common" tenant for host in WELL_KNOWN_AUTHORITY_HOSTS: - a = Authority('https://{}/common'.format(host)) + a = Authority( + 'https://{}/common'.format(host), MinimalHttpClient()) self.assertEqual( a.authorization_endpoint, 'https://%s/common/oauth2/v2.0/authorize' % host) @@ -24,18 +25,22 @@ def test_lessknown_host_will_return_a_set_of_v1_endpoints(self): # It is probably not a strict API contract. I simply mention it here. less_known = 'login.windows.net' # less.known.host/ v1_token_endpoint = 'https://{}/common/oauth2/token'.format(less_known) - a = Authority('https://{}/common'.format(less_known)) + a = Authority( + 'https://{}/common'.format(less_known), MinimalHttpClient()) self.assertEqual(a.token_endpoint, v1_token_endpoint) self.assertNotIn('v2.0', a.token_endpoint) def test_unknown_host_wont_pass_instance_discovery(self): _assert = getattr(self, "assertRaisesRegex", self.assertRaisesRegexp) # Hack with _assert(ValueError, "invalid_instance"): - Authority('https://example.com/tenant_doesnt_matter_in_this_case') + Authority('https://example.com/tenant_doesnt_matter_in_this_case', + MinimalHttpClient()) def test_invalid_host_skipping_validation_can_be_turned_off(self): try: - Authority('https://example.com/invalid', validate_authority=False) + Authority( + 'https://example.com/invalid', + MinimalHttpClient(), validate_authority=False) except ValueError as e: if "invalid_instance" in str(e): # Imprecise but good enough self.fail("validate_authority=False should turn off validation") @@ -79,7 +84,7 @@ def test_memorize(self): # We use a real authority so the constructor can finish tenant discovery authority = "https://login.microsoftonline.com/common" self.assertNotIn(authority, Authority._domains_without_user_realm_discovery) - a = Authority(authority, validate_authority=False) + a = Authority(authority, MinimalHttpClient(), validate_authority=False) # We now pretend this authority supports no User Realm Discovery class MockResponse(object): @@ -91,4 +96,3 @@ class MockResponse(object): "user_realm_discovery() should memorize domains not supporting URD") a.user_realm_discovery("john.doe@example.com", response="This would cause exception if memorization did not work") - diff --git a/tests/test_client.py b/tests/test_client.py index d1de2b6..75cdfc9 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -12,6 +12,7 @@ from msal.oauth2cli import Client, JwtSigner from msal.oauth2cli.authcode import obtain_auth_code from tests import unittest, Oauth2TestCase +from tests.http_client import MinimalHttpClient logging.basicConfig(level=logging.DEBUG) @@ -83,6 +84,7 @@ class TestClient(Oauth2TestCase): @classmethod def setUpClass(cls): + http_client = MinimalHttpClient() if "client_certificate" in CONFIG: private_key_path = CONFIG["client_certificate"]["private_key_path"] with open(os.path.join(THIS_FOLDER, private_key_path)) as f: @@ -90,6 +92,7 @@ def setUpClass(cls): cls.client = Client( CONFIG["openid_configuration"], CONFIG['client_id'], + http_client=http_client, client_assertion=JwtSigner( private_key, algorithm="RS256", @@ -103,6 +106,7 @@ def setUpClass(cls): else: cls.client = Client( CONFIG["openid_configuration"], CONFIG['client_id'], + http_client=http_client, client_secret=CONFIG.get('client_secret')) @unittest.skipIf( diff --git a/tests/test_e2e.py b/tests/test_e2e.py index 0d74eb1..28383cd 100644 --- a/tests/test_e2e.py +++ b/tests/test_e2e.py @@ -7,7 +7,7 @@ import requests import msal - +from tests.http_client import MinimalHttpClient logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) @@ -21,7 +21,8 @@ def _get_app_and_auth_code( scopes=["https://graph.microsoft.com/.default"], # Microsoft Graph **kwargs): from msal.oauth2cli.authcode import obtain_auth_code - app = msal.ClientApplication(client_id, client_secret, authority=authority) + app = msal.ClientApplication( + client_id, client_secret, authority=authority, http_client=MinimalHttpClient()) redirect_uri = "http://localhost:%d" % port ac = obtain_auth_code(port, auth_uri=app.get_authorization_request_url( scopes, redirect_uri=redirect_uri, **kwargs)) @@ -92,7 +93,8 @@ def _test_username_password(self, authority=None, client_id=None, username=None, password=None, scope=None, **ignored): assert authority and client_id and username and password and scope - self.app = msal.PublicClientApplication(client_id, authority=authority) + self.app = msal.PublicClientApplication( + client_id, authority=authority, http_client=MinimalHttpClient()) result = self.app.acquire_token_by_username_password( username, password, scopes=scope) self.assertLoosely(result) @@ -106,7 +108,7 @@ def _test_device_flow( self, client_id=None, authority=None, scope=None, **ignored): assert client_id and authority and scope self.app = msal.PublicClientApplication( - client_id, authority=authority) + client_id, authority=authority, http_client=MinimalHttpClient()) flow = self.app.initiate_device_flow(scopes=scope) assert "user_code" in flow, "DF does not seem to be provisioned: %s".format( json.dumps(flow, indent=4)) @@ -225,13 +227,13 @@ def test_ssh_cert(self): self.assertEqual(refreshed_ssh_cert["token_type"], "ssh-cert") self.assertNotEqual(result["access_token"], refreshed_ssh_cert['access_token']) - def test_client_secret(self): self.skipUnlessWithConfig(["client_id", "client_secret"]) self.app = msal.ConfidentialClientApplication( self.config["client_id"], client_credential=self.config.get("client_secret"), - authority=self.config.get("authority")) + authority=self.config.get("authority"), + http_client=MinimalHttpClient()) scope = self.config.get("scope", []) result = self.app.acquire_token_for_client(scope) self.assertIn('access_token', result) @@ -245,7 +247,8 @@ def test_client_certificate(self): private_key = f.read() # Should be in PEM format self.app = msal.ConfidentialClientApplication( self.config['client_id'], - {"private_key": private_key, "thumbprint": client_cert["thumbprint"]}) + {"private_key": private_key, "thumbprint": client_cert["thumbprint"]}, + http_client=MinimalHttpClient()) scope = self.config.get("scope", []) result = self.app.acquire_token_for_client(scope) self.assertIn('access_token', result) @@ -267,7 +270,8 @@ def test_subject_name_issuer_authentication(self): "private_key": private_key, "thumbprint": self.config["thumbprint"], "public_certificate": public_certificate, - }) + }, + http_client=MinimalHttpClient()) scope = self.config.get("scope", []) result = self.app.acquire_token_for_client(scope) self.assertIn('access_token', result) @@ -311,7 +315,7 @@ def get_lab_app( return msal.ConfidentialClientApplication(client_id, client_secret, authority="https://login.microsoftonline.com/" "72f988bf-86f1-41af-91ab-2d7cd011db47", # Microsoft tenant ID - ) + http_client=MinimalHttpClient()) def get_session(lab_app, scopes): # BTW, this infrastructure tests the confidential client flow logger.info("Creating session") @@ -398,7 +402,8 @@ def _test_acquire_token_by_auth_code( def _test_acquire_token_obo(self, config_pca, config_cca): # 1. An app obtains a token representing a user, for our mid-tier service pca = msal.PublicClientApplication( - config_pca["client_id"], authority=config_pca["authority"]) + config_pca["client_id"], authority=config_pca["authority"], + http_client=MinimalHttpClient()) pca_result = pca.acquire_token_by_username_password( config_pca["username"], config_pca["password"], @@ -413,6 +418,7 @@ def _test_acquire_token_obo(self, config_pca, config_cca): config_cca["client_id"], client_credential=config_cca["client_secret"], authority=config_cca["authority"], + http_client=MinimalHttpClient(), # token_cache= ..., # Default token cache is all-tokens-store-in-memory. # That's fine if OBO app uses short-lived msal instance per session. # Otherwise, the OBO app need to implement a one-cache-per-user setup. @@ -439,7 +445,8 @@ def _test_acquire_token_by_client_secret( **ignored): assert client_id and client_secret and authority and scope app = msal.ConfidentialClientApplication( - client_id, client_credential=client_secret, authority=authority) + client_id, client_credential=client_secret, authority=authority, + http_client=MinimalHttpClient()) result = app.acquire_token_for_client(scope) self.assertIsNotNone(result.get("access_token"), "Got %s instead" % result)