Skip to content

Commit

Permalink
refactor connection tests (#18881)
Browse files Browse the repository at this point in the history
  • Loading branch information
danarwix authored Oct 11, 2021
1 parent eae7b2e commit d59be1e
Showing 1 changed file with 65 additions and 97 deletions.
162 changes: 65 additions & 97 deletions tests/providers/trino/hooks/test_trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,72 +30,85 @@
from airflow.models import Connection
from airflow.providers.trino.hooks.trino import TrinoHook

HOOK_GET_CONNECTION = 'airflow.providers.trino.hooks.trino.TrinoHook.get_connection'
BASIC_AUTHENTICATION = 'airflow.providers.trino.hooks.trino.trino.auth.BasicAuthentication'
KERBEROS_AUTHENTICATION = 'airflow.providers.trino.hooks.trino.trino.auth.KerberosAuthentication'
TRINO_DBAPI_CONNECT = 'airflow.providers.trino.hooks.trino.trino.dbapi.connect'


class TestTrinoHookConn(unittest.TestCase):
@patch('airflow.providers.trino.hooks.trino.trino.auth.BasicAuthentication')
@patch('airflow.providers.trino.hooks.trino.trino.dbapi.connect')
@patch('airflow.providers.trino.hooks.trino.TrinoHook.get_connection')
@patch(BASIC_AUTHENTICATION)
@patch(TRINO_DBAPI_CONNECT)
@patch(HOOK_GET_CONNECTION)
def test_get_conn_basic_auth(self, mock_get_connection, mock_connect, mock_basic_auth):
mock_get_connection.return_value = Connection(
login='login', password='password', host='host', schema='hive'
)

conn = TrinoHook().get_conn()
mock_connect.assert_called_once_with(
catalog='hive',
host='host',
port=None,
http_scheme='http',
schema='hive',
source='airflow',
user='login',
isolation_level=0,
auth=mock_basic_auth.return_value,
verify=True,
)
self.set_get_connection_return_value(mock_get_connection, password='password')
TrinoHook().get_conn()
self.assert_connection_called_with(mock_connect, auth=mock_basic_auth)
mock_basic_auth.assert_called_once_with('login', 'password')
assert mock_connect.return_value == conn

@patch('airflow.providers.trino.hooks.trino.TrinoHook.get_connection')
@patch(HOOK_GET_CONNECTION)
def test_get_conn_invalid_auth(self, mock_get_connection):
mock_get_connection.return_value = Connection(
login='login',
extras = {'auth': 'kerberos'}
self.set_get_connection_return_value(
mock_get_connection,
password='password',
host='host',
schema='hive',
extra=json.dumps({'auth': 'kerberos'}),
extra=json.dumps(extras),
)
with pytest.raises(
AirflowException, match=re.escape("Kerberos authorization doesn't support password.")
):
TrinoHook().get_conn()

@patch('airflow.providers.trino.hooks.trino.trino.auth.KerberosAuthentication')
@patch('airflow.providers.trino.hooks.trino.trino.dbapi.connect')
@patch('airflow.providers.trino.hooks.trino.TrinoHook.get_connection')
@patch(KERBEROS_AUTHENTICATION)
@patch(TRINO_DBAPI_CONNECT)
@patch(HOOK_GET_CONNECTION)
def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect, mock_auth):
mock_get_connection.return_value = Connection(
login='login',
host='host',
schema='hive',
extra=json.dumps(
{
'auth': 'kerberos',
'kerberos__config': 'TEST_KERBEROS_CONFIG',
'kerberos__service_name': 'TEST_SERVICE_NAME',
'kerberos__mutual_authentication': 'TEST_MUTUAL_AUTHENTICATION',
'kerberos__force_preemptive': True,
'kerberos__hostname_override': 'TEST_HOSTNAME_OVERRIDE',
'kerberos__sanitize_mutual_error_response': True,
'kerberos__principal': 'TEST_PRINCIPAL',
'kerberos__delegate': 'TEST_DELEGATE',
'kerberos__ca_bundle': 'TEST_CA_BUNDLE',
'verify': 'true',
}
),
extras = {
'auth': 'kerberos',
'kerberos__config': 'TEST_KERBEROS_CONFIG',
'kerberos__service_name': 'TEST_SERVICE_NAME',
'kerberos__mutual_authentication': 'TEST_MUTUAL_AUTHENTICATION',
'kerberos__force_preemptive': True,
'kerberos__hostname_override': 'TEST_HOSTNAME_OVERRIDE',
'kerberos__sanitize_mutual_error_response': True,
'kerberos__principal': 'TEST_PRINCIPAL',
'kerberos__delegate': 'TEST_DELEGATE',
'kerberos__ca_bundle': 'TEST_CA_BUNDLE',
'verify': 'true',
}
self.set_get_connection_return_value(
mock_get_connection,
extra=json.dumps(extras),
)
TrinoHook().get_conn()
self.assert_connection_called_with(mock_connect, auth=mock_auth)

conn = TrinoHook().get_conn()
@parameterized.expand(
[
('False', False),
('false', False),
('true', True),
('true', True),
('/tmp/cert.crt', '/tmp/cert.crt'),
]
)
@patch(HOOK_GET_CONNECTION)
@patch(TRINO_DBAPI_CONNECT)
def test_get_conn_verify(self, current_verify, expected_verify, mock_connect, mock_get_connection):
extras = {'verify': current_verify}
self.set_get_connection_return_value(mock_get_connection, extra=json.dumps(extras))
TrinoHook().get_conn()
self.assert_connection_called_with(mock_connect, verify=expected_verify)

@staticmethod
def set_get_connection_return_value(mock_get_connection, extra=None, password=None):
mocked_connection = Connection(
login='login', password=password, host='host', schema='hive', extra=extra or '{}'
)
mock_get_connection.return_value = mocked_connection

@staticmethod
def assert_connection_called_with(mock_connect, auth=None, verify=True):
mock_connect.assert_called_once_with(
catalog='hive',
host='host',
Expand All @@ -105,54 +118,9 @@ def test_get_conn_kerberos_auth(self, mock_get_connection, mock_connect, mock_au
source='airflow',
user='login',
isolation_level=0,
auth=mock_auth.return_value,
verify=True,
auth=None if not auth else auth.return_value,
verify=verify,
)
mock_auth.assert_called_once_with(
ca_bundle='TEST_CA_BUNDLE',
config='TEST_KERBEROS_CONFIG',
delegate='TEST_DELEGATE',
force_preemptive=True,
hostname_override='TEST_HOSTNAME_OVERRIDE',
mutual_authentication='TEST_MUTUAL_AUTHENTICATION',
principal='TEST_PRINCIPAL',
sanitize_mutual_error_response=True,
service_name='TEST_SERVICE_NAME',
)
assert mock_connect.return_value == conn

@parameterized.expand(
[
('False', False),
('false', False),
('true', True),
('true', True),
('/tmp/cert.crt', '/tmp/cert.crt'),
]
)
def test_get_conn_verify(self, current_verify, expected_verify):
patcher_connect = patch('airflow.providers.trino.hooks.trino.trino.dbapi.connect')
patcher_get_connections = patch('airflow.providers.trino.hooks.trino.TrinoHook.get_connection')

with patcher_connect as mock_connect, patcher_get_connections as mock_get_connection:
mock_get_connection.return_value = Connection(
login='login', host='host', schema='hive', extra=json.dumps({'verify': current_verify})
)

conn = TrinoHook().get_conn()
mock_connect.assert_called_once_with(
catalog='hive',
host='host',
port=None,
http_scheme='http',
schema='hive',
source='airflow',
user='login',
auth=None,
isolation_level=0,
verify=expected_verify,
)
assert mock_connect.return_value == conn


class TestTrinoHook(unittest.TestCase):
Expand Down

0 comments on commit d59be1e

Please sign in to comment.