diff --git a/client_encryption/api_encryption.py b/client_encryption/api_encryption.py index 8e507e6..0eb2707 100644 --- a/client_encryption/api_encryption.py +++ b/client_encryption/api_encryption.py @@ -18,25 +18,23 @@ def __init__(self, encryption_conf_file): self._encryption_conf = FieldLevelEncryptionConfig(json_file.read()) def field_encryption(self, func): - """Decorator for API request. func is APIClient.request""" + """Decorator for API call_api. func is APIClient.call_api""" @wraps(func) - def request_function(*args, **kwargs): - """Wrap request and add field encryption layer to it.""" + def call_api_function(*args, **kwargs): + """Wrap call_api and add field encryption layer to it.""" in_body = kwargs.get("body", None) - kwargs["body"] = self._encrypt_payload(kwargs.get("headers", None), in_body) if in_body else in_body + kwargs["body"] = self._encrypt_payload(kwargs.get("header_params", None), in_body) if in_body else in_body + kwargs["_preload_content"] = False response = func(*args, **kwargs) - - if type(response.data) is not str: - response_body = self._decrypt_payload(response.getheaders(), response.json()) - response._content = json.dumps(response_body, indent=4).encode('utf-8') + response._body = self._decrypt_payload(response.getheaders(), response.data) return response - request_function.__fle__ = True - return request_function + call_api_function.__fle__ = True + return call_api_function def _encrypt_payload(self, headers, body): """Encryption enforcement based on configuration - encrypt and add session key params to header or body""" @@ -62,6 +60,7 @@ def _decrypt_payload(self, headers, body): """Encryption enforcement based on configuration - decrypt using session key params from header or body""" conf = self._encryption_conf + params = None if conf.use_http_headers: if conf.iv_field_name in headers and conf.encrypted_key_field_name in headers: @@ -75,30 +74,28 @@ def _decrypt_payload(self, headers, body): del headers[conf.encryption_key_fingerprint_field_name] params = SessionKeyParams(conf, encrypted_key, iv, oaep_digest_algo) - payload = decrypt_payload(body, conf, params) else: - # skip decryption if not iv nor key is in headers - payload = body - else: - payload = decrypt_payload(body, conf) + # skip decryption and return original body if not iv nor key is in headers + return body + + decrypted_body = decrypt_payload(body, conf, params) + payload = json.dumps(decrypted_body).encode('utf-8') return payload def add_encryption_layer(api_client, encryption_conf_file): - """Decorate APIClient.request with field level encryption""" + """Decorate APIClient.call_api with field level encryption""" api_encryption = ApiEncryption(encryption_conf_file) - api_client.request = api_encryption.field_encryption(api_client.request) + api_client.call_api = api_encryption.field_encryption(api_client.call_api) __check_oauth(api_client) # warn the user if authentication layer is missing/not set def __check_oauth(api_client): try: - oauth_layer = getattr(api_client.request, "__wrapped__").__oauth__ - if not oauth_layer or type(oauth_layer) is not bool: - __oauth_warn() + api_client.request.__wrapped__ except AttributeError: __oauth_warn() diff --git a/client_encryption/encryption_utils.py b/client_encryption/encryption_utils.py index 0440ef3..b776591 100644 --- a/client_encryption/encryption_utils.py +++ b/client_encryption/encryption_utils.py @@ -54,15 +54,21 @@ def __get_crypto_file_type(file_content): return FILETYPE_ASN1 -def load_hash_algorithm(algo_str): - """Load a hash algorithm object of Crypto.Hash from a list of supported ones.""" +def validate_hash_algorithm(algo_str): + """Validate a hash algorithm against a list of supported ones.""" if algo_str: algo_key = algo_str.replace("-", "").upper() if algo_key in _SUPPORTED_HASH: - return _SUPPORTED_HASH[algo_key] + return algo_key else: raise HashAlgorithmError("Hash algorithm invalid or not supported.") else: raise HashAlgorithmError("No hash algorithm provided.") + + +def load_hash_algorithm(algo_str): + """Load a hash algorithm object of Crypto.Hash from a list of supported ones.""" + + return _SUPPORTED_HASH[validate_hash_algorithm(algo_str)] diff --git a/client_encryption/field_level_encryption.py b/client_encryption/field_level_encryption.py index 4b6a7de..197527d 100644 --- a/client_encryption/field_level_encryption.py +++ b/client_encryption/field_level_encryption.py @@ -12,10 +12,10 @@ def encrypt_payload(payload, config, _params=None): """Encrypt some fields of a JSON payload using the given configuration.""" try: - if type(payload) is str: - json_payload = json.loads(payload) - else: + if type(payload) is dict: json_payload = copy.deepcopy(payload) + else: + json_payload = json.loads(payload) for elem, target in config.paths["$"].to_encrypt.items(): if not _params: @@ -50,10 +50,13 @@ def decrypt_payload(payload, config, _params=None): """Decrypt some fields of a JSON payload using the given configuration.""" try: - if type(payload) is str: - json_payload = json.loads(payload) - else: + if type(payload) is dict: json_payload = payload + else: + try: + json_payload = json.loads(payload) + except json.JSONDecodeError: # not a json response - return it as is + return payload for elem, target in config.paths["$"].to_decrypt.items(): try: @@ -90,7 +93,7 @@ def decrypt_payload(payload, config, _params=None): return json_payload except (IOError, ValueError, TypeError) as e: - raise EncryptionError("Payload encryption failed!", e) + raise EncryptionError("Payload decryption failed!", e) def _encrypt_value(_key, iv, node_str): diff --git a/client_encryption/field_level_encryption_config.py b/client_encryption/field_level_encryption_config.py index 817215d..124bc91 100644 --- a/client_encryption/field_level_encryption_config.py +++ b/client_encryption/field_level_encryption_config.py @@ -2,7 +2,7 @@ from OpenSSL.crypto import dump_certificate, FILETYPE_ASN1, dump_publickey from Crypto.Hash import SHA256 from client_encryption.encoding_utils import Encoding -from client_encryption.encryption_utils import load_encryption_certificate, load_decryption_key, load_hash_algorithm +from client_encryption.encryption_utils import load_encryption_certificate, load_decryption_key, validate_hash_algorithm class FieldLevelEncryptionConfig(object): @@ -44,9 +44,7 @@ def __init__(self, conf): else: self._decryption_key = None - digest_algo = json_config["oaepPaddingDigestAlgorithm"] - if load_hash_algorithm(digest_algo) is not None: - self._oaep_padding_digest_algorithm = digest_algo + self._oaep_padding_digest_algorithm = validate_hash_algorithm(json_config["oaepPaddingDigestAlgorithm"]) data_enc = Encoding(json_config["dataEncoding"].upper()) self._data_encoding = data_enc diff --git a/client_encryption/session_key_params.py b/client_encryption/session_key_params.py index 156f396..b822b71 100644 --- a/client_encryption/session_key_params.py +++ b/client_encryption/session_key_params.py @@ -1,6 +1,6 @@ from binascii import Error from Crypto.Cipher import PKCS1_OAEP, AES -from Crypto import Random +from Crypto.Random import get_random_bytes from Crypto.PublicKey import RSA from client_encryption.encoding_utils import encode_bytes, decode_value from client_encryption.encryption_utils import load_hash_algorithm @@ -62,11 +62,11 @@ def generate(config): encoding = config.data_encoding # Generate a random IV - iv = Random.new().read(SessionKeyParams._BLOCK_SIZE) + iv = get_random_bytes(SessionKeyParams._BLOCK_SIZE) iv_encoded = encode_bytes(iv, encoding) # Generate an AES secret key - secret_key = Random.new().read(SessionKeyParams._KEY_SIZE) + secret_key = get_random_bytes(SessionKeyParams._KEY_SIZE) # Encrypt the secret key secret_key_encrypted = SessionKeyParams.__wrap_secret_key(secret_key, config) diff --git a/client_encryption/version.py b/client_encryption/version.py index da840a7..112cedb 100644 --- a/client_encryption/version.py +++ b/client_encryption/version.py @@ -1,3 +1,3 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -__version__ = "1.0.3" +__version__ = "1.1.0" diff --git a/tests/test_api_encryption.py b/tests/test_api_encryption.py index c82fee7..e7cac92 100644 --- a/tests/test_api_encryption.py +++ b/tests/test_api_encryption.py @@ -62,14 +62,14 @@ def test_decrypt_payload_with_params_in_body(self): test_headers = {"Content-Type": "application/json"} - decrypted = api_encryption._decrypt_payload(body={ + decrypted = json.loads(api_encryption._decrypt_payload(body={ "encryptedData": { "iv": "uldLBySPY3VrznePihFYGQ==", "encryptedKey": "Jmh/bQPScUVFHSC9qinMGZ4lM7uetzUXcuMdEpC5g4C0Pb9HuaM3zC7K/509n7RTBZUPEzgsWtgi7m33nhpXsUo8WMcQkBIZlKn3ce+WRyZpZxcYtVoPqNn3benhcv7cq7yH1ktamUiZ5Dq7Ga+oQCaQEsOXtbGNS6vA5Bwa1pjbmMiRIbvlstInz8XTw8h/T0yLBLUJ0yYZmzmt+9i8qL8KFQ/PPDe5cXOCr1Aq2NTSixe5F2K/EI00q6D7QMpBDC7K6zDWgAOvINzifZ0DTkxVe4EE6F+FneDrcJsj+ZeIabrlRcfxtiFziH6unnXktta0sB1xcszIxXdMDbUcJA==", "encryptedValue": "KGfmdUWy89BwhQChzqZJ4w==", "oaepHashingAlgo": "SHA256" } - }, headers=test_headers) + }, headers=test_headers)) self.assertNotIn("encryptedData", decrypted) self.assertDictEqual({"data": {}}, decrypted) @@ -112,11 +112,11 @@ def test_decrypt_payload_with_params_in_headers(self): } api_encryption = to_test.ApiEncryption(self._json_config) - decrypted = api_encryption._decrypt_payload(body={ + decrypted = json.loads(api_encryption._decrypt_payload(body={ "encryptedData": { "encryptedValue": "KGfmdUWy89BwhQChzqZJ4w==" } - }, headers=test_headers) + }, headers=test_headers)) self.assertNotIn("encryptedData", decrypted) self.assertDictEqual({"data": {}}, decrypted) diff --git a/tests/test_encryption_utils.py b/tests/test_encryption_utils.py index 58cdaed..bee6343 100644 --- a/tests/test_encryption_utils.py +++ b/tests/test_encryption_utils.py @@ -131,6 +131,30 @@ def test_load_hash_algorithm_underscore(self): def test_load_hash_algorithm_none(self): self.assertRaises(HashAlgorithmError, to_test.load_hash_algorithm, None) + def test_validate_hash_algorithm(self): + hash_algo = to_test.validate_hash_algorithm("SHA224") + + self.assertEqual(hash_algo, "SHA224") + + def test_validate_hash_algorithm_dash(self): + hash_algo = to_test.validate_hash_algorithm("SHA-512") + + self.assertEqual(hash_algo, "SHA512") + + def test_validate_hash_algorithm_lowercase(self): + hash_algo = to_test.validate_hash_algorithm("sha384") + + self.assertEqual(hash_algo, "SHA384") + + def test_validate_hash_algorithm_not_supported(self): + self.assertRaises(HashAlgorithmError, to_test.validate_hash_algorithm, "MD5") + + def test_validate_hash_algorithm_underscore(self): + self.assertRaises(HashAlgorithmError, to_test.validate_hash_algorithm, "SHA_512") + + def test_validate_hash_algorithm_none(self): + self.assertRaises(HashAlgorithmError, to_test.validate_hash_algorithm, None) + @staticmethod def __strip_key(rsa_key): return rsa_key.export_key(pkcs=8).decode('utf-8').replace("\n", "")[27:-25] diff --git a/tests/test_field_level_encryption_config.py b/tests/test_field_level_encryption_config.py index c3d195b..7e9d964 100644 --- a/tests/test_field_level_encryption_config.py +++ b/tests/test_field_level_encryption_config.py @@ -151,7 +151,7 @@ def test_load_config_SHA512_oaep_padding_algorithm(self): json_conf["oaepPaddingDigestAlgorithm"] = oaep_algo_test conf = to_test.FieldLevelEncryptionConfig(json_conf) - self.__check_configuration(conf, oaep_algo=oaep_algo_test) + self.__check_configuration(conf, oaep_algo="SHA512") def test_load_config_wrong_oaep_padding_algorithm(self): oaep_algo_test = "sha_512" diff --git a/tests/utils/api_encryption_test_utils.py b/tests/utils/api_encryption_test_utils.py index 83cd4fd..42c6807 100644 --- a/tests/utils/api_encryption_test_utils.py +++ b/tests/utils/api_encryption_test_utils.py @@ -25,22 +25,22 @@ def __init__(self, api_client=None): self.api_client = api_client def do_something_get(self, **kwargs): - return self.api_client.request("GET", "localhost/testservice", headers=kwargs["headers"]) + return self.api_client.call_api("testservice", "GET", header_params=kwargs["headers"]) def do_something_post(self, **kwargs): - return self.api_client.request("POST", "localhost/testservice", headers=kwargs["headers"], body=kwargs["body"]) + return self.api_client.call_api("testservice", "POST", header_params=kwargs["headers"], body=kwargs["body"]) def do_something_delete(self, **kwargs): - return self.api_client.request("DELETE", "localhost/testservice", headers=kwargs["headers"], body=kwargs["body"]) + return self.api_client.call_api("testservice", "DELETE", header_params=kwargs["headers"], body=kwargs["body"]) def do_something_get_use_headers(self, **kwargs): - return self.api_client.request("GET", "localhost/testservice/headers", headers=kwargs["headers"]) + return self.api_client.call_api("testservice/headers", "GET", header_params=kwargs["headers"]) def do_something_post_use_headers(self, **kwargs): - return self.api_client.request("POST", "localhost/testservice/headers", headers=kwargs["headers"], body=kwargs["body"]) + return self.api_client.call_api("testservice/headers", "POST", header_params=kwargs["headers"], body=kwargs["body"]) def do_something_delete_use_headers(self, **kwargs): - return self.api_client.request("DELETE", "localhost/testservice/headers", headers=kwargs["headers"], body=kwargs["body"]) + return self.api_client.call_api("testservice/headers", "DELETE", header_params=kwargs["headers"], body=kwargs["body"]) class MockApiClient(object): @@ -56,13 +56,21 @@ def __init__(self, configuration=None, header_name=None, header_value=None, def request(self, method, url, query_params=None, headers=None, post_params=None, body=None, _preload_content=True, _request_timeout=None): + pass + + def call_api(self, resource_path, method, + path_params=None, query_params=None, header_params=None, + body=None, post_params=None, files=None, + response_type=None, auth_settings=None, async_req=None, + _return_http_data_only=None, collection_formats=None, + _preload_content=True, _request_timeout=None): check = -1 if body: - if url == "localhost/testservice/headers": - iv = headers["x-iv"] - encrypted_key = headers["x-key"] - oaep_digest_algo = headers["x-oaep-digest"] if "x-oaep-digest" in headers else None + if resource_path == "testservice/headers": + iv = header_params["x-iv"] + encrypted_key = header_params["x-key"] + oaep_digest_algo = header_params["x-oaep-digest"] if "x-oaep-digest" in header_params else None params = SessionKeyParams(self._config, encrypted_key, iv, oaep_digest_algo) else: @@ -74,7 +82,7 @@ def request(self, method, url, query_params=None, headers=None, else: res = {"data": {"secret": [53, 84, 75]}} - if url == "localhost/testservice/headers" and method in ["GET", "POST", "PUT"]: + if resource_path == "testservice/headers" and method in ["GET", "POST", "PUT"]: params = SessionKeyParams.generate(self._config) json_resp = encryption.encrypt_payload(res, self._config, params) @@ -94,7 +102,6 @@ def request(self, method, url, query_params=None, headers=None, if method in ["GET", "POST", "PUT"]: response.data = json_resp - response.json = Mock(return_value=json_resp) else: response.data = "OK" if check == 0 else "KO"