From 230ed009f90f9cdb497736a1e81ad611c9ec790b Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 7 Oct 2024 11:55:55 +0200 Subject: [PATCH] first work --- databricks/sdk/_base_client.py | 343 +++++++++++++++++++++++++ databricks/sdk/config.py | 49 +--- databricks/sdk/core.py | 318 ++--------------------- databricks/sdk/credentials_provider.py | 15 +- databricks/sdk/oauth.py | 213 ++++++++++----- examples/flask_app_with_oauth.py | 46 ++-- tests/fixture_server.py | 33 +++ tests/test_base_client.py | 282 ++++++++++++++++++++ tests/test_core.py | 279 +------------------- tests/test_oauth.py | 131 +++++++--- 10 files changed, 988 insertions(+), 721 deletions(-) create mode 100644 databricks/sdk/_base_client.py create mode 100644 tests/fixture_server.py create mode 100644 tests/test_base_client.py diff --git a/databricks/sdk/_base_client.py b/databricks/sdk/_base_client.py new file mode 100644 index 000000000..7734824be --- /dev/null +++ b/databricks/sdk/_base_client.py @@ -0,0 +1,343 @@ +import logging +from datetime import timedelta +from types import TracebackType +from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List, + Optional, Type, Union) +import urllib.parse + +import requests +import requests.adapters + +from . import useragent +from .casing import Casing +from .clock import Clock, RealClock +from .errors import DatabricksError, _ErrorCustomizer, _Parser +from .logger import RoundTrip +from .retries import retried + +logger = logging.getLogger('databricks.sdk') + + +def fix_host_if_needed(host: Optional[str]) -> Optional[str]: + if not host: + return host + + # Add a default scheme if it's missing + if '://' not in host: + host = 'https://' + host + + o = urllib.parse.urlparse(host) + # remove trailing slash + path = o.path.rstrip('/') + # remove port if 443 + netloc = o.netloc + if o.port == 443: + netloc = netloc.split(':')[0] + + return urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) + + +class _BaseClient: + + def __init__(self, + debug_truncate_bytes: int = None, + retry_timeout_seconds: int = None, + user_agent_base: str = None, + header_factory: Callable[[], dict] = None, + max_connection_pools: int = None, + max_connections_per_pool: int = None, + pool_block: bool = True, + http_timeout_seconds: float = None, + extra_error_customizers: List[_ErrorCustomizer] = None, + debug_headers: bool = False, + clock: Clock = None): + """ + :param debug_truncate_bytes: + :param retry_timeout_seconds: + :param user_agent_base: + :param header_factory: A function that returns a dictionary of headers to include in the request. + :param max_connection_pools: Number of urllib3 connection pools to cache before discarding the least + recently used pool. Python requests default value is 10. + :param max_connections_per_pool: The maximum number of connections to save in the pool. Improves performance + in multithreaded situations. For now, we're setting it to the same value as connection_pool_size. + :param pool_block: If pool_block is False, then more connections will are created, but not saved after the + first use. Blocks when no free connections are available. urllib3 ensures that no more than + pool_maxsize connections are used at a time. Prevents platform from flooding. By default, requests library + doesn't block. + :param http_timeout_seconds: + :param extra_error_customizers: + :param debug_headers: Whether to include debug headers in the request log. + :param clock: Clock object to use for time-related operations. + """ + + self._debug_truncate_bytes = debug_truncate_bytes or 96 + self._debug_headers = debug_headers + self._retry_timeout_seconds = retry_timeout_seconds or 300 + self._user_agent_base = user_agent_base or useragent.to_string() + self._header_factory = header_factory + self._clock = clock or RealClock() + self._session = requests.Session() + self._session.auth = self._authenticate + + # We don't use `max_retries` from HTTPAdapter to align with a more production-ready + # retry strategy established in the Databricks SDK for Go. See _is_retryable and + # @retried for more details. + http_adapter = requests.adapters.HTTPAdapter(pool_connections=max_connections_per_pool or 20, + pool_maxsize=max_connection_pools or 20, + pool_block=pool_block) + self._session.mount("https://", http_adapter) + + # Default to 60 seconds + self._http_timeout_seconds = http_timeout_seconds or 60 + + self._error_parser = _Parser(extra_error_customizers=extra_error_customizers) + + def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest: + if self._header_factory: + headers = self._header_factory() + for k, v in headers.items(): + r.headers[k] = v + return r + + @staticmethod + def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]: + # Convert True -> "true" for Databricks APIs to understand booleans. + # See: https://github.com/databricks/databricks-sdk-py/issues/142 + if query is None: + return None + with_fixed_bools = {k: v if type(v) != bool else ('true' if v else 'false') for k, v in query.items()} + + # Query parameters may be nested, e.g. + # {'filter_by': {'user_ids': [123, 456]}} + # The HTTP-compatible representation of this is + # filter_by.user_ids=123&filter_by.user_ids=456 + # To achieve this, we convert the above dictionary to + # {'filter_by.user_ids': [123, 456]} + # See the following for more information: + # https://cloud.google.com/endpoints/docs/grpc-service-config/reference/rpc/google.api#google.api.HttpRule + def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: + for k1, v1 in d.items(): + if isinstance(v1, dict): + v1 = dict(flatten_dict(v1)) + for k2, v2 in v1.items(): + yield f"{k1}.{k2}", v2 + else: + yield k1, v1 + + flattened = dict(flatten_dict(with_fixed_bools)) + return flattened + + def do(self, + method: str, + url: str, + query: dict = None, + headers: dict = None, + body: dict = None, + raw: bool = False, + files=None, + data=None, + auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None, + response_headers: List[str] = None) -> Union[dict, list, BinaryIO]: + if headers is None: + headers = {} + headers['User-Agent'] = self._user_agent_base + retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds), + is_retryable=self._is_retryable, + clock=self._clock) + response = retryable(self._perform)(method, + url, + query=query, + headers=headers, + body=body, + raw=raw, + files=files, + data=data, + auth=auth) + + resp = dict() + for header in response_headers if response_headers else []: + resp[header] = response.headers.get(Casing.to_header_case(header)) + if raw: + resp["contents"] = _StreamingResponse(response) + return resp + if not len(response.content): + return resp + + json_response = response.json() + if json_response is None: + return resp + + if isinstance(json_response, list): + return json_response + + return {**resp, **json_response} + + @staticmethod + def _is_retryable(err: BaseException) -> Optional[str]: + # this method is Databricks-specific port of urllib3 retries + # (see https://github.com/urllib3/urllib3/blob/main/src/urllib3/util/retry.py) + # and Databricks SDK for Go retries + # (see https://github.com/databricks/databricks-sdk-go/blob/main/apierr/errors.go) + from urllib3.exceptions import ProxyError + if isinstance(err, ProxyError): + err = err.original_error + if isinstance(err, requests.ConnectionError): + # corresponds to `connection reset by peer` and `connection refused` errors from Go, + # which are generally related to the temporary glitches in the networking stack, + # also caused by endpoint protection software, like ZScaler, to drop connections while + # not yet authenticated. + # + # return a simple string for debug log readability, as `raise TimeoutError(...) from err` + # will bubble up the original exception in case we reach max retries. + return f'cannot connect' + if isinstance(err, requests.Timeout): + # corresponds to `TLS handshake timeout` and `i/o timeout` in Go. + # + # return a simple string for debug log readability, as `raise TimeoutError(...) from err` + # will bubble up the original exception in case we reach max retries. + return f'timeout' + if isinstance(err, DatabricksError): + message = str(err) + transient_error_string_matches = [ + "com.databricks.backend.manager.util.UnknownWorkerEnvironmentException", + "does not have any associated worker environments", "There is no worker environment with id", + "Unknown worker environment", "ClusterNotReadyException", "Unexpected error", + "Please try again later or try a faster operation.", + "RPC token bucket limit has been exceeded", + ] + for substring in transient_error_string_matches: + if substring not in message: + continue + return f'matched {substring}' + return None + + def _perform(self, + method: str, + url: str, + query: dict = None, + headers: dict = None, + body: dict = None, + raw: bool = False, + files=None, + data=None, + auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None): + response = self._session.request(method, + url, + params=self._fix_query_string(query), + json=body, + headers=headers, + files=files, + data=data, + auth=auth, + stream=raw, + timeout=self._http_timeout_seconds) + self._record_request_log(response, raw=raw or data is not None or files is not None) + error = self._error_parser.get_api_error(response) + if error is not None: + raise error from None + return response + + def _record_request_log(self, response: requests.Response, raw: bool = False) -> None: + if not logger.isEnabledFor(logging.DEBUG): + return + logger.debug(RoundTrip(response, self._debug_headers, self._debug_truncate_bytes, raw).generate()) + + +class _StreamingResponse(BinaryIO): + _response: requests.Response + _buffer: bytes + _content: Union[Iterator[bytes], None] + _chunk_size: Union[int, None] + _closed: bool = False + + def fileno(self) -> int: + pass + + def flush(self) -> int: + pass + + def __init__(self, response: requests.Response, chunk_size: Union[int, None] = None): + self._response = response + self._buffer = b'' + self._content = None + self._chunk_size = chunk_size + + def _open(self) -> None: + if self._closed: + raise ValueError("I/O operation on closed file") + if not self._content: + self._content = self._response.iter_content(chunk_size=self._chunk_size) + + def __enter__(self) -> BinaryIO: + self._open() + return self + + def set_chunk_size(self, chunk_size: Union[int, None]) -> None: + self._chunk_size = chunk_size + + def close(self) -> None: + self._response.close() + self._closed = True + + def isatty(self) -> bool: + return False + + def read(self, n: int = -1) -> bytes: + self._open() + read_everything = n < 0 + remaining_bytes = n + res = b'' + while remaining_bytes > 0 or read_everything: + if len(self._buffer) == 0: + try: + self._buffer = next(self._content) + except StopIteration: + break + bytes_available = len(self._buffer) + to_read = bytes_available if read_everything else min(remaining_bytes, bytes_available) + res += self._buffer[:to_read] + self._buffer = self._buffer[to_read:] + remaining_bytes -= to_read + return res + + def readable(self) -> bool: + return self._content is not None + + def readline(self, __limit: int = ...) -> bytes: + raise NotImplementedError() + + def readlines(self, __hint: int = ...) -> List[bytes]: + raise NotImplementedError() + + def seek(self, __offset: int, __whence: int = ...) -> int: + raise NotImplementedError() + + def seekable(self) -> bool: + return False + + def tell(self) -> int: + raise NotImplementedError() + + def truncate(self, __size: Union[int, None] = ...) -> int: + raise NotImplementedError() + + def writable(self) -> bool: + return False + + def write(self, s: Union[bytes, bytearray]) -> int: + raise NotImplementedError() + + def writelines(self, lines: Iterable[bytes]) -> None: + raise NotImplementedError() + + def __next__(self) -> bytes: + return self.read(1) + + def __iter__(self) -> Iterator[bytes]: + return self._content + + def __exit__(self, t: Union[Type[BaseException], None], value: Union[BaseException, None], + traceback: Union[TracebackType, None]) -> None: + self._content = None + self._buffer = b'' + self.close() diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 5cae1b2b4..65bf3225e 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -14,7 +14,10 @@ from .credentials_provider import CredentialsStrategy, DefaultCredentials from .environments import (ALL_ENVS, AzureEnvironment, Cloud, DatabricksEnvironment, get_environment_for_hostname) -from .oauth import OidcEndpoints, Token +from .oauth import (OidcEndpoints, Token, get_account_endpoints, + get_azure_entra_id_workspace_endpoints, + get_workspace_endpoints) +from ._base_client import fix_host_if_needed logger = logging.getLogger('databricks.sdk') @@ -118,7 +121,9 @@ def __init__(self, self._set_inner_config(kwargs) self._load_from_env() self._known_file_config_loader() - self._fix_host_if_needed() + updated_host = fix_host_if_needed(self.host) + if updated_host: + self.host = updated_host self._validate() self.init_auth() self._init_product(product, product_version) @@ -250,28 +255,14 @@ def with_user_agent_extra(self, key: str, value: str) -> 'Config': @property def oidc_endpoints(self) -> Optional[OidcEndpoints]: - self._fix_host_if_needed() + self.host = fix_host_if_needed(self.host) if not self.host: return None if self.is_azure and self.azure_client_id: - # Retrieve authorize endpoint to retrieve token endpoint after - res = requests.get(f'{self.host}/oidc/oauth2/v2.0/authorize', allow_redirects=False) - real_auth_url = res.headers.get('location') - if not real_auth_url: - return None - return OidcEndpoints(authorization_endpoint=real_auth_url, - token_endpoint=real_auth_url.replace('/authorize', '/token')) + return get_azure_entra_id_workspace_endpoints(self.host) if self.is_account_client and self.account_id: - prefix = f'{self.host}/oidc/accounts/{self.account_id}' - return OidcEndpoints(authorization_endpoint=f'{prefix}/v1/authorize', - token_endpoint=f'{prefix}/v1/token') - oidc = f'{self.host}/oidc/.well-known/oauth-authorization-server' - res = requests.get(oidc) - if res.status_code != 200: - return None - auth_metadata = res.json() - return OidcEndpoints(authorization_endpoint=auth_metadata.get('authorization_endpoint'), - token_endpoint=auth_metadata.get('token_endpoint')) + return get_account_endpoints(self.host, self.account_id) + return get_workspace_endpoints(self.host) def debug_string(self) -> str: """ Returns log-friendly representation of configured attributes """ @@ -345,24 +336,6 @@ def attributes(cls) -> Iterable[ConfigAttribute]: cls._attributes = attrs return cls._attributes - def _fix_host_if_needed(self): - if not self.host: - return - - # Add a default scheme if it's missing - if '://' not in self.host: - self.host = 'https://' + self.host - - o = urllib.parse.urlparse(self.host) - # remove trailing slash - path = o.path.rstrip('/') - # remove port if 443 - netloc = o.netloc - if o.port == 443: - netloc = netloc.split(':')[0] - - self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) - def load_azure_tenant_id(self): """[Internal] Load the Azure tenant ID from the Azure Databricks login page. diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index 77e8c9aac..c9e49dc81 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -1,19 +1,13 @@ import re -from datetime import timedelta -from types import TracebackType -from typing import Any, BinaryIO, Iterator, Type +from typing import BinaryIO from urllib.parse import urlencode -from requests.adapters import HTTPAdapter - -from .casing import Casing +from ._base_client import _BaseClient from .config import * # To preserve backwards compatibility (as these definitions were previously in this module) from .credentials_provider import * -from .errors import DatabricksError, _ErrorCustomizer, _Parser -from .logger import RoundTrip +from .errors import DatabricksError, _ErrorCustomizer from .oauth import retrieve_token -from .retries import retried __all__ = ['Config', 'DatabricksError'] @@ -24,54 +18,21 @@ OIDC_TOKEN_PATH = "/oidc/v1/token" -class ApiClient: - _cfg: Config - _RETRY_AFTER_DEFAULT: int = 1 - - def __init__(self, cfg: Config = None): - if cfg is None: - cfg = Config() +class ApiClient: + def __init__(self, cfg: Config): self._cfg = cfg - # See https://github.com/databricks/databricks-sdk-go/blob/main/client/client.go#L34-L35 - self._debug_truncate_bytes = cfg.debug_truncate_bytes if cfg.debug_truncate_bytes else 96 - self._retry_timeout_seconds = cfg.retry_timeout_seconds if cfg.retry_timeout_seconds else 300 - self._user_agent_base = cfg.user_agent - self._session = requests.Session() - self._session.auth = self._authenticate - - # Number of urllib3 connection pools to cache before discarding the least - # recently used pool. Python requests default value is 10. - pool_connections = cfg.max_connection_pools - if pool_connections is None: - pool_connections = 20 - - # The maximum number of connections to save in the pool. Improves performance - # in multithreaded situations. For now, we're setting it to the same value - # as connection_pool_size. - pool_maxsize = cfg.max_connections_per_pool - if cfg.max_connections_per_pool is None: - pool_maxsize = pool_connections - - # If pool_block is False, then more connections will are created, - # but not saved after the first use. Blocks when no free connections are available. - # urllib3 ensures that no more than pool_maxsize connections are used at a time. - # Prevents platform from flooding. By default, requests library doesn't block. - pool_block = True - - # We don't use `max_retries` from HTTPAdapter to align with a more production-ready - # retry strategy established in the Databricks SDK for Go. See _is_retryable and - # @retried for more details. - http_adapter = HTTPAdapter(pool_connections=pool_connections, - pool_maxsize=pool_maxsize, - pool_block=pool_block) - self._session.mount("https://", http_adapter) - - # Default to 60 seconds - self._http_timeout_seconds = cfg.http_timeout_seconds if cfg.http_timeout_seconds else 60 - - self._error_parser = _Parser(extra_error_customizers=[_AddDebugErrorCustomizer(cfg)]) + self._api_client = _BaseClient(debug_truncate_bytes=cfg.debug_truncate_bytes, + retry_timeout_seconds=cfg.retry_timeout_seconds, + user_agent_base=cfg.user_agent, + header_factory=cfg.authenticate, + max_connection_pools=cfg.max_connection_pools, + max_connections_per_pool=cfg.max_connections_per_pool, + pool_block=True, + http_timeout_seconds=cfg.http_timeout_seconds, + extra_error_customizers=[_AddDebugErrorCustomizer(cfg)], + clock=cfg.clock) @property def account_id(self) -> str: @@ -81,40 +42,6 @@ def account_id(self) -> str: def is_account_client(self) -> bool: return self._cfg.is_account_client - def _authenticate(self, r: requests.PreparedRequest) -> requests.PreparedRequest: - headers = self._cfg.authenticate() - for k, v in headers.items(): - r.headers[k] = v - return r - - @staticmethod - def _fix_query_string(query: Optional[dict] = None) -> Optional[dict]: - # Convert True -> "true" for Databricks APIs to understand booleans. - # See: https://github.com/databricks/databricks-sdk-py/issues/142 - if query is None: - return None - with_fixed_bools = {k: v if type(v) != bool else ('true' if v else 'false') for k, v in query.items()} - - # Query parameters may be nested, e.g. - # {'filter_by': {'user_ids': [123, 456]}} - # The HTTP-compatible representation of this is - # filter_by.user_ids=123&filter_by.user_ids=456 - # To achieve this, we convert the above dictionary to - # {'filter_by.user_ids': [123, 456]} - # See the following for more information: - # https://cloud.google.com/endpoints/docs/grpc-service-config/reference/rpc/google.api#google.api.HttpRule - def flatten_dict(d: Dict[str, Any]) -> Dict[str, Any]: - for k1, v1 in d.items(): - if isinstance(v1, dict): - v1 = dict(flatten_dict(v1)) - for k2, v2 in v1.items(): - yield f"{k1}.{k2}", v2 - else: - yield k1, v1 - - flattened = dict(flatten_dict(with_fixed_bools)) - return flattened - def get_oauth_token(self, auth_details: str) -> Token: if not self._cfg.auth_type: self._cfg.authenticate() @@ -142,115 +69,22 @@ def do(self, files=None, data=None, auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None, - response_headers: List[str] = None) -> Union[dict, BinaryIO]: - if headers is None: - headers = {} + response_headers: List[str] = None) -> Union[dict, list, BinaryIO]: if url is None: # Remove extra `/` from path for Files API # Once we've fixed the OpenAPI spec, we can remove this path = re.sub('^/api/2.0/fs/files//', '/api/2.0/fs/files/', path) url = f"{self._cfg.host}{path}" - headers['User-Agent'] = self._user_agent_base - retryable = retried(timeout=timedelta(seconds=self._retry_timeout_seconds), - is_retryable=self._is_retryable, - clock=self._cfg.clock) - response = retryable(self._perform)(method, - url, - query=query, - headers=headers, - body=body, - raw=raw, - files=files, - data=data, - auth=auth) - - resp = dict() - for header in response_headers if response_headers else []: - resp[header] = response.headers.get(Casing.to_header_case(header)) - if raw: - resp["contents"] = StreamingResponse(response) - return resp - if not len(response.content): - return resp - - jsonResponse = response.json() - if jsonResponse is None: - return resp - - if isinstance(jsonResponse, list): - return jsonResponse - - return {**resp, **jsonResponse} - - @staticmethod - def _is_retryable(err: BaseException) -> Optional[str]: - # this method is Databricks-specific port of urllib3 retries - # (see https://github.com/urllib3/urllib3/blob/main/src/urllib3/util/retry.py) - # and Databricks SDK for Go retries - # (see https://github.com/databricks/databricks-sdk-go/blob/main/apierr/errors.go) - from urllib3.exceptions import ProxyError - if isinstance(err, ProxyError): - err = err.original_error - if isinstance(err, requests.ConnectionError): - # corresponds to `connection reset by peer` and `connection refused` errors from Go, - # which are generally related to the temporary glitches in the networking stack, - # also caused by endpoint protection software, like ZScaler, to drop connections while - # not yet authenticated. - # - # return a simple string for debug log readability, as `raise TimeoutError(...) from err` - # will bubble up the original exception in case we reach max retries. - return f'cannot connect' - if isinstance(err, requests.Timeout): - # corresponds to `TLS handshake timeout` and `i/o timeout` in Go. - # - # return a simple string for debug log readability, as `raise TimeoutError(...) from err` - # will bubble up the original exception in case we reach max retries. - return f'timeout' - if isinstance(err, DatabricksError): - message = str(err) - transient_error_string_matches = [ - "com.databricks.backend.manager.util.UnknownWorkerEnvironmentException", - "does not have any associated worker environments", "There is no worker environment with id", - "Unknown worker environment", "ClusterNotReadyException", "Unexpected error", - "Please try again later or try a faster operation.", - "RPC token bucket limit has been exceeded", - ] - for substring in transient_error_string_matches: - if substring not in message: - continue - return f'matched {substring}' - return None - - def _perform(self, - method: str, - url: str, - query: dict = None, - headers: dict = None, - body: dict = None, - raw: bool = False, - files=None, - data=None, - auth: Callable[[requests.PreparedRequest], requests.PreparedRequest] = None): - response = self._session.request(method, - url, - params=self._fix_query_string(query), - json=body, - headers=headers, - files=files, - data=data, - auth=auth, - stream=raw, - timeout=self._http_timeout_seconds) - self._record_request_log(response, raw=raw or data is not None or files is not None) - error = self._error_parser.get_api_error(response) - if error is not None: - raise error from None - return response - - def _record_request_log(self, response: requests.Response, raw: bool = False) -> None: - if not logger.isEnabledFor(logging.DEBUG): - return - logger.debug(RoundTrip(response, self._cfg.debug_headers, self._debug_truncate_bytes, raw).generate()) + return self._api_client.do(method=method, + url=url, + query=query, + headers=headers, + body=body, + raw=raw, + files=files, + data=data, + auth=auth, + response_headers=response_headers) class _AddDebugErrorCustomizer(_ErrorCustomizer): @@ -264,103 +98,3 @@ def customize_error(self, response: requests.Response, kwargs: dict): if response.status_code in (401, 403): message = kwargs.get('message', 'request failed') kwargs['message'] = self._cfg.wrap_debug_info(message) - - -class StreamingResponse(BinaryIO): - _response: requests.Response - _buffer: bytes - _content: Union[Iterator[bytes], None] - _chunk_size: Union[int, None] - _closed: bool = False - - def fileno(self) -> int: - pass - - def flush(self) -> int: - pass - - def __init__(self, response: requests.Response, chunk_size: Union[int, None] = None): - self._response = response - self._buffer = b'' - self._content = None - self._chunk_size = chunk_size - - def _open(self) -> None: - if self._closed: - raise ValueError("I/O operation on closed file") - if not self._content: - self._content = self._response.iter_content(chunk_size=self._chunk_size) - - def __enter__(self) -> BinaryIO: - self._open() - return self - - def set_chunk_size(self, chunk_size: Union[int, None]) -> None: - self._chunk_size = chunk_size - - def close(self) -> None: - self._response.close() - self._closed = True - - def isatty(self) -> bool: - return False - - def read(self, n: int = -1) -> bytes: - self._open() - read_everything = n < 0 - remaining_bytes = n - res = b'' - while remaining_bytes > 0 or read_everything: - if len(self._buffer) == 0: - try: - self._buffer = next(self._content) - except StopIteration: - break - bytes_available = len(self._buffer) - to_read = bytes_available if read_everything else min(remaining_bytes, bytes_available) - res += self._buffer[:to_read] - self._buffer = self._buffer[to_read:] - remaining_bytes -= to_read - return res - - def readable(self) -> bool: - return self._content is not None - - def readline(self, __limit: int = ...) -> bytes: - raise NotImplementedError() - - def readlines(self, __hint: int = ...) -> List[bytes]: - raise NotImplementedError() - - def seek(self, __offset: int, __whence: int = ...) -> int: - raise NotImplementedError() - - def seekable(self) -> bool: - return False - - def tell(self) -> int: - raise NotImplementedError() - - def truncate(self, __size: Union[int, None] = ...) -> int: - raise NotImplementedError() - - def writable(self) -> bool: - return False - - def write(self, s: Union[bytes, bytearray]) -> int: - raise NotImplementedError() - - def writelines(self, lines: Iterable[bytes]) -> None: - raise NotImplementedError() - - def __next__(self) -> bytes: - return self.read(1) - - def __iter__(self) -> Iterator[bytes]: - return self._content - - def __exit__(self, t: Union[Type[BaseException], None], value: Union[BaseException, None], - traceback: Union[TracebackType, None]) -> None: - self._content = None - self._buffer = b'' - self.close() diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 8c1655af1..ef4f48adc 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -197,19 +197,24 @@ def external_browser(cfg: 'Config') -> Optional[CredentialsProvider]: client_id = '6128a518-99a9-425b-8333-4cc94f04cacd' else: raise ValueError(f'local browser SSO is not supported') - oauth_client = OAuthClient(host=cfg.host, - client_id=client_id, - redirect_url='http://localhost:8020', - client_secret=cfg.client_secret) # Load cached credentials from disk if they exist. # Note that these are local to the Python SDK and not reused by other SDKs. - token_cache = TokenCache(oauth_client) + oidc_endpoints = cfg.oidc_endpoints + token_cache = TokenCache(host=cfg.host, + oidc_endpoints=oidc_endpoints, + client_id=client_id, + client_secret=cfg.client_secret, + redirect_url='http://localhost:8020') credentials = token_cache.load() if credentials: # Force a refresh in case the loaded credentials are expired. credentials.token() else: + oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints, + client_id=client_id, + redirect_url='http://localhost:8020', + client_secret=cfg.client_secret) consent = oauth_client.initiate_consent() if not consent: return None diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index e9a3afb90..c4279dea0 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -17,6 +17,8 @@ import requests import requests.auth +from ._base_client import _BaseClient, fix_host_if_needed + # Error code for PKCE flow in Azure Active Directory, that gets additional retry. # See https://stackoverflow.com/a/75466778/277035 for more info NO_ORIGIN_FOR_SPA_CLIENT_ERROR = 'AADSTS9002327' @@ -46,8 +48,24 @@ def __call__(self, r): @dataclass class OidcEndpoints: + """ + The endpoints used for OAuth-based authentication in Databricks. + """ + authorization_endpoint: str # ../v1/authorize + """The authorization endpoint for the OAuth flow. The user-agent should be directed to this endpoint in order for + the user to login and authorize the client for user-to-machine (U2M) flows.""" + token_endpoint: str # ../v1/token + """The token endpoint for the OAuth flow.""" + + @staticmethod + def from_dict(d: dict) -> 'OidcEndpoints': + return OidcEndpoints(authorization_endpoint=d.get('authorization_endpoint'), + token_endpoint=d.get('token_endpoint')) + + def as_dict(self) -> dict: + return {'authorization_endpoint': self.authorization_endpoint, 'token_endpoint': self.token_endpoint} @dataclass @@ -220,18 +238,76 @@ def do_GET(self): self.wfile.write(b'You can close this tab.') +def get_account_endpoints(host: str, account_id: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints: + """ + Get the OIDC endpoints for a given account. + :param host: The Databricks account host. + :param account_id: The account ID. + :return: The account's OIDC endpoints. + """ + host = fix_host_if_needed(host) + oidc = f'{host}/oidc/accounts/{account_id}/.well-known/oauth-authorization-server' + resp = client.do('GET', oidc) + return OidcEndpoints.from_dict(resp) + + +def get_workspace_endpoints(host: str, client: _BaseClient = _BaseClient()) -> OidcEndpoints: + """ + Get the OIDC endpoints for a given workspace. + :param host: The Databricks workspace host. + :return: The workspace's OIDC endpoints. + """ + host = fix_host_if_needed(host) + oidc = f'{host}/.well-known/oauth-authorization-server' + resp = client.do('GET', oidc) + return OidcEndpoints.from_dict(resp) + + +def get_azure_entra_id_workspace_endpoints(host: str) -> Optional[OidcEndpoints]: + """ + Get the Azure Entra ID endpoints for a given workspace. Can only be used when authenticating to Azure Databricks + using an application registered in Azure Entra ID. + :param host: The Databricks workspace host. + :return: The OIDC endpoints for the workspace's Azure Entra ID tenant. + """ + # In Azure, this workspace endpoint redirects to the Entra ID authorization endpoint + host = fix_host_if_needed(host) + res = requests.get(f'{host}/oidc/oauth2/v2.0/authorize', allow_redirects=False) + real_auth_url = res.headers.get('location') + if not real_auth_url: + return None + return OidcEndpoints(authorization_endpoint=real_auth_url, + token_endpoint=real_auth_url.replace('/authorize', '/token')) + + class SessionCredentials(Refreshable): - def __init__(self, client: 'OAuthClient', token: Token): - self._client = client + def __init__(self, + token: Token, + oidc_endpoints: OidcEndpoints, + client_id: str, + client_secret: str = None, + redirect_url: str = None): + self._oidc_endpoints = oidc_endpoints + self._client_id = client_id + self._client_secret = client_secret + self._redirect_url = redirect_url super().__init__(token) def as_dict(self) -> dict: return {'token': self._token.as_dict()} @staticmethod - def from_dict(client: 'OAuthClient', raw: dict) -> 'SessionCredentials': - return SessionCredentials(client=client, token=Token.from_dict(raw['token'])) + def from_dict(raw: dict, + oidc_endpoints: OidcEndpoints, + client_id: str, + client_secret: str = None, + redirect_url: str = None) -> 'SessionCredentials': + return SessionCredentials(token=Token.from_dict(raw['token']), + oidc_endpoints=oidc_endpoints, + client_id=client_id, + client_secret=client_secret, + redirect_url=redirect_url) def auth_type(self): """Implementing CredentialsProvider protocol""" @@ -252,13 +328,13 @@ def refresh(self) -> Token: raise ValueError('oauth2: token expired and refresh token is not set') params = {'grant_type': 'refresh_token', 'refresh_token': refresh_token} headers = {} - if 'microsoft' in self._client.token_url: + if 'microsoft' in self._oidc_endpoints.token_endpoint: # Tokens issued for the 'Single-Page Application' client-type may # only be redeemed via cross-origin requests - headers = {'Origin': self._client.redirect_url} - return retrieve_token(client_id=self._client.client_id, - client_secret=self._client.client_secret, - token_url=self._client.token_url, + headers = {'Origin': self._redirect_url} + return retrieve_token(client_id=self._client_id, + client_secret=self._client_secret, + token_url=self._oidc_endpoints.token_endpoint, params=params, use_params=True, headers=headers) @@ -266,27 +342,45 @@ def refresh(self) -> Token: class Consent: - def __init__(self, client: 'OAuthClient', state: str, verifier: str, auth_url: str = None) -> None: - self.auth_url = auth_url - + def __init__(self, + state: str, + verifier: str, + oidc_endpoints: OidcEndpoints, + redirect_url: str, + client_id: str, + client_secret: str = None) -> None: self._verifier = verifier self._state = state - self._client = client + self._oidc_endpoints = oidc_endpoints + self._redirect_url = redirect_url + self._client_id = client_id + self._client_secret = client_secret def as_dict(self) -> dict: - return {'state': self._state, 'verifier': self._verifier} + return { + 'state': self._state, + 'verifier': self._verifier, + 'redirect_url': self._redirect_url, + 'oidc_endpoints': self._oidc_endpoints.as_dict(), + 'client_id': self._client_id, + } @staticmethod - def from_dict(client: 'OAuthClient', raw: dict) -> 'Consent': - return Consent(client, raw['state'], raw['verifier']) + def from_dict(raw: dict, client_secret: str = None) -> 'Consent': + return Consent(raw['state'], + raw['verifier'], + oidc_endpoints=OidcEndpoints.from_dict(raw['oidc_endpoints']), + redirect_url=raw['redirect_url'], + client_id=raw['client_id'], + client_secret=client_secret) def launch_external_browser(self) -> SessionCredentials: - redirect_url = urllib.parse.urlparse(self._client.redirect_url) + redirect_url = urllib.parse.urlparse(self._redirect_url) if redirect_url.hostname not in ('localhost', '127.0.0.1'): raise ValueError(f'cannot listen on {redirect_url.hostname}') feedback = [] - logger.info(f'Opening {self.auth_url} in a browser') - webbrowser.open_new(self.auth_url) + logger.info(f'Opening {self._oidc_endpoints.authorization_endpoint} in a browser') + webbrowser.open_new(self._oidc_endpoints.authorization_endpoint) port = redirect_url.port handler_factory = functools.partial(_OAuthCallback, feedback) with HTTPServer(("localhost", port), handler_factory) as httpd: @@ -308,7 +402,7 @@ def exchange(self, code: str, state: str) -> SessionCredentials: if self._state != state: raise ValueError('state mismatch') params = { - 'redirect_uri': self._client.redirect_url, + 'redirect_uri': self._redirect_url, 'grant_type': 'authorization_code', 'code_verifier': self._verifier, 'code': code @@ -316,19 +410,20 @@ def exchange(self, code: str, state: str) -> SessionCredentials: headers = {} while True: try: - token = retrieve_token(client_id=self._client.client_id, - client_secret=self._client.client_secret, - token_url=self._client.token_url, + token = retrieve_token(client_id=self._client_id, + client_secret=self._client_secret, + token_url=self._oidc_endpoints.token_endpoint, params=params, headers=headers, use_params=True) - return SessionCredentials(self._client, token) + return SessionCredentials(token, self._oidc_endpoints, self._client_id, self._client_secret, + self._redirect_url) except ValueError as e: if NO_ORIGIN_FOR_SPA_CLIENT_ERROR in str(e): # Retry in cases of 'Single-Page Application' client-type with # 'Origin' header equal to client's redirect URL. - headers['Origin'] = self._client.redirect_url - msg = f'Retrying OAuth token exchange with {self._client.redirect_url} origin' + headers['Origin'] = self._redirect_url + msg = f'Retrying OAuth token exchange with {self._redirect_url} origin' logger.debug(msg) continue raise e @@ -354,37 +449,19 @@ class OAuthClient: """ def __init__(self, - host: str, - client_id: str, + oidc_endpoints: OidcEndpoints, redirect_url: str, - *, + client_id: str, scopes: List[str] = None, client_secret: str = None): - # TODO: is it a circular dependency?.. - from .core import Config - from .credentials_provider import credentials_strategy - @credentials_strategy('noop', []) - def noop_credentials(_: any): - return lambda: {} - - config = Config(host=host, credentials_strategy=noop_credentials) if not scopes: scopes = ['all-apis'] - oidc = config.oidc_endpoints - if not oidc: - raise ValueError(f'{host} does not support OAuth') - self.host = host self.redirect_url = redirect_url - self.client_id = client_id - self.client_secret = client_secret - self.token_url = oidc.token_endpoint - self.is_aws = config.is_aws - self.is_azure = config.is_azure - self.is_gcp = config.is_gcp - - self._auth_url = oidc.authorization_endpoint + self._client_id = client_id + self._client_secret = client_secret + self._oidc_endpoints = oidc_endpoints self._scopes = scopes def initiate_consent(self) -> Consent: @@ -397,18 +474,23 @@ def initiate_consent(self) -> Consent: params = { 'response_type': 'code', - 'client_id': self.client_id, + 'client_id': self._client_id, 'redirect_uri': self.redirect_url, 'scope': ' '.join(self._scopes), 'state': state, 'code_challenge': challenge, 'code_challenge_method': 'S256' } - url = f'{self._auth_url}?{urllib.parse.urlencode(params)}' - return Consent(self, state, verifier, auth_url=url) + f'{self._oidc_endpoints.authorization_endpoint}?{urllib.parse.urlencode(params)}' + return Consent(state, + verifier, + oidc_endpoints=self._oidc_endpoints, + redirect_url=self.redirect_url, + client_id=self._client_id, + client_secret=self._client_secret) def __repr__(self) -> str: - return f'' + return f'' @dataclass @@ -448,17 +530,28 @@ def refresh(self) -> Token: use_header=self.use_header) -class TokenCache(): +class TokenCache: BASE_PATH = "~/.config/databricks-sdk-py/oauth" - def __init__(self, client: OAuthClient) -> None: - self.client = client + def __init__(self, + host: str, + oidc_endpoints: OidcEndpoints, + client_id: str, + redirect_url: str = None, + client_secret: str = None, + scopes: list[str] = None) -> None: + self._host = host + self._client_id = client_id + self._oidc_endpoints = oidc_endpoints + self._redirect_url = redirect_url + self._client_secret = client_secret + self._scopes = scopes or [] @property def filename(self) -> str: # Include host, client_id, and scopes in the cache filename to make it unique. hash = hashlib.sha256() - for chunk in [self.client.host, self.client.client_id, ",".join(self.client._scopes), ]: + for chunk in [self._host, self._client_id, ",".join(self._scopes), ]: hash.update(chunk.encode('utf-8')) return os.path.expanduser(os.path.join(self.__class__.BASE_PATH, hash.hexdigest() + ".json")) @@ -472,7 +565,11 @@ def load(self) -> Optional[SessionCredentials]: try: with open(self.filename, 'r') as f: raw = json.load(f) - return SessionCredentials.from_dict(self.client, raw) + return SessionCredentials.from_dict(raw, + oidc_endpoints=self._oidc_endpoints, + client_id=self._client_id, + client_secret=self._client_secret, + redirect_url=self._redirect_url) except Exception: return None diff --git a/examples/flask_app_with_oauth.py b/examples/flask_app_with_oauth.py index 4128de5ca..bd4d7d7e5 100755 --- a/examples/flask_app_with_oauth.py +++ b/examples/flask_app_with_oauth.py @@ -31,7 +31,7 @@ import logging import sys -from databricks.sdk.oauth import OAuthClient +from databricks.sdk.oauth import OAuthClient, OidcEndpoints, get_workspace_endpoints APP_NAME = "flask-demo" all_clusters_template = """
    @@ -44,7 +44,7 @@
""" -def create_flask_app(oauth_client: OAuthClient): +def create_flask_app(host: str, oidc_endpoints: OidcEndpoints, client_id: str, client_secret: str, redirect_url: str): """The create_flask_app function creates a Flask app that is enabled with OAuth. It initializes the app and web session secret keys with a randomly generated token. It defines two routes for @@ -64,7 +64,7 @@ def callback(): the callback parameters, and redirects the user to the index page.""" from databricks.sdk.oauth import Consent - consent = Consent.from_dict(oauth_client, session["consent"]) + consent = Consent.from_dict(session["consent"], client_secret=client_secret) session["creds"] = consent.exchange_callback_parameters(request.args).as_dict() return redirect(url_for("index")) @@ -73,17 +73,25 @@ def index(): """The index page checks if the user has already authenticated and retrieves the user's credentials using the Databricks SDK WorkspaceClient. It then renders the template with the clusters' list.""" if "creds" not in session: + oauth_client = OAuthClient(oidc_endpoints=oidc_endpoints, + client_id=client_id, + client_secret=client_secret, + redirect_url=redirect_url) consent = oauth_client.initiate_consent() session["consent"] = consent.as_dict() - return redirect(consent.auth_url) + return redirect(oidc_endpoints.authorization_endpoint) from databricks.sdk import WorkspaceClient from databricks.sdk.oauth import SessionCredentials - credentials_provider = SessionCredentials.from_dict(oauth_client, session["creds"]) - workspace_client = WorkspaceClient(host=oauth_client.host, + credentials_strategy = SessionCredentials.from_dict(session["creds"], + oidc_endpoints=oidc_endpoints, + client_id=client_id, + client_secret=client_secret, + redirect_url=redirect_url) + workspace_client = WorkspaceClient(host=host, product=APP_NAME, - credentials_provider=credentials_provider, + credentials_strategy=credentials_strategy, ) return render_template_string(all_clusters_template, w=workspace_client) @@ -110,22 +118,6 @@ def register_custom_app(args: argparse.Namespace) -> tuple[str, str]: return custom_app.client_id, custom_app.client_secret -def init_oauth_config(args) -> OAuthClient: - """Creates Databricks SDK configuration for OAuth""" - oauth_client = OAuthClient(host=args.host, - client_id=args.client_id, - client_secret=args.client_secret, - redirect_url=f"http://localhost:{args.port}/callback", - scopes=["all-apis"], - ) - if not oauth_client.client_id: - client_id, client_secret = register_custom_app(args) - oauth_client.client_id = client_id - oauth_client.client_secret = client_secret - - return oauth_client - - def parse_arguments() -> argparse.Namespace: """Parses arguments for this demo""" parser = argparse.ArgumentParser(prog=APP_NAME, description=__doc__.strip()) @@ -145,8 +137,12 @@ def parse_arguments() -> argparse.Namespace: logging.getLogger("databricks.sdk").setLevel(logging.DEBUG) args = parse_arguments() - oauth_cfg = init_oauth_config(args) - app = create_flask_app(oauth_cfg) + oidc_endpoints = get_workspace_endpoints(args.host) + client_id, client_secret = args.client_id, args.client_secret + if not client_id: + client_id, client_secret = register_custom_app(args) + redirect_url=f"http://localhost:{args.port}/callback" + app = create_flask_app(args.host, oidc_endpoints, client_id, client_secret, redirect_url) app.run( host="localhost", diff --git a/tests/fixture_server.py b/tests/fixture_server.py new file mode 100644 index 000000000..041904144 --- /dev/null +++ b/tests/fixture_server.py @@ -0,0 +1,33 @@ +import contextlib +import functools +import typing +from http.server import BaseHTTPRequestHandler + + +@contextlib.contextmanager +def http_fixture_server(handler: typing.Callable[[BaseHTTPRequestHandler], None]): + from http.server import HTTPServer + from threading import Thread + + class _handler(BaseHTTPRequestHandler): + + def __init__(self, handler: typing.Callable[[BaseHTTPRequestHandler], None], *args): + self._handler = handler + super().__init__(*args) + + def __getattr__(self, item): + if 'do_' != item[0:3]: + raise AttributeError(f'method {item} not found') + return functools.partial(self._handler, self) + + handler_factory = functools.partial(_handler, handler) + srv = HTTPServer(('localhost', 0), handler_factory) + t = Thread(target=srv.serve_forever) + try: + t.daemon = True + t.start() + yield 'http://{0}:{1}'.format(*srv.server_address) + finally: + srv.shutdown() + + diff --git a/tests/test_base_client.py b/tests/test_base_client.py new file mode 100644 index 000000000..4cba10dbe --- /dev/null +++ b/tests/test_base_client.py @@ -0,0 +1,282 @@ +from http.server import BaseHTTPRequestHandler +from typing import List, Iterator + +import pytest +import requests + +from databricks.sdk._base_client import _BaseClient, _StreamingResponse +from databricks.sdk import errors, useragent +from databricks.sdk.core import DatabricksError + +from .clock import FakeClock +from .fixture_server import http_fixture_server + + +class DummyResponse(requests.Response): + _content: Iterator[bytes] + _closed: bool = False + + def __init__(self, content: List[bytes]) -> None: + super().__init__() + self._content = iter(content) + + def iter_content(self, chunk_size: int = 1, decode_unicode=False) -> Iterator[bytes]: + return self._content + + def close(self): + self._closed = True + + def isClosed(self): + return self._closed + + +def test_streaming_response_read(config): + content = b"some initial binary data: \x00\x01" + response = _StreamingResponse(DummyResponse([content])) + assert response.read() == content + + +def test_streaming_response_read_partial(config): + content = b"some initial binary data: \x00\x01" + response = _StreamingResponse(DummyResponse([content])) + assert response.read(8) == b"some ini" + + +def test_streaming_response_read_full(config): + content = b"some initial binary data: \x00\x01" + response = _StreamingResponse(DummyResponse([content, content])) + assert response.read() == content + content + + +def test_streaming_response_read_closes(config): + content = b"some initial binary data: \x00\x01" + dummy_response = DummyResponse([content]) + with _StreamingResponse(dummy_response) as response: + assert response.read() == content + assert dummy_response.isClosed() + + +@pytest.mark.parametrize('status_code,headers,body,expected_error', [ + (400, {}, { + "message": + "errorMessage", + "details": [{ + "type": DatabricksError._error_info_type, + "reason": "error reason", + "domain": "error domain", + "metadata": { + "etag": "error etag" + }, + }, { + "type": "wrong type", + "reason": "wrong reason", + "domain": "wrong domain", + "metadata": { + "etag": "wrong etag" + } + }], + }, + errors.BadRequest('errorMessage', + details=[{ + 'type': DatabricksError._error_info_type, + 'reason': 'error reason', + 'domain': 'error domain', + 'metadata': { + 'etag': 'error etag' + }, + }])), + (401, {}, { + 'error_code': 'UNAUTHORIZED', + 'message': 'errorMessage', + }, + errors.Unauthenticated('errorMessage', error_code='UNAUTHORIZED')), + (403, {}, { + 'error_code': 'FORBIDDEN', + 'message': 'errorMessage', + }, + errors.PermissionDenied('errorMessage', error_code='FORBIDDEN')), + (429, {}, { + 'error_code': 'TOO_MANY_REQUESTS', + 'message': 'errorMessage', + }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=1)), + (429, { + 'Retry-After': '100' + }, { + 'error_code': 'TOO_MANY_REQUESTS', + 'message': 'errorMessage', + }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=100)), + (503, {}, { + 'error_code': 'TEMPORARILY_UNAVAILABLE', + 'message': 'errorMessage', + }, errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', + retry_after_secs=1)), + (503, { + 'Retry-After': '100' + }, { + 'error_code': 'TEMPORARILY_UNAVAILABLE', + 'message': 'errorMessage', + }, + errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', + retry_after_secs=100)), + (404, {}, { + 'scimType': 'scim type', + 'detail': 'detail', + 'status': 'status', + }, errors.NotFound('scim type detail', error_code='SCIM_status')), +]) +def test_error(requests_mock, status_code, headers, body, expected_error): + client = _BaseClient(clock=FakeClock()) + requests_mock.get("/test", json=body, status_code=status_code, headers=headers) + with pytest.raises(DatabricksError) as raised: + client._perform("GET", "https://localhost/test", headers={"test": "test"}) + actual = raised.value + assert isinstance(actual, type(expected_error)) + assert str(actual) == str(expected_error) + assert actual.error_code == expected_error.error_code + assert actual.retry_after_secs == expected_error.retry_after_secs + expected_error_infos, actual_error_infos = expected_error.get_error_info(), actual.get_error_info() + assert len(expected_error_infos) == len(actual_error_infos) + for expected, actual in zip(expected_error_infos, actual_error_infos): + assert expected.type == actual.type + assert expected.reason == actual.reason + assert expected.domain == actual.domain + assert expected.metadata == actual.metadata + + +def test_api_client_do_custom_headers(requests_mock): + client = _BaseClient() + requests_mock.get("/test", + json={"well": "done"}, + request_headers={ + "test": "test", + "User-Agent": useragent.to_string() + }) + res = client.do("GET", "https://localhost/test", headers={"test": "test"}) + assert res == {"well": "done"} + + +@pytest.mark.parametrize('status_code,include_retry_after', + ((429, False), (429, True), (503, False), (503, True))) +def test_http_retry_after(status_code, include_retry_after): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) == 0: + h.send_response(status_code) + if include_retry_after: + h.send_header('Retry-After', '1') + h.send_header('Content-Type', 'application/json') + h.end_headers() + else: + h.send_response(200) + h.send_header('Content-Type', 'application/json') + h.end_headers() + h.wfile.write(b'{"foo": 1}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + res = api_client.do('GET', f'{host}/foo') + assert 'foo' in res + + assert len(requests) == 2 + + +def test_http_retry_after_wrong_format(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) == 0: + h.send_response(429) + h.send_header('Retry-After', '1.58') + h.end_headers() + else: + h.send_response(200) + h.send_header('Content-Type', 'application/json') + h.end_headers() + h.wfile.write(b'{"foo": 1}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + res = api_client.do('GET', f'{host}/foo') + assert 'foo' in res + + assert len(requests) == 2 + + +def test_http_retried_exceed_limit(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + h.send_response(429) + h.send_header('Retry-After', '1') + h.end_headers() + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(retry_timeout_seconds=1, clock=FakeClock()) + with pytest.raises(TimeoutError): + res = api_client.do('GET', f'{host}/foo') + + assert len(requests) == 1 + + +def test_http_retried_on_match(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) == 0: + h.send_response(400) + h.end_headers() + h.wfile.write(b'{"error_code": "abc", "message": "... ClusterNotReadyException ..."}') + else: + h.send_response(200) + h.end_headers() + h.wfile.write(b'{"foo": 1}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + res = api_client.do('GET', f'{host}/foo') + assert 'foo' in res + + assert len(requests) == 2 + + +def test_http_not_retried_on_normal_errors(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) == 0: + h.send_response(400) + h.end_headers() + h.wfile.write(b'{"error_code": "abc", "message": "something not found"}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + with pytest.raises(DatabricksError): + api_client.do('GET', f'{host}/foo') + + assert len(requests) == 1 + + +def test_http_retried_on_connection_error(): + requests = [] + + def inner(h: BaseHTTPRequestHandler): + if len(requests) > 0: + h.send_response(200) + h.end_headers() + h.wfile.write(b'{"foo": 1}') + requests.append(h.requestline) + + with http_fixture_server(inner) as host: + api_client = _BaseClient(clock=FakeClock()) + res = api_client.do('GET', f'{host}/foo') + assert 'foo' in res + + assert len(requests) == 2 + + diff --git a/tests/test_core.py b/tests/test_core.py index d54563d4e..b61cfa015 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -8,14 +8,11 @@ import typing from datetime import datetime from http.server import BaseHTTPRequestHandler -from typing import Iterator, List import pytest -import requests from databricks.sdk import WorkspaceClient, errors -from databricks.sdk.core import (ApiClient, Config, DatabricksError, - StreamingResponse) +from databricks.sdk.core import ApiClient, Config, DatabricksError from databricks.sdk.credentials_provider import (CliTokenSource, CredentialsProvider, CredentialsStrategy, @@ -28,7 +25,7 @@ from databricks.sdk.service.iam import AccessControlRequest from databricks.sdk.version import __version__ -from .clock import FakeClock +from .fixture_server import http_fixture_server from .conftest import noop_credentials @@ -80,32 +77,6 @@ def write_small_dummy_executable(path: pathlib.Path): return cli -def test_streaming_response_read(config): - content = b"some initial binary data: \x00\x01" - response = StreamingResponse(DummyResponse([content])) - assert response.read() == content - - -def test_streaming_response_read_partial(config): - content = b"some initial binary data: \x00\x01" - response = StreamingResponse(DummyResponse([content])) - assert response.read(8) == b"some ini" - - -def test_streaming_response_read_full(config): - content = b"some initial binary data: \x00\x01" - response = StreamingResponse(DummyResponse([content, content])) - assert response.read() == content + content - - -def test_streaming_response_read_closes(config): - content = b"some initial binary data: \x00\x01" - dummy_response = DummyResponse([content]) - with StreamingResponse(dummy_response) as response: - assert response.read() == content - assert dummy_response.isClosed() - - def write_large_dummy_executable(path: pathlib.Path): cli = path.joinpath('databricks') @@ -290,36 +261,6 @@ def test_config_parsing_non_string_env_vars(monkeypatch): assert c.debug_truncate_bytes == 100 -class DummyResponse(requests.Response): - _content: Iterator[bytes] - _closed: bool = False - - def __init__(self, content: List[bytes]) -> None: - super().__init__() - self._content = iter(content) - - def iter_content(self, chunk_size: int = 1, decode_unicode=False) -> Iterator[bytes]: - return self._content - - def close(self): - self._closed = True - - def isClosed(self): - return self._closed - - -def test_api_client_do_custom_headers(config, requests_mock): - client = ApiClient(config) - requests_mock.get("/test", - json={"well": "done"}, - request_headers={ - "test": "test", - "User-Agent": config.user_agent - }) - res = client.do("GET", "/test", headers={"test": "test"}) - assert res == {"well": "done"} - - def test_access_control_list(config, requests_mock): requests_mock.post("http://localhost/api/2.1/jobs/create", request_headers={"User-Agent": config.user_agent}) @@ -360,80 +301,22 @@ def test_deletes(config, requests_mock): @pytest.mark.parametrize('status_code,headers,body,expected_error', [ - (400, {}, { - "message": - "errorMessage", - "details": [{ - "type": DatabricksError._error_info_type, - "reason": "error reason", - "domain": "error domain", - "metadata": { - "etag": "error etag" - }, - }, { - "type": "wrong type", - "reason": "wrong reason", - "domain": "wrong domain", - "metadata": { - "etag": "wrong etag" - } - }], - }, - errors.BadRequest('errorMessage', - details=[{ - 'type': DatabricksError._error_info_type, - 'reason': 'error reason', - 'domain': 'error domain', - 'metadata': { - 'etag': 'error etag' - }, - }])), (401, {}, { 'error_code': 'UNAUTHORIZED', 'message': 'errorMessage', }, - errors.Unauthenticated('errorMessage. Config: host=http://localhost, auth_type=noop', - error_code='UNAUTHORIZED')), + errors.Unauthenticated('errorMessage. Config: host=http://localhost, auth_type=noop', error_code='UNAUTHORIZED')), (403, {}, { 'error_code': 'FORBIDDEN', 'message': 'errorMessage', }, - errors.PermissionDenied('errorMessage. Config: host=http://localhost, auth_type=noop', - error_code='FORBIDDEN')), - (429, {}, { - 'error_code': 'TOO_MANY_REQUESTS', - 'message': 'errorMessage', - }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=1)), - (429, { - 'Retry-After': '100' - }, { - 'error_code': 'TOO_MANY_REQUESTS', - 'message': 'errorMessage', - }, errors.TooManyRequests('errorMessage', error_code='TOO_MANY_REQUESTS', retry_after_secs=100)), - (503, {}, { - 'error_code': 'TEMPORARILY_UNAVAILABLE', - 'message': 'errorMessage', - }, errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', - retry_after_secs=1)), - (503, { - 'Retry-After': '100' - }, { - 'error_code': 'TEMPORARILY_UNAVAILABLE', - 'message': 'errorMessage', - }, - errors.TemporarilyUnavailable('errorMessage', error_code='TEMPORARILY_UNAVAILABLE', - retry_after_secs=100)), - (404, {}, { - 'scimType': 'scim type', - 'detail': 'detail', - 'status': 'status', - }, errors.NotFound('scim type detail', error_code='SCIM_status')), + errors.PermissionDenied('errorMessage. Config: host=http://localhost, auth_type=noop', error_code='FORBIDDEN')), ]) def test_error(config, requests_mock, status_code, headers, body, expected_error): client = ApiClient(config) requests_mock.get("/test", json=body, status_code=status_code, headers=headers) with pytest.raises(DatabricksError) as raised: - client._perform("GET", "http://localhost/test", headers={"test": "test"}) + client.do("GET", "/test", headers={"test": "test"}) actual = raised.value assert isinstance(actual, type(expected_error)) assert str(actual) == str(expected_error) @@ -448,158 +331,6 @@ def test_error(config, requests_mock, status_code, headers, body, expected_error assert expected.metadata == actual.metadata -@contextlib.contextmanager -def http_fixture_server(handler: typing.Callable[[BaseHTTPRequestHandler], None]): - from http.server import HTTPServer - from threading import Thread - - class _handler(BaseHTTPRequestHandler): - - def __init__(self, handler: typing.Callable[[BaseHTTPRequestHandler], None], *args): - self._handler = handler - super().__init__(*args) - - def __getattr__(self, item): - if 'do_' != item[0:3]: - raise AttributeError(f'method {item} not found') - return functools.partial(self._handler, self) - - handler_factory = functools.partial(_handler, handler) - srv = HTTPServer(('localhost', 0), handler_factory) - t = Thread(target=srv.serve_forever) - try: - t.daemon = True - t.start() - yield 'http://{0}:{1}'.format(*srv.server_address) - finally: - srv.shutdown() - - -@pytest.mark.parametrize('status_code,include_retry_after', - ((429, False), (429, True), (503, False), (503, True))) -def test_http_retry_after(status_code, include_retry_after): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) == 0: - h.send_response(status_code) - if include_retry_after: - h.send_header('Retry-After', '1') - h.send_header('Content-Type', 'application/json') - h.end_headers() - else: - h.send_response(200) - h.send_header('Content-Type', 'application/json') - h.end_headers() - h.wfile.write(b'{"foo": 1}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - res = api_client.do('GET', '/foo') - assert 'foo' in res - - assert len(requests) == 2 - - -def test_http_retry_after_wrong_format(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) == 0: - h.send_response(429) - h.send_header('Retry-After', '1.58') - h.end_headers() - else: - h.send_response(200) - h.send_header('Content-Type', 'application/json') - h.end_headers() - h.wfile.write(b'{"foo": 1}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - res = api_client.do('GET', '/foo') - assert 'foo' in res - - assert len(requests) == 2 - - -def test_http_retried_exceed_limit(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - h.send_response(429) - h.send_header('Retry-After', '1') - h.end_headers() - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', retry_timeout_seconds=1, clock=FakeClock())) - with pytest.raises(TimeoutError): - api_client.do('GET', '/foo') - - assert len(requests) == 1 - - -def test_http_retried_on_match(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) == 0: - h.send_response(400) - h.end_headers() - h.wfile.write(b'{"error_code": "abc", "message": "... ClusterNotReadyException ..."}') - else: - h.send_response(200) - h.end_headers() - h.wfile.write(b'{"foo": 1}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - res = api_client.do('GET', '/foo') - assert 'foo' in res - - assert len(requests) == 2 - - -def test_http_not_retried_on_normal_errors(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) == 0: - h.send_response(400) - h.end_headers() - h.wfile.write(b'{"error_code": "abc", "message": "something not found"}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - with pytest.raises(DatabricksError): - api_client.do('GET', '/foo') - - assert len(requests) == 1 - - -def test_http_retried_on_connection_error(): - requests = [] - - def inner(h: BaseHTTPRequestHandler): - if len(requests) > 0: - h.send_response(200) - h.end_headers() - h.wfile.write(b'{"foo": 1}') - requests.append(h.requestline) - - with http_fixture_server(inner) as host: - api_client = ApiClient(Config(host=host, token='_', clock=FakeClock())) - res = api_client.do('GET', '/foo') - assert 'foo' in res - - assert len(requests) == 2 - - def test_github_oidc_flow_works_with_azure(monkeypatch): def inner(h: BaseHTTPRequestHandler): diff --git a/tests/test_oauth.py b/tests/test_oauth.py index ce2d514ff..2f4ba2238 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -1,29 +1,102 @@ -from databricks.sdk.core import Config -from databricks.sdk.oauth import OAuthClient, OidcEndpoints, TokenCache - - -def test_token_cache_unique_filename_by_host(mocker): - mocker.patch.object(Config, "oidc_endpoints", - OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - common_args = dict(client_id="abc", redirect_url="http://localhost:8020") - c1 = OAuthClient(host="http://localhost:", **common_args) - c2 = OAuthClient(host="https://bar.cloud.databricks.com", **common_args) - assert TokenCache(c1).filename != TokenCache(c2).filename - - -def test_token_cache_unique_filename_by_client_id(mocker): - mocker.patch.object(Config, "oidc_endpoints", - OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - common_args = dict(host="http://localhost:", redirect_url="http://localhost:8020") - c1 = OAuthClient(client_id="abc", **common_args) - c2 = OAuthClient(client_id="def", **common_args) - assert TokenCache(c1).filename != TokenCache(c2).filename - - -def test_token_cache_unique_filename_by_scopes(mocker): - mocker.patch.object(Config, "oidc_endpoints", - OidcEndpoints("http://localhost:1234", "http://localhost:1234")) - common_args = dict(host="http://localhost:", client_id="abc", redirect_url="http://localhost:8020") - c1 = OAuthClient(scopes=["foo"], **common_args) - c2 = OAuthClient(scopes=["bar"], **common_args) - assert TokenCache(c1).filename != TokenCache(c2).filename +from databricks.sdk.oauth import OidcEndpoints, TokenCache, get_workspace_endpoints, get_azure_entra_id_workspace_endpoints, get_account_endpoints +from databricks.sdk._base_client import _BaseClient +from .clock import FakeClock + + + +def test_token_cache_unique_filename_by_host(): + common_args = dict( + client_id="abc", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(host="http://localhost:", **common_args).filename != TokenCache("https://bar.cloud.databricks.com", **common_args).filename + + +def test_token_cache_unique_filename_by_client_id(): + common_args = dict( + host="http://localhost:", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(client_id="abc", **common_args).filename != TokenCache(client_id="def", **common_args).filename + + +def test_token_cache_unique_filename_by_scopes(): + common_args = dict( + host="http://localhost:", + client_id="abc", + redirect_url="http://localhost:8020", + oidc_endpoints=OidcEndpoints("http://localhost:1234", "http://localhost:1234")) + assert TokenCache(scopes=["foo"], **common_args).filename != TokenCache(scopes=["bar"], **common_args).filename + + +def test_account_oidc_endpoints(requests_mock): + requests_mock.get("https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + json={"authorization_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "token_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token"}) + client = _BaseClient(clock=FakeClock()) + endpoints = get_account_endpoints("accounts.cloud.databricks.com", "abc-123", client=client) + assert endpoints == OidcEndpoints( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token") + +def test_account_oidc_endpoints_retry_on_429(requests_mock): + request_count = 0 + + def nth_request(n): + def observe_request(_request): + nonlocal request_count + is_match = request_count == n + if is_match: + request_count += 1 + return is_match + return observe_request + + requests_mock.get("https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + additional_matcher=nth_request(0), + status_code=429) + requests_mock.get("https://accounts.cloud.databricks.com/oidc/accounts/abc-123/.well-known/oauth-authorization-server", + additional_matcher=nth_request(1), + json={"authorization_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "token_endpoint": "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token"}) + client = _BaseClient(clock=FakeClock()) + endpoints = get_account_endpoints("accounts.cloud.databricks.com", "abc-123", client=client) + assert endpoints == OidcEndpoints( + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/authorize", + "https://accounts.cloud.databricks.com/oidc/accounts/abc-123/oauth/token") + + +def test_workspace_oidc_endpoints(requests_mock): + requests_mock.get("https://my-workspace.cloud.databricks.com/.well-known/oauth-authorization-server", + json={"authorization_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "token_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/token"}) + client = _BaseClient(clock=FakeClock()) + endpoints = get_workspace_endpoints("my-workspace.cloud.databricks.com", client=client) + assert endpoints == OidcEndpoints( + "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "https://my-workspace.cloud.databricks.com/oidc/oauth/token") + + +def test_workspace_oidc_endpoints_retry_on_429(requests_mock): + request_count = 0 + + def nth_request(n): + def observe_request(_request): + nonlocal request_count + is_match = request_count == n + if is_match: + request_count += 1 + return is_match + return observe_request + + requests_mock.get("https://my-workspace.cloud.databricks.com/.well-known/oauth-authorization-server", + additional_matcher=nth_request(0), + status_code=429) + requests_mock.get("https://my-workspace.cloud.databricks.com/.well-known/oauth-authorization-server", + additional_matcher=nth_request(1), + json={"authorization_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "token_endpoint": "https://my-workspace.cloud.databricks.com/oidc/oauth/token"}) + client = _BaseClient(clock=FakeClock()) + endpoints = get_workspace_endpoints("my-workspace.cloud.databricks.com", client=client) + assert endpoints == OidcEndpoints( + "https://my-workspace.cloud.databricks.com/oidc/oauth/authorize", + "https://my-workspace.cloud.databricks.com/oidc/oauth/token")