From a574cad7d838ad8c224ddf184c707b55e296fad8 Mon Sep 17 00:00:00 2001 From: Shige Takeda Date: Mon, 30 Mar 2020 16:29:49 -0700 Subject: [PATCH] Prep 2.2.3 (#294) * updated Python Connector reqs * SNOW-102876 secure sso python copy * SNOW-141822 bumped pandas to newest versions * SNOW-141822 bumped pandas to newest versions * SNOW-141932 build manylinux1 wheels * SNOW-118103 fix unclosed file issue * SNOW-144663 added missing test directories to tox commans * SNOW-145906 update python docs * SNOW-143923 tox housekeeping * SNOW-146266 updated Python test * SNOW-146266 fix import ordering * SNOW-145814 wrongly default keyring package * SNOW-146213 Add google storage api url to whitelist for ocsp validation * SNOW-67159 update column size python connector * SNOW-145814 fix mac sso unit test with mock * SNOW-83085 use_openssl_only mode for Python connector * SNOW-147687 in band telemetry update python * SNOW-144043: Add new optional config client_store_temporary_credential into SnowSQL and made it the same in python connector * SNOW-144043 fix lint error * SNOW-148015 Added type checking workaround for Python 3.5.1 * SNOW-121925 Adding a test to verify that the Python connector supports dashed URLs * Revert SNOW-121925 Adding a test to verify that the Python connector supports dashed URLs * python connector version bump to 2.2.3 * skip new sso tests on Travis * reenabled pandas tests --- DESCRIPTION.rst | 9 + auth.py | 204 +++--- auth_okta.py | 8 +- auth_webbrowser.py | 8 +- compat.py | 1 + connection.py | 31 +- cursor.py | 8 +- encryption_util.py | 74 ++- file_util.py | 20 +- ocsp_asn1crypto.py | 88 ++- ocsp_pyasn1.py | 596 ------------------ ocsp_snowflake.py | 1 + options.py | 13 + scripts/build_linux.sh | 1 - scripts/build_pyarrow_linux.sh | 1 - scripts/install.sh | 2 +- scripts/run_travis.sh | 2 +- scripts/test.bat | 2 +- scripts/test.sh | 4 +- scripts/test_darwin.sh | 2 +- setup.py | 4 +- test/{ => sso}/test_connection_manual.py | 17 +- test/sso/test_unit_sso_connection.py | 142 +++++ test/test_connection.py | 30 + test/test_dbapi.py | 4 +- test/test_unit_auth.py | 1 + test/test_unit_auth_okta.py | 1 + test/test_unit_auth_webbrowser.py | 1 + test/test_unit_connection.py | 118 +--- tested_requirements/requirements_35.txt | 2 +- tested_requirements/requirements_36.txt | 2 +- tested_requirements/requirements_37.txt | 2 +- ...86_64.whl.reqs.txt => requirements_38.txt} | 2 +- tox.ini | 26 +- version.py | 2 +- 35 files changed, 550 insertions(+), 879 deletions(-) delete mode 100644 ocsp_pyasn1.py rename test/{ => sso}/test_connection_manual.py (91%) create mode 100644 test/sso/test_unit_sso_connection.py rename tested_requirements/{requirements_38-linux_x86_64.whl.reqs.txt => requirements_38.txt} (95%) diff --git a/DESCRIPTION.rst b/DESCRIPTION.rst index 552742426..8424c4d0b 100644 --- a/DESCRIPTION.rst +++ b/DESCRIPTION.rst @@ -9,6 +9,15 @@ Source code is also available at: https://github.com/snowflakedb/snowflake-conne Release Notes ------------------------------------------------------------------------------- +- v2.2.3(March 30,2020) + + - Secure SSO ID Token + - Add use_openssl_only connection parameter, which disables the usage of pure Python cryptographic libraries for FIPS + - Add manylinux1 as well as manylinux2010 + - Fix a bug where a certificate file was opened and never closed in snowflake-connector-python. + - Fix python connector skips validating GCP URLs + - Adds additional client driver config information to in band telemetry. + - v2.2.2(March 9,2020) - Fix retry with chunck_downloader.py for stability. diff --git a/auth.py b/auth.py index 8b8143509..eb916b91b 100644 --- a/auth.py +++ b/auth.py @@ -8,7 +8,6 @@ import copy import json import logging -import platform import tempfile import time import uuid @@ -18,18 +17,24 @@ from threading import Lock, Thread from .auth_keypair import AuthByKeyPair -from .compat import IS_LINUX, TO_UNICODE, urlencode +from .compat import IS_LINUX, IS_WINDOWS, IS_MACOS, TO_UNICODE, urlencode from .constants import ( HTTP_HEADER_ACCEPT, HTTP_HEADER_CONTENT_TYPE, HTTP_HEADER_SERVICE_NAME, HTTP_HEADER_USER_AGENT, PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL, - PARAMETER_CLIENT_USE_SECURE_STORAGE_FOR_TEMPORARY_CREDENTIAL, ) from .description import COMPILER, IMPLEMENTATION, OPERATING_SYSTEM, PLATFORM, PYTHON_VERSION from .errorcode import ER_FAILED_TO_CONNECT_TO_DB -from .errors import BadGatewayError, DatabaseError, Error, ForbiddenError, ServiceUnavailableError +from .errors import ( + BadGatewayError, + DatabaseError, + Error, + ForbiddenError, + ServiceUnavailableError, + MissingDependencyError, +) from .network import ( ACCEPT_TYPE_APPLICATION_SNOWFLAKE, CONTENT_TYPE_APPLICATION_JSON, @@ -41,13 +46,19 @@ logger = logging.getLogger(__name__) +try: + import keyring +except ImportError as ie: + keyring = None + logger.debug('Failed to import keyring module. err=[%s]', ie) + # Cache directory CACHE_ROOT_DIR = getenv('SF_TEMPORARY_CREDENTIAL_CACHE_DIR') or \ expanduser("~") or tempfile.gettempdir() -if platform.system() == 'Windows': +if IS_WINDOWS: CACHE_DIR = path.join(CACHE_ROOT_DIR, 'AppData', 'Local', 'Snowflake', 'Caches') -elif platform.system() == 'Darwin': +elif IS_MACOS: CACHE_DIR = path.join(CACHE_ROOT_DIR, 'Library', 'Caches', 'Snowflake') else: CACHE_DIR = path.join(CACHE_ROOT_DIR, '.cache', 'snowflake') @@ -77,6 +88,7 @@ # keyring KEYRING_SERVICE_NAME = "net.snowflake.temporary_token" KEYRING_USER = "temp_token" +KEYRING_DRIVER_NAME = "SNOWFLAKE-PYTHON-DRIVER" class Auth(object): @@ -91,22 +103,28 @@ def __init__(self, rest): def base_auth_data(user, account, application, internal_application_name, internal_application_version, - ocsp_mode): + ocsp_mode, login_timeout, + network_timeout=None, + store_temp_cred=None): return { - u'data': { - u"CLIENT_APP_ID": internal_application_name, - u"CLIENT_APP_VERSION": internal_application_version, - u"SVN_REVISION": VERSION[3], - u"ACCOUNT_NAME": account, - u"LOGIN_NAME": user, - u"CLIENT_ENVIRONMENT": { - u"APPLICATION": application, - u"OS": OPERATING_SYSTEM, - u"OS_VERSION": PLATFORM, - u"PYTHON_VERSION": PYTHON_VERSION, - u"PYTHON_RUNTIME": IMPLEMENTATION, - u"PYTHON_COMPILER": COMPILER, - u"OCSP_MODE": ocsp_mode.name, + 'data': { + "CLIENT_APP_ID": internal_application_name, + "CLIENT_APP_VERSION": internal_application_version, + "SVN_REVISION": VERSION[3], + "ACCOUNT_NAME": account, + "LOGIN_NAME": user, + "CLIENT_ENVIRONMENT": { + "APPLICATION": application, + "OS": OPERATING_SYSTEM, + "OS_VERSION": PLATFORM, + "PYTHON_VERSION": PYTHON_VERSION, + "PYTHON_RUNTIME": IMPLEMENTATION, + "PYTHON_COMPILER": COMPILER, + "OCSP_MODE": ocsp_mode.name, + "TRACING": logger.getEffectiveLevel(), + "LOGIN_TIMEOUT": login_timeout, + "NETWORK_TIMEOUT": network_timeout, + "CLIENT_STORE_TEMPORARY_CREDENTIAL": store_temp_cred, } }, } @@ -132,11 +150,22 @@ def authenticate(self, auth_instance, account, user, headers[HTTP_HEADER_SERVICE_NAME] = \ session_parameters[HTTP_HEADER_SERVICE_NAME] url = u"/session/v1/login-request" + if session_parameters is not None \ + and PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL in session_parameters: + store_temp_cred = session_parameters[ + PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL] + else: + store_temp_cred = None + body_template = Auth.base_auth_data( user, account, self._rest._connection.application, self._rest._connection._internal_application_name, self._rest._connection._internal_application_version, - self._rest._connection._ocsp_mode()) + self._rest._connection._ocsp_mode(), + self._rest._connection._login_timeout, + self._rest._connection._network_timeout, + store_temp_cred, + ) body = copy.deepcopy(body_template) # updating request body @@ -317,10 +346,10 @@ def post_request_wrapper(self, url, headers, body): id_token=ret[u'data'].get(u'idToken') ) if self._rest._connection.consent_cache_id_token: - write_temporary_credential_file( - account, user, self._rest.id_token, + write_temporary_credential( + self._rest._host, account, user, self._rest.id_token, session_parameters.get( - PARAMETER_CLIENT_USE_SECURE_STORAGE_FOR_TEMPORARY_CREDENTIAL)) + PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL)) if u'sessionId' in ret[u'data']: self._rest._connection._session_id = ret[u'data'][u'sessionId'] if u'sessionInfo' in ret[u'data']: @@ -333,17 +362,26 @@ def post_request_wrapper(self, url, headers, body): return session_parameters - def read_temporary_credential(self, account, user, session_parameters): - if session_parameters.get(PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL): - read_temporary_credential_file( - session_parameters.get( - PARAMETER_CLIENT_USE_SECURE_STORAGE_FOR_TEMPORARY_CREDENTIAL) - ) - id_token = TEMPORARY_CREDENTIAL.get( - account.upper(), {}).get(user.upper()) + def read_temporary_credential(self, host, account, user, session_parameters): + if session_parameters.get(PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL, False): + id_token = None + if IS_MACOS or IS_WINDOWS: + if not keyring: + # we will leave the exception for write_temporary_credential function to raise + return False + new_target = convert_target(host, user) + try: + id_token = keyring.get_password(new_target, user.upper()) + except keyring.errors.KeyringError as ke: + logger.debug("Could not retrieve id_token from secure storage : {}".format(str(ke))) + elif IS_LINUX: + read_temporary_credential_file() + id_token = TEMPORARY_CREDENTIAL.get( + account.upper(), {}).get(user.upper()) + else: + logger.debug("connection parameter enable_sso_temporary_credential not set or OS not support") if id_token: self._rest.id_token = id_token - if self._rest.id_token: try: self._rest._id_token_session() return True @@ -354,11 +392,31 @@ def read_temporary_credential(self, account, user, session_parameters): return False -def write_temporary_credential_file( - account, user, id_token, - use_secure_storage_for_temporary_credential=False): - if not CACHE_DIR or not id_token: - # no cache is enabled or no id_token is given +def write_temporary_credential(host, account, user, id_token, store_temporary_credential=False): + if not id_token: + logger.debug("no ID token is given when try to store temporary credential") + return + if IS_MACOS or IS_WINDOWS: + if not keyring: + raise MissingDependencyError("Please install keyring module to enable SSO token cache feature.") + + new_target = convert_target(host, user) + try: + keyring.set_password(new_target, user.upper(), id_token) + except keyring.errors.KeyringError as ke: + logger.debug("Could not store id_token to keyring, %s", str(ke)) + elif IS_LINUX and store_temporary_credential: + write_temporary_credential_file(host, account, user, id_token) + else: + logger.debug("connection parameter client_store_temporary_credential not set or OS not support") + + +def write_temporary_credential_file(host, account, user, id_token): + """ + Write temporary credential file when OS is Linux + """ + if not CACHE_DIR: + # no cache is enabled return global TEMPORARY_CREDENTIAL global TEMPORARY_CREDENTIAL_LOCK @@ -377,16 +435,9 @@ def write_temporary_credential_file( "write the temporary credential file: %s", TEMPORARY_CREDENTIAL_FILE) try: - if IS_LINUX or not use_secure_storage_for_temporary_credential: - with codecs.open(TEMPORARY_CREDENTIAL_FILE, 'w', - encoding='utf-8', errors='ignore') as f: - json.dump(TEMPORARY_CREDENTIAL, f) - else: - import keyring - keyring.set_password( - KEYRING_SERVICE_NAME, KEYRING_USER, - json.dumps(TEMPORARY_CREDENTIAL)) - + with codecs.open(TEMPORARY_CREDENTIAL_FILE, 'w', + encoding='utf-8', errors='ignore') as f: + json.dump(TEMPORARY_CREDENTIAL, f) except Exception as ex: logger.debug("Failed to write a credential file: " "file=[%s], err=[%s]", TEMPORARY_CREDENTIAL_FILE, ex) @@ -394,10 +445,9 @@ def write_temporary_credential_file( unlock_temporary_credential_file() -def read_temporary_credential_file( - use_secure_storage_for_temporary_credential=False): +def read_temporary_credential_file(): """ - Read temporary credential file + Read temporary credential file when OS is Linux """ if not CACHE_DIR: # no cache is enabled @@ -416,15 +466,9 @@ def read_temporary_credential_file( "write the temporary credential file: %s", TEMPORARY_CREDENTIAL_FILE) try: - if IS_LINUX or not use_secure_storage_for_temporary_credential: - with codecs.open(TEMPORARY_CREDENTIAL_FILE, 'r', - encoding='utf-8', errors='ignore') as f: - TEMPORARY_CREDENTIAL = json.load(f) - else: - import keyring - f = keyring.get_password( - KEYRING_SERVICE_NAME, KEYRING_USER) or "{}" - TEMPORARY_CREDENTIAL = json.loads(f) + with codecs.open(TEMPORARY_CREDENTIAL_FILE, 'r', + encoding='utf-8', errors='ignore') as f: + TEMPORARY_CREDENTIAL = json.load(f) return TEMPORARY_CREDENTIAL except Exception as ex: logger.debug("Failed to read a credential file. The file may not" @@ -456,26 +500,34 @@ def unlock_temporary_credential_file(): return False -def delete_temporary_credential_file( - use_secure_storage_for_temporary_credential=False): - """ - Delete temporary credential file and its lock file - """ - global TEMPORARY_CREDENTIAL_FILE - if IS_LINUX or not use_secure_storage_for_temporary_credential: +def delete_temporary_credential(host, user, store_temporary_credential=False): + if (IS_MACOS or IS_WINDOWS) and keyring: + new_target = convert_target(host, user) try: - remove(TEMPORARY_CREDENTIAL_FILE) - except Exception as ex: - logger.debug("Failed to delete a credential file: " - "file=[%s], err=[%s]", TEMPORARY_CREDENTIAL_FILE, ex) - else: - try: - import keyring - keyring.delete_password(KEYRING_SERVICE_NAME, KEYRING_USER) + keyring.delete_password(new_target, user.upper()) except Exception as ex: logger.debug("Failed to delete credential in the keyring: err=[%s]", ex) + elif IS_LINUX and store_temporary_credential: + delete_temporary_credential_file() + + +def delete_temporary_credential_file(): + """ + Delete temporary credential file and its lock file + """ + global TEMPORARY_CREDENTIAL_FILE + try: + remove(TEMPORARY_CREDENTIAL_FILE) + except Exception as ex: + logger.debug("Failed to delete a credential file: " + "file=[%s], err=[%s]", TEMPORARY_CREDENTIAL_FILE, ex) try: removedirs(TEMPORARY_CREDENTIAL_FILE_LOCK) except Exception as ex: logger.debug("Failed to delete credential lock file: err=[%s]", ex) + + +def convert_target(host, user): + return "{host}:{user}:{driver}".format( + host=host.upper(), user=user.upper(), driver=KEYRING_DRIVER_NAME) diff --git a/auth_okta.py b/auth_okta.py index 41d70780c..ee854c54c 100644 --- a/auth_okta.py +++ b/auth_okta.py @@ -9,7 +9,8 @@ from .auth import Auth from .auth_by_plugin import AuthByPlugin from .compat import unescape, urlencode, urlsplit -from .constants import HTTP_HEADER_ACCEPT, HTTP_HEADER_CONTENT_TYPE, HTTP_HEADER_SERVICE_NAME, HTTP_HEADER_USER_AGENT +from .constants import HTTP_HEADER_ACCEPT, HTTP_HEADER_CONTENT_TYPE, \ + HTTP_HEADER_SERVICE_NAME, HTTP_HEADER_USER_AGENT from .errorcode import ER_IDP_CONNECTION_ERROR, ER_INCORRECT_DESTINATION from .errors import DatabaseError, Error from .network import CONTENT_TYPE_APPLICATION_JSON, PYTHON_CONNECTOR_USER_AGENT @@ -121,7 +122,10 @@ def _step1(self, authenticator, service_name, account, user): self._rest._connection.application, self._rest._connection._internal_application_name, self._rest._connection._internal_application_version, - self._rest._connection._ocsp_mode()) + self._rest._connection._ocsp_mode(), + self._rest._connection._login_timeout, + self._rest._connection._network_timeout, + ) body[u"data"][u"AUTHENTICATOR"] = authenticator logger.debug( diff --git a/auth_webbrowser.py b/auth_webbrowser.py index ecf99913d..3ed3a4a48 100644 --- a/auth_webbrowser.py +++ b/auth_webbrowser.py @@ -12,7 +12,8 @@ from .auth import Auth from .auth_by_plugin import AuthByPlugin from .compat import parse_qs, urlparse, urlsplit -from .constants import HTTP_HEADER_ACCEPT, HTTP_HEADER_CONTENT_TYPE, HTTP_HEADER_SERVICE_NAME, HTTP_HEADER_USER_AGENT +from .constants import HTTP_HEADER_ACCEPT, HTTP_HEADER_CONTENT_TYPE, HTTP_HEADER_SERVICE_NAME, \ + HTTP_HEADER_USER_AGENT from .errorcode import ER_IDP_CONNECTION_ERROR, ER_NO_HOSTNAME_FOUND, ER_UNABLE_TO_OPEN_BROWSER from .errors import OperationalError from .network import CONTENT_TYPE_APPLICATION_JSON, EXTERNAL_BROWSER_AUTHENTICATOR, PYTHON_CONNECTOR_USER_AGENT @@ -291,7 +292,10 @@ def _get_sso_url( self._rest._connection.application, self._rest._connection._internal_application_name, self._rest._connection._internal_application_version, - self._rest._connection._ocsp_mode()) + self._rest._connection._ocsp_mode(), + self._rest._connection._login_timeout, + self._rest._connection._network_timeout, + ) body[u'data'][u'AUTHENTICATOR'] = authenticator body[u'data'][u"BROWSER_MODE_REDIRECT_PORT"] = str(callback_port) diff --git a/compat.py b/compat.py index a7b72bc5d..5a899584a 100644 --- a/compat.py +++ b/compat.py @@ -17,6 +17,7 @@ IS_LINUX = platform.system() == 'Linux' IS_WINDOWS = platform.system() == 'Windows' +IS_MACOS = platform.system() == 'Darwin' NUM_DATA_TYPES = [] try: diff --git a/connection.py b/connection.py index dfe44a011..5d254e5db 100644 --- a/connection.py +++ b/connection.py @@ -13,7 +13,6 @@ from threading import Lock from time import strptime -from .incident import IncidentAPI from . import errors from . import proxy from .auth import Auth @@ -27,7 +26,7 @@ DEFAULT_CLIENT_PREFETCH_THREADS, MAX_CLIENT_PREFETCH_THREADS) from .compat import ( - TO_UNICODE, urlencode, PY_ISSUE_23517, IS_WINDOWS) + TO_UNICODE, urlencode, PY_ISSUE_23517, IS_LINUX, IS_WINDOWS) from .constants import ( PARAMETER_AUTOCOMMIT, PARAMETER_CLIENT_SESSION_KEEP_ALIVE_HEARTBEAT_FREQUENCY, @@ -56,6 +55,7 @@ ER_NOT_IMPLICITY_SNOWFLAKE_DATATYPE, ER_NO_NUMPY) from .errors import Error, ProgrammingError, DatabaseError +from .incident import IncidentAPI from .network import ( DEFAULT_AUTHENTICATOR, EXTERNAL_BROWSER_AUTHENTICATOR, @@ -142,6 +142,8 @@ def DefaultConverterClass(): u'support_negative_year': True, # snowflake u'log_max_query_length': LOG_MAX_QUERY_LENGTH, # snowflake u'disable_request_pooling': False, # snowflake + u'client_store_temporary_credential': False, # enable temporary credential file for Linux, default false. Mac/Win will overlook this + 'use_openssl_only': False, # only use openssl instead of python only crypto modules } APPLICATION_RE = re.compile(r'[\w\d_]+') @@ -467,6 +469,13 @@ def log_max_query_length(self): def disable_request_pooling(self): return self._disable_request_pooling + @property + def use_openssl_only(self): + """ + Use OpenSSL only instead of PYthon libraries for signature verification and encryption purposes + """ + return self._use_openssl_only + @disable_request_pooling.setter def disable_request_pooling(self, value): self._disable_request_pooling = True if value else False @@ -691,11 +700,13 @@ def __open_connection(self): if self._authenticator == EXTERNAL_BROWSER_AUTHENTICATOR: # enable storing temporary credential in a file self._session_parameters[ - PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL] = True + PARAMETER_CLIENT_STORE_TEMPORARY_CREDENTIAL] = \ + self._client_store_temporary_credential if IS_LINUX else True auth = Auth(self.rest) if not auth.read_temporary_credential( - self.account, self.user, self._session_parameters): + self.host, self.account, self.user, + self._session_parameters): self.__authenticate(auth_instance) else: # set the current objects as the session is derived from the id @@ -820,6 +831,18 @@ def __config(self, **kwargs): u'CERTIFICATE REVOCATION STATUS WILL NOT BE ' u'CHECKED.') + if 'USE_OPENSSL_ONLY' not in os.environ: + logger.info( + 'Setting use_openssl_only mode to %s', self.use_openssl_only + ) + os.environ['USE_OPENSSL_ONLY'] = str(self.use_openssl_only) + else: + logger.warning( + 'Mode use_openssl_only is already set to: %s, ignoring set request to: %s', + os.environ['USE_OPENSSL_ONLY'], + self.use_openssl_only + ) + def cmd_query(self, sql, sequence_counter, request_id, binding_params=None, is_file_transfer=False, statement_params=None, diff --git a/cursor.py b/cursor.py index fbcc8d813..7f394625f 100644 --- a/cursor.py +++ b/cursor.py @@ -11,6 +11,9 @@ from logging import getLogger from threading import (Timer, Lock) +MYPY = False +if MYPY: # from typing import TYPE_CHECKING once 3.5 is deprecated + from .connection import SnowflakeConnection from .compat import (BASE_EXCEPTION_CLASS) from .constants import ( FIELD_NAME_TO_ID, @@ -81,7 +84,10 @@ class SnowflakeCursor(object): r'alter\s+session\s+set\s+(.*)=\'?([^\']+)\'?\s*;', flags=re.IGNORECASE | re.MULTILINE | re.DOTALL) - def __init__(self, connection, use_dict_result=False, json_result_class=JsonResult): + def __init__(self, + connection: 'SnowflakeConnection', + use_dict_result: bool = False, + json_result_class: object = JsonResult): """ :param connection: connection created this cursor :param use_dict_result: whether use dict result or not. This variable only applied to diff --git a/encryption_util.py b/encryption_util.py index a2499a484..07a064cc5 100644 --- a/encryption_util.py +++ b/encryption_util.py @@ -13,10 +13,14 @@ from logging import getLogger from Cryptodome.Cipher import AES +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from .compat import PKCS5_OFFSET, PKCS5_PAD, PKCS5_UNPAD, TO_UNICODE from .constants import UTF8 +block_size = int(algorithms.AES.block_size / 8) # in bytes + def matdesc_to_unicode(matdesc): """ @@ -61,10 +65,9 @@ def get_secure_random(byte_length): @staticmethod def encrypt_file(encryption_material, in_filename, - chunk_size=AES.block_size * 4 * 1024, tmp_dir=None): + chunk_size=block_size * 4 * 1024, tmp_dir=None): """ Encrypts a file - :param s3_metadata: S3 metadata output :param encryption_material: encryption material :param in_filename: input file name :param chunk_size: read chunk size @@ -72,15 +75,21 @@ def encrypt_file(encryption_material, in_filename, :return: a encrypted file """ logger = getLogger(__name__) + use_openssl_only = os.getenv('USE_OPENSSL_ONLY', 'False') == 'True' decoded_key = base64.standard_b64decode( encryption_material.query_stage_master_key) key_size = len(decoded_key) logger.debug(u'key_size = %s', key_size) # Generate key for data encryption - iv_data = SnowflakeEncryptionUtil.get_secure_random(AES.block_size) + iv_data = SnowflakeEncryptionUtil.get_secure_random(block_size) file_key = SnowflakeEncryptionUtil.get_secure_random(key_size) - data_cipher = AES.new(key=file_key, mode=AES.MODE_CBC, IV=iv_data) + if not use_openssl_only: + data_cipher = AES.new(key=file_key, mode=AES.MODE_CBC, IV=iv_data) + else: + backend = default_backend() + cipher = Cipher(algorithms.AES(file_key), modes.CBC(iv_data), backend=backend) + encryptor = cipher.encryptor() temp_output_fd, temp_output_file = tempfile.mkstemp( text=False, dir=tmp_dir, @@ -94,17 +103,31 @@ def encrypt_file(encryption_material, in_filename, chunk = infile.read(chunk_size) if len(chunk) == 0: break - elif len(chunk) % AES.block_size != 0: - chunk = PKCS5_PAD(chunk, AES.block_size) + elif len(chunk) % block_size != 0: + chunk = PKCS5_PAD(chunk, block_size) padded = True - outfile.write(data_cipher.encrypt(chunk)) + if not use_openssl_only: + outfile.write(data_cipher.encrypt(chunk)) + else: + outfile.write(encryptor.update(chunk)) if not padded: - outfile.write(data_cipher.encrypt( - AES.block_size * chr(AES.block_size).encode(UTF8))) + if not use_openssl_only: + outfile.write(data_cipher.encrypt( + block_size * chr(block_size).encode(UTF8))) + else: + outfile.write(encryptor.update( + block_size * chr(block_size).encode(UTF8))) + if use_openssl_only: + outfile.write(encryptor.finalize()) # encrypt key with QRMK - key_cipher = AES.new(key=decoded_key, mode=AES.MODE_ECB) - enc_kek = key_cipher.encrypt(PKCS5_PAD(file_key, AES.block_size)) + if not use_openssl_only: + key_cipher = AES.new(key=decoded_key, mode=AES.MODE_ECB) + enc_kek = key_cipher.encrypt(PKCS5_PAD(file_key, block_size)) + else: + cipher = Cipher(algorithms.AES(decoded_key), modes.ECB(), backend=backend) + encryptor = cipher.encryptor() + enc_kek = encryptor.update(PKCS5_PAD(file_key, block_size)) + encryptor.finalize() mat_desc = MaterialDescriptor( smk_id=encryption_material.smk_id, @@ -115,11 +138,11 @@ def encrypt_file(encryption_material, in_filename, iv=base64.b64encode(iv_data).decode('utf-8'), matdesc=matdesc_to_unicode(mat_desc), ) - return (metadata, temp_output_file) + return metadata, temp_output_file @staticmethod def decrypt_file(metadata, encryption_material, in_filename, - chunk_size=AES.block_size * 4 * 1024, tmp_dir=None): + chunk_size=block_size * 4 * 1024, tmp_dir=None): """ Decrypts a file and stores the output in the temporary directory :param metadata: metadata input @@ -130,6 +153,7 @@ def decrypt_file(metadata, encryption_material, in_filename, :return: a decrypted file name """ logger = getLogger(__name__) + use_openssl_only = os.getenv('USE_OPENSSL_ONLY', 'False') == 'True' key_base64 = metadata.key iv_base64 = metadata.iv decoded_key = base64.standard_b64decode( @@ -137,10 +161,17 @@ def decrypt_file(metadata, encryption_material, in_filename, key_bytes = base64.standard_b64decode(key_base64) iv_bytes = base64.standard_b64decode(iv_base64) - key_cipher = AES.new(key=decoded_key, mode=AES.MODE_ECB) - file_key = PKCS5_UNPAD(key_cipher.decrypt(key_bytes)) - - data_cipher = AES.new(key=file_key, mode=AES.MODE_CBC, IV=iv_bytes) + if not use_openssl_only: + key_cipher = AES.new(key=decoded_key, mode=AES.MODE_ECB) + file_key = PKCS5_UNPAD(key_cipher.decrypt(key_bytes)) + data_cipher = AES.new(key=file_key, mode=AES.MODE_CBC, IV=iv_bytes) + else: + backend = default_backend() + cipher = Cipher(algorithms.AES(decoded_key), modes.ECB(), backend=backend) + decryptor = cipher.decryptor() + file_key = PKCS5_UNPAD(decryptor.update(key_bytes) + decryptor.finalize()) + cipher = Cipher(algorithms.AES(file_key), modes.CBC(iv_bytes), backend=backend) + decryptor = cipher.decryptor() temp_output_fd, temp_output_file = tempfile.mkstemp( text=False, dir=tmp_dir, @@ -148,7 +179,7 @@ def decrypt_file(metadata, encryption_material, in_filename, total_file_size = 0 prev_chunk = None logger.debug(u'encrypted file: %s, tmp file: %s', - in_filename, temp_output_file) + in_filename, temp_output_file) with open(in_filename, u'rb') as infile: with os.fdopen(temp_output_fd, u'wb') as outfile: while True: @@ -156,10 +187,15 @@ def decrypt_file(metadata, encryption_material, in_filename, if len(chunk) == 0: break total_file_size += len(chunk) - d = data_cipher.decrypt(chunk) + if not use_openssl_only: + d = data_cipher.decrypt(chunk) + else: + d = decryptor.update(chunk) outfile.write(d) prev_chunk = d if prev_chunk is not None: total_file_size -= PKCS5_OFFSET(prev_chunk) + if use_openssl_only: + outfile.write(decryptor.finalize()) outfile.truncate(total_file_size) return temp_output_file diff --git a/file_util.py b/file_util.py index 389d15422..4fdd5d2ab 100644 --- a/file_util.py +++ b/file_util.py @@ -10,6 +10,8 @@ from logging import getLogger from Cryptodome.Hash import SHA256 +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import hashes from .constants import UTF8 @@ -69,18 +71,30 @@ def get_digest_and_size_for_file(file_name): :param file_name: a file name :return: """ + use_openssl_only = os.getenv('USE_OPENSSL_ONLY', 'False') == 'True' CHUNK_SIZE = 16 * 4 * 1024 f = open(file_name, 'rb') - m = SHA256.new() + if not use_openssl_only: + m = SHA256.new() + else: + backend = default_backend() + chosen_hash = hashes.SHA256() + hasher = hashes.Hash(chosen_hash, backend) while True: chunk = f.read(CHUNK_SIZE) if chunk == b'': break - m.update(chunk) + if not use_openssl_only: + m.update(chunk) + else: + hasher.update(chunk) statinfo = os.stat(file_name) file_size = statinfo.st_size - digest = base64.standard_b64encode(m.digest()).decode(UTF8) + if not use_openssl_only: + digest = base64.standard_b64encode(m.digest()).decode(UTF8) + else: + digest = base64.standard_b64encode(hasher.finalize()).decode(UTF8) logger = getLogger(__name__) logger.debug(u'getting digest and size: %s, %s, file=%s', digest, file_size, file_name) diff --git a/ocsp_asn1crypto.py b/ocsp_asn1crypto.py index c63583ba6..307d2a0f2 100644 --- a/ocsp_asn1crypto.py +++ b/ocsp_asn1crypto.py @@ -3,6 +3,7 @@ # # Copyright (c) 2012-2019 Snowflake Computing Inc. All right reserved. # +import os import platform import sys import warnings @@ -12,13 +13,18 @@ from logging import getLogger from os import getenv +from Cryptodome.Hash import SHA1, SHA256, SHA384, SHA512 +from Cryptodome.PublicKey import RSA +from Cryptodome.Signature import PKCS1_v1_5 from asn1crypto.algos import DigestAlgorithm from asn1crypto.core import Integer, OctetString from asn1crypto.ocsp import CertId, OCSPRequest, OCSPResponse, Request, Requests, TBSRequest, Version from asn1crypto.x509 import Certificate -from Cryptodome.Hash import SHA1, SHA256, SHA384, SHA512 -from Cryptodome.PublicKey import RSA -from Cryptodome.Signature import PKCS1_v1_5 +from cryptography.exceptions import InvalidSignature +from cryptography.hazmat.backends import default_backend +from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives.asymmetric import padding, utils + from snowflake.connector.errorcode import ER_INVALID_OCSP_RESPONSE, ER_INVALID_OCSP_RESPONSE_CODE from snowflake.connector.errors import RevocationCheckError from snowflake.connector.ocsp_snowflake import SnowflakeOCSP @@ -49,6 +55,12 @@ class SnowflakeOCSPAsn1Crypto(SnowflakeOCSP): 'sha512': SHA512, } + SIGNATURE_ALGORITHM_TO_DIGEST_CLASS_OPENSSL = { + 'sha256': hashes.SHA256, + 'sha384': hashes.SHA3_384, + 'sha512': hashes.SHA3_512, + } + WILDCARD_CERTID = None def __init__(self, **kwargs): @@ -88,15 +100,14 @@ def read_cert_bundle(self, ca_bundle_file, storage=None): if storage is None: storage = SnowflakeOCSP.ROOT_CERTIFICATES_DICT logger.debug('reading certificate bundle: %s', ca_bundle_file) - all_certs = open(ca_bundle_file, 'rb').read() - - # don't lock storage - from asn1crypto import pem - pem_certs = pem.unarmor(all_certs, multiple=True) - for type_name, _, der_bytes in pem_certs: - if type_name == 'CERTIFICATE': - crt = Certificate.load(der_bytes) - storage[crt.subject.sha256] = crt + with open(ca_bundle_file, 'rb') as all_certs: + # don't lock storage + from asn1crypto import pem + pem_certs = pem.unarmor(all_certs.read(), multiple=True) + for type_name, _, der_bytes in pem_certs: + if type_name == 'CERTIFICATE': + crt = Certificate.load(der_bytes) + storage[crt.subject.sha256] = crt def create_ocsp_request(self, issuer, subject): """ @@ -320,21 +331,48 @@ def process_ocsp_response(self, issuer, cert_id, ocsp_response): raise RevocationCheckError(msg=debug_msg, errno=op_er.errno) def verify_signature(self, signature_algorithm, signature, cert, data): - pubkey = asymmetric.load_public_key(cert.public_key).unwrap().dump() - rsakey = RSA.importKey(pubkey) - signer = PKCS1_v1_5.new(rsakey) - if signature_algorithm in SnowflakeOCSPAsn1Crypto.SIGNATURE_ALGORITHM_TO_DIGEST_CLASS: - digest = \ - SnowflakeOCSPAsn1Crypto.SIGNATURE_ALGORITHM_TO_DIGEST_CLASS[ + use_openssl_only = os.getenv('USE_OPENSSL_ONLY', 'False') == 'True' + if not use_openssl_only: + pubkey = asymmetric.load_public_key(cert.public_key).unwrap().dump() + rsakey = RSA.importKey(pubkey) + signer = PKCS1_v1_5.new(rsakey) + if signature_algorithm in SnowflakeOCSPAsn1Crypto.SIGNATURE_ALGORITHM_TO_DIGEST_CLASS: + digest = \ + SnowflakeOCSPAsn1Crypto.SIGNATURE_ALGORITHM_TO_DIGEST_CLASS[ signature_algorithm].new() + else: + # the last resort. should not happen. + digest = SHA1.new() + digest.update(data.dump()) + if not signer.verify(digest, signature): + raise RevocationCheckError( + msg="Failed to verify the signature", + errno=ER_INVALID_OCSP_RESPONSE) + else: - # the last resort. should not happen. - digest = SHA1.new() - digest.update(data.dump()) - if not signer.verify(digest, signature): - raise RevocationCheckError( - msg="Failed to verify the signature", - errno=ER_INVALID_OCSP_RESPONSE) + backend = default_backend() + public_key = serialization.load_der_public_key(cert.public_key.dump(), backend=default_backend()) + if signature_algorithm in SnowflakeOCSPAsn1Crypto.SIGNATURE_ALGORITHM_TO_DIGEST_CLASS: + chosen_hash = \ + SnowflakeOCSPAsn1Crypto.SIGNATURE_ALGORITHM_TO_DIGEST_CLASS_OPENSSL[ + signature_algorithm]() + else: + # the last resort. should not happen. + chosen_hash = hashes.SHA1() + hasher = hashes.Hash(chosen_hash, backend) + hasher.update(data.dump()) + digest = hasher.finalize() + try: + public_key.verify( + signature, + digest, + padding.PKCS1v15(), + utils.Prehashed(chosen_hash) + ) + except InvalidSignature: + raise RevocationCheckError( + msg="Failed to verify the signature", + errno=ER_INVALID_OCSP_RESPONSE) def extract_certificate_chain(self, connection): """ diff --git a/ocsp_pyasn1.py b/ocsp_pyasn1.py deleted file mode 100644 index 72ac11321..000000000 --- a/ocsp_pyasn1.py +++ /dev/null @@ -1,596 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -# -# Copyright (c) 2012-2019 Snowflake Computing Inc. All right reserved. -# - -import hashlib -import pytz -from base64 import b64encode, b64decode -from collections import OrderedDict -from datetime import datetime -from logging import getLogger -from threading import Lock -from os import getenv - -import pyasn1 -from Cryptodome.Hash import SHA256, SHA384, SHA1, SHA512 -from Cryptodome.PublicKey import RSA -from Cryptodome.Signature import PKCS1_v1_5 -from OpenSSL.crypto import ( - FILETYPE_PEM, - FILETYPE_ASN1, - load_certificate, dump_certificate) -from pyasn1.codec.der import decoder as der_decoder -from pyasn1.codec.der import encoder as der_encoder -from pyasn1.codec.native.encoder import encode as nat_encoder -from pyasn1.type import (univ, tag) -from pyasn1_modules import (rfc2459, rfc2437, rfc2560) - -from snowflake.connector.ocsp_snowflake import SnowflakeOCSP -from .errorcode import (ER_INVALID_OCSP_RESPONSE, ER_INVALID_OCSP_RESPONSE_CODE) -from .errors import (RevocationCheckError) -from .rfc6960 import ( - OCSPRequest, - OCSPResponse, - TBSRequest, - CertID, - Request, - OCSPResponseStatus, - BasicOCSPResponse, - Version) - -from snowflake.connector.ssd_internal_keys import ret_wildcard_hkey - -logger = getLogger(__name__) - - -class SnowflakeOCSPPyasn1(SnowflakeOCSP): - """ - OCSP checks by pyasn1 - """ - - PYASN1_VERSION_LOCK = Lock() - PYASN1_VERSION = None - - # Signature Hash Algorithm - sha1WithRSAEncryption = univ.ObjectIdentifier('1.2.840.113549.1.1.5') - sha256WithRSAEncryption = univ.ObjectIdentifier('1.2.840.113549.1.1.11') - sha384WithRSAEncryption = univ.ObjectIdentifier('1.2.840.113549.1.1.12') - sha512WithRSAEncryption = univ.ObjectIdentifier('1.2.840.113549.1.1.13') - - SIGNATURE_HASH_ALGO_TO_DIGEST_CLASS = { - sha1WithRSAEncryption: SHA1, - sha256WithRSAEncryption: SHA256, - sha384WithRSAEncryption: SHA384, - sha512WithRSAEncryption: SHA512, - } - - WILDCARD_CERTID = None - - @staticmethod - def _get_pyasn1_version(): - with SnowflakeOCSPPyasn1.PYASN1_VERSION_LOCK: - if SnowflakeOCSPPyasn1.PYASN1_VERSION is not None: - return SnowflakeOCSPPyasn1.PYASN1_VERSION - - v = pyasn1.__version__ - vv = [int(x, 10) for x in v.split('.')] - vv.reverse() - SnowflakeOCSPPyasn1.PYASN1_VERSION = sum( - x * (1000 ** i) for i, x in enumerate(vv)) - return SnowflakeOCSPPyasn1.PYASN1_VERSION - - def __init__(self, **kwargs): - super(SnowflakeOCSPPyasn1, self).__init__(**kwargs) - self.WILDCARD_CERTID = self.encode_cert_id_key(ret_wildcard_hkey()) - - def encode_cert_id_key(self, hkey): - issuer_name_hash, issuer_key_hash, serial_number = hkey - issuer_name_hash, _ = der_decoder.decode(issuer_name_hash) - issuer_key_hash, _ = der_decoder.decode(issuer_key_hash) - serial_number, _ = der_decoder.decode(serial_number) - cert_id = CertID() - cert_id.setComponentByName( - 'hashAlgorithm', - rfc2459.AlgorithmIdentifier().setComponentByName( - 'algorithm', rfc2437.id_sha1)) - cert_id.setComponentByName('issuerNameHash', issuer_name_hash) - cert_id.setComponentByName('issuerKeyHash', issuer_key_hash) - cert_id.setComponentByName('serialNumber', serial_number) - return cert_id - - def decode_cert_id_key(self, cert_id): - return ( - der_encoder.encode(cert_id.getComponentByName('issuerNameHash')), - der_encoder.encode(cert_id.getComponentByName('issuerKeyHash')), - der_encoder.encode(cert_id.getComponentByName('serialNumber'))) - - def encode_cert_id_base64(self, hkey): - return b64encode(der_encoder.encode( - self.encode_cert_id_key(hkey))).decode('ascii') - - def decode_cert_id_base64(self, cert_id_base64): - cert_id, _ = der_decoder.decode(b64decode(cert_id_base64), CertID()) - return cert_id - - def read_cert_bundle(self, ca_bundle_file, storage=None): - """ - Reads a certificate file including certificates in PEM format - """ - if storage is None: - storage = SnowflakeOCSP.ROOT_CERTIFICATES_DICT - logger.debug('reading certificate bundle: %s', ca_bundle_file) - all_certs = open(ca_bundle_file, 'rb').read() - - state = 0 - contents = [] - for line in all_certs.split(b'\n'): - if state == 0 and line.startswith(b'-----BEGIN CERTIFICATE-----'): - state = 1 - contents.append(line) - elif state == 1: - contents.append(line) - if line.startswith(b'-----END CERTIFICATE-----'): - cert_openssl = load_certificate( - FILETYPE_PEM, - b'\n'.join(contents)) - cert = self._convert_openssl_to_pyasn1_certificate( - cert_openssl) - storage[self._get_subject_hash(cert)] = cert - state = 0 - contents = [] - - def _convert_openssl_to_pyasn1_certificate(self, cert_openssl): - cert_der = dump_certificate(FILETYPE_ASN1, cert_openssl) - cert = der_decoder.decode( - cert_der, asn1Spec=rfc2459.Certificate())[0] - return cert - - def _convert_pyasn1_to_openssl_certificate(self, cert): - cert_der = der_encoder.encode(cert) - cert_openssl = load_certificate(FILETYPE_ASN1, cert_der) - return cert_openssl - - def _get_name_hash(self, cert): - sha1_hash = hashlib.sha1() - sha1_hash.update(der_encoder.encode(self._get_subject(cert))) - return sha1_hash.hexdigest() - - def _get_key_hash(self, cert): - sha1_hash = hashlib.sha1() - h = SnowflakeOCSPPyasn1.bit_string_to_bytearray( - cert.getComponentByName('tbsCertificate').getComponentByName( - 'subjectPublicKeyInfo').getComponentByName('subjectPublicKey')) - sha1_hash.update(h) - return sha1_hash.hexdigest() - - def create_ocsp_request(self, issuer, subject): - """ - Create CertID and OCSPRequest - """ - hashAlgorithm = rfc2459.AlgorithmIdentifier() - hashAlgorithm.setComponentByName("algorithm", rfc2437.id_sha1) - hashAlgorithm.setComponentByName( - "parameters", univ.Any(hexValue='0500')) - - cert_id = CertID() - cert_id.setComponentByName( - 'hashAlgorithm', hashAlgorithm) - cert_id.setComponentByName( - 'issuerNameHash', - univ.OctetString(hexValue=self._get_name_hash(issuer))) - cert_id.setComponentByName( - 'issuerKeyHash', - univ.OctetString(hexValue=self._get_key_hash(issuer))) - cert_id.setComponentByName( - 'serialNumber', - subject.getComponentByName( - 'tbsCertificate').getComponentByName('serialNumber')) - - request = Request() - request.setComponentByName('reqCert', cert_id) - - request_list = univ.SequenceOf(componentType=Request()) - request_list.setComponentByPosition(0, request) - - tbs_request = TBSRequest() - tbs_request.setComponentByName('requestList', request_list) - tbs_request.setComponentByName('version', Version(0).subtype( - explicitTag=tag.Tag( - tag.tagClassContext, tag.tagFormatSimple, 0))) - - ocsp_request = OCSPRequest() - ocsp_request.setComponentByName('tbsRequest', tbs_request) - - return cert_id, ocsp_request - - def extract_certificate_chain(self, connection): - """ - Gets certificate chain and extract the key info from OpenSSL connection - """ - cert_map = OrderedDict() - logger.debug( - "# of certificates: %s", - len(connection.get_peer_cert_chain())) - - for cert_openssl in connection.get_peer_cert_chain(): - cert_der = dump_certificate(FILETYPE_ASN1, cert_openssl) - cert = der_decoder.decode( - cert_der, asn1Spec=rfc2459.Certificate())[0] - subject_sha256 = self._get_subject_hash(cert) - logger.debug( - u'subject: %s, issuer: %s', - nat_encoder(self._get_subject(cert)), - nat_encoder(self._get_issuer(cert))) - cert_map[subject_sha256] = cert - - return self.create_pair_issuer_subject(cert_map) - - def _get_subject(self, cert): - return cert.getComponentByName( - 'tbsCertificate').getComponentByName('subject') - - def _get_issuer(self, cert): - return cert.getComponentByName( - 'tbsCertificate').getComponentByName('issuer') - - def _get_subject_hash(self, cert): - sha256_hash = hashlib.sha256() - sha256_hash.update( - der_encoder.encode(self._get_subject(cert))) - return sha256_hash.digest() - - def _get_issuer_hash(self, cert): - sha256_hash = hashlib.sha256() - sha256_hash.update( - der_encoder.encode(self._get_issuer(cert))) - return sha256_hash.digest() - - def create_pair_issuer_subject(self, cert_map): - """ - Creates pairs of issuer and subject certificates - """ - issuer_subject = [] - for subject_der in cert_map: - cert = cert_map[subject_der] - - nocheck, is_ca, ocsp_urls = self._extract_extensions(cert) - if nocheck or is_ca and not ocsp_urls: - # Root certificate will not be validated - # but it is used to validate the subject certificate - continue - issuer_hash = self._get_issuer_hash(cert) - if issuer_hash not in cert_map: - # IF NO ROOT certificate is attached in the certificate chain - # read it from the local disk - self._lazy_read_ca_bundle() - logger.debug( - 'not found issuer_der: %s', self._get_issuer_hash(cert)) - if issuer_hash not in SnowflakeOCSP.ROOT_CERTIFICATES_DICT: - raise RevocationCheckError( - msg="CA certificate is NOT found in the root " - "certificate list. Make sure you use the latest " - "Python Connector package and the URL is valid.") - issuer = SnowflakeOCSP.ROOT_CERTIFICATES_DICT[issuer_hash] - else: - issuer = cert_map[issuer_hash] - - issuer_subject.append((issuer, cert)) - return issuer_subject - - def _extract_extensions(self, cert): - extensions = cert.getComponentByName( - 'tbsCertificate').getComponentByName('extensions') - is_ca = False - ocsp_urls = [] - nocheck = False - for e in extensions: - oid = e.getComponentByName('extnID') - if oid == rfc2459.id_ce_basicConstraints: - constraints = der_decoder.decode( - e.getComponentByName('extnValue'), - asn1Spec=rfc2459.BasicConstraints())[0] - is_ca = constraints.getComponentByPosition(0) - elif oid == rfc2459.id_pe_authorityInfoAccess: - auth_info = der_decoder.decode( - e.getComponentByName('extnValue'), - asn1Spec=rfc2459.AuthorityInfoAccessSyntax())[0] - for a in auth_info: - if a.getComponentByName('accessMethod') == \ - rfc2560.id_pkix_ocsp: - url = nat_encoder( - a.getComponentByName( - 'accessLocation').getComponentByName( - 'uniformResourceIdentifier')) - ocsp_urls.append(url) - elif oid == rfc2560.id_pkix_ocsp_nocheck: - nocheck = True - - return nocheck, is_ca, ocsp_urls - - def subject_name(self, cert): - return nat_encoder(self._get_subject(cert)) - - def extract_ocsp_url(self, cert): - _, _, ocsp_urls = self._extract_extensions(cert) - return ocsp_urls[0] if ocsp_urls else None - - def decode_ocsp_request(self, ocsp_request): - return der_encoder.encode(ocsp_request) - - def decode_ocsp_request_b64(self, ocsp_request): - data = self.decode_ocsp_request(ocsp_request) - b64data = b64encode(data).decode('ascii') - return b64data - - def extract_good_status(self, single_response): - """ - Extract GOOD status - """ - this_update_native = \ - self._convert_generalized_time_to_datetime( - single_response.getComponentByName('thisUpdate')) - next_update_native = \ - self._convert_generalized_time_to_datetime( - single_response.getComponentByName('nextUpdate')) - return this_update_native, next_update_native - - def extract_revoked_status(self, single_response): - """ - Extract REVOKED status - """ - cert_status = single_response.getComponentByName('certStatus') - revoked = cert_status.getComponentByName('revoked') - revocation_time = \ - self._convert_generalized_time_to_datetime( - revoked.getComponentByName('revocationTime')) - revocation_reason = revoked.getComponentByName('revocationReason') - try: - revocation_reason_str = str(revocation_reason) - except Exception: - revocation_reason_str = 'n/a' - return revocation_time, revocation_reason_str - - def _convert_generalized_time_to_datetime(self, gentime): - return datetime.strptime(str(gentime), '%Y%m%d%H%M%SZ') - - def check_cert_time_validity(self, cur_time, tbs_certificate): - cert_validity = tbs_certificate.getComponentByName('validity') - cert_not_after = cert_validity.getComponentByName('notAfter') - val_end = cert_not_after.getComponentByName('utcTime').asDateTime - cert_not_before = cert_validity.getComponentByName('notBefore') - val_start = cert_not_before.getComponentByName('utcTime').asDateTime - - if cur_time > val_end or cur_time < val_start: - debug_msg = "Certificate attached to OCSP Response is invalid. " \ - "OCSP response current time - {} certificate not " \ - "before time - {} certificate not after time - {}. ". \ - format(cur_time, val_start, val_end) - return False, debug_msg - else: - return True, None - - """ - is_valid_time - checks various components of the OCSP Response - for expiry. - :param cert_id - certificate id corresponding to OCSP Response - :param ocsp_response - :return True/False depending on time validity within the response - """ - def is_valid_time(self, cert_id, ocsp_response): - res = der_decoder.decode(ocsp_response, OCSPResponse())[0] - - if res.getComponentByName('responseStatus') != OCSPResponseStatus( - 'successful'): - raise RevocationCheckError( - msg="Invalid Status: {}".format( - res.getComponentByName('response_status')), - errno=ER_INVALID_OCSP_RESPONSE) - - response_bytes = res.getComponentByName('responseBytes') - basic_ocsp_response = der_decoder.decode( - response_bytes.getComponentByName('response'), - BasicOCSPResponse())[0] - - attached_certs = basic_ocsp_response.getComponentByName('certs') - if self._has_certs_in_ocsp_response(attached_certs): - logger.debug("Certificate is attached in Basic OCSP Response") - cert_der = der_encoder.encode(attached_certs[0]) - cert_openssl = load_certificate(FILETYPE_ASN1, cert_der) - ocsp_cert = self._convert_openssl_to_pyasn1_certificate( - cert_openssl) - - cur_time = datetime.utcnow().replace(tzinfo=pytz.utc) - tbs_certificate = ocsp_cert.getComponentByName('tbsCertificate') - - """ - Note: - We purposefully do not verify certificate signature here. - The OCSP Response is extracted from the OCSP Response Cache - which is expected to have OCSP Responses with verified - attached signature. Moreover this OCSP Response is eventually - going to be processed by the driver before being consumed by - the driver. - This step ensures that the OCSP Response cache does not have - any invalid entries. - """ - - cert_valid, debug_msg = self.check_cert_time_validity(cur_time, - tbs_certificate) - if not cert_valid: - logger.debug(debug_msg) - return False - - tbs_response_data = basic_ocsp_response.getComponentByName( - 'tbsResponseData') - single_response = tbs_response_data.getComponentByName('responses')[0] - cert_status = single_response.getComponentByName('certStatus') - try: - if cert_status.getName() == 'good': - self._process_good_status(single_response, cert_id, ocsp_response) - except Exception as ex: - logger.debug("Failed to validate ocsp response %s", ex) - return False - - return True - - def process_ocsp_response(self, issuer, cert_id, ocsp_response): - try: - res = der_decoder.decode(ocsp_response, OCSPResponse())[0] - if self.test_mode is not None: - ocsp_load_failure = getenv("SF_TEST_OCSP_FORCE_BAD_OCSP_RESPONSE") - if ocsp_load_failure is not None: - raise RevocationCheckError("Force fail") - except Exception: - raise RevocationCheckError( - msg='Invalid OCSP Response', - errno=ER_INVALID_OCSP_RESPONSE - ) - - if res.getComponentByName('responseStatus') != OCSPResponseStatus( - 'successful'): - raise RevocationCheckError( - msg="Invalid Status: {}".format( - res.getComponentByName('response_status')), - errno=ER_INVALID_OCSP_RESPONSE) - - response_bytes = res.getComponentByName('responseBytes') - basic_ocsp_response = der_decoder.decode( - response_bytes.getComponentByName('response'), - BasicOCSPResponse())[0] - - attached_certs = basic_ocsp_response.getComponentByName('certs') - if self._has_certs_in_ocsp_response(attached_certs): - logger.debug("Certificate is attached in Basic OCSP Response") - cert_der = der_encoder.encode(attached_certs[0]) - cert_openssl = load_certificate(FILETYPE_ASN1, cert_der) - ocsp_cert = self._convert_openssl_to_pyasn1_certificate(cert_openssl) - - cur_time = datetime.utcnow().replace(tzinfo=pytz.utc) - tbs_certificate = ocsp_cert.getComponentByName('tbsCertificate') - - """ - Signature verification should happen before any kind of - validation - """ - - self.verify_signature( - ocsp_cert.getComponentByName('signatureAlgorithm'), - ocsp_cert.getComponentByName('signatureValue'), - issuer, - ocsp_cert.getComponentByName('tbsCertificate')) - - cert_valid, debug_msg = self.check_cert_time_validity(cur_time, - tbs_certificate) - if not cert_valid: - raise RevocationCheckError( - msg=debug_msg, - errno=ER_INVALID_OCSP_RESPONSE_CODE - ) - else: - logger.debug("Certificate is NOT attached in Basic OCSP Response. " - "Using issuer's certificate") - ocsp_cert = issuer - - tbs_response_data = basic_ocsp_response.getComponentByName( - 'tbsResponseData') - - logger.debug("Verifying the OCSP response is signed by the issuer.") - self.verify_signature( - basic_ocsp_response.getComponentByName('signatureAlgorithm'), - basic_ocsp_response.getComponentByName('signature'), - ocsp_cert, - tbs_response_data - ) - - single_response = tbs_response_data.getComponentByName('responses')[0] - cert_status = single_response.getComponentByName('certStatus') - - if self.test_mode is not None: - test_cert_status = getenv("SF_TEST_OCSP_CERT_STATUS") - if test_cert_status == 'revoked': - cert_status = 'revoked' - elif test_cert_status == 'unknown': - cert_status = 'unknown' - elif test_cert_status == 'good': - cert_status = 'good' - - try: - if cert_status.getName() == 'good': - self._process_good_status(single_response, cert_id, ocsp_response) - SnowflakeOCSP.OCSP_CACHE.update_cache(self, cert_id, ocsp_response) - elif cert_status.getName() == 'revoked': - self._process_revoked_status(single_response, cert_id) - elif cert_status.getName() == 'unknown': - self._process_unknown_status(cert_id) - else: - debug_msg = "Unknown revocation status was returned. " \ - "OCSP response may be malformed: {}. ".format(cert_status) - raise RevocationCheckError( - msg=debug_msg, - errno=ER_INVALID_OCSP_RESPONSE_CODE) - except RevocationCheckError as op_er: - if not self.debug_ocsp_failure_url: - debug_msg = op_er.msg - else: - debug_msg = "{} Consider running curl -o ocsp.der {}".\ - format(op_er.msg, - self.debug_ocsp_failure_url) - raise RevocationCheckError( - msg=debug_msg, - errno=op_er.errno) - - def verify_signature(self, signature_algorithm, signature, cert, data): - """ - Verifies the signature - """ - sig = SnowflakeOCSPPyasn1.bit_string_to_bytearray(signature) - sig = sig.decode('latin-1').encode('latin-1') - - pubkey = SnowflakeOCSPPyasn1.bit_string_to_bytearray( - cert.getComponentByName( - 'tbsCertificate').getComponentByName( - 'subjectPublicKeyInfo').getComponentByName('subjectPublicKey')) - pubkey = pubkey.decode('latin-1').encode('latin-1') - - rsakey = RSA.importKey(pubkey) - signer = PKCS1_v1_5.new(rsakey) - - algorithm = signature_algorithm[0] - if algorithm in SnowflakeOCSPPyasn1.SIGNATURE_HASH_ALGO_TO_DIGEST_CLASS: - digest = SnowflakeOCSPPyasn1.SIGNATURE_HASH_ALGO_TO_DIGEST_CLASS[ - algorithm].new() - else: - digest = SHA1.new() - - data = der_encoder.encode(data) - digest.update(data) - if not signer.verify(digest, sig): - raise RevocationCheckError( - msg="Failed to verify the signature", - errno=ER_INVALID_OCSP_RESPONSE) - - def _has_certs_in_ocsp_response(self, certs): - """ - Check if the certificate is attached to OCSP response - """ - if SnowflakeOCSPPyasn1._get_pyasn1_version() <= 3000: - return certs is not None - else: - # behavior changed. - return certs is not None and certs.hasValue() and certs[ - 0].hasValue() - - @staticmethod - def bit_string_to_bytearray(bit_string): - """ - Converts Bitstring to bytearray - """ - ret = [] - for idx in range(int(len(bit_string) / 8)): - v = 0 - for idx0, bit in enumerate(bit_string[idx * 8:idx * 8 + 8]): - v = v | (bit << (7 - idx0)) - ret.append(v) - return bytearray(ret) diff --git a/ocsp_snowflake.py b/ocsp_snowflake.py index 0149e139f..eeae2183f 100644 --- a/ocsp_snowflake.py +++ b/ocsp_snowflake.py @@ -866,6 +866,7 @@ class SnowflakeOCSP(object): r'(.*\.snowflakecomputing\.com$' r'|(?:|.*\.)s3.*\.amazonaws\.com$' # start with s3 or .s3 in the middle r'|.*\.okta\.com$' + r'|(?:|.*\.)storage\.googleapis\.com$' r'|.*\.blob\.core\.windows\.net$' r'|.*\.blob\.core\.usgovcloudapi\.net$)') diff --git a/options.py b/options.py index 997c1992e..3886bb3ad 100644 --- a/options.py +++ b/options.py @@ -2,6 +2,7 @@ # Flags to see whether optional dependencies were installed installed_pandas = False +installed_keyring = False class MissingPandas(object): @@ -10,6 +11,12 @@ def __getattr__(self, item): raise MissingDependencyError('pandas') +class MissingKeyring(object): + + def __getattr__(self, item): + raise MissingDependencyError('keyring') + + try: import pandas import pyarrow @@ -17,3 +24,9 @@ def __getattr__(self, item): except ImportError: pandas = MissingPandas() pyarrow = MissingPandas() + +try: + import keyring + installed_keyring = True +except ImportError: + keyring = MissingKeyring() diff --git a/scripts/build_linux.sh b/scripts/build_linux.sh index 71376657d..f18e7e9ea 100755 --- a/scripts/build_linux.sh +++ b/scripts/build_linux.sh @@ -17,4 +17,3 @@ unset ENABLE_EXT_MODULES mkdir -p $THIS_DIR/../dist/docker/repaired_wheels auditwheel repair --plat manylinux2010_x86_64 -L connector $THIS_DIR/../dist/docker/$PYTHON_VERSION/*.whl -w $THIS_DIR/../dist/docker/repaired_wheels -rm $THIS_DIR/../dist/docker/repaired_wheels/*manylinux1_x86_64.whl || true diff --git a/scripts/build_pyarrow_linux.sh b/scripts/build_pyarrow_linux.sh index d9664866a..c736cca48 100755 --- a/scripts/build_pyarrow_linux.sh +++ b/scripts/build_pyarrow_linux.sh @@ -22,7 +22,6 @@ function build_connector_with_python() { # audit wheel files mkdir -p $CONNECTOR_DIR/dist/docker/repaired_wheels auditwheel repair --plat manylinux2010_x86_64 -L connector $CONNECTOR_DIR/dist/docker/$PYTHON/*.whl -w $CONNECTOR_DIR/dist/docker/repaired_wheels - rm $CONNECTOR_DIR/dist/docker/repaired_wheels/*manylinux1_x86_64.whl || true deactivate } diff --git a/scripts/install.sh b/scripts/install.sh index e39cd5242..1e65790d3 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -38,7 +38,7 @@ if [ "$TRAVIS_OS_NAME" == "osx" ]; then else pv=${TRAVIS_PYTHON_VERSION} $THIS_DIR/build_inside_docker.sh $pv - CONNECTOR_WHL=$(ls $THIS_DIR/../dist/docker/repaired_wheels/snowflake_connector_python*cp${PYTHON_ENV}*.whl | sort -r | head -n 1) + CONNECTOR_WHL=$(ls $THIS_DIR/../dist/docker/repaired_wheels/snowflake_connector_python*cp${PYTHON_ENV}*manylinux2010*.whl | sort -r | head -n 1) pip install -U ${CONNECTOR_WHL}[pandas,development] cd $THIS_DIR/.. fi diff --git a/scripts/run_travis.sh b/scripts/run_travis.sh index cc606bb6c..809e1aec8 100755 --- a/scripts/run_travis.sh +++ b/scripts/run_travis.sh @@ -31,7 +31,7 @@ else # shellcheck disable=SC2068 ${TIMEOUT_CMD[@]} py.test -vvv --cov=snowflake.connector \ --cov-report=xml:python_connector_${TRAVIS_PYTHON_VERSION}_coverage.xml \ - test || ret=$? + --ignore=test/sso test || ret=$? fi # TIMEOUT or SUCCESS diff --git a/scripts/test.bat b/scripts/test.bat index 76dfda776..886186f36 100644 --- a/scripts/test.bat +++ b/scripts/test.bat @@ -63,7 +63,7 @@ if %errorlevel% neq 0 goto :error set JUNIT_REPORT_DIR=%workspace% set COV_REPORT_DIR=%workspace% -tox -e py%pv%-ci,py%pv%-pandas-ci,coverage --external_wheels ..\..\..\%connector_whl% -- --basetemp=%workspace%\pytest-tmp\ +tox -e py%pv%-ci,py%pv%-pandas-ci,py%pv%-sso-ci,coverage --external_wheels ..\..\..\%connector_whl% -- --basetemp=%workspace%\pytest-tmp\ if %errorlevel% neq 0 goto :error call deactivate diff --git a/scripts/test.sh b/scripts/test.sh index 1d60ec05c..dcfb83ce2 100755 --- a/scripts/test.sh +++ b/scripts/test.sh @@ -7,8 +7,8 @@ THIS_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" # shellcheck disable=SC1090 source "${THIS_DIR}/py_exec.sh" CONNECTOR_DIR="$( dirname "${THIS_DIR}")" -CONNECTOR_WHL=$(ls $CONNECTOR_DIR/dist/docker/repaired_wheels/snowflake_connector_python*cp${PYTHON_ENV}*.whl | sort -r | head -n 1) -TEST_ENVLIST=fix_lint,py${PYTHON_ENV}-ci,py${PYTHON_ENV}-pandas-ci,coverage +CONNECTOR_WHL=$(ls $CONNECTOR_DIR/dist/docker/repaired_wheels/snowflake_connector_python*cp${PYTHON_ENV}*manylinux2010*.whl | sort -r | head -n 1) +TEST_ENVLIST=fix_lint,py${PYTHON_ENV}-ci,py${PYTHON_ENV}-pandas-ci,py${PYTHON_ENV}-sso-ci,coverage if [[ -n "$PIP_INDEX_URL" ]]; then echo "PIP_INDEX_URL before now: ${PIP_INDEX_URL}" diff --git a/scripts/test_darwin.sh b/scripts/test_darwin.sh index 1c20d4fc9..c73d2d49b 100755 --- a/scripts/test_darwin.sh +++ b/scripts/test_darwin.sh @@ -48,7 +48,7 @@ aws s3 cp --only-show-errors \ log INFO "Testing Connector in python${PYTHON_ENV}" CONNECTOR_WHL=$(ls ${WORKSPACE}/snowflake_connector_python*cp${PYTHON_ENV}*.whl) -TEST_ENVLIST=fix_lint,py${PYTHON_ENV}-ci,py${PYTHON_ENV}-pandas-ci,coverage +TEST_ENVLIST=fix_lint,py${PYTHON_ENV}-ci,py${PYTHON_ENV}-pandas-ci,py${PYTHON_ENV}-sso-ci,coverage cd $CONNECTOR_DIR diff --git a/setup.py b/setup.py index 1beb85142..82bb46a71 100644 --- a/setup.py +++ b/setup.py @@ -221,13 +221,13 @@ def _get_arrow_lib_as_linker_input(self): }, extras_require={ "secure-local-storage": [ - 'keyring!=16.1.0' + 'keyring<22.0.0,!=16.1.0', ], "pandas": [ 'pyarrow>=0.15.1,<0.16.0;python_version=="3.5" and platform_system=="Windows"', 'pyarrow>=0.16.0,<0.17.0;python_version!="3.5" or platform_system!="Windows"', 'pandas==0.24.2;python_version=="3.5"', - 'pandas<1.0.0;python_version>"3.5"', + 'pandas>=1.0.0,<1.1.0;python_version>"3.5"', ], "development": [ 'pytest', diff --git a/test/test_connection_manual.py b/test/sso/test_connection_manual.py similarity index 91% rename from test/test_connection_manual.py rename to test/sso/test_connection_manual.py index a0dcf578d..a121150d9 100644 --- a/test/test_connection_manual.py +++ b/test/sso/test_connection_manual.py @@ -19,7 +19,11 @@ import pytest import snowflake.connector -from snowflake.connector.auth import delete_temporary_credential_file +from snowflake.connector.auth import delete_temporary_credential +import os +import sys + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) try: from parameters import (CONNECTION_PARAMETERS_SSO) @@ -67,12 +71,11 @@ def test_connect_externalbrowser(token_validity_test_values): should not create popups. """ - delete_temporary_credential_file(True) # delete secure storage - delete_temporary_credential_file(False) # delete file cache - CONNECTION_PARAMETERS_SSO['session_parameters'] = \ - { - "CLIENT_USE_SECURE_STORAGE_FOR_TEMPORARY_CREDENTAIL": True, - } + delete_temporary_credential( + host=CONNECTION_PARAMETERS_SSO['host'], + user=CONNECTION_PARAMETERS_SSO['user'], + store_temporary_credential=True) # delete existing temporary credential + CONNECTION_PARAMETERS_SSO['client_store_temporary_credential'] = True # change database and schema to non-default one print("[INFO] 1st connection gets id token and stores in the cache file. " diff --git a/test/sso/test_unit_sso_connection.py b/test/sso/test_unit_sso_connection.py new file mode 100644 index 000000000..919cd06df --- /dev/null +++ b/test/sso/test_unit_sso_connection.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2012-2019 Snowflake Computing Inc. All right reserved. +# +import os +import snowflake.connector +from mock import patch, Mock +from snowflake.connector.auth import delete_temporary_credential +from snowflake.connector.compat import IS_MACOS + +@patch( + 'snowflake.connector.auth_webbrowser.AuthByWebBrowser.authenticate') +@patch( + 'snowflake.connector.network.SnowflakeRestful._post_request' +) +def test_connect_externalbrowser( + mockSnowflakeRestfulPostRequest, + mockAuthByBrowserAuthenticate): + """ + Connect with authentictor=externalbrowser mock. + """ + + os.environ['SF_TEMPORARY_CREDENTIAL_CACHE_DIR'] = os.getenv( + "WORKSPACE", os.path.expanduser("~")) + + def mock_post_request(url, headers, json_body, **kwargs): + global mock_post_req_cnt + ret = None + if mock_post_req_cnt == 0: + # return from /v1/login-request + ret = { + u'success': True, + u'message': None, + u'data': { + u'token': u'TOKEN', + u'masterToken': u'MASTER_TOKEN', + u'idToken': u'ID_TOKEN', + }} + elif mock_post_req_cnt == 1: + # return from /token-request + ret = { + u'success': True, + u'message': None, + u'data': { + u'sessionToken': u'NEW_TOKEN', + }} + elif mock_post_req_cnt == 2: + # return from USE WAREHOUSE TESTWH_NEW + ret = { + u'success': True, + u'message': None, + u'data': { + u'finalDatabase': 'TESTDB', + u'finalWarehouse': 'TESTWH_NEW', + }} + elif mock_post_req_cnt == 3: + # return from USE DATABASE TESTDB_NEW + ret = { + u'success': True, + u'message': None, + u'data': { + u'finalDatabase': 'TESTDB_NEW', + u'finalWarehouse': 'TESTWH_NEW', + }} + elif mock_post_req_cnt == 4: + # return from SELECT 1 + ret = { + u'success': True, + u'message': None, + u'data': { + u'finalDatabase': 'TESTDB_NEW', + u'finalWarehouse': 'TESTWH_NEW', + }} + mock_post_req_cnt += 1 + return ret + + def mock_get_password(service, user): + global mock_get_pwd_cnt + ret = None + if mock_get_pwd_cnt == 1: + # second connection + ret = 'ID_TOKEN' + mock_get_pwd_cnt += 1 + return ret + + global mock_post_req_cnt, mock_get_pwd_cnt + mock_post_req_cnt, mock_get_pwd_cnt = 0, 0 + + # pre-authentication doesn't matter + mockAuthByBrowserAuthenticate.return_value = None + + # POST requests mock + mockSnowflakeRestfulPostRequest.side_effect = mock_post_request + + def test_body(): + account = 'testaccount' + user = 'testuser' + authenticator = 'externalbrowser' + host = 'testaccount.snowflakecomputing.com' + + delete_temporary_credential( + host=host, user=user, store_temporary_credential=True) + + # first connection + con = snowflake.connector.connect( + account=account, + user=user, + host=host, + authenticator=authenticator, + database='TESTDB', + warehouse='TESTWH', + client_store_temporary_credential=True, + ) + assert con._rest.token == u'TOKEN' + assert con._rest.master_token == u'MASTER_TOKEN' + assert con._rest.id_token == u'ID_TOKEN' + + # second connection that uses the id token to get the session token + con = snowflake.connector.connect( + account=account, + user=user, + host=host, + authenticator=authenticator, + database='TESTDB_NEW', # override the database + warehouse='TESTWH_NEW', # override the warehouse + client_store_temporary_credential=True, + ) + + assert con._rest.token == u'NEW_TOKEN' + assert con._rest.master_token is None + assert con._rest.id_token == 'ID_TOKEN' + assert con.database == 'TESTDB_NEW' + assert con.warehouse == 'TESTWH_NEW' + + if IS_MACOS: + with patch('keyring.delete_password', Mock(return_value=None) + ), patch('keyring.set_password', Mock(return_value=None) + ), patch('keyring.get_password', Mock(side_effect=mock_get_password)): + test_body() + else: + test_body() diff --git a/test/test_connection.py b/test/test_connection.py index 53eda29ac..a90afb59b 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -29,6 +29,8 @@ def test_basic(conn_testaccount): Basic Connection test """ assert conn_testaccount, 'invalid cnx' + # Test default values + assert not conn_testaccount.use_openssl_only conn_testaccount._set_current_objects() @@ -652,3 +654,31 @@ def mock_auth(self, auth_instance): authenticator=orig_authenticator, ) assert cnx + + +def test_use_openssl_only(db_parameters): + cnx = snowflake.connector.connect( + user=db_parameters['user'], + password=db_parameters['password'], + host=db_parameters['host'], + port=db_parameters['port'], + account=db_parameters['account'], + protocol=db_parameters['protocol'], + use_openssl_only=True, + ) + assert cnx + assert 'USE_OPENSSL_ONLY' in os.environ + # Note during testing conftest will default this value to False, so if testing this we need to manually clear it + # Let's test it again, after clearing it + del os.environ['USE_OPENSSL_ONLY'] + cnx = snowflake.connector.connect( + user=db_parameters['user'], + password=db_parameters['password'], + host=db_parameters['host'], + port=db_parameters['port'], + account=db_parameters['account'], + protocol=db_parameters['protocol'], + use_openssl_only=True, + ) + assert cnx + assert os.environ['USE_OPENSSL_ONLY'] == 'True' diff --git a/test/test_dbapi.py b/test/test_dbapi.py index bb127ed1a..1889f01c4 100644 --- a/test/test_dbapi.py +++ b/test/test_dbapi.py @@ -572,6 +572,8 @@ def test_setoutputsize_basic( def test_description2(conn_local): try: with conn_local() as con: + # ENABLE_FIX_67159 changes the column size to the actual size. By default it is disabled at the moment. + expected_column_size = 26 if not con.account.startswith("sfctest0") else 16777216 cur = con.cursor() executeDDL1(cur) assert len( @@ -603,7 +605,7 @@ def test_description2(conn_local): # number (FIXED) ('COL1', 0, None, None, 9, 4, False), # decimal - ('COL2', 2, None, 16777216, None, None, False), + ('COL2', 2, None, expected_column_size, None, None, False), # string ('COL3', 3, None, None, None, None, True), # date diff --git a/test/test_unit_auth.py b/test/test_unit_auth.py index c05295e64..0ba515e11 100644 --- a/test/test_unit_auth.py +++ b/test/test_unit_auth.py @@ -16,6 +16,7 @@ def _init_rest(application, post_requset): connection = MagicMock() connection._login_timeout = 120 + connection._network_timeout = None connection.errorhandler = Mock(return_value=None) connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) type(connection).application = PropertyMock(return_value=application) diff --git a/test/test_unit_auth_okta.py b/test/test_unit_auth_okta.py index 2a2b15b5e..145f4116b 100644 --- a/test/test_unit_auth_okta.py +++ b/test/test_unit_auth_okta.py @@ -247,6 +247,7 @@ def post_request(url, headers, body, **kwargs): connection = MagicMock() connection._login_timeout = 120 + connection._network_timeout = None connection.errorhandler = Mock(return_value=None) connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) type(connection).application = PropertyMock(return_value=CLIENT_NAME) diff --git a/test/test_unit_auth_webbrowser.py b/test/test_unit_auth_webbrowser.py index 93bd85bf2..58a86449b 100644 --- a/test/test_unit_auth_webbrowser.py +++ b/test/test_unit_auth_webbrowser.py @@ -185,6 +185,7 @@ def post_request(url, headers, body, **kwargs): connection = MagicMock() connection._login_timeout = 120 + connection._network_timeout = None connection.errorhandler = Mock(return_value=None) connection._ocsp_mode = Mock(return_value=OCSPMode.FAIL_OPEN) type(connection).application = PropertyMock(return_value=CLIENT_NAME) diff --git a/test/test_unit_connection.py b/test/test_unit_connection.py index 3276fc88b..2f6fc1baf 100644 --- a/test/test_unit_connection.py +++ b/test/test_unit_connection.py @@ -3,123 +3,9 @@ # # Copyright (c) 2012-2019 Snowflake Computing Inc. All right reserved. # -import os - import pytest import snowflake.connector from mock import patch -from snowflake.connector.auth import delete_temporary_credential_file - - -@patch( - 'snowflake.connector.auth_webbrowser.AuthByWebBrowser.authenticate') -@patch( - 'snowflake.connector.network.SnowflakeRestful._post_request' -) -def test_connect_externalbrowser( - mockSnowflakeRestfulPostRequest, - mockAuthByBrowserAuthenticate): - """ - Connect with authentictor=externalbrowser mock. - """ - - os.environ['SF_TEMPORARY_CREDENTIAL_CACHE_DIR'] = os.getenv( - "WORKSPACE", os.path.expanduser("~")) - - def mock_post_request(url, headers, json_body, **kwargs): - global mock_cnt - ret = None - if mock_cnt == 0: - # return from /v1/login-request - ret = { - u'success': True, - u'message': None, - u'data': { - u'token': u'TOKEN', - u'masterToken': u'MASTER_TOKEN', - u'idToken': u'ID_TOKEN', - }} - elif mock_cnt == 1: - # return from /token-request - ret = { - u'success': True, - u'message': None, - u'data': { - u'sessionToken': u'NEW_TOKEN', - }} - elif mock_cnt == 2: - # return from USE WAREHOUSE TESTWH_NEW - ret = { - u'success': True, - u'message': None, - u'data': { - u'finalDatabase': 'TESTDB', - u'finalWarehouse': 'TESTWH_NEW', - }} - elif mock_cnt == 3: - # return from USE DATABASE TESTDB_NEW - ret = { - u'success': True, - u'message': None, - u'data': { - u'finalDatabase': 'TESTDB_NEW', - u'finalWarehouse': 'TESTWH_NEW', - }} - elif mock_cnt == 4: - # return from SELECT 1 - ret = { - u'success': True, - u'message': None, - u'data': { - u'finalDatabase': 'TESTDB_NEW', - u'finalWarehouse': 'TESTWH_NEW', - }} - mock_cnt += 1 - return ret - - global mock_cnt - mock_cnt = 0 - - # pre-authentication doesn't matter - mockAuthByBrowserAuthenticate.return_value = None - - # POST requests mock - mockSnowflakeRestfulPostRequest.side_effect = mock_post_request - - delete_temporary_credential_file() - - mock_cnt = 0 - - account = 'testaccount' - user = 'testuser' - authenticator = 'externalbrowser' - - # first connection - con = snowflake.connector.connect( - account=account, - user=user, - authenticator=authenticator, - database='TESTDB', - warehouse='TESTWH', - ) - assert con._rest.token == u'TOKEN' - assert con._rest.master_token == u'MASTER_TOKEN' - assert con._rest.id_token == u'ID_TOKEN' - - # second connection that uses the id token to get the session token - con = snowflake.connector.connect( - account=account, - user=user, - authenticator=authenticator, - database='TESTDB_NEW', # override the database - warehouse='TESTWH_NEW', # override the warehouse - ) - - assert con._rest.token == u'NEW_TOKEN' - assert con._rest.master_token is None - assert con._rest.id_token == 'ID_TOKEN' - assert con.database == 'TESTDB_NEW' - assert con.warehouse == 'TESTWH_NEW' @patch( @@ -137,7 +23,7 @@ def mock_post_request(url, headers, json_body, **kwargs): u'data': { u'token': u'TOKEN', u'masterToken': u'MASTER_TOKEN', - u'idToken': u'ID_TOKEN', + u'idToken': None, u'parameters': [ {'name': 'SERVICE_NAME', 'value': "FAKE_SERVICE_NAME"} ], @@ -180,7 +66,7 @@ def mock_post_request(url, headers, json_body, **kwargs): u'data': { u'token': u'TOKEN', u'masterToken': u'MASTER_TOKEN', - u'idToken': u'ID_TOKEN', + u'idToken': None, u'parameters': [ {'name': 'SERVICE_NAME', 'value': "FAKE_SERVICE_NAME"} ], diff --git a/tested_requirements/requirements_35.txt b/tested_requirements/requirements_35.txt index cb574391f..6f8d6230d 100644 --- a/tested_requirements/requirements_35.txt +++ b/tested_requirements/requirements_35.txt @@ -1,5 +1,5 @@ asn1crypto==1.3.0 -azure-common==1.1.24 +azure-common==1.1.25 azure-storage-blob==2.1.0 azure-storage-common==2.1.0 boto3==1.11.17 diff --git a/tested_requirements/requirements_36.txt b/tested_requirements/requirements_36.txt index cb574391f..6f8d6230d 100644 --- a/tested_requirements/requirements_36.txt +++ b/tested_requirements/requirements_36.txt @@ -1,5 +1,5 @@ asn1crypto==1.3.0 -azure-common==1.1.24 +azure-common==1.1.25 azure-storage-blob==2.1.0 azure-storage-common==2.1.0 boto3==1.11.17 diff --git a/tested_requirements/requirements_37.txt b/tested_requirements/requirements_37.txt index cb574391f..6f8d6230d 100644 --- a/tested_requirements/requirements_37.txt +++ b/tested_requirements/requirements_37.txt @@ -1,5 +1,5 @@ asn1crypto==1.3.0 -azure-common==1.1.24 +azure-common==1.1.25 azure-storage-blob==2.1.0 azure-storage-common==2.1.0 boto3==1.11.17 diff --git a/tested_requirements/requirements_38-linux_x86_64.whl.reqs.txt b/tested_requirements/requirements_38.txt similarity index 95% rename from tested_requirements/requirements_38-linux_x86_64.whl.reqs.txt rename to tested_requirements/requirements_38.txt index cb574391f..6f8d6230d 100644 --- a/tested_requirements/requirements_38-linux_x86_64.whl.reqs.txt +++ b/tested_requirements/requirements_38.txt @@ -1,5 +1,5 @@ asn1crypto==1.3.0 -azure-common==1.1.24 +azure-common==1.1.25 azure-storage-blob==2.1.0 azure-storage-common==2.1.0 boto3==1.11.17 diff --git a/tox.ini b/tox.ini index 0190a2c9e..1de7ddc45 100644 --- a/tox.ini +++ b/tox.ini @@ -1,25 +1,29 @@ [tox] envlist = fix_lint, - py{35,36,37,38}{-pandas,}, + py{35,36,37,38}{-pandas,-sso,}, coverage skip_missing_interpreters = true +requires = + tox-external-wheels>=0.1.6 [testenv] description = run the tests with pytest under {basepython} extras = development pandas: pandas + sso: secure-local-storage deps = pip >= 19.3.1 install_command = python -m pip install -U {opts} {packages} external_wheels = -; external wheels location can be overriden in command line - py35: dist/docker/repaired_wheels/*cp35*.whl - py36: dist/docker/repaired_wheels/*cp36*.whl - py37: dist/docker/repaired_wheels/*cp37*.whl - py38: dist/docker/repaired_wheels/*cp38*.whl +; external wheels location can be overwritten in command line + py35: ../../dist/docker/repaired_wheels/*cp35*.whl + py36: ../../dist/docker/repaired_wheels/*cp36*.whl + py37: ../../dist/docker/repaired_wheels/*cp37*.whl + py38: ../../dist/docker/repaired_wheels/*cp38*.whl setenv = COVERAGE_FILE = {env:COVERAGE_FILE:{toxworkdir}/.coverage.{envname}} + ci: SNOWFLAKE_PYTEST_OPTS = -vvv passenv = AWS_ACCESS_KEY_ID AWS_SECRET_ACCESS_KEY @@ -27,10 +31,9 @@ passenv = ; This is required on windows. Otherwise pwd module won't be imported successfully, see https://github.com/tox-dev/tox/issues/1455 USERNAME commands = - !pandas-!ci: pytest --cov "snowflake.connector" --junitxml {env:JUNIT_REPORT_DIR:{toxworkdir}}/junit.{envname}.xml --ignore=test/pandas {posargs:test} - pandas-!ci: pytest --cov "snowflake.connector" --junitxml {env:JUNIT_REPORT_DIR:{toxworkdir}}/junit.{envname}.xml {posargs:test/pandas} - !pandas-ci: pytest -vvv --cov "snowflake.connector" --junitxml {env:JUNIT_REPORT_DIR:{toxworkdir}}/junit.{envname}.xml --ignore=test/pandas {posargs} test - pandas-ci: pytest -vvv --cov "snowflake.connector" --junitxml {env:JUNIT_REPORT_DIR:{toxworkdir}}/junit.{envname}.xml {posargs} test/pandas + !pandas-!sso: pytest {env:SNOWFLAKE_PYTEST_OPTS:} --cov "snowflake.connector" --junitxml {env:JUNIT_REPORT_DIR:{toxworkdir}}/junit.{envname}.xml --ignore=test/pandas --ignore=test/sso {posargs} test + pandas: pytest {env:SNOWFLAKE_PYTEST_OPTS:} --cov "snowflake.connector" --junitxml {env:JUNIT_REPORT_DIR:{toxworkdir}}/junit.{envname}.xml {posargs} test/pandas + sso: pytest {env:SNOWFLAKE_PYTEST_OPTS:} --cov "snowflake.connector" --junitxml {env:JUNIT_REPORT_DIR:{toxworkdir}}/junit.{envname}.xml {posargs} test/sso [testenv:coverage] description = [run locally after tests]: combine coverage data and create report; @@ -55,7 +58,6 @@ deps = flake8 commands = flake8 {posargs} [testenv:fix_lint] -extras = description = format the code base to adhere to our styles, and complain about what we cannot do automatically passenv = PROGRAMDATA @@ -68,7 +70,7 @@ commands = /bin/bash -c 'pre-commit run --files **/*' [testenv:dev] description = create dev environment -extras = pandas, development +extras = pandas, development, sso usedevelop = True commands = python -m pip list --format=columns python -c "print(r'{envpython}')" diff --git a/version.py b/version.py index 81ce4930a..4d8deeee2 100644 --- a/version.py +++ b/version.py @@ -1,3 +1,3 @@ # Update this for the versions # Don't change the forth version number from None -VERSION = (2, 2, 2, None) +VERSION = (2, 2, 3, None)