diff --git a/dbt/adapters/databricks/connections.py b/dbt/adapters/databricks/connections.py index 887e31ad..b536fda6 100644 --- a/dbt/adapters/databricks/connections.py +++ b/dbt/adapters/databricks/connections.py @@ -428,11 +428,13 @@ def connect() -> DatabricksHandle: conn = DatabricksHandle.from_connection_args( conn_args, creds.cluster_id is not None ) - if conn.open: + if conn: databricks_connection.session_id = conn.session_id databricks_connection.last_used_time = time.time() - return conn + return conn + else: + raise DbtDatabaseError("Failed to create connection") except Error as exc: logger.error(ConnectionCreateError(exc)) raise diff --git a/dbt/adapters/databricks/handle.py b/dbt/adapters/databricks/handle.py index 8fcdc273..1b7d30db 100644 --- a/dbt/adapters/databricks/handle.py +++ b/dbt/adapters/databricks/handle.py @@ -2,7 +2,8 @@ import re import sys from collections.abc import Callable, Sequence -from typing import TYPE_CHECKING, Any, Optional +from types import TracebackType +from typing import TYPE_CHECKING, Any, Optional, TypeVar from dbt_common.exceptions import DbtRuntimeError @@ -69,23 +70,42 @@ def _cleanup( utils.handle_exceptions_as_warning(lambda: cleanup(self._cursor), failLog) def fetchall(self) -> Sequence[tuple]: - return self._cursor.fetchall() + return self._safe_execute(lambda cursor: cursor.fetchall()) def fetchone(self) -> Optional[tuple]: - return self._cursor.fetchone() + return self._safe_execute(lambda cursor: cursor.fetchone()) def fetchmany(self, size: int) -> Sequence[tuple]: - return self._cursor.fetchmany(size) + return self._safe_execute(lambda cursor: cursor.fetchmany(size)) def get_response(self) -> AdapterResponse: return AdapterResponse(_message="OK", query_id=self._cursor.query_id or "N/A") + T = TypeVar("T") + + def _safe_execute(self, f: Callable[[Cursor], T]) -> T: + if not self.open: + raise DbtRuntimeError("Attempting to execute on a closed cursor") + return f(self._cursor) + def __str__(self) -> str: return ( f"Cursor(session-id={self._cursor.connection.get_session_id_hex()}, " f"command-id={self._cursor.query_id})" ) + def __enter__(self) -> "CursorWrapper": + return self + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> bool: + self.close() + return exc_val is None + class DatabricksHandle: """ @@ -96,11 +116,11 @@ class DatabricksHandle: def __init__( self, - conn: Optional[Connection], + conn: Connection, is_cluster: bool, ): self._conn = conn - self.open = self._conn is not None + self.open = True self._cursor: Optional[CursorWrapper] = None self._dbr_version: Optional[tuple[int, int]] = None self._is_cluster = is_cluster @@ -128,9 +148,7 @@ def dbr_version(self) -> tuple[int, int]: @property def session_id(self) -> str: - if self._conn: - return self._conn.get_session_id_hex() - return "N/A" + return self._conn.get_session_id_hex() def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> CursorWrapper: """ @@ -160,7 +178,6 @@ def cancel(self) -> None: Cancel in progress query, if any, then close connection and cursor. """ self._cleanup( - lambda conn: conn.close() if conn else None, lambda cursor: cursor.cancel(), lambda: f"{self} - Cancelling", lambda ex: f"{self} - Exception while cancelling: {ex}", @@ -172,7 +189,6 @@ def close(self) -> None: """ self._cleanup( - lambda conn: conn.close() if conn else None, lambda cursor: cursor.close(), lambda: f"{self} - Closing", lambda ex: f"{self} - Exception while closing: {ex}", @@ -185,7 +201,9 @@ def rollback(self) -> None: logger.debug("NotImplemented: rollback") @staticmethod - def from_connection_args(conn_args: dict[str, Any], is_cluster: bool) -> "DatabricksHandle": + def from_connection_args( + conn_args: dict[str, Any], is_cluster: bool + ) -> Optional["DatabricksHandle"]: """ Create a new DatabricksHandle from the given connection arguments. """ @@ -193,16 +211,14 @@ def from_connection_args(conn_args: dict[str, Any], is_cluster: bool) -> "Databr conn = dbsql.connect(**conn_args) if not conn: logger.warning(f"Failed to create connection for {conn_args.get('http_path')}") - + return None connection = DatabricksHandle(conn, is_cluster=is_cluster) - logger.debug(f"{connection} - Created") return connection def _cleanup( self, - connect_op: ConnectionOp, cursor_op: CursorWrapperOp, startLog: LogOp, failLog: FailLogOp, @@ -217,7 +233,7 @@ def _cleanup( if self._cursor: cursor_op(self._cursor) - utils.handle_exceptions_as_warning(lambda: connect_op(self._conn), failLog) + utils.handle_exceptions_as_warning(lambda: self._conn.close(), failLog) def _safe_execute(self, f: CursorExecOp) -> CursorWrapper: """ diff --git a/dbt/adapters/databricks/utils.py b/dbt/adapters/databricks/utils.py index e50f3f1f..dccdd16c 100644 --- a/dbt/adapters/databricks/utils.py +++ b/dbt/adapters/databricks/utils.py @@ -85,6 +85,6 @@ def quote(name: str) -> str: def handle_exceptions_as_warning(op: Callable[[], None], log_gen: ExceptionToStrOp) -> None: try: - return op() + op() except Exception as e: logger.warning(log_gen(e)) diff --git a/tests/unit/test_adapter.py b/tests/unit/test_adapter.py index 61f8f29c..f8c3b4a1 100644 --- a/tests/unit/test_adapter.py +++ b/tests/unit/test_adapter.py @@ -267,6 +267,7 @@ def connect( assert http_headers is None else: assert http_headers == expected_http_headers + return Mock() return connect diff --git a/tests/unit/test_handle.py b/tests/unit/test_handle.py index f0878f0e..01326063 100644 --- a/tests/unit/test_handle.py +++ b/tests/unit/test_handle.py @@ -108,6 +108,20 @@ def test_get_response__with_query_id(self, cursor): wrapper = CursorWrapper(cursor) assert wrapper.get_response() == AdapterResponse("OK", query_id="id") + def test_with__no_exception(self, cursor): + with CursorWrapper(cursor) as c: + c.fetchone() + cursor.fetchone.assert_called_once() + cursor.close.assert_called_once() + + def test_with__exception(self, cursor): + cursor.fetchone.side_effect = Exception("foo") + with pytest.raises(Exception, match="foo"): + with CursorWrapper(cursor) as c: + c.fetchone() + cursor.fetchone.assert_called_once() + cursor.close.assert_called_once() + class TestDatabricksHandle: @pytest.fixture