Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
benc-db committed Feb 5, 2025
1 parent 601e877 commit 44a9da8
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 19 deletions.
6 changes: 4 additions & 2 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,11 +428,13 @@ def connect() -> DatabricksHandle:
conn = DatabricksHandle.from_connection_args(
conn_args, creds.cluster_id is not None
)
if conn.open:
if conn:
databricks_connection.session_id = conn.session_id
databricks_connection.last_used_time = time.time()

return conn
return conn
else:
raise DbtDatabaseError("Failed to create connection")
except Error as exc:
logger.error(ConnectionCreateError(exc))
raise
Expand Down
48 changes: 32 additions & 16 deletions dbt/adapters/databricks/handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import re
import sys
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, Optional
from types import TracebackType
from typing import TYPE_CHECKING, Any, Optional, TypeVar

from dbt_common.exceptions import DbtRuntimeError

Expand Down Expand Up @@ -69,23 +70,42 @@ def _cleanup(
utils.handle_exceptions_as_warning(lambda: cleanup(self._cursor), failLog)

def fetchall(self) -> Sequence[tuple]:
return self._cursor.fetchall()
return self._safe_execute(lambda cursor: cursor.fetchall())

def fetchone(self) -> Optional[tuple]:
return self._cursor.fetchone()
return self._safe_execute(lambda cursor: cursor.fetchone())

def fetchmany(self, size: int) -> Sequence[tuple]:
return self._cursor.fetchmany(size)
return self._safe_execute(lambda cursor: cursor.fetchmany(size))

def get_response(self) -> AdapterResponse:
return AdapterResponse(_message="OK", query_id=self._cursor.query_id or "N/A")

T = TypeVar("T")

def _safe_execute(self, f: Callable[[Cursor], T]) -> T:
if not self.open:
raise DbtRuntimeError("Attempting to execute on a closed cursor")
return f(self._cursor)

def __str__(self) -> str:
return (
f"Cursor(session-id={self._cursor.connection.get_session_id_hex()}, "
f"command-id={self._cursor.query_id})"
)

def __enter__(self) -> "CursorWrapper":
return self

def __exit__(
self,
exc_type: Optional[type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> bool:
self.close()
return exc_val is None


class DatabricksHandle:
"""
Expand All @@ -96,11 +116,11 @@ class DatabricksHandle:

def __init__(
self,
conn: Optional[Connection],
conn: Connection,
is_cluster: bool,
):
self._conn = conn
self.open = self._conn is not None
self.open = True
self._cursor: Optional[CursorWrapper] = None
self._dbr_version: Optional[tuple[int, int]] = None
self._is_cluster = is_cluster
Expand Down Expand Up @@ -128,9 +148,7 @@ def dbr_version(self) -> tuple[int, int]:

@property
def session_id(self) -> str:
if self._conn:
return self._conn.get_session_id_hex()
return "N/A"
return self._conn.get_session_id_hex()

def execute(self, sql: str, bindings: Optional[Sequence[Any]] = None) -> CursorWrapper:
"""
Expand Down Expand Up @@ -160,7 +178,6 @@ def cancel(self) -> None:
Cancel in progress query, if any, then close connection and cursor.
"""
self._cleanup(
lambda conn: conn.close() if conn else None,
lambda cursor: cursor.cancel(),
lambda: f"{self} - Cancelling",
lambda ex: f"{self} - Exception while cancelling: {ex}",
Expand All @@ -172,7 +189,6 @@ def close(self) -> None:
"""

self._cleanup(
lambda conn: conn.close() if conn else None,
lambda cursor: cursor.close(),
lambda: f"{self} - Closing",
lambda ex: f"{self} - Exception while closing: {ex}",
Expand All @@ -185,24 +201,24 @@ def rollback(self) -> None:
logger.debug("NotImplemented: rollback")

@staticmethod
def from_connection_args(conn_args: dict[str, Any], is_cluster: bool) -> "DatabricksHandle":
def from_connection_args(
conn_args: dict[str, Any], is_cluster: bool
) -> Optional["DatabricksHandle"]:
"""
Create a new DatabricksHandle from the given connection arguments.
"""

conn = dbsql.connect(**conn_args)
if not conn:
logger.warning(f"Failed to create connection for {conn_args.get('http_path')}")

return None
connection = DatabricksHandle(conn, is_cluster=is_cluster)

logger.debug(f"{connection} - Created")

return connection

def _cleanup(
self,
connect_op: ConnectionOp,
cursor_op: CursorWrapperOp,
startLog: LogOp,
failLog: FailLogOp,
Expand All @@ -217,7 +233,7 @@ def _cleanup(
if self._cursor:
cursor_op(self._cursor)

utils.handle_exceptions_as_warning(lambda: connect_op(self._conn), failLog)
utils.handle_exceptions_as_warning(lambda: self._conn.close(), failLog)

def _safe_execute(self, f: CursorExecOp) -> CursorWrapper:
"""
Expand Down
2 changes: 1 addition & 1 deletion dbt/adapters/databricks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,6 @@ def quote(name: str) -> str:

def handle_exceptions_as_warning(op: Callable[[], None], log_gen: ExceptionToStrOp) -> None:
try:
return op()
op()
except Exception as e:
logger.warning(log_gen(e))
1 change: 1 addition & 0 deletions tests/unit/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def connect(
assert http_headers is None
else:
assert http_headers == expected_http_headers
return Mock()

return connect

Expand Down
14 changes: 14 additions & 0 deletions tests/unit/test_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,20 @@ def test_get_response__with_query_id(self, cursor):
wrapper = CursorWrapper(cursor)
assert wrapper.get_response() == AdapterResponse("OK", query_id="id")

def test_with__no_exception(self, cursor):
with CursorWrapper(cursor) as c:
c.fetchone()
cursor.fetchone.assert_called_once()
cursor.close.assert_called_once()

def test_with__exception(self, cursor):
cursor.fetchone.side_effect = Exception("foo")
with pytest.raises(Exception, match="foo"):
with CursorWrapper(cursor) as c:
c.fetchone()
cursor.fetchone.assert_called_once()
cursor.close.assert_called_once()


class TestDatabricksHandle:
@pytest.fixture
Expand Down

0 comments on commit 44a9da8

Please sign in to comment.