diff --git a/CHANGELOG.md b/CHANGELOG.md index 5476e30ec..c8b267071 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,25 +10,7 @@ ## dbt-databricks 1.9.2 (TBD) -### Under the Hood - -- Refactor global state reading ([888](https://github.com/databricks/dbt-databricks/pull/888)) -- Switch to relation.render() for string interpolation ([903](https://github.com/databricks/dbt-databricks/pull/903)) -- Ensure retry defaults for PySQL ([907](https://github.com/databricks/dbt-databricks/pull/907)) - -## dbt-databricks 1.9.1 (December 16, 2024) - -### Features - -- Merge strategy now supports the `update set ...` action with the explicit list of updates for `when not matched by source` ([866](https://github.com/databricks/dbt-databricks/pull/866)) (thanks @mi-volodin). - -### Under the Hood - -- Removed pins for pandas and pydantic to ease user burdens ([874](https://github.com/databricks/dbt-databricks/pull/874)) -- Add more relation types to make codegen happy ([875](https://github.com/databricks/dbt-databricks/pull/875)) -- add UP ruleset ([865](https://github.com/databricks/dbt-databricks/pull/865)) - -## dbt-databricks 1.9.0 (December 9, 2024) +## dbt-databricks 1.9.0 (TBD) ### Features @@ -53,7 +35,6 @@ - Replace array indexing with 'get' in split_part so as not to raise exception when indexing beyond bounds ([839](https://github.com/databricks/dbt-databricks/pull/839)) - Set queue enabled for Python notebook jobs ([856](https://github.com/databricks/dbt-databricks/pull/856)) -- Ensure columns that are added get backticked ([859](https://github.com/databricks/dbt-databricks/pull/859)) ### Under the Hood @@ -64,7 +45,11 @@ - Prepare for python typing deprecations ([837](https://github.com/databricks/dbt-databricks/pull/837)) - Fix behavior flag use in init of DatabricksAdapter (thanks @VersusFacit!) ([836](https://github.com/databricks/dbt-databricks/pull/836)) - Restrict pydantic to V1 per dbt Labs' request ([843](https://github.com/databricks/dbt-databricks/pull/843)) +<<<<<<< HEAD - Switching to Ruff for formatting and linting ([847](https://github.com/databricks/dbt-databricks/pull/847) +======= +- Switching to Ruff for formatting and linting ([847](https://github.com/databricks/dbt-databricks/pull/847)) +>>>>>>> parent of f1065d27 (resolve conflict) - Switching to Hatch and pyproject.toml for project config ([853](https://github.com/databricks/dbt-databricks/pull/853)) ## dbt-databricks 1.8.7 (October 10, 2024) diff --git a/dbt/adapters/databricks/api_client.py b/dbt/adapters/databricks/api_client.py index 235e5f52a..ca9de081b 100644 --- a/dbt/adapters/databricks/api_client.py +++ b/dbt/adapters/databricks/api_client.py @@ -245,7 +245,7 @@ def _poll_api( @dataclass(frozen=True, eq=True, unsafe_hash=True) -class CommandExecution: +class CommandExecution(object): command_id: str context_id: str cluster_id: str @@ -459,60 +459,6 @@ def run(self, job_id: str, enable_queueing: bool = True) -> str: return response_json["run_id"] -class DltPipelineApi(PollableApi): - def __init__(self, session: Session, host: str, polling_interval: int): - super().__init__(session, host, "/api/2.0/pipelines", polling_interval, 60 * 60) - - def poll_for_completion(self, pipeline_id: str) -> None: - self._poll_api( - url=f"/{pipeline_id}", - params={}, - get_state_func=lambda response: response.json()["state"], - terminal_states={"IDLE", "FAILED", "DELETED"}, - expected_end_state="IDLE", - unexpected_end_state_func=self._get_exception, - ) - - def _get_exception(self, response: Response) -> None: - response_json = response.json() - cause = response_json.get("cause") - if cause: - raise DbtRuntimeError(f"Pipeline {response_json.get('pipeline_id')} failed: {cause}") - else: - latest_update = response_json.get("latest_updates")[0] - last_error = self.get_update_error(response_json.get("pipeline_id"), latest_update) - raise DbtRuntimeError( - f"Pipeline {response_json.get('pipeline_id')} failed: {last_error}" - ) - - def get_update_error(self, pipeline_id: str, update_id: str) -> str: - response = self.session.get(f"/{pipeline_id}/events") - if response.status_code != 200: - raise DbtRuntimeError( - f"Error getting pipeline event info for {pipeline_id}: {response.text}" - ) - - events = response.json().get("events", []) - update_events = [ - e - for e in events - if e.get("event_type", "") == "update_progress" - and e.get("origin", {}).get("update_id") == update_id - ] - - error_events = [ - e - for e in update_events - if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED" - ] - - msg = "" - if error_events: - msg = error_events[0].get("message", "") - - return msg - - class DatabricksApiClient: def __init__( self, @@ -534,7 +480,6 @@ def __init__( self.job_runs = JobRunsApi(session, host, polling_interval, timeout) self.workflows = WorkflowJobApi(session, host) self.workflow_permissions = JobPermissionsApi(session, host) - self.dlt_pipelines = DltPipelineApi(session, host, polling_interval) @staticmethod def create( diff --git a/dbt/adapters/databricks/column.py b/dbt/adapters/databricks/column.py index 66bad7416..0cb94c188 100644 --- a/dbt/adapters/databricks/column.py +++ b/dbt/adapters/databricks/column.py @@ -1,7 +1,6 @@ from dataclasses import dataclass -from typing import Any, ClassVar, Optional +from typing import ClassVar, Optional -from dbt.adapters.databricks.utils import quote from dbt.adapters.spark.column import SparkColumn @@ -50,17 +49,4 @@ def render_for_create(self) -> str: return column_str def __repr__(self) -> str: - return f"" - - @staticmethod - def get_name(column: dict[str, Any]) -> str: - name = column["name"] - return quote(name) if column.get("quote", False) else name - - @staticmethod - def format_remove_column_list(columns: list["DatabricksColumn"]) -> str: - return ", ".join([quote(c.name) for c in columns]) - - @staticmethod - def format_add_column_list(columns: list["DatabricksColumn"]) -> str: - return ", ".join([f"{quote(c.name)} {c.data_type}" for c in columns]) + return "".format(self.name, self.data_type) diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 151921857..4a4f21023 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -1,20 +1,23 @@ import decimal +import os import re import sys import time import uuid import warnings -from collections.abc import Callable, Hashable, Iterator, Sequence +from collections.abc import Callable, Iterator, Sequence from contextlib import contextmanager from dataclasses import dataclass from multiprocessing.context import SpawnContext from numbers import Number -from typing import TYPE_CHECKING, Any, Optional, cast +from threading import get_ident +from typing import TYPE_CHECKING, Any, Hashable, Optional, cast from dbt_common.events.contextvars import get_node_info from dbt_common.events.functions import fire_event from dbt_common.exceptions import DbtDatabaseError, DbtInternalError, DbtRuntimeError from dbt_common.utils import cast_to_str +from requests import Session import databricks.sql as dbsql from databricks.sql.client import Connection as DatabricksSQLConnection @@ -61,6 +64,7 @@ CursorCreate, ) from dbt.adapters.databricks.events.other_events import QueryError +from dbt.adapters.databricks.events.pipeline_events import PipelineRefresh, PipelineRefreshError from dbt.adapters.databricks.logging import logger from dbt.adapters.databricks.python_models.run_tracking import PythonRunTracker from dbt.adapters.databricks.utils import redact_credentials @@ -88,6 +92,9 @@ DBR_VERSION_REGEX = re.compile(r"([1-9][0-9]*)\.(x|0|[1-9][0-9]*)") +# toggle for session managements that minimizes the number of sessions opened/closed +USE_LONG_SESSIONS = os.getenv("DBT_DATABRICKS_LONG_SESSIONS", "True").upper() == "TRUE" + # Number of idle seconds before a connection is automatically closed. Only applicable if # USE_LONG_SESSIONS is true. # Updated when idle times of 180s were causing errors @@ -223,6 +230,97 @@ def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> None: bindings = [self._fix_binding(binding) for binding in bindings] self._cursor.execute(sql, bindings) + def poll_refresh_pipeline(self, pipeline_id: str) -> None: + # interval in seconds + polling_interval = 10 + + # timeout in seconds + timeout = 60 * 60 + + stopped_states = ("COMPLETED", "FAILED", "CANCELED") + host: str = self._creds.host or "" + headers = ( + self._cursor.connection.thrift_backend._auth_provider._header_factory # type: ignore + ) + session = Session() + session.auth = BearerAuth(headers) + session.headers = {"User-Agent": self._user_agent} + pipeline = _get_pipeline_state(session, host, pipeline_id) + # get the most recently created update for the pipeline + latest_update = _find_update(pipeline) + if not latest_update: + raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}") + + state = latest_update.get("state") + # we use update_id to retrieve the update in the polling loop + update_id = latest_update.get("update_id", "") + prev_state = state + + logger.info(PipelineRefresh(pipeline_id, update_id, str(state))) + + start = time.time() + exceeded_timeout = False + while state not in stopped_states: + if time.time() - start > timeout: + exceeded_timeout = True + break + + # should we do exponential backoff? + time.sleep(polling_interval) + + pipeline = _get_pipeline_state(session, host, pipeline_id) + # get the update we are currently polling + update = _find_update(pipeline, update_id) + if not update: + raise DbtRuntimeError( + f"Error getting pipeline update info: {pipeline_id}, update: {update_id}" + ) + + state = update.get("state") + if state != prev_state: + logger.info(PipelineRefresh(pipeline_id, update_id, str(state))) + prev_state = state + + if state == "FAILED": + logger.error( + PipelineRefreshError( + pipeline_id, + update_id, + _get_update_error_msg(session, host, pipeline_id, update_id), + ) + ) + + # another update may have been created due to retry_on_fail settings + # get the latest update and see if it is a new one + latest_update = _find_update(pipeline) + if not latest_update: + raise DbtRuntimeError(f"No update created for pipeline: {pipeline_id}") + + latest_update_id = latest_update.get("update_id", "") + if latest_update_id != update_id: + update_id = latest_update_id + state = None + + if exceeded_timeout: + raise DbtRuntimeError("timed out waiting for materialized view refresh") + + if state == "FAILED": + msg = _get_update_error_msg(session, host, pipeline_id, update_id) + raise DbtRuntimeError(f"Error refreshing pipeline {pipeline_id} {msg}") + + if state == "CANCELED": + raise DbtRuntimeError(f"Refreshing pipeline {pipeline_id} cancelled") + + return + + @classmethod + def findUpdate(cls, updates: list, id: str) -> Optional[dict]: + matches = [x for x in updates if x.get("update_id") == id] + if matches: + return matches[0] + + return None + @property def hex_query_id(self) -> str: """Return the hex GUID for this query @@ -380,18 +478,12 @@ class DatabricksConnectionManager(SparkConnectionManager): credentials_manager: Optional[DatabricksCredentialManager] = None _user_agent = f"dbt-databricks/{__version__}" - def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): - super().__init__(profile, mp_context) - creds = cast(DatabricksCredentials, self.profile.credentials) - self.api_client = DatabricksApiClient.create(creds, 15 * 60) - self.threads_compute_connections: dict[ - Hashable, dict[Hashable, DatabricksDBTConnection] - ] = {} - def cancel_open(self) -> list[str]: cancelled = super().cancel_open() + creds = cast(DatabricksCredentials, self.profile.credentials) + api_client = DatabricksApiClient.create(creds, 15 * 60) logger.info("Cancelling open python jobs") - PythonRunTracker.cancel_runs(self.api_client) + PythonRunTracker.cancel_runs(api_client) return cancelled def compare_dbr_version(self, major: int, minor: int) -> int: @@ -435,19 +527,39 @@ def set_connection_name( 'connection_named', called by 'connection_for(node)'. Creates a connection for this thread if one doesn't already exist, and will rename an existing connection.""" - self._cleanup_idle_connections() conn_name: str = "master" if name is None else name # Get a connection for this thread - conn = self._get_if_exists_compute_connection(_get_compute_name(query_header_context) or "") + conn = self.get_if_exists() + + if conn and conn.name == conn_name and conn.state == ConnectionState.OPEN: + # Found a connection and nothing to do, so just return it + return conn if conn is None: - conn = self._create_compute_connection(conn_name, query_header_context) + # Create a new connection + conn = DatabricksDBTConnection( + type=Identifier(self.TYPE), + name=conn_name, + state=ConnectionState.INIT, + transaction_open=False, + handle=None, + credentials=self.profile.credentials, + ) + conn.handle = LazyHandle(self.get_open_for_context(query_header_context)) + # Add the connection to thread_connections for this thread + self.set_thread_connection(conn) + fire_event( + NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) + ) else: # existing connection either wasn't open or didn't have the right name - conn = self._update_compute_connection(conn, conn_name) - - conn._acquire(query_header_context) + if conn.state != ConnectionState.OPEN: + conn.handle = LazyHandle(self.get_open_for_context(query_header_context)) + if conn.name != conn_name: + orig_conn_name: str = conn.name or "" + conn.name = conn_name + fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=conn_name)) return conn @@ -457,8 +569,6 @@ def add_query( auto_begin: bool = True, bindings: Optional[Any] = None, abridge_sql_log: bool = False, - retryable_exceptions: tuple[type[Exception], ...] = tuple(), - retry_limit: int = 1, *, close_cursor: bool = False, ) -> tuple[Connection, Any]: @@ -472,7 +582,7 @@ def add_query( try: log_sql = redact_credentials(sql) if abridge_sql_log: - log_sql = f"{log_sql[:512]}..." + log_sql = "{}...".format(log_sql[:512]) fire_event( SQLQuery( @@ -731,6 +841,16 @@ def release(self) -> None: conn._release() + # override + @classmethod + def close(cls, connection: Connection) -> Connection: + try: + return super().close(connection) + except Exception as e: + logger.warning(f"ignoring error when closing connection: {e}") + connection.state = ConnectionState.CLOSED + return connection + # override def cleanup_all(self) -> None: with self.lock: @@ -750,99 +870,146 @@ def cleanup_all(self) -> None: self.thread_connections.clear() self.threads_compute_connections.clear() - @classmethod - def get_open_for_context( - cls, query_header_context: Any = None - ) -> Callable[[Connection], Connection]: - # If there is no node we can simply return the exsting class method open. - # If there is a node create a closure that will call cls._open with the node. - if not query_header_context: - return cls.open + def _update_compute_connection( + self, conn: DatabricksDBTConnection, new_name: str + ) -> DatabricksDBTConnection: + if conn.name == new_name and conn.state == ConnectionState.OPEN: + # Found a connection and nothing to do, so just return it + return conn - def open_for_model(connection: Connection) -> Connection: - return cls._open(connection, query_header_context) + orig_conn_name: str = conn.name or "" - return open_for_model + if conn.state != ConnectionState.OPEN: + conn.handle = LazyHandle(self.open) + if conn.name != new_name: + conn.name = new_name + fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=new_name)) - @classmethod - def open(cls, connection: Connection) -> Connection: - databricks_connection = cast(DatabricksDBTConnection, connection) + current_thread_conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) + if current_thread_conn and current_thread_conn.compute_name != conn.compute_name: + self.clear_thread_connection() + self.set_thread_connection(conn) - if connection.state == ConnectionState.OPEN: - return connection + logger.debug(ConnectionReuse(str(conn), orig_conn_name)) - creds: DatabricksCredentials = connection.credentials - timeout = creds.connect_timeout + return conn - # gotta keep this so we don't prompt users many times - cls.credentials_manager = creds.authenticate() + def _add_compute_connection(self, conn: DatabricksDBTConnection) -> None: + """Add a new connection to the map of connection per thread per compute.""" - invocation_env = creds.get_invocation_env() - user_agent_entry = cls._user_agent - if invocation_env: - user_agent_entry = f"{cls._user_agent}; {invocation_env}" + with self.lock: + thread_map = self._get_compute_connections() + if conn.compute_name in thread_map: + raise DbtInternalError( + f"In set_thread_compute_connection, connection exists for `{conn.compute_name}`" + ) + thread_map[conn.compute_name] = conn - connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr] + def _get_compute_connections( + self, + ) -> dict[Hashable, DatabricksDBTConnection]: + """Retrieve a map of compute name to connection for the current thread.""" - http_headers: list[tuple[str, str]] = list( - creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items() - ) + thread_id = self.get_thread_identifier() + with self.lock: + thread_map = self.threads_compute_connections.get(thread_id) + if not thread_map: + thread_map = {} + self.threads_compute_connections[thread_id] = thread_map + return thread_map - # If a model specifies a compute resource the http path - # may be different than the http_path property of creds. - http_path = databricks_connection.http_path + def _get_if_exists_compute_connection( + self, compute_name: str + ) -> Optional[DatabricksDBTConnection]: + """Get the connection for the current thread and named compute, if it exists.""" - def connect() -> DatabricksSQLConnectionWrapper: - assert cls.credentials_manager is not None - try: - # TODO: what is the error when a user specifies a catalog they don't have access to - conn = dbsql.connect( - server_hostname=creds.host, - http_path=http_path, - credentials_provider=cls.credentials_manager.credentials_provider, - http_headers=http_headers if http_headers else None, - session_configuration=creds.session_properties, - catalog=creds.database, - use_inline_params="silent", - # schema=creds.schema, # TODO: Explicitly set once DBR 7.3LTS is EOL. - _user_agent_entry=user_agent_entry, - **connection_parameters, - ) + with self.lock: + threads_map = self._get_compute_connections() + return threads_map.get(compute_name) - if conn: - databricks_connection.session_id = conn.get_session_id_hex() - databricks_connection.last_used_time = time.time() - logger.debug(ConnectionCreated(str(databricks_connection))) + def _cleanup_idle_connections(self) -> None: + with self.lock: + # Get all connections associated with this thread. There can be multiple connections + # if different models use different compute resources + thread_conns = self._get_compute_connections() + for conn in thread_conns.values(): + logger.debug(ConnectionIdleCheck(str(conn))) - return DatabricksSQLConnectionWrapper( - conn, - is_cluster=creds.cluster_id is not None, - creds=creds, - user_agent=user_agent_entry, - ) - except Error as exc: - logger.error(ConnectionCreateError(exc)) - raise + # Generally speaking we only want to close/refresh the connection if the + # acquire_release_count is zero. i.e. the connection is not currently in use. + # However python models acquire a connection then run the pyton model, which + # doesn't actually use the connection. If the python model takes lone enought to + # run the connection can be idle long enough to timeout on the back end. + # If additional sql needs to be run after the python model, but before the + # connection is released, the connection needs to be refreshed or there will + # be a failure. Making an exception when language is 'python' allows the + # the call to _cleanup_idle_connections from get_thread_connection to refresh the + # connection in this scenario. + if ( + conn.acquire_release_count == 0 or conn.language == "python" + ) and conn._idle_too_long(): + logger.debug(ConnectionIdleClose(str(conn))) + self.close(conn) + conn._reset_handle(self._open) - def exponential_backoff(attempt: int) -> int: - return attempt * attempt + def _create_compute_connection( + self, conn_name: str, query_header_context: Any = None + ) -> DatabricksDBTConnection: + """Create anew connection for the combination of current thread and compute associated + with the given node.""" - retryable_exceptions = [] - # this option is for backwards compatibility - if creds.retry_all: - retryable_exceptions = [Error] + # Create a new connection + compute_name = _get_compute_name(query_header_context) or "" - return cls.retry_connection( - connection, - connect=connect, - logger=logger, - retryable_exceptions=retryable_exceptions, - retry_limit=creds.connect_retries, - retry_timeout=(timeout if timeout is not None else exponential_backoff), + conn = DatabricksDBTConnection( + type=Identifier(self.TYPE), + name=conn_name, + state=ConnectionState.INIT, + transaction_open=False, + handle=None, + credentials=self.profile.credentials, ) + conn.compute_name = compute_name + creds = cast(DatabricksCredentials, self.profile.credentials) + conn.http_path = _get_http_path(query_header_context, creds=creds) or "" + conn.thread_identifier = cast(tuple[int, int], self.get_thread_identifier()) + conn.max_idle_time = _get_max_idle_time(query_header_context, creds=creds) + + conn.handle = LazyHandle(self.open) + + logger.debug(ConnectionCreate(str(conn))) + + # Add this connection to the thread/compute connection pool. + self._add_compute_connection(conn) + # Remove the connection currently in use by this thread from the thread connection pool. + self.clear_thread_connection() + # Add the connection to thread connection pool. + self.set_thread_connection(conn) + + fire_event( + NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) + ) + + return conn + + def get_thread_connection(self) -> Connection: + conn = super().get_thread_connection() + self._cleanup_idle_connections() + dbr_conn = cast(DatabricksDBTConnection, conn) + logger.debug(ConnectionRetrieve(str(dbr_conn))) + + return conn @classmethod - def _open(cls, connection: Connection, query_header_context: Any = None) -> Connection: + def open(cls, connection: Connection) -> Connection: + # Once long session management is no longer under the USE_LONG_SESSIONS toggle + # this should be renamed and replace the _open class method. + assert ( + USE_LONG_SESSIONS + ), "This path, '_open2', should only be reachable with USE_LONG_SESSIONS" + + databricks_connection = cast(DatabricksDBTConnection, connection) + if connection.state == ConnectionState.OPEN: return connection @@ -850,7 +1017,7 @@ def _open(cls, connection: Connection, query_header_context: Any = None) -> Conn timeout = creds.connect_timeout # gotta keep this so we don't prompt users many times - cls.credentials_provider = creds.authenticate(cls.credentials_provider) + cls.credentials_manager = creds.authenticate() invocation_env = creds.get_invocation_env() user_agent_entry = cls._user_agent @@ -865,15 +1032,16 @@ def _open(cls, connection: Connection, query_header_context: Any = None) -> Conn # If a model specifies a compute resource the http path # may be different than the http_path property of creds. - http_path = _get_http_path(query_header_context, creds) + http_path = databricks_connection.http_path def connect() -> DatabricksSQLConnectionWrapper: + assert cls.credentials_manager is not None try: # TODO: what is the error when a user specifies a catalog they don't have access to - conn: DatabricksSQLConnection = dbsql.connect( + conn = dbsql.connect( server_hostname=creds.host, http_path=http_path, - credentials_provider=cls.credentials_provider, + credentials_provider=cls.credentials_manager.credentials_provider, http_headers=http_headers if http_headers else None, session_configuration=creds.session_properties, catalog=creds.database, @@ -882,7 +1050,11 @@ def connect() -> DatabricksSQLConnectionWrapper: _user_agent_entry=user_agent_entry, **connection_parameters, ) - logger.debug(ConnectionCreated(str(conn))) + + if conn: + databricks_connection.session_id = conn.get_session_id_hex() + databricks_connection.last_used_time = time.time() + logger.debug(ConnectionCreated(str(databricks_connection))) return DatabricksSQLConnectionWrapper( conn, @@ -911,156 +1083,59 @@ def exponential_backoff(attempt: int) -> int: retry_timeout=(timeout if timeout is not None else exponential_backoff), ) - # override - @classmethod - def close(cls, connection: Connection) -> Connection: - try: - return super().close(connection) - except Exception as e: - logger.warning(f"ignoring error when closing connection: {e}") - connection.state = ConnectionState.CLOSED - return connection - @classmethod - def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterResponse: - _query_id = getattr(cursor, "hex_query_id", None) - if cursor is None: - logger.debug("No cursor was provided. Query ID not available.") - query_id = "N/A" - else: - query_id = _query_id - message = "OK" - return DatabricksAdapterResponse(_message=message, query_id=query_id) # type: ignore - - def get_thread_connection(self) -> Connection: - conn = super().get_thread_connection() - self._cleanup_idle_connections() - dbr_conn = cast(DatabricksDBTConnection, conn) - logger.debug(ConnectionRetrieve(str(dbr_conn))) - - return conn - - def _add_compute_connection(self, conn: DatabricksDBTConnection) -> None: - """Add a new connection to the map of connection per thread per compute.""" +def _get_pipeline_state(session: Session, host: str, pipeline_id: str) -> dict: + pipeline_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}" - with self.lock: - thread_map = self._get_compute_connections() - if conn.compute_name in thread_map: - raise DbtInternalError( - f"In set_thread_compute_connection, connection exists for `{conn.compute_name}`" - ) - thread_map[conn.compute_name] = conn - - def _cleanup_idle_connections(self) -> None: - with self.lock: - # Get all connections associated with this thread. There can be multiple connections - # if different models use different compute resources - thread_conns = self._get_compute_connections() - for conn in thread_conns.values(): - logger.debug(ConnectionIdleCheck(str(conn))) + response = session.get(pipeline_url) + if response.status_code != 200: + raise DbtRuntimeError(f"Error getting pipeline info for {pipeline_id}: {response.text}") - # Generally speaking we only want to close/refresh the connection if the - # acquire_release_count is zero. i.e. the connection is not currently in use. - # However python models acquire a connection then run the pyton model, which - # doesn't actually use the connection. If the python model takes lone enought to - # run the connection can be idle long enough to timeout on the back end. - # If additional sql needs to be run after the python model, but before the - # connection is released, the connection needs to be refreshed or there will - # be a failure. Making an exception when language is 'python' allows the - # the call to _cleanup_idle_connections from get_thread_connection to refresh the - # connection in this scenario. - if ( - conn.acquire_release_count == 0 or conn.language == "python" - ) and conn._idle_too_long(): - logger.debug(ConnectionIdleClose(str(conn))) - self.close(conn) - conn._reset_handle(self._open) + return response.json() - def _create_compute_connection( - self, conn_name: str, query_header_context: Any = None - ) -> DatabricksDBTConnection: - """Create anew connection for the combination of current thread and compute associated - with the given node.""" - # Create a new connection - compute_name = _get_compute_name(query_header_context) or "" +def _find_update(pipeline: dict, id: str = "") -> Optional[dict]: + updates = pipeline.get("latest_updates", []) + if not updates: + raise DbtRuntimeError(f"No updates for pipeline: {pipeline.get('pipeline_id', '')}") - conn = DatabricksDBTConnection( - type=Identifier(self.TYPE), - name=conn_name, - state=ConnectionState.INIT, - transaction_open=False, - handle=None, - credentials=self.profile.credentials, - ) - conn.compute_name = compute_name - creds = cast(DatabricksCredentials, self.profile.credentials) - conn.http_path = _get_http_path(query_header_context, creds=creds) or "" - conn.thread_identifier = cast(tuple[int, int], self.get_thread_identifier()) - conn.max_idle_time = _get_max_idle_time(query_header_context, creds=creds) + if not id: + return updates[0] - conn.handle = LazyHandle(self.open) + matches = [x for x in updates if x.get("update_id") == id] + if matches: + return matches[0] - logger.debug(ConnectionCreate(str(conn))) + return None - # Add this connection to the thread/compute connection pool. - self._add_compute_connection(conn) - # Remove the connection currently in use by this thread from the thread connection pool. - self.clear_thread_connection() - # Add the connection to thread connection pool. - self.set_thread_connection(conn) - fire_event( - NewConnection(conn_name=conn_name, conn_type=self.TYPE, node_info=get_node_info()) +def _get_update_error_msg(session: Session, host: str, pipeline_id: str, update_id: str) -> str: + events_url = f"https://{host}/api/2.0/pipelines/{pipeline_id}/events" + response = session.get(events_url) + if response.status_code != 200: + raise DbtRuntimeError( + f"Error getting pipeline event info for {pipeline_id}: {response.text}" ) - return conn + events = response.json().get("events", []) + update_events = [ + e + for e in events + if e.get("event_type", "") == "update_progress" + and e.get("origin", {}).get("update_id") == update_id + ] - def _get_if_exists_compute_connection( - self, compute_name: str - ) -> Optional[DatabricksDBTConnection]: - """Get the connection for the current thread and named compute, if it exists.""" + error_events = [ + e + for e in update_events + if e.get("details", {}).get("update_progress", {}).get("state", "") == "FAILED" + ] - with self.lock: - threads_map = self._get_compute_connections() - return threads_map.get(compute_name) + msg = "" + if error_events: + msg = error_events[0].get("message", "") - def _get_compute_connections( - self, - ) -> dict[Hashable, DatabricksDBTConnection]: - """Retrieve a map of compute name to connection for the current thread.""" - - thread_id = self.get_thread_identifier() - with self.lock: - thread_map = self.threads_compute_connections.get(thread_id) - if not thread_map: - thread_map = {} - self.threads_compute_connections[thread_id] = thread_map - return thread_map - - def _update_compute_connection( - self, conn: DatabricksDBTConnection, new_name: str - ) -> DatabricksDBTConnection: - if conn.name == new_name and conn.state == ConnectionState.OPEN: - # Found a connection and nothing to do, so just return it - return conn - - orig_conn_name: str = conn.name or "" - - if conn.state != ConnectionState.OPEN: - conn.handle = LazyHandle(self.open) - if conn.name != new_name: - conn.name = new_name - fire_event(ConnectionReused(orig_conn_name=orig_conn_name, conn_name=new_name)) - - current_thread_conn = cast(Optional[DatabricksDBTConnection], self.get_if_exists()) - if current_thread_conn and current_thread_conn.compute_name != conn.compute_name: - self.clear_thread_connection() - self.set_thread_connection(conn) - - logger.debug(ConnectionReuse(str(conn), orig_conn_name)) - - return conn + return msg def _get_compute_name(query_header_context: Any) -> Optional[str]: @@ -1080,18 +1155,24 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O """Get the http_path for the compute specified for the node. If none is specified default will be used.""" + thread_id = (os.getpid(), get_ident()) + # ResultNode *should* have relation_name attr, but we work around a core # issue by checking. relation_name = getattr(query_header_context, "relation_name", "[unknown]") # If there is no node we return the http_path for the default compute. if not query_header_context: + if not USE_LONG_SESSIONS: + logger.debug(f"Thread {thread_id}: using default compute resource.") return creds.http_path # Get the name of the compute resource specified in the node's config. # If none is specified return the http_path for the default compute. compute_name = _get_compute_name(query_header_context) if not compute_name: + if not USE_LONG_SESSIONS: + logger.debug(f"On thread {thread_id}: {relation_name} using default compute resource.") return creds.http_path # Get the http_path for the named compute. @@ -1106,6 +1187,11 @@ def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> O f"does not specify http_path, relation: {relation_name}" ) + if not USE_LONG_SESSIONS: + logger.debug( + f"On thread {thread_id}: {relation_name} using compute resource '{compute_name}'." + ) + return http_path diff --git a/dbt/adapters/databricks/credentials.py b/dbt/adapters/databricks/credentials.py index bc5054392..e30ab2f0c 100644 --- a/dbt/adapters/databricks/credentials.py +++ b/dbt/adapters/databricks/credentials.py @@ -15,16 +15,10 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.core import Config, CredentialsProvider from dbt.adapters.contracts.connection import Credentials -from dbt.adapters.databricks.auth import m2m_auth, token_auth -from dbt.adapters.databricks.events.credential_events import ( - CredentialLoadError, - CredentialSaveError, - CredentialShardEvent, -) -from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.databricks.logging import logger CATALOG_KEY_IN_SESSION_PROPERTIES = "databricks.catalog" +DBT_DATABRICKS_INVOCATION_ENV = "DBT_DATABRICKS_INVOCATION_ENV" DBT_DATABRICKS_INVOCATION_ENV_REGEX = re.compile("^[A-z0-9\\-]+$") EXTRACT_CLUSTER_ID_FROM_HTTP_PATH_REGEX = re.compile(r"/?sql/protocolv1/o/\d+/(.*)") DBT_DATABRICKS_HTTP_SESSION_HEADERS = "DBT_DATABRICKS_HTTP_SESSION_HEADERS" @@ -75,10 +69,8 @@ class DatabricksCredentials(Credentials): @classmethod def __pre_deserialize__(cls, data: dict[Any, Any]) -> dict[Any, Any]: data = super().__pre_deserialize__(data) - data.setdefault("database", None) - data.setdefault("connection_parameters", {}) - data["connection_parameters"].setdefault("_retry_stop_after_attempts_count", 30) - data["connection_parameters"].setdefault("_retry_delay_max", 60) + if "database" not in data: + data["database"] = None return data def __post_init__(self) -> None: @@ -140,16 +132,21 @@ def __post_init__(self) -> None: def validate_creds(self) -> None: for key in ["host", "http_path"]: if not getattr(self, key): - raise DbtConfigError(f"The config '{key}' is required to connect to Databricks") + raise DbtConfigError( + "The config '{}' is required to connect to Databricks".format(key) + ) + if not self.token and self.auth_type != "oauth": raise DbtConfigError( - "The config `auth_type: oauth` is required when not using access token" + ("The config `auth_type: oauth` is required when not using access token") ) if not self.client_id and self.client_secret: raise DbtConfigError( - "The config 'client_id' is required to connect " - "to Databricks when 'client_secret' is present" + ( + "The config 'client_id' is required to connect " + "to Databricks when 'client_secret' is present" + ) ) if (not self.azure_client_id and self.azure_client_secret) or ( @@ -164,7 +161,7 @@ def validate_creds(self) -> None: @classmethod def get_invocation_env(cls) -> Optional[str]: - invocation_env = GlobalState.get_invocation_env() + invocation_env = os.environ.get(DBT_DATABRICKS_INVOCATION_ENV) if invocation_env: # Thrift doesn't allow nested () so we need to ensure # that the passed user agent is valid. @@ -174,7 +171,9 @@ def get_invocation_env(cls) -> Optional[str]: @classmethod def get_all_http_headers(cls, user_http_session_headers: dict[str, str]) -> dict[str, str]: - http_session_headers_str = GlobalState.get_http_session_headers() + http_session_headers_str: Optional[str] = os.environ.get( + DBT_DATABRICKS_HTTP_SESSION_HEADERS + ) http_session_headers_dict: dict[str, str] = ( { diff --git a/dbt/adapters/databricks/global_state.py b/dbt/adapters/databricks/global_state.py deleted file mode 100644 index cdc5df986..000000000 --- a/dbt/adapters/databricks/global_state.py +++ /dev/null @@ -1,48 +0,0 @@ -import os -from typing import ClassVar, Optional - - -class GlobalState: - """Global state is a bad idea, but since we don't control instantiation, better to have it in a - single place than scattered throughout the codebase. - """ - - __invocation_env: ClassVar[Optional[str]] = None - __invocation_env_set: ClassVar[bool] = False - - @classmethod - def get_invocation_env(cls) -> Optional[str]: - if not cls.__invocation_env_set: - cls.__invocation_env = os.getenv("DBT_DATABRICKS_INVOCATION_ENV") - cls.__invocation_env_set = True - return cls.__invocation_env - - __session_headers: ClassVar[Optional[str]] = None - __session_headers_set: ClassVar[bool] = False - - @classmethod - def get_http_session_headers(cls) -> Optional[str]: - if not cls.__session_headers_set: - cls.__session_headers = os.getenv("DBT_DATABRICKS_HTTP_SESSION_HEADERS") - cls.__session_headers_set = True - return cls.__session_headers - - __describe_char_bypass: ClassVar[Optional[bool]] = None - - @classmethod - def get_char_limit_bypass(cls) -> bool: - if cls.__describe_char_bypass is None: - cls.__describe_char_bypass = ( - os.getenv("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "False").upper() == "TRUE" - ) - return cls.__describe_char_bypass - - __connector_log_level: ClassVar[Optional[str]] = None - - @classmethod - def get_connector_log_level(cls) -> str: - if cls.__connector_log_level is None: - cls.__connector_log_level = os.getenv( - "DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN" - ).upper() - return cls.__connector_log_level diff --git a/dbt/adapters/databricks/impl.py b/dbt/adapters/databricks/impl.py index 262973d83..d8c2a9162 100644 --- a/dbt/adapters/databricks/impl.py +++ b/dbt/adapters/databricks/impl.py @@ -32,8 +32,13 @@ GetColumnsByInformationSchema, ) from dbt.adapters.databricks.column import DatabricksColumn -from dbt.adapters.databricks.connections import DatabricksConnectionManager -from dbt.adapters.databricks.global_state import GlobalState +from dbt.adapters.databricks.connections import ( + USE_LONG_SESSIONS, + DatabricksConnectionManager, + DatabricksDBTConnection, + DatabricksSQLConnectionWrapper, + ExtendedSessionConnectionManager, +) from dbt.adapters.databricks.python_models.python_submissions import ( AllPurposeClusterPythonJobHelper, JobClusterPythonJobHelper, @@ -149,7 +154,7 @@ def get_identifier_list_string(table_names: set[str]) -> str: """ _identifier = "|".join(table_names) - bypass_2048_char_limit = GlobalState.get_char_limit_bypass() + bypass_2048_char_limit = os.environ.get("DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS", "false") if bypass_2048_char_limit == "true": _identifier = _identifier if len(_identifier) < 2048 else "*" return _identifier @@ -161,7 +166,10 @@ class DatabricksAdapter(SparkAdapter): Relation = DatabricksRelation Column = DatabricksColumn - ConnectionManager = DatabricksConnectionManager + if USE_LONG_SESSIONS: + ConnectionManager: type[DatabricksConnectionManager] = ExtendedSessionConnectionManager + else: + ConnectionManager = DatabricksConnectionManager connections: DatabricksConnectionManager @@ -545,7 +553,7 @@ def _get_catalog_for_relation_map( used_schemas: frozenset[tuple[str, str]], ) -> tuple["Table", list[Exception]]: with executor(self.config) as tpe: - futures: list[Future[Table]] = [] + futures: list[Future["Table"]] = [] for schema, relations in relation_map.items(): if schema in used_schemas: identifier = get_identifier_list_string(relations) @@ -835,14 +843,20 @@ def get_from_relation( ) -> DatabricksRelationConfig: """Get the relation config from the relation.""" - relation_config = super().get_from_relation(adapter, relation) + relation_config = super(DeltaLiveTableAPIBase, cls).get_from_relation(adapter, relation) + connection = cast(DatabricksDBTConnection, adapter.connections.get_thread_connection()) + wrapper: DatabricksSQLConnectionWrapper = connection.handle # Ensure any current refreshes are completed before returning the relation config tblproperties = cast(TblPropertiesConfig, relation_config.config["tblproperties"]) if tblproperties.pipeline_id: - adapter.connections.api_client.dlt_pipelines.poll_for_completion( - tblproperties.pipeline_id - ) + # TODO fix this path so that it doesn't need a cursor + # It just calls APIs to poll the pipeline status + cursor = wrapper.cursor() + try: + cursor.poll_refresh_pipeline(tblproperties.pipeline_id) + finally: + cursor.close() return relation_config diff --git a/dbt/adapters/databricks/logging.py b/dbt/adapters/databricks/logging.py index 81e7449e1..d0f1d42ba 100644 --- a/dbt/adapters/databricks/logging.py +++ b/dbt/adapters/databricks/logging.py @@ -1,7 +1,7 @@ +import os from logging import Handler, LogRecord, getLogger from typing import Union -from dbt.adapters.databricks.global_state import GlobalState from dbt.adapters.events.logging import AdapterLogger logger = AdapterLogger("Databricks") @@ -22,7 +22,7 @@ def emit(self, record: LogRecord) -> None: dbt_adapter_logger = AdapterLogger("databricks-sql-connector") pysql_logger = getLogger("databricks.sql") -pysql_logger_level = GlobalState.get_connector_log_level() +pysql_logger_level = os.environ.get("DBT_DATABRICKS_CONNECTOR_LOG_LEVEL", "WARN").upper() pysql_logger.setLevel(pysql_logger_level) pysql_handler = DbtCoreHandler(dbt_logger=dbt_adapter_logger, level=pysql_logger_level) diff --git a/dbt/adapters/databricks/python_models/run_tracking.py b/dbt/adapters/databricks/python_models/run_tracking.py index 4b4fea419..e48dae7d4 100644 --- a/dbt/adapters/databricks/python_models/run_tracking.py +++ b/dbt/adapters/databricks/python_models/run_tracking.py @@ -6,7 +6,7 @@ from dbt.adapters.databricks.logging import logger -class PythonRunTracker: +class PythonRunTracker(object): _run_ids: set[str] = set() _commands: set[CommandExecution] = set() _lock = threading.Lock() diff --git a/dbt/adapters/databricks/relation.py b/dbt/adapters/databricks/relation.py index 4afffa2d4..efa091a4a 100644 --- a/dbt/adapters/databricks/relation.py +++ b/dbt/adapters/databricks/relation.py @@ -1,6 +1,6 @@ from collections.abc import Iterable from dataclasses import dataclass, field -from typing import Any, Optional, Type # noqa +from typing import Any, Optional, Type from dbt_common.contracts.constraints import ConstraintType from dbt_common.dataclass_schema import StrEnum @@ -41,8 +41,6 @@ class DatabricksRelationType(StrEnum): Foreign = "foreign" StreamingTable = "streaming_table" External = "external" - ManagedShallowClone = "managed_shallow_clone" - ExternalShallowClone = "external_shallow_clone" Unknown = "unknown" @@ -136,7 +134,7 @@ def matches( return match @classproperty - def get_relation_type(cls) -> Type[DatabricksRelationType]: # noqa + def get_relation_type(cls) -> Type[DatabricksRelationType]: return DatabricksRelationType def information_schema(self, view_name: Optional[str] = None) -> InformationSchema: diff --git a/dbt/adapters/databricks/relation_configs/tblproperties.py b/dbt/adapters/databricks/relation_configs/tblproperties.py index d708c7f83..41d18f3fd 100644 --- a/dbt/adapters/databricks/relation_configs/tblproperties.py +++ b/dbt/adapters/databricks/relation_configs/tblproperties.py @@ -38,7 +38,6 @@ class TblPropertiesConfig(DatabricksComponentConfig): "delta.feature.rowTracking", "delta.rowTracking.materializedRowCommitVersionColumnName", "delta.rowTracking.materializedRowIdColumnName", - "spark.internal.pipelines.top_level_entry.user_specified_name", ] def __eq__(self, __value: Any) -> bool: diff --git a/dbt/adapters/databricks/utils.py b/dbt/adapters/databricks/utils.py index 3dfd4096f..616e5b368 100644 --- a/dbt/adapters/databricks/utils.py +++ b/dbt/adapters/databricks/utils.py @@ -73,7 +73,3 @@ def handle_missing_objects(exec: Callable[[], T], default: T) -> T: if check_not_found_error(errmsg): return default raise e - - -def quote(name: str) -> str: - return f"`{name}`" diff --git a/dbt/include/databricks/macros/adapters/columns.sql b/dbt/include/databricks/macros/adapters/columns.sql index d9b041ccd..e1fc1d116 100644 --- a/dbt/include/databricks/macros/adapters/columns.sql +++ b/dbt/include/databricks/macros/adapters/columns.sql @@ -25,20 +25,3 @@ {% do return(load_result('get_columns_comments_via_information_schema').table) %} {% endmacro %} - -{% macro databricks__alter_relation_add_remove_columns(relation, add_columns, remove_columns) %} - {% if remove_columns %} - {% if not relation.is_delta %} - {{ exceptions.raise_compiler_error('Delta format required for dropping columns from tables') }} - {% endif %} - {%- call statement('alter_relation_remove_columns') -%} - ALTER TABLE {{ relation.render() }} DROP COLUMNS ({{ api.Column.format_remove_column_list(remove_columns) }}) - {%- endcall -%} - {% endif %} - - {% if add_columns %} - {%- call statement('alter_relation_add_columns') -%} - ALTER TABLE {{ relation.render() }} ADD COLUMNS ({{ api.Column.format_add_column_list(add_columns) }}) - {%- endcall -%} - {% endif %} -{% endmacro %} \ No newline at end of file diff --git a/dbt/include/databricks/macros/adapters/persist_docs.sql b/dbt/include/databricks/macros/adapters/persist_docs.sql index a8ad48bab..8e959a9f9 100644 --- a/dbt/include/databricks/macros/adapters/persist_docs.sql +++ b/dbt/include/databricks/macros/adapters/persist_docs.sql @@ -1,10 +1,12 @@ {% macro databricks__alter_column_comment(relation, column_dict) %} {% if config.get('file_format', default='delta') in ['delta', 'hudi'] %} - {% for column in column_dict.values() %} - {% set comment = column['description'] %} + {% for column_name in column_dict %} + {% set comment = column_dict[column_name]['description'] %} {% set escaped_comment = comment | replace('\'', '\\\'') %} {% set comment_query %} - alter table {{ relation.render()|lower }} change column {{ api.Column.get_name(column) }} comment '{{ escaped_comment }}'; + alter table {{ relation }} change column + {{ adapter.quote(column_name) if column_dict[column_name]['quote'] else column_name }} + comment '{{ escaped_comment }}'; {% endset %} {% do run_query(comment_query) %} {% endfor %} @@ -13,7 +15,7 @@ {% macro alter_table_comment(relation, model) %} {% set comment_query %} - comment on table {{ relation.render()|lower }} is '{{ model.description | replace("'", "\\'") }}' + comment on table {{ relation|lower }} is '{{ model.description | replace("'", "\\'") }}' {% endset %} {% do run_query(comment_query) %} {% endmacro %} @@ -28,3 +30,18 @@ {% do alter_column_comment(relation, columns_to_persist_docs) %} {% endif %} {% endmacro %} + +{% macro get_column_comment_sql(column_name, column_dict) -%} + {% if column_name in column_dict and column_dict[column_name]["description"] -%} + {% set escaped_description = column_dict[column_name]["description"] | replace("'", "\\'") %} + {% set column_comment_clause = "comment '" ~ escaped_description ~ "'" %} + {%- endif -%} + {{ adapter.quote(column_name) }} {{ column_comment_clause }} +{% endmacro %} + +{% macro get_persist_docs_column_list(model_columns, query_columns) %} + {% for column_name in query_columns %} + {{ get_column_comment_sql(column_name, model_columns) }} + {{- ", " if not loop.last else "" }} + {% endfor %} +{% endmacro %} diff --git a/dbt/include/databricks/macros/materializations/incremental/strategies.sql b/dbt/include/databricks/macros/materializations/incremental/strategies.sql index 311c2f405..57db5496a 100644 --- a/dbt/include/databricks/macros/materializations/incremental/strategies.sql +++ b/dbt/include/databricks/macros/materializations/incremental/strategies.sql @@ -91,15 +91,8 @@ select {{source_cols_csv}} from {{ source_relation }} {%- set not_matched_by_source_action = config.get('not_matched_by_source_action') -%} {%- set not_matched_by_source_condition = config.get('not_matched_by_source_condition') -%} - - {%- set not_matched_by_source_action_trimmed = not_matched_by_source_action | lower | trim(' \n\t') %} - {%- set not_matched_by_source_action_is_set = ( - not_matched_by_source_action_trimmed == 'delete' - or not_matched_by_source_action_trimmed.startswith('update') - ) - %} - + {% if unique_key %} {% if unique_key is sequence and unique_key is not mapping and unique_key is not string %} {% for key in unique_key %} @@ -144,12 +137,12 @@ select {{source_cols_csv}} from {{ source_relation }} then insert {{ get_merge_insert(on_schema_change, source_columns, source_alias) }} {%- endif %} - {%- if not_matched_by_source_action_is_set %} + {%- if not_matched_by_source_action == 'delete' %} when not matched by source {%- if not_matched_by_source_condition %} and ({{ not_matched_by_source_condition }}) {%- endif %} - then {{ not_matched_by_source_action }} + then delete {%- endif %} {% endmacro %} diff --git a/dbt/include/databricks/macros/materializations/seeds/helpers.sql b/dbt/include/databricks/macros/materializations/seeds/helpers.sql index d1ddc997e..d9bffc748 100644 --- a/dbt/include/databricks/macros/materializations/seeds/helpers.sql +++ b/dbt/include/databricks/macros/materializations/seeds/helpers.sql @@ -6,7 +6,7 @@ {% set batch_size = get_batch_size() %} {% set column_override = model['config'].get('column_types', {}) %} - {% set must_cast = model['config'].get('file_format', 'delta') == 'parquet' %} + {% set must_cast = model['config'].get("file_format", "delta") == "parquet" %} {% set statements = [] %} diff --git a/dbt/include/databricks/macros/materializations/snapshot.sql b/dbt/include/databricks/macros/materializations/snapshot.sql index 3a513a24d..3d1236a15 100644 --- a/dbt/include/databricks/macros/materializations/snapshot.sql +++ b/dbt/include/databricks/macros/materializations/snapshot.sql @@ -1,4 +1,27 @@ +{% macro databricks_build_snapshot_staging_table(strategy, sql, target_relation) %} + {% set tmp_identifier = target_relation.identifier ~ '__dbt_tmp' %} + + {%- set tmp_relation = api.Relation.create(identifier=tmp_identifier, + schema=target_relation.schema, + database=target_relation.database, + type='view') -%} + + {% set select = snapshot_staging_table(strategy, sql, target_relation) %} + + {# needs to be a non-temp view so that its columns can be ascertained via `describe` #} + {% call statement('build_snapshot_staging_relation') %} + create or replace view {{ tmp_relation }} + as + {{ select }} + {% endcall %} + + {% do return(tmp_relation) %} +{% endmacro %} + + {% materialization snapshot, adapter='databricks' %} + {%- set config = model['config'] -%} + {%- set target_table = model.get('alias', model.get('name')) -%} {%- set strategy_name = config.get('strategy') -%} @@ -39,43 +62,47 @@ {{ run_hooks(pre_hooks, inside_transaction=True) }} {% set strategy_macro = strategy_dispatch(strategy_name) %} - {% set strategy = strategy_macro(model, "snapshotted_data", "source_data", model['config'], target_relation_exists) %} + {% set strategy = strategy_macro(model, "snapshotted_data", "source_data", config, target_relation_exists) %} {% if not target_relation_exists %} {% set build_sql = build_snapshot_table(strategy, model['compiled_code']) %} - {% set build_or_select_sql = build_sql %} {% set final_sql = create_table_as(False, target_relation, build_sql) %} - {% else %} + {% call statement('main') %} + {{ final_sql }} + {% endcall %} + + {% do persist_docs(target_relation, model, for_relation=False) %} - {% set columns = config.get("snapshot_table_column_names") or get_snapshot_table_column_names() %} + {% else %} - {{ adapter.assert_valid_snapshot_target_given_strategy(target_relation, columns, strategy) }} + {{ adapter.valid_snapshot_target(target_relation) }} - {% set build_or_select_sql = snapshot_staging_table(strategy, sql, target_relation) %} - {% set staging_table = build_snapshot_staging_table(strategy, sql, target_relation) %} + {% if target_relation.database is none %} + {% set staging_table = spark_build_snapshot_staging_table(strategy, sql, target_relation) %} + {% else %} + {% set staging_table = databricks_build_snapshot_staging_table(strategy, sql, target_relation) %} + {% endif %} -- this may no-op if the database does not require column expansion {% do adapter.expand_target_column_types(from_relation=staging_table, to_relation=target_relation) %} - {% set remove_columns = ['dbt_change_type', 'DBT_CHANGE_TYPE', 'dbt_unique_key', 'DBT_UNIQUE_KEY'] %} - {% if unique_key | is_list %} - {% for key in strategy.unique_key %} - {{ remove_columns.append('dbt_unique_key_' + loop.index|string) }} - {{ remove_columns.append('DBT_UNIQUE_KEY_' + loop.index|string) }} - {% endfor %} - {% endif %} - {% set missing_columns = adapter.get_missing_columns(staging_table, target_relation) - | rejectattr('name', 'in', remove_columns) + | rejectattr('name', 'equalto', 'dbt_change_type') + | rejectattr('name', 'equalto', 'DBT_CHANGE_TYPE') + | rejectattr('name', 'equalto', 'dbt_unique_key') + | rejectattr('name', 'equalto', 'DBT_UNIQUE_KEY') | list %} {% do create_columns(target_relation, missing_columns) %} {% set source_columns = adapter.get_columns_in_relation(staging_table) - | rejectattr('name', 'in', remove_columns) + | rejectattr('name', 'equalto', 'dbt_change_type') + | rejectattr('name', 'equalto', 'DBT_CHANGE_TYPE') + | rejectattr('name', 'equalto', 'dbt_unique_key') + | rejectattr('name', 'equalto', 'DBT_UNIQUE_KEY') | list %} {% set quoted_source_columns = [] %} @@ -90,34 +117,23 @@ ) %} - {% endif %} - + {% call statement_with_staging_table('main', staging_table) %} + {{ final_sql }} + {% endcall %} - {{ check_time_data_types(build_or_select_sql) }} + {% do persist_docs(target_relation, model, for_relation=True) %} - {% call statement('main') %} - {{ final_sql }} - {% endcall %} - - {% set should_revoke = should_revoke(target_relation_exists, full_refresh_mode=False) %} - {% do apply_grants(target_relation, grant_config, should_revoke=should_revoke) %} + {% endif %} - {% do persist_docs(target_relation, model) %} + {% set should_revoke = should_revoke(target_relation_exists, full_refresh_mode) %} + {% do apply_grants(target_relation, grant_config, should_revoke) %} - {% if not target_relation_exists %} - {% do create_indexes(target_relation) %} - {% endif %} + {% do persist_constraints(target_relation, model) %} {{ run_hooks(post_hooks, inside_transaction=True) }} {{ adapter.commit() }} - {% if staging_table is defined %} - {% do post_snapshot(staging_table) %} - {% endif %} - - {% do persist_constraints(target_relation, model) %} - {{ run_hooks(post_hooks, inside_transaction=False) }} {{ return({'relations': [target_relation]}) }} diff --git a/dbt/include/databricks/macros/relations/constraints.sql b/dbt/include/databricks/macros/relations/constraints.sql index 6d999823a..bb77145ff 100644 --- a/dbt/include/databricks/macros/relations/constraints.sql +++ b/dbt/include/databricks/macros/relations/constraints.sql @@ -106,19 +106,19 @@ {% macro get_constraint_sql(relation, constraint, model, column={}) %} {% set statements = [] %} - {% set type = constraint.get('type', '') %} + {% set type = constraint.get("type", "") %} {% if type == 'check' %} - {% set expression = constraint.get('expression', '') %} + {% set expression = constraint.get("expression", "") %} {% if not expression %} {{ exceptions.raise_compiler_error('Invalid check constraint expression') }} {% endif %} - {% set name = constraint.get('name') %} + {% set name = constraint.get("name") %} {% if not name %} {% if local_md5 %} {{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }} - {%- set name = local_md5 (relation.identifier ~ ";" ~ column.get('name', '') ~ ";" ~ expression ~ ";") -%} + {%- set name = local_md5 (relation.identifier ~ ";" ~ column.get("name", "") ~ ";" ~ expression ~ ";") -%} {% else %} {{ exceptions.raise_compiler_error("Constraint of type " ~ type ~ " with no `name` provided, and no md5 utility.") }} {% endif %} @@ -126,15 +126,15 @@ {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " check (" ~ expression ~ ");" %} {% do statements.append(stmt) %} {% elif type == 'not_null' %} - {% set column_names = constraint.get('columns', []) %} + {% set column_names = constraint.get("columns", []) %} {% if column and not column_names %} {% set column_names = [column['name']] %} {% endif %} {% for column_name in column_names %} {% set column = model.get('columns', {}).get(column_name) %} {% if column %} - {% set quoted_name = api.Column.get_name(column) %} - {% set stmt = "alter table " ~ relation.render() ~ " change column " ~ quoted_name ~ " set not null " ~ (constraint.expression or "") ~ ";" %} + {% set quoted_name = adapter.quote(column['name']) if column['quote'] else column['name'] %} + {% set stmt = "alter table " ~ relation ~ " change column " ~ quoted_name ~ " set not null " ~ (constraint.expression or "") ~ ";" %} {% do statements.append(stmt) %} {% else %} {{ exceptions.warn('not_null constraint on invalid column: ' ~ column_name) }} @@ -144,7 +144,7 @@ {% if constraint.get('warn_unenforced') %} {{ exceptions.warn("unenforced constraint type: " ~ type)}} {% endif %} - {% set column_names = constraint.get('columns', []) %} + {% set column_names = constraint.get("columns", []) %} {% if column and not column_names %} {% set column_names = [column['name']] %} {% endif %} @@ -154,14 +154,14 @@ {% if not column %} {{ exceptions.warn('Invalid primary key column: ' ~ column_name) }} {% else %} - {% set quoted_name = api.Column.get_name(column) %} + {% set quoted_name = adapter.quote(column['name']) if column['quote'] else column['name'] %} {% do quoted_names.append(quoted_name) %} {% endif %} {% endfor %} {% set joined_names = quoted_names|join(", ") %} - {% set name = constraint.get('name') %} + {% set name = constraint.get("name") %} {% if not name %} {% if local_md5 %} {{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }} @@ -170,7 +170,7 @@ {{ exceptions.raise_compiler_error("Constraint of type " ~ type ~ " with no `name` provided, and no md5 utility.") }} {% endif %} {% endif %} - {% set stmt = "alter table " ~ relation.render() ~ " add constraint " ~ name ~ " primary key(" ~ joined_names ~ ");" %} + {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " primary key(" ~ joined_names ~ ");" %} {% do statements.append(stmt) %} {% elif type == 'foreign_key' %} @@ -178,7 +178,7 @@ {{ exceptions.warn("unenforced constraint type: " ~ constraint.type)}} {% endif %} - {% set name = constraint.get('name') %} + {% set name = constraint.get("name") %} {% if constraint.get('expression') %} @@ -191,9 +191,9 @@ {% endif %} {% endif %} - {% set stmt = "alter table " ~ relation.render() ~ " add constraint " ~ name ~ " foreign key" ~ constraint.get('expression') %} + {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " foreign key" ~ constraint.get('expression') %} {% else %} - {% set column_names = constraint.get('columns', []) %} + {% set column_names = constraint.get("columns", []) %} {% if column and not column_names %} {% set column_names = [column['name']] %} {% endif %} @@ -203,14 +203,14 @@ {% if not column %} {{ exceptions.warn('Invalid foreign key column: ' ~ column_name) }} {% else %} - {% set quoted_name = api.Column.get_name(column) %} + {% set quoted_name = adapter.quote(column['name']) if column['quote'] else column['name'] %} {% do quoted_names.append(quoted_name) %} {% endif %} {% endfor %} {% set joined_names = quoted_names|join(", ") %} - {% set parent = constraint.get('to') %} + {% set parent = constraint.get("to") %} {% if not parent %} {{ exceptions.raise_compiler_error('No parent table defined for foreign key: ' ~ expression) }} {% endif %} @@ -227,8 +227,8 @@ {% endif %} {% endif %} - {% set stmt = "alter table " ~ relation.render() ~ " add constraint " ~ name ~ " foreign key(" ~ joined_names ~ ") references " ~ parent %} - {% set parent_columns = constraint.get('to_columns') %} + {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " foreign key(" ~ joined_names ~ ") references " ~ parent %} + {% set parent_columns = constraint.get("to_columns") %} {% if parent_columns %} {% set stmt = stmt ~ "(" ~ parent_columns|join(", ") ~ ")"%} {% endif %} @@ -236,13 +236,13 @@ {% set stmt = stmt ~ ";" %} {% do statements.append(stmt) %} {% elif type == 'custom' %} - {% set expression = constraint.get('expression', '') %} + {% set expression = constraint.get("expression", "") %} {% if not expression %} {{ exceptions.raise_compiler_error('Missing custom constraint expression') }} {% endif %} - {% set name = constraint.get('name') %} - {% set expression = constraint.get('expression') %} + {% set name = constraint.get("name") %} + {% set expression = constraint.get("expression") %} {% if not name %} {% if local_md5 %} {{ exceptions.warn("Constraint of type " ~ type ~ " with no `name` provided. Generating hash instead for relation " ~ relation.identifier) }} @@ -251,7 +251,7 @@ {{ exceptions.raise_compiler_error("Constraint of type " ~ type ~ " with no `name` provided, and no md5 utility.") }} {% endif %} {% endif %} - {% set stmt = "alter table " ~ relation.render() ~ " add constraint " ~ name ~ " " ~ expression ~ ";" %} + {% set stmt = "alter table " ~ relation ~ " add constraint " ~ name ~ " " ~ expression ~ ";" %} {% do statements.append(stmt) %} {% elif constraint.get('warn_unsupported') %} {{ exceptions.warn("unsupported constraint type: " ~ constraint.type)}} @@ -264,15 +264,15 @@ {# convert constraints defined using the original databricks format #} {% set dbt_constraints = [] %} {% for constraint in constraints %} - {% if constraint.get and constraint.get('type') %} + {% if constraint.get and constraint.get("type") %} {# already in model contract format #} {% do dbt_constraints.append(constraint) %} {% else %} {% if column %} {% if constraint == "not_null" %} - {% do dbt_constraints.append({"type": "not_null", "columns": [column.get('name')]}) %} + {% do dbt_constraints.append({"type": "not_null", "columns": [column.get("name")]}) %} {% else %} - {{ exceptions.raise_compiler_error('Invalid constraint for column ' ~ column.get('name', "") ~ '. Only `not_null` is supported.') }} + {{ exceptions.raise_compiler_error('Invalid constraint for column ' ~ column.get("name", "") ~ '. Only `not_null` is supported.') }} {% endif %} {% else %} {% set name = constraint['name'] %} diff --git a/dbt/include/databricks/macros/relations/liquid_clustering.sql b/dbt/include/databricks/macros/relations/liquid_clustering.sql index b30269fd9..3cf810488 100644 --- a/dbt/include/databricks/macros/relations/liquid_clustering.sql +++ b/dbt/include/databricks/macros/relations/liquid_clustering.sql @@ -15,7 +15,7 @@ {%- set cols = [cols] -%} {%- endif -%} {%- call statement('set_cluster_by_columns') -%} - ALTER {{ target_relation.type }} {{ target_relation.render() }} CLUSTER BY ({{ cols | join(', ') }}) + ALTER {{ target_relation.type }} {{ target_relation }} CLUSTER BY ({{ cols | join(', ') }}) {%- endcall -%} {%- endif %} {%- endmacro -%} \ No newline at end of file diff --git a/dbt/include/databricks/macros/relations/materialized_view/alter.sql b/dbt/include/databricks/macros/relations/materialized_view/alter.sql index d406508d2..41d9bed06 100644 --- a/dbt/include/databricks/macros/relations/materialized_view/alter.sql +++ b/dbt/include/databricks/macros/relations/materialized_view/alter.sql @@ -46,6 +46,6 @@ {% macro get_alter_mv_internal(relation, configuration_changes) %} {%- set refresh = configuration_changes.changes["refresh"] -%} -- Currently only schedule can be altered - ALTER MATERIALIZED VIEW {{ relation.render() }} + ALTER MATERIALIZED VIEW {{ relation }} {{ get_alter_sql_refresh_schedule(refresh.cron, refresh.time_zone_value, refresh.is_altered) -}} {% endmacro %} diff --git a/dbt/include/databricks/macros/relations/materialized_view/drop.sql b/dbt/include/databricks/macros/relations/materialized_view/drop.sql index 4def47441..f3774119d 100644 --- a/dbt/include/databricks/macros/relations/materialized_view/drop.sql +++ b/dbt/include/databricks/macros/relations/materialized_view/drop.sql @@ -1,3 +1,3 @@ {% macro databricks__drop_materialized_view(relation) -%} - drop materialized view if exists {{ relation.render() }} + drop materialized view if exists {{ relation }} {%- endmacro %} diff --git a/dbt/include/databricks/macros/relations/materialized_view/refresh.sql b/dbt/include/databricks/macros/relations/materialized_view/refresh.sql index 9967eb21f..10a8346be 100644 --- a/dbt/include/databricks/macros/relations/materialized_view/refresh.sql +++ b/dbt/include/databricks/macros/relations/materialized_view/refresh.sql @@ -1,3 +1,3 @@ {% macro databricks__refresh_materialized_view(relation) -%} - refresh materialized view {{ relation.render() }} + refresh materialized view {{ relation }} {% endmacro %} diff --git a/dbt/include/databricks/macros/relations/streaming_table/drop.sql b/dbt/include/databricks/macros/relations/streaming_table/drop.sql index 1cfc246a8..c8e0cd839 100644 --- a/dbt/include/databricks/macros/relations/streaming_table/drop.sql +++ b/dbt/include/databricks/macros/relations/streaming_table/drop.sql @@ -3,5 +3,5 @@ {%- endmacro %} {% macro default__drop_streaming_table(relation) -%} - drop table if exists {{ relation.render() }} + drop table if exists {{ relation }} {%- endmacro %} diff --git a/dbt/include/databricks/macros/relations/streaming_table/refresh.sql b/dbt/include/databricks/macros/relations/streaming_table/refresh.sql index 94c96d5cc..66b86f1f4 100644 --- a/dbt/include/databricks/macros/relations/streaming_table/refresh.sql +++ b/dbt/include/databricks/macros/relations/streaming_table/refresh.sql @@ -3,7 +3,7 @@ {%- endmacro %} {% macro databricks__refresh_streaming_table(relation, sql) -%} - create or refresh streaming table {{ relation.render() }} + create or refresh streaming table {{ relation }} as {{ sql }} {% endmacro %} diff --git a/dbt/include/databricks/macros/relations/table/create.sql b/dbt/include/databricks/macros/relations/table/create.sql index b2aba2fec..9e74d57d6 100644 --- a/dbt/include/databricks/macros/relations/table/create.sql +++ b/dbt/include/databricks/macros/relations/table/create.sql @@ -5,9 +5,9 @@ {%- else -%} {%- set file_format = config.get('file_format', default='delta') -%} {% if file_format == 'delta' %} - create or replace table {{ relation.render() }} + create or replace table {{ relation }} {% else %} - create table {{ relation.render() }} + create table {{ relation }} {% endif %} {%- set contract_config = config.get('contract') -%} {% if contract_config and contract_config.enforced %} diff --git a/dbt/include/databricks/macros/relations/table/drop.sql b/dbt/include/databricks/macros/relations/table/drop.sql index 7bce7cf46..3a7d0ced0 100644 --- a/dbt/include/databricks/macros/relations/table/drop.sql +++ b/dbt/include/databricks/macros/relations/table/drop.sql @@ -1,3 +1,3 @@ {% macro databricks__drop_table(relation) -%} - drop table if exists {{ relation.render() }} + drop table if exists {{ relation }} {%- endmacro %} diff --git a/dbt/include/databricks/macros/relations/tags.sql b/dbt/include/databricks/macros/relations/tags.sql index fb39c3785..3467631df 100644 --- a/dbt/include/databricks/macros/relations/tags.sql +++ b/dbt/include/databricks/macros/relations/tags.sql @@ -33,7 +33,7 @@ {%- endmacro -%} {% macro alter_set_tags(relation, tags) -%} - ALTER {{ relation.type }} {{ relation.render() }} SET TAGS ( + ALTER {{ relation.type }} {{ relation }} SET TAGS ( {% for tag in tags -%} '{{ tag }}' = '{{ tags[tag] }}' {%- if not loop.last %}, {% endif -%} {%- endfor %} @@ -41,7 +41,7 @@ {%- endmacro -%} {% macro alter_unset_tags(relation, tags) -%} - ALTER {{ relation.type }} {{ relation.render() }} UNSET TAGS ( + ALTER {{ relation.type }} {{ relation }} UNSET TAGS ( {% for tag in tags -%} '{{ tag }}' {%- if not loop.last %}, {%- endif %} {%- endfor %} diff --git a/dbt/include/databricks/macros/relations/tblproperties.sql b/dbt/include/databricks/macros/relations/tblproperties.sql index b11fd7b5c..34b6488f7 100644 --- a/dbt/include/databricks/macros/relations/tblproperties.sql +++ b/dbt/include/databricks/macros/relations/tblproperties.sql @@ -17,7 +17,7 @@ {% set tblproperty_statment = databricks__tblproperties_clause(tblproperties) %} {% if tblproperty_statment %} {%- call statement('apply_tblproperties') -%} - ALTER {{ relation.type }} {{ relation.render() }} SET {{ tblproperty_statment}} + ALTER {{ relation.type }} {{ relation }} SET {{ tblproperty_statment}} {%- endcall -%} {% endif %} {%- endmacro -%} diff --git a/dbt/include/databricks/macros/relations/view/create.sql b/dbt/include/databricks/macros/relations/view/create.sql index 5399b4ef5..096e12de4 100644 --- a/dbt/include/databricks/macros/relations/view/create.sql +++ b/dbt/include/databricks/macros/relations/view/create.sql @@ -1,5 +1,5 @@ {% macro databricks__create_view_as(relation, sql) -%} - create or replace view {{ relation.render() }} + create or replace view {{ relation }} {% if config.persist_column_docs() -%} {% set model_columns = model.columns %} {% set query_columns = get_columns_in_query(sql) %} diff --git a/dbt/include/databricks/macros/relations/view/drop.sql b/dbt/include/databricks/macros/relations/view/drop.sql index 9098c925f..aa199d760 100644 --- a/dbt/include/databricks/macros/relations/view/drop.sql +++ b/dbt/include/databricks/macros/relations/view/drop.sql @@ -1,3 +1,3 @@ {% macro databricks__drop_view(relation) -%} - drop view if exists {{ relation.render() }} + drop view if exists {{ relation }} {%- endmacro %} diff --git a/docs/databricks-merge.md b/docs/databricks-merge.md index 4034b8f6a..caa003367 100644 --- a/docs/databricks-merge.md +++ b/docs/databricks-merge.md @@ -18,11 +18,7 @@ From v.1.9 onwards `merge` behavior can be tuned by setting the additional param - `skip_matched_step`: if set to `true`, dbt will completely skip the `matched` clause of the merge statement. - `skip_not_matched_step`: similarly if `true` the `not matched` clause will be skipped. - - `not_matched_by_source_action`: can be set to an action for the case the record does not exist in a source dataset. - - if set to `delete` the corresponding `when not matched by source ... then delete` clause will be added to the merge statement. - - if the action starts with `update` then the format `update set ` is assumed, which will run update statement syntactically as provided. - Can be multiline formatted. - - in other cases by default no action is taken and now error raised. + - `not_matched_by_source_action`: if set to `delete` the corresponding `when not matched by source ... then delete` clause will be added to the merge statement. - `merge_with_schema_evolution`: when set to `true` dbt generates the merge statement with `WITH SCHEMA EVOLUTION` clause. - Step conditions that are expressed with an explicit SQL predicates allow to execute corresponding action only in case the conditions are met in addition to matching by the `unique_key`. @@ -44,11 +40,7 @@ Example below illustrates how these parameters affect the merge statement genera matched_condition='t.tech_change_ts < s.tech_change_ts', not_matched_condition='s.attr1 IS NOT NULL', not_matched_by_source_condition='t.tech_change_ts < current_timestamp()', - not_matched_by_source_action=''' - update set - t.attr1 = 'deleted', - t.tech_change_ts = current_timestamp() - ''', + not_matched_by_source_action='delete', merge_with_schema_evolution=true ) }} @@ -101,7 +93,5 @@ when not matched when not matched by source and t.tech_change_ts < current_timestamp() - then update set - t.attr1 = 'deleted', - t.tech_change_ts = current_timestamp() + then delete ``` diff --git a/pyproject.toml b/pyproject.toml index 52a8e7a0d..c5f1f97c7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,7 +29,8 @@ dependencies = [ "dbt-core>=1.9.0rc2, <2.0", "dbt-spark>=1.9.0b1, <2.0", "keyring>=23.13.0", - "pydantic>=1.10.0", + "pandas<2.2.0", + "pydantic>=1.10.0, <2", ] [project.urls] @@ -75,10 +76,8 @@ dependencies = [ "freezegun", "mypy", "pre-commit", - "ruff", "types-requests", "debugpy", - "pydantic>=1.10.0, <2", ] path = ".hatch" python = "3.9" @@ -102,7 +101,7 @@ line-length = 100 target-version = 'py39' [tool.ruff.lint] -select = ["E", "W", "F", "I", "UP"] +select = ["E", "W", "F", "I"] ignore = ["E203"] [tool.pytest.ini_options] diff --git a/tests/functional/adapter/concurrency/test_concurrency.py b/tests/functional/adapter/concurrency/test_concurrency.py index 8feeb7a7b..b1b8aded5 100644 --- a/tests/functional/adapter/concurrency/test_concurrency.py +++ b/tests/functional/adapter/concurrency/test_concurrency.py @@ -27,4 +27,4 @@ def test_concurrency(self, project): util.check_table_does_not_exist(project.adapter, "invalid") util.check_table_does_not_exist(project.adapter, "skip") - assert "PASS=5 WARN=0 ERROR=1 SKIP=1" in output + assert "PASS=5 WARN=0 ERROR=1 SKIP=1 TOTAL=7" in output diff --git a/tests/functional/adapter/ephemeral/test_ephemeral.py b/tests/functional/adapter/ephemeral/test_ephemeral.py index 52efdc0fc..c00585b81 100644 --- a/tests/functional/adapter/ephemeral/test_ephemeral.py +++ b/tests/functional/adapter/ephemeral/test_ephemeral.py @@ -33,7 +33,7 @@ def test_ephemeral_nested(self, project): results = util.run_dbt(["run"]) assert len(results) == 2 assert os.path.exists("./target/run/test/models/root_view.sql") - with open("./target/run/test/models/root_view.sql") as fp: + with open("./target/run/test/models/root_view.sql", "r") as fp: sql_file = fp.read() sql_file = re.sub(r"\d+", "", sql_file) diff --git a/tests/functional/adapter/hooks/test_model_hooks.py b/tests/functional/adapter/hooks/test_model_hooks.py index bf1a4e6c6..9a3ba61ed 100644 --- a/tests/functional/adapter/hooks/test_model_hooks.py +++ b/tests/functional/adapter/hooks/test_model_hooks.py @@ -49,7 +49,7 @@ def get_ctx_vars(self, state, count, project): "invocation_id", "thread_id", ] - field_list = ", ".join([f"{f}" for f in fields]) + field_list = ", ".join(["{}".format(f) for f in fields]) query = ( f"select {field_list} from {project.test_schema}.on_model_hook" f" where test_state = '{state}'" diff --git a/tests/functional/adapter/hooks/test_run_hooks.py b/tests/functional/adapter/hooks/test_run_hooks.py index 1f133d86a..5c7dd5c2d 100644 --- a/tests/functional/adapter/hooks/test_run_hooks.py +++ b/tests/functional/adapter/hooks/test_run_hooks.py @@ -65,7 +65,7 @@ def get_ctx_vars(self, state, project): "invocation_id", "thread_id", ] - field_list = ", ".join([f"{f}" for f in fields]) + field_list = ", ".join(["{}".format(f) for f in fields]) query = ( f"select {field_list} from {project.test_schema}.on_run_hook where test_state = " f"'{state}'" diff --git a/tests/functional/adapter/incremental/fixtures.py b/tests/functional/adapter/incremental/fixtures.py index 18fcee40e..9d0f29133 100644 --- a/tests/functional/adapter/incremental/fixtures.py +++ b/tests/functional/adapter/incremental/fixtures.py @@ -264,14 +264,7 @@ 4,Baron,Harkonnen,1 """ -not_matched_by_source_then_del_expected = """id,first,second,V -2,Paul,Atreides,0 -3,Dunkan,Aidaho,1 -4,Baron,Harkonnen,1 -""" - -not_matched_by_source_then_upd_expected = """id,first,second,V -1,--,--,-1 +not_matched_by_source_expected = """id,first,second,V 2,Paul,Atreides,0 3,Dunkan,Aidaho,1 4,Baron,Harkonnen,1 @@ -418,7 +411,7 @@ {% endif %} """ -not_matched_by_source_then_delete_model = """ +not_matched_by_source_model = """ {{ config( materialized = 'incremental', unique_key = 'id', @@ -453,46 +446,6 @@ {% endif %} """ -not_matched_by_source_then_update_model = """ -{{ config( - materialized = 'incremental', - unique_key = 'id', - incremental_strategy='merge', - target_alias='t', - source_alias='s', - skip_matched_step=true, - not_matched_by_source_condition='t.V > 0', - not_matched_by_source_action=''' - update set - t.first = \\\'--\\\', - t.second = \\\'--\\\', - t.V = -1 - ''', -) }} - -{% if not is_incremental() %} - --- data for first invocation of model - -select 1 as id, 'Vasya' as first, 'Pupkin' as second, 1 as V -union all -select 2 as id, 'Paul' as first, 'Atreides' as second, 0 as V -union all -select 3 as id, 'Dunkan' as first, 'Aidaho' as second, 1 as V - -{% else %} - --- data for subsequent incremental update - --- id = 1 should be updated with --- id = 2 should be kept as condition doesn't hold (t.V = 0) -select 3 as id, 'Dunkan' as first, 'Aidaho' as second, 2 as V -- No update, skipped -union all -select 4 as id, 'Baron' as first, 'Harkonnen' as second, 1 as V -- should append - -{% endif %} -""" - merge_schema_evolution_model = """ {{ config( materialized = 'incremental', diff --git a/tests/functional/adapter/incremental/test_incremental_strategies.py b/tests/functional/adapter/incremental/test_incremental_strategies.py index 45a24362f..88db60eed 100644 --- a/tests/functional/adapter/incremental/test_incremental_strategies.py +++ b/tests/functional/adapter/incremental/test_incremental_strategies.py @@ -288,38 +288,17 @@ def test_merge(self, project): ) -class TestNotMatchedBySourceAndConditionThenDelete(IncrementalBase): +class TestNotMatchedBySourceAndCondition(IncrementalBase): @pytest.fixture(scope="class") def seeds(self): return { - "not_matched_by_source_expected.csv": fixtures.not_matched_by_source_then_del_expected, + "not_matched_by_source_expected.csv": fixtures.not_matched_by_source_expected, } @pytest.fixture(scope="class") def models(self): return { - "not_matched_by_source.sql": fixtures.not_matched_by_source_then_delete_model, - } - - def test_merge(self, project): - self.seed_and_run_twice() - util.check_relations_equal( - project.adapter, - ["not_matched_by_source", "not_matched_by_source_expected"], - ) - - -class TestNotMatchedBySourceAndConditionThenUpdate(IncrementalBase): - @pytest.fixture(scope="class") - def seeds(self): - return { - "not_matched_by_source_expected.csv": fixtures.not_matched_by_source_then_upd_expected, - } - - @pytest.fixture(scope="class") - def models(self): - return { - "not_matched_by_source.sql": fixtures.not_matched_by_source_then_update_model, + "not_matched_by_source.sql": fixtures.not_matched_by_source_model, } def test_merge(self, project): diff --git a/tests/functional/adapter/materialized_view_tests/test_changes.py b/tests/functional/adapter/materialized_view_tests/test_changes.py index 0a470d545..1c1fa4c71 100644 --- a/tests/functional/adapter/materialized_view_tests/test_changes.py +++ b/tests/functional/adapter/materialized_view_tests/test_changes.py @@ -19,7 +19,7 @@ def _check_tblproperties(tblproperties: TblPropertiesConfig, expected: dict): final_tblproperties = { - k: v for k, v in tblproperties.tblproperties.items() if k not in tblproperties.ignore_list + k: v for k, v in tblproperties.tblproperties.items() if not k.startswith("pipeline") } assert final_tblproperties == expected diff --git a/tests/functional/adapter/python_model/test_python_model.py b/tests/functional/adapter/python_model/test_python_model.py index 726791dfa..858214b79 100644 --- a/tests/functional/adapter/python_model/test_python_model.py +++ b/tests/functional/adapter/python_model/test_python_model.py @@ -58,7 +58,7 @@ def test_changing_schema_with_log_validation(self, project, logs_dir): ) util.run_dbt(["run"]) log_file = os.path.join(logs_dir, "dbt.log") - with open(log_file) as f: + with open(log_file, "r") as f: log = f.read() # validate #5510 log_code_execution works assert "On model.test.simple_python_model:" in log diff --git a/tests/functional/adapter/simple_snapshot/test_new_record_mode.py b/tests/functional/adapter/simple_snapshot/test_new_record_mode.py deleted file mode 100644 index 6b436a311..000000000 --- a/tests/functional/adapter/simple_snapshot/test_new_record_mode.py +++ /dev/null @@ -1,74 +0,0 @@ -import pytest - -from dbt.tests.adapter.simple_snapshot.new_record_mode import ( - _delete_sql, - _invalidate_sql, - _ref_snapshot_sql, - _seed_new_record_mode, - _snapshot_actual_sql, - _snapshots_yml, - _update_sql, -) -from dbt.tests.util import check_relations_equal, run_dbt - - -class TestDatabricksSnapshotNewRecordMode: - @pytest.fixture(scope="class") - def snapshots(self): - return {"snapshot.sql": _snapshot_actual_sql} - - @pytest.fixture(scope="class") - def models(self): - return { - "snapshots.yml": _snapshots_yml, - "ref_snapshot.sql": _ref_snapshot_sql, - } - - @pytest.fixture(scope="class") - def seed_new_record_mode(self): - return _seed_new_record_mode - - @pytest.fixture(scope="class") - def invalidate_sql_1(self): - return _invalidate_sql.split(";", 1)[0].replace("BEGIN", "") - - @pytest.fixture(scope="class") - def invalidate_sql_2(self): - return _invalidate_sql.split(";", 1)[1].replace("END", "").replace(";", "") - - @pytest.fixture(scope="class") - def update_sql(self): - return _update_sql.replace("text", "string") - - @pytest.fixture(scope="class") - def delete_sql(self): - return _delete_sql - - def test_snapshot_new_record_mode( - self, project, seed_new_record_mode, invalidate_sql_1, invalidate_sql_2, update_sql - ): - for sql in ( - seed_new_record_mode.replace("text", "string") - .replace("TEXT", "STRING") - .replace("BEGIN", "") - .replace("END;", "") - .replace(" WITHOUT TIME ZONE", "") - .split(";") - ): - project.run_sql(sql) - results = run_dbt(["snapshot"]) - assert len(results) == 1 - - project.run_sql(invalidate_sql_1) - project.run_sql(invalidate_sql_2) - project.run_sql(update_sql) - - results = run_dbt(["snapshot"]) - assert len(results) == 1 - - check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) - - project.run_sql(_delete_sql) - - results = run_dbt(["snapshot"]) - assert len(results) == 1 diff --git a/tests/functional/adapter/simple_snapshot/test_various_configs.py b/tests/functional/adapter/simple_snapshot/test_various_configs.py deleted file mode 100644 index 18b82de00..000000000 --- a/tests/functional/adapter/simple_snapshot/test_various_configs.py +++ /dev/null @@ -1,345 +0,0 @@ -import datetime - -import pytest -from agate import Table - -from dbt.tests.adapter.simple_snapshot.fixtures import ( - create_multi_key_seed_sql, - create_multi_key_snapshot_expected_sql, - create_seed_sql, - create_snapshot_expected_sql, - model_seed_sql, - populate_multi_key_snapshot_expected_sql, - populate_snapshot_expected_sql, - populate_snapshot_expected_valid_to_current_sql, - ref_snapshot_sql, - seed_insert_sql, - seed_multi_key_insert_sql, - snapshot_actual_sql, - snapshots_multi_key_yml, - snapshots_no_column_names_yml, - snapshots_valid_to_current_yml, - snapshots_yml, - update_multi_key_sql, - update_sql, - update_with_current_sql, -) -from dbt.tests.util import ( - check_relations_equal, - get_manifest, - run_dbt, - run_dbt_and_capture, - run_sql_with_adapter, - update_config_file, -) - - -def text_replace(input: str) -> str: - return input.replace("TEXT", "STRING").replace("text", "string") - - -create_snapshot_expected_sql = text_replace(create_snapshot_expected_sql) -populate_snapshot_expected_sql = text_replace(populate_snapshot_expected_sql) -populate_snapshot_expected_valid_to_current_sql = text_replace( - populate_snapshot_expected_valid_to_current_sql -) -update_with_current_sql = text_replace(update_with_current_sql) -create_multi_key_snapshot_expected_sql = text_replace(create_multi_key_snapshot_expected_sql) -populate_multi_key_snapshot_expected_sql = text_replace(populate_multi_key_snapshot_expected_sql) -update_sql = text_replace(update_sql) -update_multi_key_sql = text_replace(update_multi_key_sql) - -invalidate_sql_1 = """ --- update records 11 - 21. Change email and updated_at field -update {schema}.seed set - updated_at = updated_at + interval '1 hour', - email = case when id = 20 then 'pfoxj@creativecommons.org' else 'new_' || email end -where id >= 10 and id <= 20 -""" - -invalidate_sql_2 = """ --- invalidate records 11 - 21 -update {schema}.snapshot_expected set - test_valid_to = updated_at + interval '1 hour' -where id >= 10 and id <= 20; -""" - -invalidate_multi_key_sql_1 = """ --- update records 11 - 21. Change email and updated_at field -update {schema}.seed set - updated_at = updated_at + interval '1 hour', - email = case when id1 = 20 then 'pfoxj@creativecommons.org' else 'new_' || email end -where id1 >= 10 and id1 <= 20; -""" - -invalidate_multi_key_sql_2 = """ --- invalidate records 11 - 21 -update {schema}.snapshot_expected set - test_valid_to = updated_at + interval '1 hour' -where id1 >= 10 and id1 <= 20; -""" - - -class BaseSnapshotColumnNames: - @pytest.fixture(scope="class") - def snapshots(self): - return {"snapshot.sql": snapshot_actual_sql} - - @pytest.fixture(scope="class") - def models(self): - return { - "snapshots.yml": snapshots_yml, - "ref_snapshot.sql": ref_snapshot_sql, - } - - def test_snapshot_column_names(self, project): - project.run_sql(create_seed_sql) - project.run_sql(create_snapshot_expected_sql) - project.run_sql(seed_insert_sql) - project.run_sql(populate_snapshot_expected_sql) - - results = run_dbt(["snapshot"]) - assert len(results) == 1 - - project.run_sql(invalidate_sql_1) - project.run_sql(invalidate_sql_2) - project.run_sql(update_sql) - - results = run_dbt(["snapshot"]) - assert len(results) == 1 - - check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) - - -class BaseSnapshotColumnNamesFromDbtProject: - @pytest.fixture(scope="class") - def snapshots(self): - return {"snapshot.sql": snapshot_actual_sql} - - @pytest.fixture(scope="class") - def models(self): - return { - "snapshots.yml": snapshots_no_column_names_yml, - "ref_snapshot.sql": ref_snapshot_sql, - } - - @pytest.fixture(scope="class") - def project_config_update(self): - return { - "snapshots": { - "test": { - "+snapshot_meta_column_names": { - "dbt_valid_to": "test_valid_to", - "dbt_valid_from": "test_valid_from", - "dbt_scd_id": "test_scd_id", - "dbt_updated_at": "test_updated_at", - } - } - } - } - - def test_snapshot_column_names_from_project(self, project): - project.run_sql(create_seed_sql) - project.run_sql(create_snapshot_expected_sql) - project.run_sql(seed_insert_sql) - project.run_sql(populate_snapshot_expected_sql) - - results = run_dbt(["snapshot"]) - assert len(results) == 1 - - project.run_sql(invalidate_sql_1) - project.run_sql(invalidate_sql_2) - project.run_sql(update_sql) - - results = run_dbt(["snapshot"]) - assert len(results) == 1 - - check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) - - -class BaseSnapshotInvalidColumnNames: - @pytest.fixture(scope="class") - def snapshots(self): - return {"snapshot.sql": snapshot_actual_sql} - - @pytest.fixture(scope="class") - def models(self): - return { - "snapshots.yml": snapshots_no_column_names_yml, - "ref_snapshot.sql": ref_snapshot_sql, - } - - @pytest.fixture(scope="class") - def project_config_update(self): - return { - "snapshots": { - "test": { - "+snapshot_meta_column_names": { - "dbt_valid_to": "test_valid_to", - "dbt_valid_from": "test_valid_from", - "dbt_scd_id": "test_scd_id", - "dbt_updated_at": "test_updated_at", - } - } - } - } - - def test_snapshot_invalid_column_names(self, project): - project.run_sql(create_seed_sql) - project.run_sql(create_snapshot_expected_sql) - project.run_sql(seed_insert_sql) - project.run_sql(populate_snapshot_expected_sql) - - results = run_dbt(["snapshot"]) - assert len(results) == 1 - manifest = get_manifest(project.project_root) - snapshot_node = manifest.nodes["snapshot.test.snapshot_actual"] - snapshot_node.config.snapshot_meta_column_names == { - "dbt_valid_to": "test_valid_to", - "dbt_valid_from": "test_valid_from", - "dbt_scd_id": "test_scd_id", - "dbt_updated_at": "test_updated_at", - } - - project.run_sql(invalidate_sql_1) - project.run_sql(invalidate_sql_2) - project.run_sql(update_sql) - - # Change snapshot_meta_columns and look for an error - different_columns = { - "snapshots": { - "test": { - "+snapshot_meta_column_names": { - "dbt_valid_to": "test_valid_to", - "dbt_updated_at": "test_updated_at", - } - } - } - } - update_config_file(different_columns, "dbt_project.yml") - - results, log_output = run_dbt_and_capture(["snapshot"], expect_pass=False) - assert len(results) == 1 - assert "Compilation Error in snapshot snapshot_actual" in log_output - assert "Snapshot target is missing configured columns" in log_output - - -class BaseSnapshotDbtValidToCurrent: - @pytest.fixture(scope="class") - def snapshots(self): - return {"snapshot.sql": snapshot_actual_sql} - - @pytest.fixture(scope="class") - def models(self): - return { - "snapshots.yml": snapshots_valid_to_current_yml, - "ref_snapshot.sql": ref_snapshot_sql, - } - - def test_valid_to_current(self, project): - project.run_sql(create_seed_sql) - project.run_sql(create_snapshot_expected_sql) - project.run_sql(seed_insert_sql) - project.run_sql(populate_snapshot_expected_valid_to_current_sql) - - results = run_dbt(["snapshot"]) - assert len(results) == 1 - - original_snapshot: Table = run_sql_with_adapter( - project.adapter, - "select id, test_scd_id, test_valid_to from {schema}.snapshot_actual", - "all", - ) - assert original_snapshot[0][2] == datetime.datetime( - 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc - ) - original_row = list( - filter(lambda x: x[1] == "61ecd07d17b8a4acb57d115eebb0e2c9", original_snapshot) - ) - assert original_row[0][2] == datetime.datetime( - 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc - ) - - project.run_sql(invalidate_sql_1) - project.run_sql(invalidate_sql_2) - project.run_sql(update_with_current_sql) - - results = run_dbt(["snapshot"]) - assert len(results) == 1 - - updated_snapshot: Table = run_sql_with_adapter( - project.adapter, - "select id, test_scd_id, test_valid_to from {schema}.snapshot_actual", - "all", - ) - print(updated_snapshot) - assert updated_snapshot[0][2] == datetime.datetime( - 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc - ) - # Original row that was updated now has a non-current (2099/12/31) date - original_row = list( - filter(lambda x: x[1] == "61ecd07d17b8a4acb57d115eebb0e2c9", updated_snapshot) - ) - assert original_row[0][2] == datetime.datetime( - 2016, 8, 20, 16, 44, 49, tzinfo=datetime.timezone.utc - ) - updated_row = list( - filter(lambda x: x[1] == "af1f803f2179869aeacb1bfe2b23c1df", updated_snapshot) - ) - - # Updated row has a current date - assert updated_row[0][2] == datetime.datetime( - 2099, 12, 31, 0, 0, tzinfo=datetime.timezone.utc - ) - - check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) - - -# This uses snapshot_meta_column_names, yaml-only snapshot def, -# and multiple keys -class BaseSnapshotMultiUniqueKey: - @pytest.fixture(scope="class") - def models(self): - return { - "seed.sql": model_seed_sql, - "snapshots.yml": snapshots_multi_key_yml, - "ref_snapshot.sql": ref_snapshot_sql, - } - - def test_multi_column_unique_key(self, project): - project.run_sql(create_multi_key_seed_sql) - project.run_sql(create_multi_key_snapshot_expected_sql) - project.run_sql(seed_multi_key_insert_sql) - project.run_sql(populate_multi_key_snapshot_expected_sql) - - results = run_dbt(["snapshot"]) - assert len(results) == 1 - - project.run_sql(invalidate_multi_key_sql_1) - project.run_sql(invalidate_multi_key_sql_2) - project.run_sql(update_multi_key_sql) - - results = run_dbt(["snapshot"]) - assert len(results) == 1 - - check_relations_equal(project.adapter, ["snapshot_actual", "snapshot_expected"]) - - -class TestDatabricksSnapshotColumnNames(BaseSnapshotColumnNames): - pass - - -class TestDatabricksSnapshotColumnNamesFromDbtProject(BaseSnapshotColumnNamesFromDbtProject): - pass - - -class TestDatabricksSnapshotInvalidColumnNames(BaseSnapshotInvalidColumnNames): - pass - - -class TestDatabricksSnapshotDbtValidToCurrent(BaseSnapshotDbtValidToCurrent): - pass - - -class TestDatabricksSnapshotMultiUniqueKey(BaseSnapshotMultiUniqueKey): - pass diff --git a/tests/unit/api_client/test_dlt_pipeline_api.py b/tests/unit/api_client/test_dlt_pipeline_api.py deleted file mode 100644 index 7dd1418e2..000000000 --- a/tests/unit/api_client/test_dlt_pipeline_api.py +++ /dev/null @@ -1,71 +0,0 @@ -import pytest -from dbt_common.exceptions import DbtRuntimeError - -from dbt.adapters.databricks.api_client import DltPipelineApi -from tests.unit.api_client.api_test_base import ApiTestBase - - -class TestDltPipelineApi(ApiTestBase): - @pytest.fixture - def api(self, session, host): - return DltPipelineApi(session, host, 1) - - @pytest.fixture - def pipeline_id(self): - return "pipeline_id" - - @pytest.fixture - def update_id(self): - return "update_id" - - def test_get_update_error__non_200(self, api, session, pipeline_id, update_id): - session.get.return_value.status_code = 500 - with pytest.raises(DbtRuntimeError): - api.get_update_error(pipeline_id, update_id) - - def test_get_update_error__200_no_events(self, api, session, pipeline_id, update_id): - session.get.return_value.status_code = 200 - session.get.return_value.json.return_value = {"events": []} - assert api.get_update_error(pipeline_id, update_id) == "" - - def test_get_update_error__200_no_error_events(self, api, session, pipeline_id, update_id): - session.get.return_value.status_code = 200 - session.get.return_value.json.return_value = { - "events": [{"event_type": "update_progress", "origin": {"update_id": update_id}}] - } - assert api.get_update_error(pipeline_id, update_id) == "" - - def test_get_update_error__200_error_events(self, api, session, pipeline_id, update_id): - session.get.return_value.status_code = 200 - session.get.return_value.json.return_value = { - "events": [ - { - "message": "I failed", - "details": {"update_progress": {"state": "FAILED"}}, - "event_type": "update_progress", - "origin": {"update_id": update_id}, - } - ] - } - assert api.get_update_error(pipeline_id, update_id) == "I failed" - - def test_poll_for_completion__non_200(self, api, session, pipeline_id): - self.assert_non_200_raises_error(lambda: api.poll_for_completion(pipeline_id), session) - - def test_poll_for_completion__200(self, api, session, host, pipeline_id): - session.get.return_value.status_code = 200 - session.get.return_value.json.return_value = {"state": "IDLE"} - api.poll_for_completion(pipeline_id) - session.get.assert_called_once_with( - f"https://{host}/api/2.0/pipelines/{pipeline_id}", json=None, params={} - ) - - def test_poll_for_completion__failed_with_cause(self, api, session, pipeline_id): - session.get.return_value.status_code = 200 - session.get.return_value.json.return_value = { - "state": "FAILED", - "pipeline_id": pipeline_id, - "cause": "I failed", - } - with pytest.raises(DbtRuntimeError, match=f"Pipeline {pipeline_id} failed: I failed"): - api.poll_for_completion(pipeline_id) diff --git a/tests/unit/api_client/test_workspace_api.py b/tests/unit/api_client/test_workspace_api.py index 208e3324e..322a91722 100644 --- a/tests/unit/api_client/test_workspace_api.py +++ b/tests/unit/api_client/test_workspace_api.py @@ -36,7 +36,7 @@ def test_upload_notebook__non_200(self, api, session): def test_upload_notebook__200(self, api, session, host): session.post.return_value.status_code = 200 - encoded = base64.b64encode(b"code").decode() + encoded = base64.b64encode("code".encode()).decode() api.upload_notebook("path", "code") session.post.assert_called_once_with( f"https://{host}/api/2.0/workspace/import", diff --git a/tests/unit/macros/relations/test_constraint_macros.py b/tests/unit/macros/relations/test_constraint_macros.py index feac1797b..351ca8cbb 100644 --- a/tests/unit/macros/relations/test_constraint_macros.py +++ b/tests/unit/macros/relations/test_constraint_macros.py @@ -1,8 +1,5 @@ -from unittest.mock import Mock - import pytest -from dbt.adapters.databricks.column import DatabricksColumn from tests.unit.macros.base import MacroTestBase @@ -19,7 +16,6 @@ def macro_folders_to_load(self) -> list: def modify_context(self, default_context) -> None: # Mock local_md5 default_context["local_md5"] = lambda s: f"hash({s})" - default_context["api"] = Mock(Column=DatabricksColumn) def render_constraints(self, template, *args): return self.run_macro(template, "databricks_constraints_to_dbt", *args) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 3a094c4a4..83688523b 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -11,6 +11,8 @@ from dbt.adapters.databricks.column import DatabricksColumn from dbt.adapters.databricks.credentials import ( CATALOG_KEY_IN_SESSION_PROPERTIES, + DBT_DATABRICKS_HTTP_SESSION_HEADERS, + DBT_DATABRICKS_INVOCATION_ENV, ) from dbt.adapters.databricks.impl import get_identifier_list_string from dbt.adapters.databricks.relation import DatabricksRelation, DatabricksRelationType @@ -112,10 +114,7 @@ def test_invalid_custom_user_agent(self): with pytest.raises(DbtValidationError) as excinfo: config = self._get_config() adapter = DatabricksAdapter(config, get_context("spawn")) - with patch( - "dbt.adapters.databricks.global_state.GlobalState.get_invocation_env", - return_value="(Some-thing)", - ): + with patch.dict("os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "(Some-thing)"}): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -129,9 +128,8 @@ def test_custom_user_agent(self): "dbt.adapters.databricks.connections.dbsql.connect", new=self._connect_func(expected_invocation_env="databricks-workflows"), ): - with patch( - "dbt.adapters.databricks.global_state.GlobalState.get_invocation_env", - return_value="databricks-workflows", + with patch.dict( + "os.environ", **{DBT_DATABRICKS_INVOCATION_ENV: "databricks-workflows"} ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -192,9 +190,9 @@ def _test_environment_http_headers( "dbt.adapters.databricks.connections.dbsql.connect", new=self._connect_func(expected_http_headers=expected_http_headers), ): - with patch( - "dbt.adapters.databricks.global_state.GlobalState.get_http_session_headers", - return_value=http_headers_str, + with patch.dict( + "os.environ", + **{DBT_DATABRICKS_HTTP_SESSION_HEADERS: http_headers_str}, ): connection = adapter.acquire_connection("dummy") connection.handle # trigger lazy-load @@ -345,15 +343,13 @@ def _test_databricks_sql_connector_http_header_connection(self, http_headers, co assert connection.credentials.token == "dapiXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX" assert connection.credentials.schema == "analytics" - @patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create") - def test_list_relations_without_caching__no_relations(self, _): + def test_list_relations_without_caching__no_relations(self): with patch.object(DatabricksAdapter, "get_relations_without_caching") as mocked: mocked.return_value = [] adapter = DatabricksAdapter(Mock(flags={}), get_context("spawn")) assert adapter.list_relations("database", "schema") == [] - @patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create") - def test_list_relations_without_caching__some_relations(self, _): + def test_list_relations_without_caching__some_relations(self): with patch.object(DatabricksAdapter, "get_relations_without_caching") as mocked: mocked.return_value = [("name", "table", "hudi", "owner")] adapter = DatabricksAdapter(Mock(flags={}), get_context("spawn")) @@ -367,8 +363,7 @@ def test_list_relations_without_caching__some_relations(self, _): assert relation.owner == "owner" assert relation.is_hudi - @patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create") - def test_list_relations_without_caching__hive_relation(self, _): + def test_list_relations_without_caching__hive_relation(self): with patch.object(DatabricksAdapter, "get_relations_without_caching") as mocked: mocked.return_value = [("name", "table", None, None)] adapter = DatabricksAdapter(Mock(flags={}), get_context("spawn")) @@ -381,8 +376,7 @@ def test_list_relations_without_caching__hive_relation(self, _): assert relation.type == DatabricksRelationType.Table assert not relation.has_information() - @patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create") - def test_get_schema_for_catalog__no_columns(self, _): + def test_get_schema_for_catalog__no_columns(self): with patch.object(DatabricksAdapter, "_list_relations_with_information") as list_info: list_info.return_value = [(Mock(), "info")] with patch.object(DatabricksAdapter, "_get_columns_for_catalog") as get_columns: @@ -391,8 +385,7 @@ def test_get_schema_for_catalog__no_columns(self, _): table = adapter._get_schema_for_catalog("database", "schema", "name") assert len(table.rows) == 0 - @patch("dbt.adapters.databricks.api_client.DatabricksApiClient.create") - def test_get_schema_for_catalog__some_columns(self, _): + def test_get_schema_for_catalog__some_columns(self): with patch.object(DatabricksAdapter, "_list_relations_with_information") as list_info: list_info.return_value = [(Mock(), "info")] with patch.object(DatabricksAdapter, "_get_columns_for_catalog") as get_columns: @@ -888,10 +881,7 @@ def test_describe_table_extended_2048_char_limit(self): assert get_identifier_list_string(table_names) == "|".join(table_names) # If environment variable is set, then limit the number of characters - with patch( - "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", - return_value="true", - ): + with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): # Long list of table names is capped assert get_identifier_list_string(table_names) == "*" @@ -920,10 +910,7 @@ def test_describe_table_extended_should_limit(self): table_names = set([f"customers_{i}" for i in range(200)]) # If environment variable is set, then limit the number of characters - with patch( - "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", - return_value="true", - ): + with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): # Long list of table names is capped assert get_identifier_list_string(table_names) == "*" @@ -936,10 +923,7 @@ def test_describe_table_extended_may_limit(self): table_names = set([f"customers_{i}" for i in range(200)]) # If environment variable is set, then we may limit the number of characters - with patch( - "dbt.adapters.databricks.global_state.GlobalState.get_char_limit_bypass", - return_value="true", - ): + with patch.dict("os.environ", **{"DBT_DESCRIBE_TABLE_2048_CHAR_BYPASS": "true"}): # But a short list of table names is not capped assert get_identifier_list_string(list(table_names)[:5]) == "|".join( list(table_names)[:5] diff --git a/tests/unit/test_auth.py b/tests/unit/test_auth.py index 90ba0f594..6571c9cb2 100644 --- a/tests/unit/test_auth.py +++ b/tests/unit/test_auth.py @@ -165,7 +165,7 @@ def get_password(self, servicename, username): if not os.path.exists(file_path): return None - with open(file_path) as file: + with open(file_path, "r") as file: password = file.read() return password diff --git a/tests/unit/test_column.py b/tests/unit/test_column.py index 95fbbee5c..0519b6a45 100644 --- a/tests/unit/test_column.py +++ b/tests/unit/test_column.py @@ -1,5 +1,3 @@ -import pytest - from dbt.adapters.databricks.column import DatabricksColumn diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 3e692a00a..81e772113 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,4 @@ -from dbt.adapters.databricks.utils import quote, redact_credentials, remove_ansi +from dbt.adapters.databricks.utils import redact_credentials, remove_ansi class TestDatabricksUtils: @@ -64,6 +64,3 @@ def test_remove_ansi(self): 72 # how to execute python model in notebook """ assert remove_ansi(test_string) == expected_string - - def test_quote(self): - assert quote("table") == "`table`"