Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/http odbc conn extra #1092

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ jobs:
test:
- "apache_spark"
- "spark_session"
- "spark_http_odbc"
- "databricks_sql_endpoint"
- "databricks_cluster"
- "databricks_http_cluster"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/release-internal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ jobs:
test:
- "apache_spark"
- "spark_session"
- "spark_http_odbc"
- "databricks_sql_endpoint"
- "databricks_cluster"
- "databricks_http_cluster"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/release-prep.yml
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,7 @@ jobs:
test:
- "apache_spark"
- "spark_session"
- "spark_http_odbc"
- "databricks_sql_endpoint"
- "databricks_cluster"
- "databricks_http_cluster"
Expand Down
2 changes: 1 addition & 1 deletion dagger/run_dbt_spark_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ async def test_spark(test_args):
spark_ctr, spark_host = get_spark_container(client)
tst_container = tst_container.with_service_binding(alias=spark_host, service=spark_ctr)

elif test_profile in ["databricks_cluster", "databricks_sql_endpoint"]:
elif test_profile in ["databricks_cluster", "databricks_sql_endpoint", "spark_http_odbc"]:
tst_container = (
tst_container.with_workdir("/")
.with_exec(["./scripts/configure_odbc.sh"])
Expand Down
54 changes: 34 additions & 20 deletions dbt/adapters/spark/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class SparkCredentials(Credentials):
auth: Optional[str] = None
kerberos_service_name: Optional[str] = None
organization: str = "0"
connection_str_extra: Optional[str] = None
connect_retries: int = 0
connect_timeout: int = 10
use_ssl: bool = False
Expand Down Expand Up @@ -483,38 +484,51 @@ def open(cls, connection: Connection) -> Connection:
http_path = cls.SPARK_SQL_ENDPOINT_HTTP_PATH.format(
endpoint=creds.endpoint
)
elif creds.connection_str_extra is not None:
required_fields = ["driver", "host", "port", "connection_str_extra"]
else:
raise DbtConfigError(
"Either `cluster` or `endpoint` must set when"
"Either `cluster`, `endpoint`, `connection_str_extra` must set when"
" using the odbc method to connect to Spark"
)

cls.validate_creds(creds, required_fields)

dbt_spark_version = __version__.version
user_agent_entry = (
f"dbt-labs-dbt-spark/{dbt_spark_version} (Databricks)" # noqa
)

# http://simba.wpengine.com/products/Spark/doc/ODBC_InstallGuide/unix/content/odbc/hi/configuring/serverside.htm
ssp = {f"SSP_{k}": f"{{{v}}}" for k, v in creds.server_side_parameters.items()}

# https://www.simba.com/products/Spark/doc/v2/ODBC_InstallGuide/unix/content/odbc/options/driver.htm
connection_str = _build_odbc_connnection_string(
DRIVER=creds.driver,
HOST=creds.host,
PORT=creds.port,
UID="token",
PWD=creds.token,
HTTPPath=http_path,
AuthMech=3,
SparkServerType=3,
ThriftTransport=2,
SSL=1,
UserAgentEntry=user_agent_entry,
LCaseSspKeyName=0 if ssp else 1,
**ssp,
)
if creds.token is not None:
# https://www.simba.com/products/Spark/doc/v2/ODBC_InstallGuide/unix/content/odbc/options/driver.htm
connection_str = _build_odbc_connnection_string(
DRIVER=creds.driver,
HOST=creds.host,
PORT=creds.port,
UID="token",
PWD=creds.token,
HTTPPath=http_path,
AuthMech=3,
SparkServerType=3,
ThriftTransport=2,
SSL=1,
UserAgentEntry=user_agent_entry,
LCaseSspKeyName=0 if ssp else 1,
**ssp,
)
else:
connection_str = _build_odbc_connnection_string(
DRIVER=creds.driver,
HOST=creds.host,
PORT=creds.port,
ThriftTransport=2,
SSL=1,
UserAgentEntry=user_agent_entry,
LCaseSspKeyName=0 if ssp else 1,
**ssp,
)
if creds.connection_str_extra is not None:
connection_str = connection_str + ";" + creds.connection_str_extra

conn = pyodbc.connect(connection_str, autocommit=True)
handle = PyodbcConnectionWrapper(conn)
Expand Down
15 changes: 15 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def dbt_profile_target(request):
target = databricks_http_cluster_target()
elif profile_type == "spark_session":
target = spark_session_target()
elif profile_type == "spark_http_odbc":
target = spark_http_odbc_target()
else:
raise ValueError(f"Invalid profile type '{profile_type}'")
return target
Expand Down Expand Up @@ -101,6 +103,19 @@ def spark_session_target():
"method": "session",
}

def spark_http_odbc_target():
return {
"type": "spark",
"method": "odbc",
"host": os.getenv("DBT_DATABRICKS_HOST_NAME"),
"port": 443,
"driver": os.getenv("ODBC_DRIVER"),
"connection_str_extra": f'UID=token;PWD={os.getenv("DBT_DATABRICKS_TOKEN")};HTTPPath=/sql/1.0/endpoints/{os.getenv("DBT_DATABRICKS_ENDPOINT")};AuthMech=3;SparkServerType=3',
"connect_retries": 3,
"connect_timeout": 5,
"retry_all": True,
}


@pytest.fixture(autouse=True)
def skip_by_profile_type(request):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def test_run_incremental_fail_on_schema_change(self, project):
assert "Compilation Error" in results_two[1].message


@pytest.mark.skip_profile("databricks_sql_endpoint")
@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_http_odbc")
class TestAppendOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand All @@ -32,7 +32,7 @@ def project_config_update(self):
}


@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session")
@pytest.mark.skip_profile("databricks_sql_endpoint", "spark_session", "spark_http_odbc"
class TestInsertOverwriteOnSchemaChange(IncrementalOnSchemaChangeIgnoreFail):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def run_and_test(self, project):
check_relations_equal(project.adapter, ["default_append", "expected_append"])

@pytest.mark.skip_profile(
"databricks_http_cluster", "databricks_sql_endpoint", "spark_session"
"databricks_http_cluster", "databricks_sql_endpoint", "spark_session", "spark_http_odbc"
)
def test_default_append(self, project):
self.run_and_test(project)
Expand All @@ -77,7 +77,7 @@ def run_and_test(self, project):
check_relations_equal(project.adapter, ["insert_overwrite_partitions", "expected_upsert"])

@pytest.mark.skip_profile(
"databricks_http_cluster", "databricks_sql_endpoint", "spark_session"
"databricks_http_cluster", "databricks_sql_endpoint", "spark_session", "spark_http_odbc"
)
def test_insert_overwrite(self, project):
self.run_and_test(project)
Expand All @@ -103,7 +103,7 @@ def run_and_test(self, project):
check_relations_equal(project.adapter, ["merge_update_columns", "expected_partial_upsert"])

@pytest.mark.skip_profile(
"apache_spark", "databricks_http_cluster", "databricks_sql_endpoint", "spark_session"
"apache_spark", "databricks_http_cluster", "databricks_sql_endpoint", "spark_session", "spark_http_odbc"
)
def test_delta_strategies(self, project):
self.run_and_test(project)
Expand Down
6 changes: 3 additions & 3 deletions tests/functional/adapter/test_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def models(self):


@pytest.mark.skip_profile(
"spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster"
"spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster", "spark_http_odbc"
)
class TestSparkTableConstraintsColumnsEqualDatabricksHTTP(
DatabricksHTTPSetup, BaseTableConstraintsColumnsEqual
Expand All @@ -198,7 +198,7 @@ def models(self):


@pytest.mark.skip_profile(
"spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster"
"spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster", "spark_http_odbc"
)
class TestSparkViewConstraintsColumnsEqualDatabricksHTTP(
DatabricksHTTPSetup, BaseViewConstraintsColumnsEqual
Expand All @@ -213,7 +213,7 @@ def models(self):


@pytest.mark.skip_profile(
"spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster"
"spark_session", "apache_spark", "databricks_sql_endpoint", "databricks_cluster", "spark_http_odbc"
)
class TestSparkIncrementalConstraintsColumnsEqualDatabricksHTTP(
DatabricksHTTPSetup, BaseIncrementalConstraintsColumnsEqual
Expand Down
8 changes: 4 additions & 4 deletions tests/functional/adapter/test_python_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
from dbt.tests.adapter.python_model.test_spark import BasePySparkTests


@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc")
class TestPythonModelSpark(BasePythonModelTests):
pass


@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc")
class TestPySpark(BasePySparkTests):
def test_different_dataframes(self, project):
"""
Expand All @@ -33,7 +33,7 @@ def test_different_dataframes(self, project):
assert len(results) == 3


@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc")
class TestPythonIncrementalModelSpark(BasePythonIncrementalTests):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down Expand Up @@ -78,7 +78,7 @@ def model(dbt, spark):
"""


@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint")
@pytest.mark.skip_profile("apache_spark", "spark_session", "databricks_sql_endpoint", "spark_http_odbc")
class TestChangingSchemaSpark:
"""
Confirm that we can setup a spot instance and parse required packages into the Databricks job.
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/adapter/test_store_test_failures.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)


@pytest.mark.skip_profile("spark_session", "databricks_cluster", "databricks_sql_endpoint")
@pytest.mark.skip_profile("spark_session", "databricks_cluster", "databricks_sql_endpoint", "spark_http_odbc")
class TestSparkStoreTestFailures(StoreTestFailuresBase):
@pytest.fixture(scope="class")
def project_config_update(self):
Expand Down
Loading