Skip to content

Commit

Permalink
Removing dependency of six
Browse files Browse the repository at this point in the history
Adding missing arguments to api call

Use cryptography lower bound as low as 0.6

Add test cases for _str2bytes()

Choose cryptography upper bound as <4
  • Loading branch information
rayluo committed Oct 30, 2020
1 parent 5a82af1 commit de618ba
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 40 deletions.
28 changes: 14 additions & 14 deletions msal/application.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
import functools
import json
import time

import six
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

try: # Python 2
from urlparse import urljoin
except: # Python 3
Expand Down Expand Up @@ -95,6 +90,14 @@ def _merge_claims_challenge_and_capabilities(capabilities, claims_challenge):
return json.dumps(claims_dict)


def _str2bytes(raw):
# A conversion based on duck-typing rather than six.text_type
try:
return raw.encode(encoding="utf-8")
except:
return raw


class ClientApplication(object):

ACQUIRE_TOKEN_SILENT_ID = "84"
Expand Down Expand Up @@ -261,16 +264,13 @@ def _build_client(self, client_credential, authority):
if not client_credential.get("passphrase"):
unencrypted_private_key = client_credential['private_key']
else:
if isinstance(client_credential['private_key'], six.text_type):
private_key = client_credential['private_key'].encode(encoding="utf-8")
else:
private_key = client_credential['private_key']
if isinstance(client_credential['passphrase'], six.text_type):
password = client_credential['passphrase'].encode(encoding="utf-8")
else:
password = client_credential['passphrase']
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.backends import default_backend
unencrypted_private_key = serialization.load_pem_private_key(
private_key, password=password, backend=default_backend())
_str2bytes(client_credential["private_key"]),
_str2bytes(client_credential["passphrase"]),
backend=default_backend(), # It was a required param until 2020
)
assertion = JwtAssertionCreator(
unencrypted_private_key, algorithm="RS256",
sha1_thumbprint=client_credential.get("thumbprint"), headers=headers)
Expand Down
12 changes: 10 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,16 @@
install_requires=[
'requests>=2.0.0,<3',
'PyJWT[crypto]>=1.0.0,<2',
'six>=1.6',
'cryptography>=2.1.4'

'cryptography>=0.6,<4',
# load_pem_private_key() is available since 0.6
# https://github.com/pyca/cryptography/blob/master/CHANGELOG.rst#06---2014-09-29
#
# Not sure what should be used as an upper bound here
# https://github.com/pyca/cryptography/issues/5532
# We will go with "<4" for now, which is also what our another dependency,
# pyjwt, currently use.

]
)

30 changes: 6 additions & 24 deletions tests/test_application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Note: Since Aug 2019 we move all e2e tests into test_e2e.py,
# so this test_application file contains only unit tests without dependency.
from msal.application import *
from msal.application import _str2bytes
import msal
from msal.application import _merge_claims_challenge_and_capabilities
from tests import unittest
Expand Down Expand Up @@ -39,31 +40,12 @@ def test_extract_multiple_tag_enclosed_certs(self):
self.assertEqual(["my_cert1", "my_cert2"], extract_certs(pem))


class TestEncryptedKeyAsClientCredential(unittest.TestCase):
# Internally, we use serialization.load_pem_private_key() to load an encrypted private key with a passphrase
# This function takes in encrypted key in bytes and passphrase in bytes too
# Our code handles such a conversion, adding test cases to verify such a conversion is needed
class TestBytesConversion(unittest.TestCase):
def test_string_to_bytes(self):
self.assertEqual(type(_str2bytes("some string")), type(b"bytes"))

def test_encyrpted_key_in_bytes_and_string_password_should_error(self):
private_key = b"""
-----BEGIN ENCRYPTED PRIVATE KEY-----
test_private_key
-----END ENCRYPTED PRIVATE KEY-----
"""
with self.assertRaises(TypeError):
# Using a unicode string for Python 2 to identify it as a string and not default to bytes
serialization.load_pem_private_key(
private_key, password=u"string_password", backend=default_backend())

def test_encyrpted_key_is_string_and_bytes_password_should_error(self):
private_key = u"""
-----BEGIN ENCRYPTED PRIVATE KEY-----
test_private_key
-----END ENCRYPTED PRIVATE KEY-----
"""
with self.assertRaises(TypeError):
serialization.load_pem_private_key(
private_key, password=b"byte_password", backend=default_backend())
def test_bytes_to_bytes(self):
self.assertEqual(type(_str2bytes(b"some bytes")), type(b"bytes"))


class TestClientApplicationAcquireTokenSilentErrorBehaviors(unittest.TestCase):
Expand Down

0 comments on commit de618ba

Please sign in to comment.