Skip to content

Commit

Permalink
Merge pull request #4 from Mastercard/feature/decrypt_issue
Browse files Browse the repository at this point in the history
Decrypt issue
  • Loading branch information
pqlMC authored Jun 20, 2019
2 parents 3a7aac2 + c9b5781 commit 4f780a7
Show file tree
Hide file tree
Showing 10 changed files with 90 additions and 55 deletions.
37 changes: 17 additions & 20 deletions client_encryption/api_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,23 @@ def __init__(self, encryption_conf_file):
self._encryption_conf = FieldLevelEncryptionConfig(json_file.read())

def field_encryption(self, func):
"""Decorator for API request. func is APIClient.request"""
"""Decorator for API call_api. func is APIClient.call_api"""

@wraps(func)
def request_function(*args, **kwargs):
"""Wrap request and add field encryption layer to it."""
def call_api_function(*args, **kwargs):
"""Wrap call_api and add field encryption layer to it."""

in_body = kwargs.get("body", None)
kwargs["body"] = self._encrypt_payload(kwargs.get("headers", None), in_body) if in_body else in_body
kwargs["body"] = self._encrypt_payload(kwargs.get("header_params", None), in_body) if in_body else in_body
kwargs["_preload_content"] = False

response = func(*args, **kwargs)

if type(response.data) is not str:
response_body = self._decrypt_payload(response.getheaders(), response.json())
response._content = json.dumps(response_body, indent=4).encode('utf-8')
response._body = self._decrypt_payload(response.getheaders(), response.data)

return response

request_function.__fle__ = True
return request_function
call_api_function.__fle__ = True
return call_api_function

def _encrypt_payload(self, headers, body):
"""Encryption enforcement based on configuration - encrypt and add session key params to header or body"""
Expand All @@ -62,6 +60,7 @@ def _decrypt_payload(self, headers, body):
"""Encryption enforcement based on configuration - decrypt using session key params from header or body"""

conf = self._encryption_conf
params = None

if conf.use_http_headers:
if conf.iv_field_name in headers and conf.encrypted_key_field_name in headers:
Expand All @@ -75,30 +74,28 @@ def _decrypt_payload(self, headers, body):
del headers[conf.encryption_key_fingerprint_field_name]

params = SessionKeyParams(conf, encrypted_key, iv, oaep_digest_algo)
payload = decrypt_payload(body, conf, params)
else:
# skip decryption if not iv nor key is in headers
payload = body
else:
payload = decrypt_payload(body, conf)
# skip decryption and return original body if not iv nor key is in headers
return body

decrypted_body = decrypt_payload(body, conf, params)
payload = json.dumps(decrypted_body).encode('utf-8')

return payload


def add_encryption_layer(api_client, encryption_conf_file):
"""Decorate APIClient.request with field level encryption"""
"""Decorate APIClient.call_api with field level encryption"""

api_encryption = ApiEncryption(encryption_conf_file)
api_client.request = api_encryption.field_encryption(api_client.request)
api_client.call_api = api_encryption.field_encryption(api_client.call_api)

__check_oauth(api_client) # warn the user if authentication layer is missing/not set


def __check_oauth(api_client):
try:
oauth_layer = getattr(api_client.request, "__wrapped__").__oauth__
if not oauth_layer or type(oauth_layer) is not bool:
__oauth_warn()
api_client.request.__wrapped__
except AttributeError:
__oauth_warn()

Expand Down
12 changes: 9 additions & 3 deletions client_encryption/encryption_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,15 +54,21 @@ def __get_crypto_file_type(file_content):
return FILETYPE_ASN1


def load_hash_algorithm(algo_str):
"""Load a hash algorithm object of Crypto.Hash from a list of supported ones."""
def validate_hash_algorithm(algo_str):
"""Validate a hash algorithm against a list of supported ones."""

if algo_str:
algo_key = algo_str.replace("-", "").upper()

if algo_key in _SUPPORTED_HASH:
return _SUPPORTED_HASH[algo_key]
return algo_key
else:
raise HashAlgorithmError("Hash algorithm invalid or not supported.")
else:
raise HashAlgorithmError("No hash algorithm provided.")


def load_hash_algorithm(algo_str):
"""Load a hash algorithm object of Crypto.Hash from a list of supported ones."""

return _SUPPORTED_HASH[validate_hash_algorithm(algo_str)]
17 changes: 10 additions & 7 deletions client_encryption/field_level_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@ def encrypt_payload(payload, config, _params=None):
"""Encrypt some fields of a JSON payload using the given configuration."""

try:
if type(payload) is str:
json_payload = json.loads(payload)
else:
if type(payload) is dict:
json_payload = copy.deepcopy(payload)
else:
json_payload = json.loads(payload)

for elem, target in config.paths["$"].to_encrypt.items():
if not _params:
Expand Down Expand Up @@ -50,10 +50,13 @@ def decrypt_payload(payload, config, _params=None):
"""Decrypt some fields of a JSON payload using the given configuration."""

try:
if type(payload) is str:
json_payload = json.loads(payload)
else:
if type(payload) is dict:
json_payload = payload
else:
try:
json_payload = json.loads(payload)
except json.JSONDecodeError: # not a json response - return it as is
return payload

for elem, target in config.paths["$"].to_decrypt.items():
try:
Expand Down Expand Up @@ -90,7 +93,7 @@ def decrypt_payload(payload, config, _params=None):
return json_payload

except (IOError, ValueError, TypeError) as e:
raise EncryptionError("Payload encryption failed!", e)
raise EncryptionError("Payload decryption failed!", e)


def _encrypt_value(_key, iv, node_str):
Expand Down
6 changes: 2 additions & 4 deletions client_encryption/field_level_encryption_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from OpenSSL.crypto import dump_certificate, FILETYPE_ASN1, dump_publickey
from Crypto.Hash import SHA256
from client_encryption.encoding_utils import Encoding
from client_encryption.encryption_utils import load_encryption_certificate, load_decryption_key, load_hash_algorithm
from client_encryption.encryption_utils import load_encryption_certificate, load_decryption_key, validate_hash_algorithm


class FieldLevelEncryptionConfig(object):
Expand Down Expand Up @@ -44,9 +44,7 @@ def __init__(self, conf):
else:
self._decryption_key = None

digest_algo = json_config["oaepPaddingDigestAlgorithm"]
if load_hash_algorithm(digest_algo) is not None:
self._oaep_padding_digest_algorithm = digest_algo
self._oaep_padding_digest_algorithm = validate_hash_algorithm(json_config["oaepPaddingDigestAlgorithm"])

data_enc = Encoding(json_config["dataEncoding"].upper())
self._data_encoding = data_enc
Expand Down
6 changes: 3 additions & 3 deletions client_encryption/session_key_params.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from binascii import Error
from Crypto.Cipher import PKCS1_OAEP, AES
from Crypto import Random
from Crypto.Random import get_random_bytes
from Crypto.PublicKey import RSA
from client_encryption.encoding_utils import encode_bytes, decode_value
from client_encryption.encryption_utils import load_hash_algorithm
Expand Down Expand Up @@ -62,11 +62,11 @@ def generate(config):
encoding = config.data_encoding

# Generate a random IV
iv = Random.new().read(SessionKeyParams._BLOCK_SIZE)
iv = get_random_bytes(SessionKeyParams._BLOCK_SIZE)
iv_encoded = encode_bytes(iv, encoding)

# Generate an AES secret key
secret_key = Random.new().read(SessionKeyParams._KEY_SIZE)
secret_key = get_random_bytes(SessionKeyParams._KEY_SIZE)

# Encrypt the secret key
secret_key_encrypted = SessionKeyParams.__wrap_secret_key(secret_key, config)
Expand Down
2 changes: 1 addition & 1 deletion client_encryption/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
__version__ = "1.0.3"
__version__ = "1.1.0"
8 changes: 4 additions & 4 deletions tests/test_api_encryption.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,14 +62,14 @@ def test_decrypt_payload_with_params_in_body(self):

test_headers = {"Content-Type": "application/json"}

decrypted = api_encryption._decrypt_payload(body={
decrypted = json.loads(api_encryption._decrypt_payload(body={
"encryptedData": {
"iv": "uldLBySPY3VrznePihFYGQ==",
"encryptedKey": "Jmh/bQPScUVFHSC9qinMGZ4lM7uetzUXcuMdEpC5g4C0Pb9HuaM3zC7K/509n7RTBZUPEzgsWtgi7m33nhpXsUo8WMcQkBIZlKn3ce+WRyZpZxcYtVoPqNn3benhcv7cq7yH1ktamUiZ5Dq7Ga+oQCaQEsOXtbGNS6vA5Bwa1pjbmMiRIbvlstInz8XTw8h/T0yLBLUJ0yYZmzmt+9i8qL8KFQ/PPDe5cXOCr1Aq2NTSixe5F2K/EI00q6D7QMpBDC7K6zDWgAOvINzifZ0DTkxVe4EE6F+FneDrcJsj+ZeIabrlRcfxtiFziH6unnXktta0sB1xcszIxXdMDbUcJA==",
"encryptedValue": "KGfmdUWy89BwhQChzqZJ4w==",
"oaepHashingAlgo": "SHA256"
}
}, headers=test_headers)
}, headers=test_headers))

self.assertNotIn("encryptedData", decrypted)
self.assertDictEqual({"data": {}}, decrypted)
Expand Down Expand Up @@ -112,11 +112,11 @@ def test_decrypt_payload_with_params_in_headers(self):
}

api_encryption = to_test.ApiEncryption(self._json_config)
decrypted = api_encryption._decrypt_payload(body={
decrypted = json.loads(api_encryption._decrypt_payload(body={
"encryptedData": {
"encryptedValue": "KGfmdUWy89BwhQChzqZJ4w=="
}
}, headers=test_headers)
}, headers=test_headers))

self.assertNotIn("encryptedData", decrypted)
self.assertDictEqual({"data": {}}, decrypted)
Expand Down
24 changes: 24 additions & 0 deletions tests/test_encryption_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,30 @@ def test_load_hash_algorithm_underscore(self):
def test_load_hash_algorithm_none(self):
self.assertRaises(HashAlgorithmError, to_test.load_hash_algorithm, None)

def test_validate_hash_algorithm(self):
hash_algo = to_test.validate_hash_algorithm("SHA224")

self.assertEqual(hash_algo, "SHA224")

def test_validate_hash_algorithm_dash(self):
hash_algo = to_test.validate_hash_algorithm("SHA-512")

self.assertEqual(hash_algo, "SHA512")

def test_validate_hash_algorithm_lowercase(self):
hash_algo = to_test.validate_hash_algorithm("sha384")

self.assertEqual(hash_algo, "SHA384")

def test_validate_hash_algorithm_not_supported(self):
self.assertRaises(HashAlgorithmError, to_test.validate_hash_algorithm, "MD5")

def test_validate_hash_algorithm_underscore(self):
self.assertRaises(HashAlgorithmError, to_test.validate_hash_algorithm, "SHA_512")

def test_validate_hash_algorithm_none(self):
self.assertRaises(HashAlgorithmError, to_test.validate_hash_algorithm, None)

@staticmethod
def __strip_key(rsa_key):
return rsa_key.export_key(pkcs=8).decode('utf-8').replace("\n", "")[27:-25]
2 changes: 1 addition & 1 deletion tests/test_field_level_encryption_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def test_load_config_SHA512_oaep_padding_algorithm(self):
json_conf["oaepPaddingDigestAlgorithm"] = oaep_algo_test

conf = to_test.FieldLevelEncryptionConfig(json_conf)
self.__check_configuration(conf, oaep_algo=oaep_algo_test)
self.__check_configuration(conf, oaep_algo="SHA512")

def test_load_config_wrong_oaep_padding_algorithm(self):
oaep_algo_test = "sha_512"
Expand Down
31 changes: 19 additions & 12 deletions tests/utils/api_encryption_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,22 @@ def __init__(self, api_client=None):
self.api_client = api_client

def do_something_get(self, **kwargs):
return self.api_client.request("GET", "localhost/testservice", headers=kwargs["headers"])
return self.api_client.call_api("testservice", "GET", header_params=kwargs["headers"])

def do_something_post(self, **kwargs):
return self.api_client.request("POST", "localhost/testservice", headers=kwargs["headers"], body=kwargs["body"])
return self.api_client.call_api("testservice", "POST", header_params=kwargs["headers"], body=kwargs["body"])

def do_something_delete(self, **kwargs):
return self.api_client.request("DELETE", "localhost/testservice", headers=kwargs["headers"], body=kwargs["body"])
return self.api_client.call_api("testservice", "DELETE", header_params=kwargs["headers"], body=kwargs["body"])

def do_something_get_use_headers(self, **kwargs):
return self.api_client.request("GET", "localhost/testservice/headers", headers=kwargs["headers"])
return self.api_client.call_api("testservice/headers", "GET", header_params=kwargs["headers"])

def do_something_post_use_headers(self, **kwargs):
return self.api_client.request("POST", "localhost/testservice/headers", headers=kwargs["headers"], body=kwargs["body"])
return self.api_client.call_api("testservice/headers", "POST", header_params=kwargs["headers"], body=kwargs["body"])

def do_something_delete_use_headers(self, **kwargs):
return self.api_client.request("DELETE", "localhost/testservice/headers", headers=kwargs["headers"], body=kwargs["body"])
return self.api_client.call_api("testservice/headers", "DELETE", header_params=kwargs["headers"], body=kwargs["body"])


class MockApiClient(object):
Expand All @@ -56,13 +56,21 @@ def __init__(self, configuration=None, header_name=None, header_value=None,
def request(self, method, url, query_params=None, headers=None,
post_params=None, body=None, _preload_content=True,
_request_timeout=None):
pass

def call_api(self, resource_path, method,
path_params=None, query_params=None, header_params=None,
body=None, post_params=None, files=None,
response_type=None, auth_settings=None, async_req=None,
_return_http_data_only=None, collection_formats=None,
_preload_content=True, _request_timeout=None):
check = -1

if body:
if url == "localhost/testservice/headers":
iv = headers["x-iv"]
encrypted_key = headers["x-key"]
oaep_digest_algo = headers["x-oaep-digest"] if "x-oaep-digest" in headers else None
if resource_path == "testservice/headers":
iv = header_params["x-iv"]
encrypted_key = header_params["x-key"]
oaep_digest_algo = header_params["x-oaep-digest"] if "x-oaep-digest" in header_params else None

params = SessionKeyParams(self._config, encrypted_key, iv, oaep_digest_algo)
else:
Expand All @@ -74,7 +82,7 @@ def request(self, method, url, query_params=None, headers=None,
else:
res = {"data": {"secret": [53, 84, 75]}}

if url == "localhost/testservice/headers" and method in ["GET", "POST", "PUT"]:
if resource_path == "testservice/headers" and method in ["GET", "POST", "PUT"]:
params = SessionKeyParams.generate(self._config)
json_resp = encryption.encrypt_payload(res, self._config, params)

Expand All @@ -94,7 +102,6 @@ def request(self, method, url, query_params=None, headers=None,

if method in ["GET", "POST", "PUT"]:
response.data = json_resp
response.json = Mock(return_value=json_resp)
else:
response.data = "OK" if check == 0 else "KO"

Expand Down

0 comments on commit 4f780a7

Please sign in to comment.