diff --git a/.changes/unreleased/Under the Hood-20240111-114806.yaml b/.changes/unreleased/Under the Hood-20240111-114806.yaml new file mode 100644 index 000000000..31705f468 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20240111-114806.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Update import paths and list_relations to support decoupling adapters/core +time: 2024-01-11T11:48:06.120111-08:00 +custom: + Author: colin-rogers-dbt + Issue: "972" diff --git a/dagger/requirements.txt b/dagger/requirements.txt index df36543c2..b50c448d3 100644 --- a/dagger/requirements.txt +++ b/dagger/requirements.txt @@ -1,2 +1,2 @@ -dagger-io~=0.8.0 +dagger-io~=0.9.7 python-dotenv diff --git a/dagger/run_dbt_spark_tests.py b/dagger/run_dbt_spark_tests.py index 91b9988c9..436cb1e92 100644 --- a/dagger/run_dbt_spark_tests.py +++ b/dagger/run_dbt_spark_tests.py @@ -2,6 +2,7 @@ import argparse import sys +from typing import Dict import anyio as anyio import dagger as dagger @@ -19,7 +20,7 @@ TESTING_ENV_VARS.update({"ODBC_DRIVER": "/opt/simba/spark/lib/64/libsparkodbc_sb64.so"}) -def env_variables(envs: dict[str, str]): +def env_variables(envs: Dict[str, str]): def env_variables_inner(ctr: dagger.Container): for key, value in envs.items(): ctr = ctr.with_env_variable(key, value) @@ -28,18 +29,19 @@ def env_variables_inner(ctr: dagger.Container): return env_variables_inner -async def get_postgres_container(client: dagger.Client) -> (dagger.Container, str): - ctr = await ( +def get_postgres_container(client: dagger.Client) -> (dagger.Container, str): + ctr = ( client.container() .from_("postgres:13") .with_env_variable("POSTGRES_PASSWORD", "postgres") .with_exposed_port(PG_PORT) + .as_service() ) return ctr, "postgres_db" -async def get_spark_container(client: dagger.Client) -> (dagger.Container, str): +def get_spark_container(client: dagger.Client) -> (dagger.Service, str): spark_dir = client.host().directory("./dagger/spark-container") spark_ctr_base = ( client.container() @@ -63,7 +65,7 @@ async def get_spark_container(client: dagger.Client) -> (dagger.Container, str): ) # postgres is the metastore here - pg_ctr, pg_host = await get_postgres_container(client) + pg_ctr, pg_host = get_postgres_container(client) spark_ctr = ( spark_ctr_base.with_service_binding(alias=pg_host, service=pg_ctr) @@ -77,6 +79,7 @@ async def get_spark_container(client: dagger.Client) -> (dagger.Container, str): ] ) .with_exposed_port(10000) + .as_service() ) return spark_ctr, "spark_db" @@ -85,29 +88,49 @@ async def get_spark_container(client: dagger.Client) -> (dagger.Container, str): async def test_spark(test_args): async with dagger.Connection(dagger.Config(log_output=sys.stderr)) as client: test_profile = test_args.profile - req_files = client.host().directory("./", include=["*.txt", "*.env", "*.ini"]) + + # create cache volumes, these are persisted between runs saving time when developing locally + os_reqs_cache = client.cache_volume("os_reqs") + pip_cache = client.cache_volume("pip") + + # setup directories as we don't want to copy the whole repo into the container + req_files = client.host().directory( + "./", include=["*.txt", "*.env", "*.ini", "*.md", "setup.py"] + ) dbt_spark_dir = client.host().directory("./dbt") test_dir = client.host().directory("./tests") scripts = client.host().directory("./dagger/scripts") + platform = dagger.Platform("linux/amd64") tst_container = ( client.container(platform=platform) .from_("python:3.8-slim") - .with_directory("/.", req_files) - .with_directory("/dbt", dbt_spark_dir) - .with_directory("/tests", test_dir) + .with_mounted_cache("/var/cache/apt/archives", os_reqs_cache) + .with_mounted_cache("/root/.cache/pip", pip_cache) + # install OS deps first so any local changes don't invalidate the cache .with_directory("/scripts", scripts) - .with_exec("./scripts/install_os_reqs.sh") + .with_exec(["./scripts/install_os_reqs.sh"]) + # install dbt-spark + python deps + .with_directory("/src", req_files) + .with_directory("src/dbt", dbt_spark_dir) + .with_directory("src/tests", test_dir) + .with_workdir("/src") + .with_exec(["pip", "install", "-U", "pip"]) .with_exec(["pip", "install", "-r", "requirements.txt"]) .with_exec(["pip", "install", "-r", "dev-requirements.txt"]) + .with_exec(["pip", "install", "-e", "."]) ) if test_profile == "apache_spark": - spark_ctr, spark_host = await get_spark_container(client) + spark_ctr, spark_host = get_spark_container(client) tst_container = tst_container.with_service_binding(alias=spark_host, service=spark_ctr) elif test_profile in ["databricks_cluster", "databricks_sql_endpoint"]: - tst_container = tst_container.with_exec("./scripts/configure_odbc.sh") + tst_container = ( + tst_container.with_workdir("/") + .with_exec(["./scripts/configure_odbc.sh"]) + .with_workdir("/src") + ) elif test_profile == "spark_session": tst_container = tst_container.with_exec(["pip", "install", "pyspark"]) diff --git a/dbt/adapters/spark/column.py b/dbt/adapters/spark/column.py index a57fa0565..39f6f529e 100644 --- a/dbt/adapters/spark/column.py +++ b/dbt/adapters/spark/column.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Optional, TypeVar, Union from dbt.adapters.base.column import Column -from dbt.dataclass_schema import dbtClassMixin +from dbt_common.dataclass_schema import dbtClassMixin Self = TypeVar("Self", bound="SparkColumn") diff --git a/dbt/adapters/spark/connections.py b/dbt/adapters/spark/connections.py index 966f5584e..83048f921 100644 --- a/dbt/adapters/spark/connections.py +++ b/dbt/adapters/spark/connections.py @@ -1,11 +1,17 @@ from contextlib import contextmanager -import dbt.exceptions -from dbt.adapters.base import Credentials +from dbt.adapters.contracts.connection import ( + AdapterResponse, + ConnectionState, + Connection, + Credentials, +) +from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.exceptions import FailedToConnectError from dbt.adapters.sql import SQLConnectionManager -from dbt.contracts.connection import ConnectionState, AdapterResponse -from dbt.events import AdapterLogger -from dbt.utils import DECIMALS +from dbt_common.exceptions import DbtConfigError, DbtRuntimeError, DbtDatabaseError + +from dbt_common.utils.encoding import DECIMALS from dbt.adapters.spark import __version__ try: @@ -22,8 +28,7 @@ pyodbc = None from datetime import datetime import sqlparams -from dbt.contracts.connection import Connection -from dbt.dataclass_schema import StrEnum +from dbt_common.dataclass_schema import StrEnum from dataclasses import dataclass, field from typing import Any, Dict, Optional, Union, Tuple, List, Generator, Iterable, Sequence @@ -92,15 +97,15 @@ def cluster_id(self) -> Optional[str]: def __post_init__(self) -> None: if self.method is None: - raise dbt.exceptions.DbtRuntimeError("Must specify `method` in profile") + raise DbtRuntimeError("Must specify `method` in profile") if self.host is None: - raise dbt.exceptions.DbtRuntimeError("Must specify `host` in profile") + raise DbtRuntimeError("Must specify `host` in profile") if self.schema is None: - raise dbt.exceptions.DbtRuntimeError("Must specify `schema` in profile") + raise DbtRuntimeError("Must specify `schema` in profile") # spark classifies database and schema as the same thing if self.database is not None and self.database != self.schema: - raise dbt.exceptions.DbtRuntimeError( + raise DbtRuntimeError( f" schema: {self.schema} \n" f" database: {self.database} \n" f"On Spark, database must be omitted or have the same value as" @@ -112,7 +117,7 @@ def __post_init__(self) -> None: try: import pyodbc # noqa: F401 except ImportError as e: - raise dbt.exceptions.DbtRuntimeError( + raise DbtRuntimeError( f"{self.method} connection method requires " "additional dependencies. \n" "Install the additional required dependencies with " @@ -121,7 +126,7 @@ def __post_init__(self) -> None: ) from e if self.method == SparkConnectionMethod.ODBC and self.cluster and self.endpoint: - raise dbt.exceptions.DbtRuntimeError( + raise DbtRuntimeError( "`cluster` and `endpoint` cannot both be set when" f" using {self.method} method to connect to Spark" ) @@ -130,7 +135,7 @@ def __post_init__(self) -> None: self.method == SparkConnectionMethod.HTTP or self.method == SparkConnectionMethod.THRIFT ) and not (ThriftState and THttpClient and hive): - raise dbt.exceptions.DbtRuntimeError( + raise DbtRuntimeError( f"{self.method} connection method requires " "additional dependencies. \n" "Install the additional required dependencies with " @@ -141,7 +146,7 @@ def __post_init__(self) -> None: try: import pyspark # noqa: F401 except ImportError as e: - raise dbt.exceptions.DbtRuntimeError( + raise DbtRuntimeError( f"{self.method} connection method requires " "additional dependencies. \n" "Install the additional required dependencies with " @@ -291,13 +296,11 @@ def execute(self, sql: str, bindings: Optional[List[Any]] = None) -> None: if poll_state.errorMessage: logger.debug("Poll response: {}".format(poll_state)) logger.debug("Poll status: {}".format(state)) - raise dbt.exceptions.DbtDatabaseError(poll_state.errorMessage) + raise DbtDatabaseError(poll_state.errorMessage) elif state not in STATE_SUCCESS: status_type = ThriftState._VALUES_TO_NAMES.get(state, "Unknown<{!r}>".format(state)) - raise dbt.exceptions.DbtDatabaseError( - "Query failed with status: {}".format(status_type) - ) + raise DbtDatabaseError("Query failed with status: {}".format(status_type)) logger.debug("Poll status: {}, query complete".format(state)) @@ -358,9 +361,9 @@ def exception_handler(self, sql: str) -> Generator[None, None, None]: thrift_resp = exc.args[0] if hasattr(thrift_resp, "status"): msg = thrift_resp.status.errorMessage - raise dbt.exceptions.DbtRuntimeError(msg) + raise DbtRuntimeError(msg) else: - raise dbt.exceptions.DbtRuntimeError(str(exc)) + raise DbtRuntimeError(str(exc)) def cancel(self, connection: Connection) -> None: connection.handle.cancel() @@ -390,7 +393,7 @@ def validate_creds(cls, creds: Any, required: Iterable[str]) -> None: for key in required: if not hasattr(creds, key): - raise dbt.exceptions.DbtProfileError( + raise DbtConfigError( "The config '{}' is required when using the {} method" " to connect to Spark".format(key, method) ) @@ -481,7 +484,7 @@ def open(cls, connection: Connection) -> Connection: endpoint=creds.endpoint ) else: - raise dbt.exceptions.DbtProfileError( + raise DbtConfigError( "Either `cluster` or `endpoint` must set when" " using the odbc method to connect to Spark" ) @@ -525,9 +528,7 @@ def open(cls, connection: Connection) -> Connection: Connection(server_side_parameters=creds.server_side_parameters) ) else: - raise dbt.exceptions.DbtProfileError( - f"invalid credential method: {creds.method}" - ) + raise DbtConfigError(f"invalid credential method: {creds.method}") break except Exception as e: exc = e @@ -537,7 +538,7 @@ def open(cls, connection: Connection) -> Connection: msg = "Failed to connect" if creds.token is not None: msg += ", is your token valid?" - raise dbt.exceptions.FailedToConnectError(msg) from e + raise FailedToConnectError(msg) from e retryable_message = _is_retryable_error(e) if retryable_message and creds.connect_retries > 0: msg = ( @@ -558,7 +559,7 @@ def open(cls, connection: Connection) -> Connection: logger.warning(msg) time.sleep(creds.connect_timeout) else: - raise dbt.exceptions.FailedToConnectError("failed to connect") from e + raise FailedToConnectError("failed to connect") from e else: raise exc # type: ignore diff --git a/dbt/adapters/spark/impl.py b/dbt/adapters/spark/impl.py index 16c3a3cb7..9a1a7ec06 100644 --- a/dbt/adapters/spark/impl.py +++ b/dbt/adapters/spark/impl.py @@ -1,18 +1,31 @@ +import os import re from concurrent.futures import Future from dataclasses import dataclass -from typing import Any, Dict, Iterable, List, Optional, Union, Type, Tuple, Callable, Set +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + Union, + Type, + Tuple, + Callable, + Set, + FrozenSet, +) from dbt.adapters.base.relation import InformationSchema -from dbt.contracts.graph.manifest import Manifest +from dbt.adapters.contracts.connection import AdapterResponse +from dbt.adapters.events.logging import AdapterLogger +from dbt_common.exceptions import DbtRuntimeError, CompilationError +from dbt_common.utils import AttrDict, executor from typing_extensions import TypeAlias import agate -import dbt -import dbt.exceptions - from dbt.adapters.base import AdapterConfig, PythonJobHelper from dbt.adapters.base.impl import catch_as_completed, ConstraintSupport from dbt.adapters.sql import SQLAdapter @@ -24,14 +37,16 @@ AllPurposeClusterPythonJobHelper, ) from dbt.adapters.base import BaseRelation -from dbt.clients.agate_helper import DEFAULT_TYPE_TESTER -from dbt.contracts.connection import AdapterResponse -from dbt.contracts.graph.nodes import ConstraintType -from dbt.contracts.relation import RelationType -from dbt.events import AdapterLogger -from dbt.utils import executor, AttrDict +from dbt.adapters.contracts.relation import RelationType, RelationConfig +from dbt_common.clients.agate_helper import DEFAULT_TYPE_TESTER +from dbt_common.contracts.constraints import ConstraintType logger = AdapterLogger("Spark") +packages = ["pyhive.hive", "thrift.transport", "thrift.protocol"] +log_level = os.getenv("DBT_SPARK_LOG_LEVEL", "ERROR") +for package in packages: + logger.debug(f"Setting {package} logging to {log_level}") + logger.set_adapter_dependency_log_level(package, log_level) GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME = "get_columns_in_relation_raw" LIST_SCHEMAS_MACRO_NAME = "list_schemas" @@ -144,7 +159,7 @@ def _get_relation_information(self, row: agate.Row) -> RelationInfo: try: _schema, name, _, information = row except ValueError: - raise dbt.exceptions.DbtRuntimeError( + raise DbtRuntimeError( f'Invalid value from "show tables extended ...", got {len(row)} values, expected 4' ) @@ -155,7 +170,7 @@ def _get_relation_information_using_describe(self, row: agate.Row) -> RelationIn try: _schema, name, _ = row except ValueError: - raise dbt.exceptions.DbtRuntimeError( + raise DbtRuntimeError( f'Invalid value from "show tables ...", got {len(row)} values, expected 3' ) @@ -164,7 +179,7 @@ def _get_relation_information_using_describe(self, row: agate.Row) -> RelationIn table_results = self.execute_macro( DESCRIBE_TABLE_EXTENDED_MACRO_NAME, kwargs={"table_name": table_name} ) - except dbt.exceptions.DbtRuntimeError as e: + except DbtRuntimeError as e: logger.debug(f"Error while retrieving information about {table_name}: {e.msg}") table_results = AttrDict() @@ -219,7 +234,7 @@ def list_relations_without_caching(self, schema_relation: BaseRelation) -> List[ row_list=show_table_extended_rows, relation_info_func=self._get_relation_information, ) - except dbt.exceptions.DbtRuntimeError as e: + except DbtRuntimeError as e: errmsg = getattr(e, "msg", "") if f"Database '{schema_relation}' not found" in errmsg: return [] @@ -236,7 +251,7 @@ def list_relations_without_caching(self, schema_relation: BaseRelation) -> List[ row_list=show_table_rows, relation_info_func=self._get_relation_information_using_describe, ) - except dbt.exceptions.DbtRuntimeError as e: + except DbtRuntimeError as e: description = "Error while retrieving information about" logger.debug(f"{description} {schema_relation}: {e.msg}") return [] @@ -298,7 +313,7 @@ def get_columns_in_relation(self, relation: BaseRelation) -> List[SparkColumn]: GET_COLUMNS_IN_RELATION_RAW_MACRO_NAME, kwargs={"relation": relation} ) columns = self.parse_describe_extended(relation, rows) - except dbt.exceptions.DbtRuntimeError as e: + except DbtRuntimeError as e: # spark would throw error when table doesn't exist, where other # CDW would just return and empty list, normalizing the behavior here errmsg = getattr(e, "msg", "") @@ -352,11 +367,13 @@ def _get_columns_for_catalog(self, relation: BaseRelation) -> Iterable[Dict[str, yield as_dict def get_catalog( - self, manifest: Manifest, selected_nodes: Optional[Set] = None + self, + relation_configs: Iterable[RelationConfig], + used_schemas: FrozenSet[Tuple[str, str]], ) -> Tuple[agate.Table, List[Exception]]: - schema_map = self._get_catalog_schemas(manifest) + schema_map = self._get_catalog_schemas(relation_configs) if len(schema_map) > 1: - raise dbt.exceptions.CompilationError( + raise CompilationError( f"Expected only one database in get_catalog, found " f"{list(schema_map)}" ) @@ -371,7 +388,7 @@ def get_catalog( self._get_one_catalog, info, [schema], - manifest, + relation_configs, ) ) catalogs, exceptions = catch_as_completed(futures) @@ -381,10 +398,10 @@ def _get_one_catalog( self, information_schema: InformationSchema, schemas: Set[str], - manifest: Manifest, + used_schemas: FrozenSet[Tuple[str, str]], ) -> agate.Table: if len(schemas) != 1: - raise dbt.exceptions.CompilationError( + raise CompilationError( f"Expected only one schema in spark _get_one_catalog, found " f"{schemas}" ) diff --git a/dbt/adapters/spark/python_submissions.py b/dbt/adapters/spark/python_submissions.py index 89831ca7f..e3e7cb370 100644 --- a/dbt/adapters/spark/python_submissions.py +++ b/dbt/adapters/spark/python_submissions.py @@ -4,8 +4,9 @@ from typing import Any, Dict, Callable, Iterable import uuid -import dbt.exceptions from dbt.adapters.base import PythonJobHelper +from dbt_common.exceptions import DbtRuntimeError + from dbt.adapters.spark import SparkCredentials from dbt.adapters.spark import __version__ @@ -53,7 +54,7 @@ def _create_work_dir(self, path: str) -> None: }, ) if response.status_code != 200: - raise dbt.exceptions.DbtRuntimeError( + raise DbtRuntimeError( f"Error creating work_dir for python notebooks\n {response.content!r}" ) @@ -71,9 +72,7 @@ def _upload_notebook(self, path: str, compiled_code: str) -> None: }, ) if response.status_code != 200: - raise dbt.exceptions.DbtRuntimeError( - f"Error creating python notebook.\n {response.content!r}" - ) + raise DbtRuntimeError(f"Error creating python notebook.\n {response.content!r}") def _submit_job(self, path: str, cluster_spec: dict) -> str: job_spec = { @@ -99,9 +98,7 @@ def _submit_job(self, path: str, cluster_spec: dict) -> str: json=job_spec, ) if submit_response.status_code != 200: - raise dbt.exceptions.DbtRuntimeError( - f"Error creating python run.\n {submit_response.content!r}" - ) + raise DbtRuntimeError(f"Error creating python run.\n {submit_response.content!r}") return submit_response.json()["run_id"] def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> None: @@ -135,7 +132,7 @@ def _submit_through_notebook(self, compiled_code: str, cluster_spec: dict) -> No json_run_output = run_output.json() result_state = json_run_output["metadata"]["state"]["result_state"] if result_state != "SUCCESS": - raise dbt.exceptions.DbtRuntimeError( + raise DbtRuntimeError( "Python model failed with traceback as:\n" "(Note that the line number here does not " "match the line number in your code due to dbt templating)\n" @@ -169,9 +166,9 @@ def polling( response = status_func(**status_func_kwargs) state = get_state_func(response) if exceeded_timeout: - raise dbt.exceptions.DbtRuntimeError("python model run timed out") + raise DbtRuntimeError("python model run timed out") if state != expected_end_state: - raise dbt.exceptions.DbtRuntimeError( + raise DbtRuntimeError( "python model run ended in state" f"{state} with state_message\n{get_state_msg_func(response)}" ) @@ -205,9 +202,7 @@ def create(self) -> str: }, ) if response.status_code != 200: - raise dbt.exceptions.DbtRuntimeError( - f"Error creating an execution context.\n {response.content!r}" - ) + raise DbtRuntimeError(f"Error creating an execution context.\n {response.content!r}") return response.json()["id"] def destroy(self, context_id: str) -> str: @@ -221,9 +216,7 @@ def destroy(self, context_id: str) -> str: }, ) if response.status_code != 200: - raise dbt.exceptions.DbtRuntimeError( - f"Error deleting an execution context.\n {response.content!r}" - ) + raise DbtRuntimeError(f"Error deleting an execution context.\n {response.content!r}") return response.json()["id"] @@ -246,9 +239,7 @@ def execute(self, context_id: str, command: str) -> str: }, ) if response.status_code != 200: - raise dbt.exceptions.DbtRuntimeError( - f"Error creating a command.\n {response.content!r}" - ) + raise DbtRuntimeError(f"Error creating a command.\n {response.content!r}") return response.json()["id"] def status(self, context_id: str, command_id: str) -> Dict[str, Any]: @@ -263,9 +254,7 @@ def status(self, context_id: str, command_id: str) -> Dict[str, Any]: }, ) if response.status_code != 200: - raise dbt.exceptions.DbtRuntimeError( - f"Error getting status of command.\n {response.content!r}" - ) + raise DbtRuntimeError(f"Error getting status of command.\n {response.content!r}") return response.json() @@ -298,7 +287,7 @@ def submit(self, compiled_code: str) -> None: get_state_msg_func=lambda response: response.json()["results"]["data"], ) if response["results"]["resultType"] == "error": - raise dbt.exceptions.DbtRuntimeError( + raise DbtRuntimeError( f"Python model failed with traceback as:\n" f"{response['results']['cause']}" ) diff --git a/dbt/adapters/spark/relation.py b/dbt/adapters/spark/relation.py index e80f2623f..860935cbd 100644 --- a/dbt/adapters/spark/relation.py +++ b/dbt/adapters/spark/relation.py @@ -2,9 +2,9 @@ from dataclasses import dataclass, field from dbt.adapters.base.relation import BaseRelation, Policy +from dbt.adapters.events.logging import AdapterLogger -from dbt.exceptions import DbtRuntimeError -from dbt.events import AdapterLogger +from dbt_common.exceptions import DbtRuntimeError logger = AdapterLogger("Spark") diff --git a/dbt/adapters/spark/session.py b/dbt/adapters/spark/session.py index b5b2bebdb..7a6982e50 100644 --- a/dbt/adapters/spark/session.py +++ b/dbt/adapters/spark/session.py @@ -7,9 +7,9 @@ from typing import Any, Dict, List, Optional, Tuple, Union, Sequence from dbt.adapters.spark.connections import SparkConnectionWrapper -from dbt.events import AdapterLogger -from dbt.utils import DECIMALS -from dbt.exceptions import DbtRuntimeError +from dbt.adapters.events.logging import AdapterLogger +from dbt_common.utils.encoding import DECIMALS +from dbt_common.exceptions import DbtRuntimeError from pyspark.sql import DataFrame, Row, SparkSession from pyspark.sql.utils import AnalysisException diff --git a/dbt/include/spark/macros/adapters.sql b/dbt/include/spark/macros/adapters.sql index bfc1f198d..bf9f63cf9 100644 --- a/dbt/include/spark/macros/adapters.sql +++ b/dbt/include/spark/macros/adapters.sql @@ -294,7 +294,7 @@ {% macro spark__list_relations_without_caching(relation) %} {% call statement('list_relations_without_caching', fetch_result=True) -%} - show table extended in {{ relation }} like '*' + show table extended in {{ relation.schema }} like '*' {% endcall %} {% do return(load_result('list_relations_without_caching').table) %} @@ -305,7 +305,7 @@ {#-- V2 iceberg tables #} {#-- https://issues.apache.org/jira/browse/SPARK-33393 #} {% call statement('list_relations_without_caching_show_tables', fetch_result=True) -%} - show tables in {{ schema_relation }} like '*' + show tables in {{ schema_relation.schema }} like '*' {% endcall %} {% do return(load_result('list_relations_without_caching_show_tables').table) %} diff --git a/dev-requirements.txt b/dev-requirements.txt index 765482e25..28a626fc3 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,7 +1,6 @@ # install latest changes in dbt-core # TODO: how to automate switching from develop to version branches? -git+https://github.com/dbt-labs/dbt-core.git@c2bc2f009bbeeb46b3c69d082ab4d485597898af#egg=dbt-core&subdirectory=core -git+https://github.com/dbt-labs/dbt-core.git@c2bc2f009bbeeb46b3c69d082ab4d485597898af#egg=dbt-tests-adapter&subdirectory=tests/adapter +git+https://github.com/dbt-labs/dbt-adapters.git#subdirectory=dbt-tests-adapter # if version 1.x or greater -> pin to major version # if version 0.x -> pin to minor diff --git a/setup.py b/setup.py index 301b4a41f..2d6e00e53 100644 --- a/setup.py +++ b/setup.py @@ -20,7 +20,6 @@ print('Please upgrade setuptools with "pip install --upgrade setuptools" ' "and try again") sys.exit(1) - # pull long description from README this_directory = os.path.abspath(os.path.dirname(__file__)) with open(os.path.join(this_directory, "README.md"), "r", encoding="utf8") as f: @@ -40,17 +39,8 @@ def _get_plugin_version_dict(): return match.groupdict() -# require a compatible minor version (~=), prerelease if this is a prerelease -def _get_dbt_core_version(): - parts = _get_plugin_version_dict() - minor = "{major}.{minor}.0".format(**parts) - pre = parts["prekind"] + "1" if parts["prekind"] else "" - return f"{minor}{pre}" - - package_name = "dbt-spark" package_version = "1.8.0a1" -dbt_core_version = _get_dbt_core_version() description = """The Apache Spark adapter plugin for dbt""" odbc_extras = ["pyodbc~=4.0.39"] @@ -73,8 +63,9 @@ def _get_dbt_core_version(): packages=find_namespace_packages(include=["dbt", "dbt.*"]), include_package_data=True, install_requires=[ - "dbt-core~={}".format(dbt_core_version), "sqlparams>=3.0.0", + "dbt-common<1.0", + "dbt-adapters~=0.1.0a1", ], extras_require={ "ODBC": odbc_extras, diff --git a/tests/conftest.py b/tests/conftest.py index 700ade4d3..efba41a5f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -42,9 +42,9 @@ def apache_spark_target(): "user": "dbt", "method": "thrift", "port": 10000, - "connect_retries": 3, - "connect_timeout": 5, - "retry_all": True, + "connect_retries": 2, + "connect_timeout": 3, + "retry_all": False, } @@ -59,7 +59,7 @@ def databricks_cluster_target(): "port": 443, "connect_retries": 3, "connect_timeout": 5, - "retry_all": True, + "retry_all": False, "user": os.getenv("DBT_DATABRICKS_USER"), } @@ -87,11 +87,9 @@ def databricks_http_cluster_target(): "token": os.getenv("DBT_DATABRICKS_TOKEN"), "method": "http", "port": 443, - # more retries + longer timout to handle unavailability while cluster is restarting - # return failures quickly in dev, retry all failures in CI (up to 5 min) - "connect_retries": 5, - "connect_timeout": 60, - "retry_all": bool(os.getenv("DBT_DATABRICKS_RETRY_ALL", False)), + "connect_retries": 3, + "connect_timeout": 5, + "retry_all": False, "user": os.getenv("DBT_DATABRICKS_USER"), } diff --git a/tests/functional/conftest.py b/tests/functional/conftest.py new file mode 100644 index 000000000..c1a0397bd --- /dev/null +++ b/tests/functional/conftest.py @@ -0,0 +1,19 @@ +from multiprocessing import Lock + +import pytest + +_db_start_lock = Lock() +_DB_CLUSTER_STARTED = False + + +# Running this should prevent tests from needing to be retried because the Databricks cluster isn't available +@pytest.fixture(scope="class", autouse=True) +def start_databricks_cluster(project, request): + global _DB_CLUSTER_STARTED + profile_type = request.config.getoption("--profile") + with _db_start_lock: + if "databricks" in profile_type and not _DB_CLUSTER_STARTED: + print("Starting Databricks cluster") + project.run_sql("SELECT 1") + + _DB_CLUSTER_STARTED = True diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index a7da63301..54e9f0158 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -1,4 +1,5 @@ import unittest +from multiprocessing import get_context from unittest import mock import dbt.flags as flags @@ -146,7 +147,7 @@ def _get_target_odbc_sql_endpoint(self, project): def test_http_connection(self): config = self._get_target_http(self.project_cfg) - adapter = SparkAdapter(config) + adapter = SparkAdapter(config, get_context("spawn")) def hive_http_connect(thrift_transport, configuration): self.assertEqual(thrift_transport.scheme, "https") @@ -171,7 +172,7 @@ def hive_http_connect(thrift_transport, configuration): def test_thrift_connection(self): config = self._get_target_thrift(self.project_cfg) - adapter = SparkAdapter(config) + adapter = SparkAdapter(config, get_context("spawn")) def hive_thrift_connect( host, port, username, auth, kerberos_service_name, password, configuration @@ -195,7 +196,7 @@ def hive_thrift_connect( def test_thrift_ssl_connection(self): config = self._get_target_use_ssl_thrift(self.project_cfg) - adapter = SparkAdapter(config) + adapter = SparkAdapter(config, get_context("spawn")) def hive_thrift_connect(thrift_transport, configuration): self.assertIsNotNone(thrift_transport) @@ -215,7 +216,7 @@ def hive_thrift_connect(thrift_transport, configuration): def test_thrift_connection_kerberos(self): config = self._get_target_thrift_kerberos(self.project_cfg) - adapter = SparkAdapter(config) + adapter = SparkAdapter(config, get_context("spawn")) def hive_thrift_connect( host, port, username, auth, kerberos_service_name, password, configuration @@ -239,7 +240,7 @@ def hive_thrift_connect( def test_odbc_cluster_connection(self): config = self._get_target_odbc_cluster(self.project_cfg) - adapter = SparkAdapter(config) + adapter = SparkAdapter(config, get_context("spawn")) def pyodbc_connect(connection_str, autocommit): self.assertTrue(autocommit) @@ -266,7 +267,7 @@ def pyodbc_connect(connection_str, autocommit): def test_odbc_endpoint_connection(self): config = self._get_target_odbc_sql_endpoint(self.project_cfg) - adapter = SparkAdapter(config) + adapter = SparkAdapter(config, get_context("spawn")) def pyodbc_connect(connection_str, autocommit): self.assertTrue(autocommit) @@ -329,7 +330,9 @@ def test_parse_relation(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config).parse_describe_extended(relation, input_cols) + rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended( + relation, input_cols + ) self.assertEqual(len(rows), 4) self.assertEqual( rows[0].to_column_dict(omit_none=False), @@ -418,7 +421,9 @@ def test_parse_relation_with_integer_owner(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config).parse_describe_extended(relation, input_cols) + rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended( + relation, input_cols + ) self.assertEqual(rows[0].to_column_dict().get("table_owner"), "1234") @@ -454,7 +459,9 @@ def test_parse_relation_with_statistics(self): input_cols = [Row(keys=["col_name", "data_type"], values=r) for r in plain_rows] config = self._get_target_http(self.project_cfg) - rows = SparkAdapter(config).parse_describe_extended(relation, input_cols) + rows = SparkAdapter(config, get_context("spawn")).parse_describe_extended( + relation, input_cols + ) self.assertEqual(len(rows), 1) self.assertEqual( rows[0].to_column_dict(omit_none=False), @@ -483,7 +490,7 @@ def test_parse_relation_with_statistics(self): def test_relation_with_database(self): config = self._get_target_http(self.project_cfg) - adapter = SparkAdapter(config) + adapter = SparkAdapter(config, get_context("spawn")) # fine adapter.Relation.create(schema="different", identifier="table") with self.assertRaises(DbtRuntimeError): @@ -564,7 +571,9 @@ def test_parse_columns_from_information_with_table_type_and_delta_provider(self) ) config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config).parse_columns_from_information(relation) + columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information( + relation + ) self.assertEqual(len(columns), 4) self.assertEqual( columns[0].to_column_dict(omit_none=False), @@ -649,7 +658,9 @@ def test_parse_columns_from_information_with_view_type(self): ) config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config).parse_columns_from_information(relation) + columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information( + relation + ) self.assertEqual(len(columns), 4) self.assertEqual( columns[1].to_column_dict(omit_none=False), @@ -715,7 +726,9 @@ def test_parse_columns_from_information_with_table_type_and_parquet_provider(sel ) config = self._get_target_http(self.project_cfg) - columns = SparkAdapter(config).parse_columns_from_information(relation) + columns = SparkAdapter(config, get_context("spawn")).parse_columns_from_information( + relation + ) self.assertEqual(len(columns), 4) self.assertEqual( diff --git a/tests/unit/utils.py b/tests/unit/utils.py index ac8c62244..17cd3ee78 100644 --- a/tests/unit/utils.py +++ b/tests/unit/utils.py @@ -9,7 +9,7 @@ import agate import pytest -from dbt.dataclass_schema import ValidationError +from dbt_common.dataclass_schema import ValidationError from dbt.config.project import PartialProject