Skip to content

Commit

Permalink
Allowing transport layer to be customized (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
abhidnya13 authored Apr 21, 2020
1 parent 099d576 commit f2340a4
Show file tree
Hide file tree
Showing 11 changed files with 200 additions and 116 deletions.
56 changes: 35 additions & 21 deletions msal/application.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import functools
import json
import time
try: # Python 2
from urlparse import urljoin
Expand Down Expand Up @@ -54,11 +56,11 @@ def decorate_scope(
CLIENT_CURRENT_TELEMETRY = 'x-client-current-telemetry'

def _get_new_correlation_id():
return str(uuid.uuid4())
return str(uuid.uuid4())


def _build_current_telemetry_request_header(public_api_id, force_refresh=False):
return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0")
return "1|{},{}|".format(public_api_id, "1" if force_refresh else "0")


def extract_certs(public_cert_content):
Expand Down Expand Up @@ -92,6 +94,7 @@ def __init__(
self, client_id,
client_credential=None, authority=None, validate_authority=True,
token_cache=None,
http_client=None,
verify=True, proxies=None, timeout=None,
client_claims=None, app_name=None, app_version=None):
"""Create an instance of application.
Expand Down Expand Up @@ -151,18 +154,24 @@ def __init__(
:param TokenCache cache:
Sets the token cache used by this ClientApplication instance.
By default, an in-memory cache will be created and used.
:param http_client: (optional)
Your implementation of abstract class HttpClient <msal.oauth2cli.http.http_client>
Defaults to a requests session instance
:param verify: (optional)
It will be passed to the
`verify parameter in the underlying requests library
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#ssl-cert-verification>`_
This does not apply if you have chosen to pass your own Http client
:param proxies: (optional)
It will be passed to the
`proxies parameter in the underlying requests library
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#proxies>`_
This does not apply if you have chosen to pass your own Http client
:param timeout: (optional)
It will be passed to the
`timeout parameter in the underlying requests library
<http://docs.python-requests.org/en/v2.9.1/user/advanced/#timeouts>`_
This does not apply if you have chosen to pass your own Http client
:param app_name: (optional)
You can provide your application name for Microsoft telemetry purposes.
Default value is None, means it will not be passed to Microsoft.
Expand All @@ -173,14 +182,21 @@ def __init__(
self.client_id = client_id
self.client_credential = client_credential
self.client_claims = client_claims
self.verify = verify
self.proxies = proxies
self.timeout = timeout
if http_client:
self.http_client = http_client
else:
self.http_client = requests.Session()
self.http_client.verify = verify
self.http_client.proxies = proxies
# Requests, does not support session - wide timeout
# But you can patch that (https://github.com/psf/requests/issues/3341):
self.http_client.request = functools.partial(
self.http_client.request, timeout=timeout)
self.app_name = app_name
self.app_version = app_version
self.authority = Authority(
authority or "https://login.microsoftonline.com/common/",
validate_authority, verify=verify, proxies=proxies, timeout=timeout)
self.http_client, validate_authority=validate_authority)
# Here the self.authority is not the same type as authority in input
self.token_cache = token_cache or TokenCache()
self.client = self._build_client(client_credential, self.authority)
Expand Down Expand Up @@ -223,14 +239,14 @@ def _build_client(self, client_credential, authority):
return Client(
server_configuration,
self.client_id,
http_client=self.http_client,
default_headers=default_headers,
default_body=default_body,
client_assertion=client_assertion,
client_assertion_type=client_assertion_type,
on_obtaining_tokens=self.token_cache.add,
on_removing_rt=self.token_cache.remove_rt,
on_updating_rt=self.token_cache.update_rt,
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
on_updating_rt=self.token_cache.update_rt)

def get_authorization_request_url(
self,
Expand Down Expand Up @@ -288,12 +304,13 @@ def get_authorization_request_url(
# Multi-tenant app can use new authority on demand
the_authority = Authority(
authority,
verify=self.verify, proxies=self.proxies, timeout=self.timeout,
self.http_client
) if authority else self.authority

client = Client(
{"authorization_endpoint": the_authority.authorization_endpoint},
self.client_id)
self.client_id,
http_client=self.http_client)
return client.build_auth_request_uri(
response_type=response_type,
redirect_uri=redirect_uri, state=state, login_hint=login_hint,
Expand Down Expand Up @@ -399,13 +416,12 @@ def _find_msal_accounts(self, environment):

def _get_authority_aliases(self, instance):
if not self.authority_groups:
resp = requests.get(
resp = self.http_client.get(
"https://login.microsoftonline.com/common/discovery/instance?api-version=1.1&authorization_endpoint=https://login.microsoftonline.com/common/oauth2/authorize",
headers={'Accept': 'application/json'},
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
headers={'Accept': 'application/json'})
resp.raise_for_status()
self.authority_groups = [
set(group['aliases']) for group in resp.json()['metadata']]
set(group['aliases']) for group in json.loads(resp.text)['metadata']]
for group in self.authority_groups:
if instance in group:
return [alias for alias in group if alias != instance]
Expand Down Expand Up @@ -524,7 +540,7 @@ def acquire_token_silent_with_error(
warnings.warn("We haven't decided how/if this method will accept authority parameter")
# the_authority = Authority(
# authority,
# verify=self.verify, proxies=self.proxies, timeout=self.timeout,
# self.http_client,
# ) if authority else self.authority
result = self._acquire_token_silent_from_cache_and_possibly_refresh_it(
scopes, account, self.authority, force_refresh=force_refresh,
Expand All @@ -536,8 +552,8 @@ def acquire_token_silent_with_error(
for alias in self._get_authority_aliases(self.authority.instance):
the_authority = Authority(
"https://" + alias + "/" + self.authority.tenant,
validate_authority=False,
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
self.http_client,
validate_authority=False)
result = self._acquire_token_silent_from_cache_and_possibly_refresh_it(
scopes, account, the_authority, force_refresh=force_refresh,
correlation_id=correlation_id,
Expand Down Expand Up @@ -780,13 +796,11 @@ def acquire_token_by_username_password(

def _acquire_token_by_username_password_federated(
self, user_realm_result, username, password, scopes=None, **kwargs):
verify = kwargs.pop("verify", self.verify)
proxies = kwargs.pop("proxies", self.proxies)
wstrust_endpoint = {}
if user_realm_result.get("federation_metadata_url"):
wstrust_endpoint = mex_send_request(
user_realm_result["federation_metadata_url"],
verify=verify, proxies=proxies)
self.http_client)
if wstrust_endpoint is None:
raise ValueError("Unable to find wstrust endpoint from MEX. "
"This typically happens when attempting MSA accounts. "
Expand All @@ -798,7 +812,7 @@ def _acquire_token_by_username_password_federated(
wstrust_endpoint.get("address",
# Fallback to an AAD supplied endpoint
user_realm_result.get("federation_active_auth_url")),
wstrust_endpoint.get("action"), verify=verify, proxies=proxies)
wstrust_endpoint.get("action"), self.http_client)
if not ("token" in wstrust_result and "type" in wstrust_result):
raise RuntimeError("Unsuccessful RSTR. %s" % wstrust_result)
GRANT_TYPE_SAML1_1 = 'urn:ietf:params:oauth:grant-type:saml1_1-bearer'
Expand Down
38 changes: 17 additions & 21 deletions msal/authority.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import json
try:
from urllib.parse import urlparse
except ImportError: # Fall back to Python 2
from urlparse import urlparse
import logging

import requests

from .exceptions import MsalServiceError


Expand All @@ -25,6 +24,7 @@
"b2clogin.de",
]


class Authority(object):
"""This class represents an (already-validated) authority.
Expand All @@ -33,9 +33,7 @@ class Authority(object):
"""
_domains_without_user_realm_discovery = set([])

def __init__(self, authority_url, validate_authority=True,
verify=True, proxies=None, timeout=None,
):
def __init__(self, authority_url, http_client, validate_authority=True):
"""Creates an authority instance, and also validates it.
:param validate_authority:
Expand All @@ -44,9 +42,7 @@ def __init__(self, authority_url, validate_authority=True,
This parameter only controls whether an instance discovery will be
performed.
"""
self.verify = verify
self.proxies = proxies
self.timeout = timeout
self.http_client = http_client
authority, self.instance, tenant = canonicalize(authority_url)
parts = authority.path.split('/')
is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS) or (
Expand All @@ -56,7 +52,7 @@ def __init__(self, authority_url, validate_authority=True,
payload = instance_discovery(
"https://{}{}/oauth2/v2.0/authorize".format(
self.instance, authority.path),
verify=verify, proxies=proxies, timeout=timeout)
self.http_client)
if payload.get("error") == "invalid_instance":
raise ValueError(
"invalid_instance: "
Expand All @@ -75,7 +71,7 @@ def __init__(self, authority_url, validate_authority=True,
))
openid_config = tenant_discovery(
tenant_discovery_endpoint,
verify=verify, proxies=proxies, timeout=timeout)
self.http_client)
logger.debug("openid_config = %s", openid_config)
self.authorization_endpoint = openid_config['authorization_endpoint']
self.token_endpoint = openid_config['token_endpoint']
Expand All @@ -87,15 +83,14 @@ def user_realm_discovery(self, username, correlation_id=None, response=None):
# "federation_protocol", "cloud_audience_urn",
# "federation_metadata_url", "federation_active_auth_url", etc.
if self.instance not in self.__class__._domains_without_user_realm_discovery:
resp = response or requests.get(
resp = response or self.http_client.get(
"https://{netloc}/common/userrealm/{username}?api-version=1.0".format(
netloc=self.instance, username=username),
headers={'Accept':'application/json',
'client-request-id': correlation_id},
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
headers={'Accept': 'application/json',
'client-request-id': correlation_id},)
if resp.status_code != 404:
resp.raise_for_status()
return resp.json()
return json.loads(resp.text)
self.__class__._domains_without_user_realm_discovery.add(self.instance)
return {} # This can guide the caller to fall back normal ROPC flow

Expand All @@ -113,20 +108,21 @@ def canonicalize(authority_url):
% authority_url)
return authority, authority.hostname, parts[1]

def instance_discovery(url, **kwargs):
return requests.get( # Note: This URL seemingly returns V1 endpoint only
def instance_discovery(url, http_client, **kwargs):
resp = http_client.get( # Note: This URL seemingly returns V1 endpoint only
'https://{}/common/discovery/instance'.format(
WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too
# See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103
# and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33
),
params={'authorization_endpoint': url, 'api-version': '1.0'},
**kwargs).json()
**kwargs)
return json.loads(resp.text)

def tenant_discovery(tenant_discovery_endpoint, **kwargs):
def tenant_discovery(tenant_discovery_endpoint, http_client, **kwargs):
# Returns Openid Configuration
resp = requests.get(tenant_discovery_endpoint, **kwargs)
payload = resp.json()
resp = http_client.get(tenant_discovery_endpoint, **kwargs)
payload = json.loads(resp.text)
if 'authorization_endpoint' in payload and 'token_endpoint' in payload:
return payload
raise MsalServiceError(status_code=resp.status_code, **payload)
Expand Down
7 changes: 3 additions & 4 deletions msal/mex.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,14 @@
except ImportError:
from xml.etree import ElementTree as ET

import requests


def _xpath_of_root(route_to_leaf):
# Construct an xpath suitable to find a root node which has a specified leaf
return '/'.join(route_to_leaf + ['..'] * (len(route_to_leaf)-1))

def send_request(mex_endpoint, **kwargs):
mex_document = requests.get(

def send_request(mex_endpoint, http_client, **kwargs):
mex_document = http_client.get(
mex_endpoint, headers={'Content-Type': 'application/soap+xml'},
**kwargs).text
return Mex(mex_document).get_wstrust_username_password_endpoint()
Expand Down
Loading

0 comments on commit f2340a4

Please sign in to comment.