From 6141f77ac4d3ab56dd986f2aa356d4bf6dcc8df0 Mon Sep 17 00:00:00 2001 From: nilan3 Date: Tue, 27 Aug 2024 12:30:14 +0100 Subject: [PATCH 1/2] add support for extra odbc connection properties --- .github/workflows/integration.yml | 1 + .github/workflows/release-internal.yml | 1 + .github/workflows/release-prep.yml | 1 + dagger/run_dbt_spark_tests.py | 2 +- dbt/adapters/spark/connections.py | 56 ++++++++++++------- tests/conftest.py | 15 +++++ .../test_incremental_on_schema_change.py | 4 +- .../test_incremental_strategies.py | 6 +- tests/functional/adapter/test_constraints.py | 6 +- tests/functional/adapter/test_python_model.py | 8 +-- .../adapter/test_store_test_failures.py | 2 +- 11 files changed, 67 insertions(+), 35 deletions(-) diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml index 699d45391..35bd9cae0 100644 --- a/.github/workflows/integration.yml +++ b/.github/workflows/integration.yml @@ -76,6 +76,7 @@ jobs: test: - "apache_spark" - "spark_session" + - "spark_http_odbc" - "databricks_sql_endpoint" - "databricks_cluster" - "databricks_http_cluster" diff --git a/.github/workflows/release-internal.yml b/.github/workflows/release-internal.yml index d4e7a3c93..1a5090312 100644 --- a/.github/workflows/release-internal.yml +++ b/.github/workflows/release-internal.yml @@ -79,6 +79,7 @@ jobs: test: - "apache_spark" - "spark_session" + - "spark_http_odbc" - "databricks_sql_endpoint" - "databricks_cluster" - "databricks_http_cluster" diff --git a/.github/workflows/release-prep.yml b/.github/workflows/release-prep.yml index 9cb2c3e19..9937463d3 100644 --- a/.github/workflows/release-prep.yml +++ b/.github/workflows/release-prep.yml @@ -482,6 +482,7 @@ jobs: test: - "apache_spark" - "spark_session" + - "spark_http_odbc" - "databricks_sql_endpoint" - "databricks_cluster" - "databricks_http_cluster" diff --git a/dagger/run_dbt_spark_tests.py b/dagger/run_dbt_spark_tests.py index 15f9cf2c2..67fa56587 100644 --- a/dagger/run_dbt_spark_tests.py +++ b/dagger/run_dbt_spark_tests.py @@ -137,7 +137,7 @@ async def test_spark(test_args): spark_ctr, spark_host = get_spark_container(client) tst_container = tst_container.with_service_binding(alias=spark_host, service=spark_ctr) - elif test_profile in ["databricks_cluster", "databricks_sql_endpoint"]: + elif test_profile in ["databricks_cluster", "databricks_sql_endpoint", "spark_http_odbc"]: tst_container = ( tst_container.with_workdir("/") .with_exec(["./scripts/configure_odbc.sh"]) diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 0405eaf5b..f5dcbe6a1 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -78,6 +78,7 @@ class SparkCredentials(Credentials): auth: Optional[str] = None kerberos_service_name: Optional[str] = None organization: str = "0" + connection_str_extra: Optional[str] = None connect_retries: int = 0 connect_timeout: int = 10 use_ssl: bool = False @@ -154,7 +155,7 @@ def __post_init__(self) -> None: f"ImportError({e.msg})" ) from e - if self.method != SparkConnectionMethod.SESSION: + if self.method != SparkConnectionMethod.SESSION and self.host is not None: self.host = self.host.rstrip("/") self.server_side_parameters = { @@ -483,38 +484,51 @@ def open(cls, connection: Connection) -> Connection: http_path = cls.SPARK_SQL_ENDPOINT_HTTP_PATH.format( endpoint=creds.endpoint ) + elif creds.connection_str_extra is not None: + required_fields = ["driver", "host", "port", "connection_str_extra"] else: raise DbtConfigError( - "Either `cluster` or `endpoint` must set when" + "Either `cluster`, `endpoint`, `connection_str_extra` must set when" " using the odbc method to connect to Spark" ) cls.validate_creds(creds, required_fields) - dbt_spark_version = __version__.version user_agent_entry = ( f"dbt-labs-dbt-spark/{dbt_spark_version} (Databricks)" # noqa ) - # http://simba.wpengine.com/products/Spark/doc/ODBC_InstallGuide/unix/content/odbc/hi/configuring/serverside.htm ssp = {f"SSP_{k}": f"{{{v}}}" for k, v in creds.server_side_parameters.items()} - - # https://www.simba.com/products/Spark/doc/v2/ODBC_InstallGuide/unix/content/odbc/options/driver.htm - connection_str = _build_odbc_connnection_string( - DRIVER=creds.driver, - HOST=creds.host, - PORT=creds.port, - UID="token", - PWD=creds.token, - HTTPPath=http_path, - AuthMech=3, - SparkServerType=3, - ThriftTransport=2, - SSL=1, - UserAgentEntry=user_agent_entry, - LCaseSspKeyName=0 if ssp else 1, - **ssp, - ) + if creds.token is not None: + # https://www.simba.com/products/Spark/doc/v2/ODBC_InstallGuide/unix/content/odbc/options/driver.htm + connection_str = _build_odbc_connnection_string( + DRIVER=creds.driver, + HOST=creds.host, + PORT=creds.port, + UID="token", + PWD=creds.token, + HTTPPath=http_path, + AuthMech=3, + SparkServerType=3, + ThriftTransport=2, + SSL=1, + UserAgentEntry=user_agent_entry, + LCaseSspKeyName=0 if ssp else 1, + **ssp, + ) + else: + connection_str = _build_odbc_connnection_string( + DRIVER=creds.driver, + HOST=creds.host, + PORT=creds.port, + ThriftTransport=2, + SSL=1, + UserAgentEntry=user_agent_entry, + LCaseSspKeyName=0 if ssp else 1, + **ssp, + ) + if creds.connection_str_extra is not None: + connection_str = connection_str + ";" + creds.connection_str_extra conn = pyodbc.connect(connection_str, autocommit=True) handle = PyodbcConnectionWrapper(conn) diff --git a/tests/conftest.py b/tests/conftest.py index efba41a5f..fe4174f5c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -30,6 +30,8 @@ def dbt_profile_target(request): target = databricks_http_cluster_target() elif profile_type == "spark_session": target = spark_session_target() + elif profile_type == "spark_http_odbc": + target = spark_http_odbc_target() else: raise ValueError(f"Invalid profile type '{profile_type}'") return target @@ -101,6 +103,19 @@ def spark_session_target(): "method": "session", } +def spark_http_odbc_target(): + return { + "type": "spark", + "method": "odbc", + "host": os.getenv("DBT_DATABRICKS_HOST_NAME"), + "port": 443, + "driver": os.getenv("ODBC_DRIVER"), + "connection_str_extra": f'UID=token;PWD={os.getenv("DBT_DATABRICKS_TOKEN")};HTTPPath=/sql/1.0/endpoints/{os.getenv("DBT_DATABRICKS_ENDPOINT")};AuthMech=3;SparkServerType=3', + "connect_retries": 3, + "connect_timeout": 5, + "retry_all": True, + } + @pytest.fixture(autouse=True) def skip_by_profile_type(request): diff --git a/tests/functional/adapter/incremental/test_incremental_on_schema_change.py b/tests/functional/adapter/incremental/test_incremental_on_schema_change.py index 478329668..7e05290ad 100644 --- a/tests/functional/adapter/incremental/test_incremental_on_schema_change.py +++ b/tests/functional/adapter/incremental/test_incremental_on_schema_change.py @@ -21,7 +21,7 @@ def test_run_incremental_fail_on_schema_change(self, project): assert "Compilation Error" in results_two[1].message -@pytest.mark.skip_profile("databricks_sql_endpoint") +@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_http_odbc") class TestAppendOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail): @pytest.fixture(scope="class") def project_config_update(self): @@ -32,7 +32,7 @@ def project_config_update(self): } -@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session") +@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session", "spark_http_odbc" class TestInsertOverwriteOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail): @pytest.fixture(scope="class") def project_config_update(self): diff --git a/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py b/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py index b05fcb279..eb447ee4f 100644 --- a/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py +++ b/tests/functional/adapter/incremental_strategies/test_incremental_strategies.py @@ -55,7 +55,7 @@ def run_and_test(self, project): check_relations_equal(project.adapter, ["default_append", "expected_append"]) @pytest.mark.skip_profile( - "databricks_http_cluster", "databricks_sql_endpoint", "spark_session" + "databricks_http_cluster", "databricks_sql_endpoint", "spark_session", "spark_http_odbc" ) def test_default_append(self, project): self.run_and_test(project) @@ -77,7 +77,7 @@ def run_and_test(self, project): check_relations_equal(project.adapter, ["insert_overwrite_partitions", "expected_upsert"]) @pytest.mark.skip_profile( - "databricks_http_cluster", "databricks_sql_endpoint", "spark_session" + "databricks_http_cluster", "databricks_sql_endpoint", "spark_session", "spark_http_odbc" ) def test_insert_overwrite(self, project): self.run_and_test(project) @@ -103,7 +103,7 @@ def run_and_test(self, project): check_relations_equal(project.adapter, ["merge_update_columns", "expected_partial_upsert"]) @pytest.mark.skip_profile( - "apache_spark", "databricks_http_cluster", "databricks_sql_endpoint", "spark_session" + "apache_spark", "databricks_http_cluster", "databricks_sql_endpoint", "spark_session", "spark_http_odbc" ) def test_delta_strategies(self, project): self.run_and_test(project) diff --git a/tests/functional/adapter/test_constraints.py b/tests/functional/adapter/test_constraints.py index e35a13a64..0b5b80e6e 100644 --- a/tests/functional/adapter/test_constraints.py +++ b/tests/functional/adapter/test_constraints.py @@ -183,7 +183,7 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster" + "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster", "spark_http_odbc" ) class TestSparkTableConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseTableConstraintsColumnsEqual @@ -198,7 +198,7 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster" + "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster", "spark_http_odbc" ) class TestSparkViewConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseViewConstraintsColumnsEqual @@ -213,7 +213,7 @@ def models(self): @pytest.mark.skip_profile( - "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster" + "spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster", "spark_http_odbc" ) class TestSparkIncrementalConstraintsColumnsEqualDatabricksHTTP( DatabricksHTTPSetup, BaseIncrementalConstraintsColumnsEqual diff --git a/tests/functional/adapter/test_python_model.py b/tests/functional/adapter/test_python_model.py index cd798d1da..60125be09 100644 --- a/tests/functional/adapter/test_python_model.py +++ b/tests/functional/adapter/test_python_model.py @@ -8,12 +8,12 @@ from dbt.tests.adapter.python_model.test_spark import BasePySparkTests -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc") class TestPythonModelSpark(BasePythonModelTests): pass -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc") class TestPySpark(BasePySparkTests): def test_different_dataframes(self, project): """ @@ -33,7 +33,7 @@ def test_different_dataframes(self, project): assert len(results) == 3 -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc") class TestPythonIncrementalModelSpark(BasePythonIncrementalTests): @pytest.fixture(scope="class") def project_config_update(self): @@ -78,7 +78,7 @@ def model(dbt, spark): """ -@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint") +@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc") class TestChangingSchemaSpark: """ Confirm that we can setup a spot instance and parse required packages into the Databricks job. diff --git a/tests/functional/adapter/test_store_test_failures.py b/tests/functional/adapter/test_store_test_failures.py index e78bd4f71..91f52e4b4 100644 --- a/tests/functional/adapter/test_store_test_failures.py +++ b/tests/functional/adapter/test_store_test_failures.py @@ -7,7 +7,7 @@ ) -@pytest.mark.skip_profile("spark_session", "databricks_cluster", "databricks_sql_endpoint") +@pytest.mark.skip_profile("spark_session", "databricks_cluster", "databricks_sql_endpoint", "spark_http_odbc") class TestSparkStoreTestFailures(StoreTestFailuresBase): @pytest.fixture(scope="class") def project_config_update(self): From 0e91dbd46c5c5dcdc6fd76934411ad2fa2132cb2 Mon Sep 17 00:00:00 2001 From: nilan3 Date: Tue, 27 Aug 2024 12:46:37 +0100 Subject: [PATCH 2/2] clean up --- dbt/adapters/spark/connections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index f5dcbe6a1..7a1639816 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -155,7 +155,7 @@ def __post_init__(self) -> None: f"ImportError({e.msg})" ) from e - if self.method != SparkConnectionMethod.SESSION and self.host is not None: + if self.method != SparkConnectionMethod.SESSION: self.host = self.host.rstrip("/") self.server_side_parameters = {