Skip to content

Commit

Permalink
Add support for acquiring a token with a pre-signed JWT (#271)
Browse files Browse the repository at this point in the history
* Add support for acquiring a token with a client provided, pre-signed
JWT.

Useful for where the signing takes place externally for example using
Azure Key Vault (AKV).

AKV sample included.

* Changes to parameter name for #271

* Address comment in #271 "No need to repeat this statement twice in both if and else"

* merge rayluo / microsoft-authentication-library-for-python:patch1

* Update msal/application.py

Co-authored-by: Ray Luo <[email protected]>

* Update tests/test_e2e.py

Co-authored-by: Ray Luo <[email protected]>

* Resolve merge conflict

Co-authored-by: David Freedman <[email protected]>
Co-authored-by: Ray Luo <[email protected]>
  • Loading branch information
3 people authored Jun 7, 2021
1 parent a433b71 commit 9082dc1
Show file tree
Hide file tree
Showing 4 changed files with 186 additions and 23 deletions.
54 changes: 33 additions & 21 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,14 @@ def __init__(
"The provided signature value did not match the expected signature value",
you may try use only the leaf cert (in PEM/str format) instead.
*Added in version 1.13.0*:
It can also be a completly pre-signed assertion that you've assembled yourself.
Simply pass a container containing only the key "client_assertion", like this::
{
"client_assertion": "...a JWT with claims aud, exp, iss, jti, nbf, and sub..."
}
:param dict client_claims:
*Added in version 0.5.0*:
It is a dictionary of extra claims that would be signed by
Expand Down Expand Up @@ -391,28 +399,32 @@ def _build_client(self, client_credential, authority):
default_headers['x-app-ver'] = self.app_version
default_body = {"client_info": 1}
if isinstance(client_credential, dict):
assert ("private_key" in client_credential
and "thumbprint" in client_credential)
headers = {}
if 'public_certificate' in client_credential:
headers["x5c"] = extract_certs(client_credential['public_certificate'])
if not client_credential.get("passphrase"):
unencrypted_private_key = client_credential['private_key']
else:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
unencrypted_private_key = serialization.load_pem_private_key(
_str2bytes(client_credential["private_key"]),
_str2bytes(client_credential["passphrase"]),
backend=default_backend(), # It was a required param until 2020
)
assertion = JwtAssertionCreator(
unencrypted_private_key, algorithm="RS256",
sha1_thumbprint=client_credential.get("thumbprint"), headers=headers)
client_assertion = assertion.create_regenerative_assertion(
audience=authority.token_endpoint, issuer=self.client_id,
additional_claims=self.client_claims or {})
assert (("private_key" in client_credential
and "thumbprint" in client_credential) or
"client_assertion" in client_credential)
client_assertion_type = Client.CLIENT_ASSERTION_TYPE_JWT
if 'client_assertion' in client_credential:
client_assertion = client_credential['client_assertion']
else:
headers = {}
if 'public_certificate' in client_credential:
headers["x5c"] = extract_certs(client_credential['public_certificate'])
if not client_credential.get("passphrase"):
unencrypted_private_key = client_credential['private_key']
else:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
unencrypted_private_key = serialization.load_pem_private_key(
_str2bytes(client_credential["private_key"]),
_str2bytes(client_credential["passphrase"]),
backend=default_backend(), # It was a required param until 2020
)
assertion = JwtAssertionCreator(
unencrypted_private_key, algorithm="RS256",
sha1_thumbprint=client_credential.get("thumbprint"), headers=headers)
client_assertion = assertion.create_regenerative_assertion(
audience=authority.token_endpoint, issuer=self.client_id,
additional_claims=self.client_claims or {})
else:
default_body['client_secret'] = client_credential
central_configuration = {
Expand Down
134 changes: 134 additions & 0 deletions sample/vault_jwt_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
The configuration file would look like this (sans those // comments):
{
"tenant": "your_tenant_name",
// Your target tenant, DNS name
"client_id": "your_client_id",
// Target app ID in Azure AD
"scope": ["https://graph.microsoft.com/.default"],
// Specific to Client Credentials Grant i.e. acquire_token_for_client(),
// you don't specify, in the code, the individual scopes you want to access.
// Instead, you statically declared them when registering your application.
// Therefore the only possible scope is "resource/.default"
// (here "https://graph.microsoft.com/.default")
// which means "the static permissions defined in the application".
"vault_tenant": "your_vault_tenant_name",
// Your Vault tenant may be different to your target tenant
// If that's not the case, you can set this to the same
// as "tenant"
"vault_clientid": "your_vault_client_id",
// Client ID of your vault app in your vault tenant
"vault_clientsecret": "your_vault_client_secret",
// Secret for your vault app
"vault_url": "your_vault_url",
// URL of your vault app
"cert": "your_cert_name",
// Name of your certificate in your vault
"cert_thumb": "your_cert_thumbprint",
// Thumbprint of your certificate
"endpoint": "https://graph.microsoft.com/v1.0/users"
// For this resource to work, you need to visit Application Permissions
// page in portal, declare scope User.Read.All, which needs admin consent
// https://github.com/Azure-Samples/ms-identity-python-daemon/blob/master/2-Call-MsGraph-WithCertificate/README.md
}
You can then run this sample with a JSON configuration file:
python sample.py parameters.json
"""

import base64
import json
import logging
import requests
import sys
import time
import uuid
import msal

# Optional logging
# logging.basicConfig(level=logging.DEBUG) # Enable DEBUG log for entire script
# logging.getLogger("msal").setLevel(logging.INFO) # Optionally disable MSAL DEBUG logs

from azure.keyvault import KeyVaultClient, KeyVaultAuthentication
from azure.common.credentials import ServicePrincipalCredentials
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes

config = json.load(open(sys.argv[1]))

def auth_vault_callback(server, resource, scope):
credentials = ServicePrincipalCredentials(
client_id=config['vault_clientid'],
secret=config['vault_clientsecret'],
tenant=config['vault_tenant'],
resource='https://vault.azure.net'
)
token = credentials.token
return token['token_type'], token['access_token']


def make_vault_jwt():

header = {
'alg': 'RS256',
'typ': 'JWT',
'x5t': base64.b64encode(
config['cert_thumb'].decode('hex'))
}
header_b64 = base64.b64encode(json.dumps(header).encode('utf-8'))

body = {
'aud': "https://login.microsoftonline.com/%s/oauth2/token" %
config['tenant'],
'exp': (int(time.time()) + 600),
'iss': config['client_id'],
'jti': str(uuid.uuid4()),
'nbf': int(time.time()),
'sub': config['client_id']
}
body_b64 = base64.b64encode(json.dumps(body).encode('utf-8'))

full_b64 = b'.'.join([header_b64, body_b64])

client = KeyVaultClient(KeyVaultAuthentication(auth_vault_callback))
chosen_hash = hashes.SHA256()
hasher = hashes.Hash(chosen_hash, default_backend())
hasher.update(full_b64)
digest = hasher.finalize()
signed_digest = client.sign(config['vault_url'],
config['cert'], '', 'RS256',
digest).result

full_token = b'.'.join([full_b64, base64.b64encode(signed_digest)])

return full_token


authority = "https://login.microsoftonline.com/%s" % config['tenant']

app = msal.ConfidentialClientApplication(
config['client_id'], authority=authority, client_credential={"client_assertion": make_vault_jwt()}
)

# The pattern to acquire a token looks like this.
result = None

# Firstly, looks up a token from cache
# Since we are looking for token for the current app, NOT for an end user,
# notice we give account parameter as None.
result = app.acquire_token_silent(config["scope"], account=None)

if not result:
logging.info("No suitable token exists in cache. Let's get a new one from AAD.")
result = app.acquire_token_for_client(scopes=config["scope"])

if "access_token" in result:
# Calling graph using the access token
graph_data = requests.get( # Use token to call downstream service
config["endpoint"],
headers={'Authorization': 'Bearer ' + result['access_token']},).json()
print("Graph API call result: %s" % json.dumps(graph_data, indent=2))
else:
print(result.get("error"))
print(result.get("error_description"))
print(result.get("correlation_id")) # You may need this when reporting a bug

10 changes: 9 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,15 @@ class TestClient(Oauth2TestCase):
@classmethod
def setUpClass(cls):
http_client = MinimalHttpClient()
if "client_certificate" in CONFIG:
if "client_assertion" in CONFIG:
cls.client = Client(
CONFIG["openid_configuration"],
CONFIG['client_id'],
http_client=http_client,
client_assertion=CONFIG["client_assertion"],
client_assertion_type=Client.CLIENT_ASSERTION_TYPE_JWT,
)
elif "client_certificate" in CONFIG:
private_key_path = CONFIG["client_certificate"]["private_key_path"]
with open(os.path.join(THIS_FOLDER, private_key_path)) as f:
private_key = f.read() # Expecting PEM format
Expand Down
11 changes: 10 additions & 1 deletion tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,16 @@ def test_subject_name_issuer_authentication(self):
self.assertIn('access_token', result)
self.assertCacheWorksForApp(result, scope)

def test_client_assertion(self):
self.skipUnlessWithConfig(["client_id", "client_assertion"])
self.app = msal.ConfidentialClientApplication(
self.config['client_id'], authority=self.config["authority"],
client_credential={"client_assertion": self.config["client_assertion"]},
http_client=MinimalHttpClient())
scope = self.config.get("scope", [])
result = self.app.acquire_token_for_client(scope)
self.assertIn('access_token', result)
self.assertCacheWorksForApp(result, scope)

@unittest.skipUnless(os.path.exists(CONFIG), "Optional %s not found" % CONFIG)
class DeviceFlowTestCase(E2eTestCase): # A leaf class so it will be run only once
Expand Down Expand Up @@ -882,4 +892,3 @@ def test_acquire_token_silent_with_an_empty_cache_should_return_none(self):

if __name__ == "__main__":
unittest.main()

0 comments on commit 9082dc1

Please sign in to comment.