From fbd76acb3a3baeb1a8c27f80c77b91638f582d00 Mon Sep 17 00:00:00 2001 From: ttzhou Date: Thu, 21 Nov 2024 12:16:48 -0500 Subject: [PATCH 1/2] Allow json_result_force_utf8_encoding specification in SnowflakeHook extra dict --- .../providers/snowflake/hooks/snowflake.py | 15 ++++++++++++- .../tests/snowflake/hooks/test_snowflake.py | 22 +++++++++++++++++++ 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/providers/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/src/airflow/providers/snowflake/hooks/snowflake.py index 81785338e7ba5..bb77063b29cf0 100644 --- a/providers/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/src/airflow/providers/snowflake/hooks/snowflake.py @@ -200,6 +200,7 @@ def _get_conn_params(self) -> dict[str, str | None]: region = self._get_field(extra_dict, "region") or "" role = self._get_field(extra_dict, "role") or "" insecure_mode = _try_to_boolean(self._get_field(extra_dict, "insecure_mode")) + json_result_force_utf8_decoding = _try_to_boolean(self._get_field(extra_dict, "json_result_force_utf8_decoding")) schema = conn.schema or "" client_request_mfa_token = _try_to_boolean(self._get_field(extra_dict, "client_request_mfa_token")) @@ -224,6 +225,9 @@ def _get_conn_params(self) -> dict[str, str | None]: if insecure_mode: conn_config["insecure_mode"] = insecure_mode + if json_result_force_utf8_decoding: + conn_config["json_result_force_utf8_decoding"] = json_result_force_utf8_decoding + if client_request_mfa_token: conn_config["client_request_mfa_token"] = client_request_mfa_token @@ -301,7 +305,13 @@ def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str: for k, v in conn_params.items() if v and k - not in ["session_parameters", "insecure_mode", "private_key", "client_request_mfa_token"] + not in [ + "session_parameters", + "insecure_mode", + "private_key", + "client_request_mfa_token", + "json_result_force_utf8_decoding", + ] } ) @@ -323,6 +333,9 @@ def get_sqlalchemy_engine(self, engine_kwargs=None): if "insecure_mode" in conn_params: engine_kwargs.setdefault("connect_args", {}) engine_kwargs["connect_args"]["insecure_mode"] = True + if "json_result_force_utf8_decoding" in conn_params: + engine_kwargs.setdefault("connect_args", {}) + engine_kwargs["connect_args"]["json_result_force_utf8_decoding"] = True for key in ["session_parameters", "private_key"]: if conn_params.get(key): engine_kwargs.setdefault("connect_args", {}) diff --git a/providers/tests/snowflake/hooks/test_snowflake.py b/providers/tests/snowflake/hooks/test_snowflake.py index b7c9382654be0..d75f1a4baf14c 100644 --- a/providers/tests/snowflake/hooks/test_snowflake.py +++ b/providers/tests/snowflake/hooks/test_snowflake.py @@ -138,6 +138,7 @@ class TestPytestSnowflakeHook: "extra__snowflake__region": "af_region", "extra__snowflake__role": "af_role", "extra__snowflake__insecure_mode": "True", + "extra__snowflake__json_result_force_utf8_decoding": "True", "extra__snowflake__client_request_mfa_token": "True", }, }, @@ -158,6 +159,7 @@ class TestPytestSnowflakeHook: "user": "user", "warehouse": "af_wh", "insecure_mode": True, + "json_result_force_utf8_decoding": True, "client_request_mfa_token": True, }, ), @@ -171,6 +173,7 @@ class TestPytestSnowflakeHook: "extra__snowflake__region": "af_region", "extra__snowflake__role": "af_role", "extra__snowflake__insecure_mode": "False", + "extra__snowflake__json_result_force_utf8_decoding": "False", "extra__snowflake__client_request_mfa_token": "False", }, }, @@ -247,6 +250,7 @@ class TestPytestSnowflakeHook: "extra": { **BASE_CONNECTION_KWARGS["extra"], "extra__snowflake__insecure_mode": False, + "extra__snowflake__json_result_force_utf8_decoding": True, "extra__snowflake__client_request_mfa_token": False, }, }, @@ -266,6 +270,7 @@ class TestPytestSnowflakeHook: "session_parameters": None, "user": "user", "warehouse": "af_wh", + "json_result_force_utf8_decoding": True, }, ), ], @@ -473,6 +478,23 @@ def test_get_sqlalchemy_engine_should_support_insecure_mode(self): ) assert mock_create_engine.return_value == conn + def test_get_sqlalchemy_engine_should_support_json_result_force_utf8_decoding(self): + connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS) + connection_kwargs["extra"]["extra__snowflake__json_result_force_utf8_decoding"] = "True" + + with ( + mock.patch.dict("os.environ", AIRFLOW_CONN_TEST_CONN=Connection(**connection_kwargs).get_uri()), + mock.patch("airflow.providers.snowflake.hooks.snowflake.create_engine") as mock_create_engine, + ): + hook = SnowflakeHook(snowflake_conn_id="test_conn") + conn = hook.get_sqlalchemy_engine() + mock_create_engine.assert_called_once_with( + "snowflake://user:pw@airflow.af_region/db/public" + "?application=AIRFLOW&authenticator=snowflake&role=af_role&warehouse=af_wh", + connect_args={"json_result_force_utf8_decoding": True}, + ) + assert mock_create_engine.return_value == conn + def test_get_sqlalchemy_engine_should_support_session_parameters(self): connection_kwargs = deepcopy(BASE_CONNECTION_KWARGS) connection_kwargs["extra"]["session_parameters"] = {"TEST_PARAM": "AA", "TEST_PARAM_B": 123} From fd6783792ceb333865a543d3e5516b8d88ae4cc8 Mon Sep 17 00:00:00 2001 From: ttzhou Date: Thu, 21 Nov 2024 14:20:36 -0500 Subject: [PATCH 2/2] Use a set for the not in --- .../src/airflow/providers/snowflake/hooks/snowflake.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/providers/src/airflow/providers/snowflake/hooks/snowflake.py b/providers/src/airflow/providers/snowflake/hooks/snowflake.py index bb77063b29cf0..0575ffa2b1cef 100644 --- a/providers/src/airflow/providers/snowflake/hooks/snowflake.py +++ b/providers/src/airflow/providers/snowflake/hooks/snowflake.py @@ -200,7 +200,9 @@ def _get_conn_params(self) -> dict[str, str | None]: region = self._get_field(extra_dict, "region") or "" role = self._get_field(extra_dict, "role") or "" insecure_mode = _try_to_boolean(self._get_field(extra_dict, "insecure_mode")) - json_result_force_utf8_decoding = _try_to_boolean(self._get_field(extra_dict, "json_result_force_utf8_decoding")) + json_result_force_utf8_decoding = _try_to_boolean( + self._get_field(extra_dict, "json_result_force_utf8_decoding") + ) schema = conn.schema or "" client_request_mfa_token = _try_to_boolean(self._get_field(extra_dict, "client_request_mfa_token")) @@ -305,13 +307,13 @@ def _conn_params_to_sqlalchemy_uri(self, conn_params: dict) -> str: for k, v in conn_params.items() if v and k - not in [ + not in { "session_parameters", "insecure_mode", "private_key", "client_request_mfa_token", "json_result_force_utf8_decoding", - ] + } } )