Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Testbranch2 #917

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions dbt/adapters/databricks/api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

from dbt.adapters.databricks import utils
from dbt.adapters.databricks.__version__ import version
from dbt.adapters.databricks.auth import BearerAuth
from dbt.adapters.databricks.credentials import DatabricksCredentials
from dbt.adapters.databricks.credentials import BearerAuth, DatabricksCredentials
from dbt.adapters.databricks.logging import logger

DEFAULT_POLLING_INTERVAL = 10
Expand Down Expand Up @@ -557,8 +556,7 @@ def create(
http_headers = credentials.get_all_http_headers(
connection_parameters.pop("http_headers", {})
)
credentials_provider = credentials.authenticate(None)
header_factory = credentials_provider(None) # type: ignore
header_factory = credentials.authenticate().credentials_provider() # type: ignore
session.auth = BearerAuth(header_factory)

session.headers.update({"User-Agent": user_agent, **http_headers})
Expand Down
100 changes: 0 additions & 100 deletions dbt/adapters/databricks/auth.py

This file was deleted.

150 changes: 146 additions & 4 deletions dbt/adapters/databricks/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
)
from dbt.adapters.databricks.__version__ import version as __version__
from dbt.adapters.databricks.api_client import DatabricksApiClient
from dbt.adapters.databricks.credentials import DatabricksCredentials, TCredentialProvider
from dbt.adapters.databricks.credentials import (
BearerAuth,
DatabricksCredentialManager,
DatabricksCredentials,
)
from dbt.adapters.databricks.events.connection_events import (
ConnectionAcquire,
ConnectionCancel,
Expand Down Expand Up @@ -373,7 +377,7 @@ def _reset_handle(self, open: Callable[[Connection], Connection]) -> None:

class DatabricksConnectionManager(SparkConnectionManager):
TYPE: str = "databricks"
credentials_provider: Optional[TCredentialProvider] = None
credentials_manager: Optional[DatabricksCredentialManager] = None
_user_agent = f"dbt-databricks/{__version__}"

def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext):
Expand Down Expand Up @@ -581,6 +585,143 @@ def list_tables(self, database: str, schema: str, identifier: Optional[str] = No
),
)

@classmethod
def get_open_for_context(
cls, query_header_context: Any = None
) -> Callable[[Connection], Connection]:
# If there is no node we can simply return the exsting class method open.
# If there is a node create a closure that will call cls._open with the node.
if not query_header_context:
return cls.open

def open_for_model(connection: Connection) -> Connection:
return cls._open(connection, query_header_context)

return open_for_model

@classmethod
def open(cls, connection: Connection) -> Connection:
# Simply call _open with no ResultNode argument.
# Because this is an overridden method we can't just add
# a ResultNode parameter to open.
return cls._open(connection)

@classmethod
def _open(cls, connection: Connection, query_header_context: Any = None) -> Connection:
if connection.state == ConnectionState.OPEN:
return connection

creds: DatabricksCredentials = connection.credentials
timeout = creds.connect_timeout

# gotta keep this so we don't prompt users many times
cls.credentials_manager = creds.authenticate()

invocation_env = creds.get_invocation_env()
user_agent_entry = cls._user_agent
if invocation_env:
user_agent_entry = f"{cls._user_agent}; {invocation_env}"

connection_parameters = creds.connection_parameters.copy() # type: ignore[union-attr]

http_headers: list[tuple[str, str]] = list(
creds.get_all_http_headers(connection_parameters.pop("http_headers", {})).items()
)

# If a model specifies a compute resource the http path
# may be different than the http_path property of creds.
http_path = _get_http_path(query_header_context, creds)

def connect() -> DatabricksSQLConnectionWrapper:
assert cls.credentials_manager is not None
try:
# TODO: what is the error when a user specifies a catalog they don't have access to
conn: DatabricksSQLConnection = dbsql.connect(
server_hostname=creds.host,
http_path=http_path,
credentials_provider=cls.credentials_manager.credentials_provider,
http_headers=http_headers if http_headers else None,
session_configuration=creds.session_properties,
catalog=creds.database,
use_inline_params="silent",
# schema=creds.schema, # TODO: Explicitly set once DBR 7.3LTS is EOL.
_user_agent_entry=user_agent_entry,
**connection_parameters,
)
logger.debug(ConnectionCreated(str(conn)))

return DatabricksSQLConnectionWrapper(
conn,
is_cluster=creds.cluster_id is not None,
creds=creds,
user_agent=user_agent_entry,
)
except Error as exc:
logger.error(ConnectionCreateError(exc))
raise

def exponential_backoff(attempt: int) -> int:
return attempt * attempt

retryable_exceptions = []
# this option is for backwards compatibility
if creds.retry_all:
retryable_exceptions = [Error]

return cls.retry_connection(
connection,
connect=connect,
logger=logger,
retryable_exceptions=retryable_exceptions,
retry_limit=creds.connect_retries,
retry_timeout=(timeout if timeout is not None else exponential_backoff),
)

@classmethod
def get_response(cls, cursor: DatabricksSQLCursorWrapper) -> DatabricksAdapterResponse:
_query_id = getattr(cursor, "hex_query_id", None)
if cursor is None:
logger.debug("No cursor was provided. Query ID not available.")
query_id = "N/A"
else:
query_id = _query_id
message = "OK"
return DatabricksAdapterResponse(_message=message, query_id=query_id) # type: ignore


class ExtendedSessionConnectionManager(DatabricksConnectionManager):
def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext) -> None:
assert (
GlobalState.get_use_long_sessions()
), "This connection manager should only be used when USE_LONG_SESSIONS is enabled"
super().__init__(profile, mp_context)
self.threads_compute_connections: dict[
Hashable, dict[Hashable, DatabricksDBTConnection]
] = {}

def set_connection_name(
self, name: Optional[str] = None, query_header_context: Any = None
) -> Connection:
"""Called by 'acquire_connection' in DatabricksAdapter, which is called by
'connection_named', called by 'connection_for(node)'.
Creates a connection for this thread if one doesn't already
exist, and will rename an existing connection."""
self._cleanup_idle_connections()

conn_name: str = "master" if name is None else name

# Get a connection for this thread
conn = self._get_if_exists_compute_connection(_get_compute_name(query_header_context) or "")

if conn is None:
conn = self._create_compute_connection(conn_name, query_header_context)
else: # existing connection either wasn't open or didn't have the right name
conn = self._update_compute_connection(conn, conn_name)

conn._acquire(query_header_context)

return conn

# override
def release(self) -> None:
with self.lock:
Expand Down Expand Up @@ -634,7 +775,7 @@ def open(cls, connection: Connection) -> Connection:
timeout = creds.connect_timeout

# gotta keep this so we don't prompt users many times
cls.credentials_provider = creds.authenticate(cls.credentials_provider)
cls.credentials_manager = creds.authenticate()

invocation_env = creds.get_invocation_env()
user_agent_entry = cls._user_agent
Expand All @@ -652,12 +793,13 @@ def open(cls, connection: Connection) -> Connection:
http_path = databricks_connection.http_path

def connect() -> DatabricksSQLConnectionWrapper:
assert cls.credentials_manager is not None
try:
# TODO: what is the error when a user specifies a catalog they don't have access to
conn = dbsql.connect(
server_hostname=creds.host,
http_path=http_path,
credentials_provider=cls.credentials_provider,
credentials_provider=cls.credentials_manager.credentials_provider,
http_headers=http_headers if http_headers else None,
session_configuration=creds.session_properties,
catalog=creds.database,
Expand Down
Loading
Loading