From f0b23a5a08a68731edcf731a5e865e4b807236c4 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 22 Jan 2024 11:23:08 +0100 Subject: [PATCH] Support dev and staging workspaces (#514) ## Changes Copies the DatabricksEnvironment abstraction from the Go SDK to the Python SDK. This enables using the SDK for Azure development and staging workspaces, which use an alternate Azure login application ID. Additionally, this PR changes `azure_environment` to accept the same values as the Go SDK and Terraform providers for both Databricks and Azure, i.e. `PUBLIC`, `USGOVERNMENT`, and `CHINA`. The first commit of this PR is a refactor moving config & credentials provider code to separate files, as core.py is getting quite big and difficult to maintain. The second one is much smaller and has just the net new bits. ## Tests - [x] Unit test for the new environment lookup code. - [x] Manual test Azure CLI auth against a staging Azure workspace. --- .codegen/__init__.py.tmpl | 5 +- databricks/sdk/__init__.py | 5 +- databricks/sdk/azure.py | 10 +- databricks/sdk/config.py | 452 +++++++++++ databricks/sdk/core.py | 1028 +----------------------- databricks/sdk/credentials_provider.py | 617 ++++++++++++++ databricks/sdk/environments.py | 72 ++ databricks/sdk/oauth.py | 3 +- tests/conftest.py | 3 +- tests/test_auth.py | 2 +- tests/test_core.py | 43 +- tests/test_environments.py | 19 + tests/test_metadata_service_auth.py | 3 +- tests/test_oauth.py | 4 +- 14 files changed, 1216 insertions(+), 1050 deletions(-) create mode 100644 databricks/sdk/config.py create mode 100644 databricks/sdk/credentials_provider.py create mode 100644 databricks/sdk/environments.py create mode 100644 tests/test_environments.py diff --git a/.codegen/__init__.py.tmpl b/.codegen/__init__.py.tmpl index 40d06293e..df854970a 100644 --- a/.codegen/__init__.py.tmpl +++ b/.codegen/__init__.py.tmpl @@ -1,5 +1,6 @@ import databricks.sdk.core as client import databricks.sdk.dbutils as dbutils +from databricks.sdk.credentials_provider import CredentialsProvider from databricks.sdk.mixins.files import DbfsExt from databricks.sdk.mixins.compute import ClustersExt @@ -43,7 +44,7 @@ class WorkspaceClient: debug_headers: bool = None, product="unknown", product_version="0.0.0", - credentials_provider: client.CredentialsProvider = None, + credentials_provider: CredentialsProvider = None, config: client.Config = None): if not config: config = client.Config({{range $args}}{{.}}={{.}}, {{end}} @@ -91,7 +92,7 @@ class AccountClient: debug_headers: bool = None, product="unknown", product_version="0.0.0", - credentials_provider: client.CredentialsProvider = None, + credentials_provider: CredentialsProvider = None, config: client.Config = None): if not config: config = client.Config({{range $args}}{{.}}={{.}}, {{end}} diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index cecf6d62c..5d345d381 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -1,5 +1,6 @@ import databricks.sdk.core as client import databricks.sdk.dbutils as dbutils +from databricks.sdk.credentials_provider import CredentialsProvider from databricks.sdk.mixins.compute import ClustersExt from databricks.sdk.mixins.files import DbfsExt from databricks.sdk.mixins.workspace import WorkspaceExt @@ -114,7 +115,7 @@ def __init__(self, debug_headers: bool = None, product="unknown", product_version="0.0.0", - credentials_provider: client.CredentialsProvider = None, + credentials_provider: CredentialsProvider = None, config: client.Config = None): if not config: config = client.Config(host=host, @@ -585,7 +586,7 @@ def __init__(self, debug_headers: bool = None, product="unknown", product_version="0.0.0", - credentials_provider: client.CredentialsProvider = None, + credentials_provider: CredentialsProvider = None, config: client.Config = None): if not config: config = client.Config(host=host, diff --git a/databricks/sdk/azure.py b/databricks/sdk/azure.py index e079a8c94..3e008d8c2 100644 --- a/databricks/sdk/azure.py +++ b/databricks/sdk/azure.py @@ -15,19 +15,15 @@ class AzureEnvironment: ARM_DATABRICKS_RESOURCE_ID = "2ff814a6-3304-4ab8-85cb-cd0e6f879c1d" ENVIRONMENTS = dict( - PUBLIC=AzureEnvironment(name="AzurePublicCloud", + PUBLIC=AzureEnvironment(name="PUBLIC", service_management_endpoint="https://management.core.windows.net/", resource_manager_endpoint="https://management.azure.com/", active_directory_endpoint="https://login.microsoftonline.com/"), - GERMAN=AzureEnvironment(name="AzureGermanCloud", - service_management_endpoint="https://management.core.cloudapi.de/", - resource_manager_endpoint="https://management.microsoftazure.de/", - active_directory_endpoint="https://login.microsoftonline.de/"), - USGOVERNMENT=AzureEnvironment(name="AzureUSGovernmentCloud", + USGOVERNMENT=AzureEnvironment(name="USGOVERNMENT", service_management_endpoint="https://management.core.usgovcloudapi.net/", resource_manager_endpoint="https://management.usgovcloudapi.net/", active_directory_endpoint="https://login.microsoftonline.us/"), - CHINA=AzureEnvironment(name="AzureChinaCloud", + CHINA=AzureEnvironment(name="CHINA", service_management_endpoint="https://management.core.chinacloudapi.cn/", resource_manager_endpoint="https://management.chinacloudapi.cn/", active_directory_endpoint="https://login.chinacloudapi.cn/"), diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py new file mode 100644 index 000000000..0dadfc927 --- /dev/null +++ b/databricks/sdk/config.py @@ -0,0 +1,452 @@ +import configparser +import copy +import logging +import os +import pathlib +import platform +import sys +import urllib.parse +from typing import Dict, Iterable, Optional + +import requests + +from .azure import AzureEnvironment +from .credentials_provider import CredentialsProvider, DefaultCredentials +from .environments import (ALL_ENVS, DEFAULT_ENVIRONMENT, Cloud, + DatabricksEnvironment) +from .oauth import OidcEndpoints +from .version import __version__ + +logger = logging.getLogger('databricks.sdk') + + +class ConfigAttribute: + """ Configuration attribute metadata and descriptor protocols. """ + + # name and transform are discovered from Config.__new__ + name: str = None + transform: type = str + + def __init__(self, env: str = None, auth: str = None, sensitive: bool = False): + self.env = env + self.auth = auth + self.sensitive = sensitive + + def __get__(self, cfg: 'Config', owner): + if not cfg: + return None + return cfg._inner.get(self.name, None) + + def __set__(self, cfg: 'Config', value: any): + cfg._inner[self.name] = self.transform(value) + + def __repr__(self) -> str: + return f"" + + +class Config: + host: str = ConfigAttribute(env='DATABRICKS_HOST') + account_id: str = ConfigAttribute(env='DATABRICKS_ACCOUNT_ID') + token: str = ConfigAttribute(env='DATABRICKS_TOKEN', auth='pat', sensitive=True) + username: str = ConfigAttribute(env='DATABRICKS_USERNAME', auth='basic') + password: str = ConfigAttribute(env='DATABRICKS_PASSWORD', auth='basic', sensitive=True) + client_id: str = ConfigAttribute(env='DATABRICKS_CLIENT_ID', auth='oauth') + client_secret: str = ConfigAttribute(env='DATABRICKS_CLIENT_SECRET', auth='oauth', sensitive=True) + profile: str = ConfigAttribute(env='DATABRICKS_CONFIG_PROFILE') + config_file: str = ConfigAttribute(env='DATABRICKS_CONFIG_FILE') + google_service_account: str = ConfigAttribute(env='DATABRICKS_GOOGLE_SERVICE_ACCOUNT', auth='google') + google_credentials: str = ConfigAttribute(env='GOOGLE_CREDENTIALS', auth='google', sensitive=True) + azure_workspace_resource_id: str = ConfigAttribute(env='DATABRICKS_AZURE_RESOURCE_ID', auth='azure') + azure_use_msi: bool = ConfigAttribute(env='ARM_USE_MSI', auth='azure') + azure_client_secret: str = ConfigAttribute(env='ARM_CLIENT_SECRET', auth='azure', sensitive=True) + azure_client_id: str = ConfigAttribute(env='ARM_CLIENT_ID', auth='azure') + azure_tenant_id: str = ConfigAttribute(env='ARM_TENANT_ID', auth='azure') + azure_environment: str = ConfigAttribute(env='ARM_ENVIRONMENT') + databricks_cli_path: str = ConfigAttribute(env='DATABRICKS_CLI_PATH') + auth_type: str = ConfigAttribute(env='DATABRICKS_AUTH_TYPE') + cluster_id: str = ConfigAttribute(env='DATABRICKS_CLUSTER_ID') + warehouse_id: str = ConfigAttribute(env='DATABRICKS_WAREHOUSE_ID') + skip_verify: bool = ConfigAttribute() + http_timeout_seconds: float = ConfigAttribute() + debug_truncate_bytes: int = ConfigAttribute(env='DATABRICKS_DEBUG_TRUNCATE_BYTES') + debug_headers: bool = ConfigAttribute(env='DATABRICKS_DEBUG_HEADERS') + rate_limit: int = ConfigAttribute(env='DATABRICKS_RATE_LIMIT') + retry_timeout_seconds: int = ConfigAttribute() + metadata_service_url = ConfigAttribute(env='DATABRICKS_METADATA_SERVICE_URL', + auth='metadata-service', + sensitive=True) + max_connection_pools: int = ConfigAttribute() + max_connections_per_pool: int = ConfigAttribute() + databricks_environment: Optional[DatabricksEnvironment] = None + + def __init__(self, + *, + credentials_provider: CredentialsProvider = None, + product="unknown", + product_version="0.0.0", + **kwargs): + self._inner = {} + self._user_agent_other_info = [] + self._credentials_provider = credentials_provider if credentials_provider else DefaultCredentials() + if 'databricks_environment' in kwargs: + self.databricks_environment = kwargs['databricks_environment'] + del kwargs['databricks_environment'] + try: + self._set_inner_config(kwargs) + self._load_from_env() + self._known_file_config_loader() + self._fix_host_if_needed() + self._validate() + self._init_auth() + self._product = product + self._product_version = product_version + except ValueError as e: + message = self.wrap_debug_info(str(e)) + raise ValueError(message) from e + + def wrap_debug_info(self, message: str) -> str: + debug_string = self.debug_string() + if debug_string: + message = f'{message.rstrip(".")}. {debug_string}' + return message + + @staticmethod + def parse_dsn(dsn: str) -> 'Config': + uri = urllib.parse.urlparse(dsn) + if uri.scheme != 'databricks': + raise ValueError(f'Expected databricks:// scheme, got {uri.scheme}://') + kwargs = {'host': f'https://{uri.hostname}'} + if uri.username: + kwargs['username'] = uri.username + if uri.password: + kwargs['password'] = uri.password + query = dict(urllib.parse.parse_qsl(uri.query)) + for attr in Config.attributes(): + if attr.name not in query: + continue + kwargs[attr.name] = query[attr.name] + return Config(**kwargs) + + def authenticate(self) -> Dict[str, str]: + """ Returns a list of fresh authentication headers """ + return self._header_factory() + + def as_dict(self) -> dict: + return self._inner + + def _get_azure_environment_name(self) -> str: + if not self.azure_environment: + return "PUBLIC" + env = self.azure_environment.upper() + # Compatibility with older versions of the SDK that allowed users to specify AzurePublicCloud or AzureChinaCloud + if env.startswith("AZURE"): + env = env[len("AZURE"):] + if env.endswith("CLOUD"): + env = env[:-len("CLOUD")] + return env + + @property + def environment(self) -> DatabricksEnvironment: + """Returns the environment based on configuration.""" + if self.databricks_environment: + return self.databricks_environment + if self.host: + for environment in ALL_ENVS: + if self.host.endswith(environment.dns_zone): + return environment + if self.azure_workspace_resource_id: + azure_env = self._get_azure_environment_name() + for environment in ALL_ENVS: + if environment.cloud != Cloud.AZURE: + continue + if environment.azure_environment.name != azure_env: + continue + if environment.dns_zone.startswith(".dev") or environment.dns_zone.startswith(".staging"): + continue + return environment + return DEFAULT_ENVIRONMENT + + @property + def is_azure(self) -> bool: + return self.environment.cloud == Cloud.AZURE + + @property + def is_gcp(self) -> bool: + return self.environment.cloud == Cloud.GCP + + @property + def is_aws(self) -> bool: + return self.environment.cloud == Cloud.AWS + + @property + def is_account_client(self) -> bool: + if not self.host: + return False + return self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.") + + @property + def arm_environment(self) -> AzureEnvironment: + return self.environment.azure_environment + + @property + def effective_azure_login_app_id(self): + return self.environment.azure_application_id + + @property + def hostname(self) -> str: + url = urllib.parse.urlparse(self.host) + return url.netloc + + @property + def is_any_auth_configured(self) -> bool: + for attr in Config.attributes(): + if not attr.auth: + continue + value = self._inner.get(attr.name, None) + if value: + return True + return False + + @property + def user_agent(self): + """ Returns User-Agent header used by this SDK """ + py_version = platform.python_version() + os_name = platform.uname().system.lower() + + ua = [ + f"{self._product}/{self._product_version}", f"databricks-sdk-py/{__version__}", + f"python/{py_version}", f"os/{os_name}", f"auth/{self.auth_type}", + ] + if len(self._user_agent_other_info) > 0: + ua.append(' '.join(self._user_agent_other_info)) + if len(self._upstream_user_agent) > 0: + ua.append(self._upstream_user_agent) + if 'DATABRICKS_RUNTIME_VERSION' in os.environ: + runtime_version = os.environ['DATABRICKS_RUNTIME_VERSION'] + if runtime_version != '': + runtime_version = self._sanitize_header_value(runtime_version) + ua.append(f'runtime/{runtime_version}') + + return ' '.join(ua) + + @staticmethod + def _sanitize_header_value(value: str) -> str: + value = value.replace(' ', '-') + value = value.replace('/', '-') + return value + + @property + def _upstream_user_agent(self) -> str: + product = os.environ.get('DATABRICKS_SDK_UPSTREAM', None) + product_version = os.environ.get('DATABRICKS_SDK_UPSTREAM_VERSION', None) + if product is not None and product_version is not None: + return f"upstream/{product} upstream-version/{product_version}" + return "" + + def with_user_agent_extra(self, key: str, value: str) -> 'Config': + self._user_agent_other_info.append(f"{key}/{value}") + return self + + @property + def oidc_endpoints(self) -> Optional[OidcEndpoints]: + self._fix_host_if_needed() + if not self.host: + return None + if self.is_azure and self.azure_client_id: + # Retrieve authorize endpoint to retrieve token endpoint after + res = requests.get(f'{self.host}/oidc/oauth2/v2.0/authorize', allow_redirects=False) + real_auth_url = res.headers.get('location') + if not real_auth_url: + return None + return OidcEndpoints(authorization_endpoint=real_auth_url, + token_endpoint=real_auth_url.replace('/authorize', '/token')) + if self.is_account_client and self.account_id: + prefix = f'{self.host}/oidc/accounts/{self.account_id}' + return OidcEndpoints(authorization_endpoint=f'{prefix}/v1/authorize', + token_endpoint=f'{prefix}/v1/token') + oidc = f'{self.host}/oidc/.well-known/oauth-authorization-server' + res = requests.get(oidc) + if res.status_code != 200: + return None + auth_metadata = res.json() + return OidcEndpoints(authorization_endpoint=auth_metadata.get('authorization_endpoint'), + token_endpoint=auth_metadata.get('token_endpoint')) + + def debug_string(self) -> str: + """ Returns log-friendly representation of configured attributes """ + buf = [] + attrs_used = [] + envs_used = [] + for attr in Config.attributes(): + if attr.env and os.environ.get(attr.env): + envs_used.append(attr.env) + value = getattr(self, attr.name) + if not value: + continue + safe = '***' if attr.sensitive else f'{value}' + attrs_used.append(f'{attr.name}={safe}') + if attrs_used: + buf.append(f'Config: {", ".join(attrs_used)}') + if envs_used: + buf.append(f'Env: {", ".join(envs_used)}') + return '. '.join(buf) + + def to_dict(self) -> Dict[str, any]: + return self._inner + + @property + def sql_http_path(self) -> Optional[str]: + """(Experimental) Return HTTP path for SQL Drivers. + + If `cluster_id` or `warehouse_id` are configured, return a valid HTTP Path argument + used in construction of JDBC/ODBC DSN string. + + See https://docs.databricks.com/integrations/jdbc-odbc-bi.html + """ + if (not self.cluster_id) and (not self.warehouse_id): + return None + if self.cluster_id and self.warehouse_id: + raise ValueError('cannot have both cluster_id and warehouse_id') + headers = self.authenticate() + headers['User-Agent'] = f'{self.user_agent} sdk-feature/sql-http-path' + if self.cluster_id: + response = requests.get(f"{self.host}/api/2.0/preview/scim/v2/Me", headers=headers) + # get workspace ID from the response header + workspace_id = response.headers.get('x-databricks-org-id') + return f'sql/protocolv1/o/{workspace_id}/{self.cluster_id}' + if self.warehouse_id: + return f'/sql/1.0/warehouses/{self.warehouse_id}' + + @classmethod + def attributes(cls) -> Iterable[ConfigAttribute]: + """ Returns a list of Databricks SDK configuration metadata """ + if hasattr(cls, '_attributes'): + return cls._attributes + if sys.version_info[1] >= 10: + import inspect + anno = inspect.get_annotations(cls) + else: + # Python 3.7 compatibility: getting type hints require extra hop, as described in + # "Accessing The Annotations Dict Of An Object In Python 3.9 And Older" section of + # https://docs.python.org/3/howto/annotations.html + anno = cls.__dict__['__annotations__'] + attrs = [] + for name, v in cls.__dict__.items(): + if type(v) != ConfigAttribute: + continue + v.name = name + v.transform = anno.get(name, str) + attrs.append(v) + cls._attributes = attrs + return cls._attributes + + def _fix_host_if_needed(self): + if not self.host: + return + # fix url to remove trailing slash + o = urllib.parse.urlparse(self.host) + if not o.hostname: + # only hostname is specified + self.host = f"https://{self.host}" + else: + self.host = f"{o.scheme}://{o.netloc}" + + def _set_inner_config(self, keyword_args: Dict[str, any]): + for attr in self.attributes(): + if attr.name not in keyword_args: + continue + if keyword_args.get(attr.name, None) is None: + continue + self.__setattr__(attr.name, keyword_args[attr.name]) + + def _load_from_env(self): + found = False + for attr in self.attributes(): + if not attr.env: + continue + if attr.name in self._inner: + continue + value = os.environ.get(attr.env) + if not value: + continue + self.__setattr__(attr.name, value) + found = True + if found: + logger.debug('Loaded from environment') + + def _known_file_config_loader(self): + if not self.profile and (self.is_any_auth_configured or self.host + or self.azure_workspace_resource_id): + # skip loading configuration file if there's any auth configured + # directly as part of the Config() constructor. + return + config_file = self.config_file + if not config_file: + config_file = "~/.databrickscfg" + config_path = pathlib.Path(config_file).expanduser() + if not config_path.exists(): + logger.debug("%s does not exist", config_path) + return + ini_file = configparser.ConfigParser() + ini_file.read(config_path) + profile = self.profile + has_explicit_profile = self.profile is not None + # In Go SDK, we skip merging the profile with DEFAULT section, though Python's ConfigParser.items() + # is returning profile key-value pairs _including those from DEFAULT_. This is not what we expect + # from Unified Auth test suite at the moment. Hence, the private variable access. + # See: https://docs.python.org/3/library/configparser.html#mapping-protocol-access + if not has_explicit_profile and not ini_file.defaults(): + logger.debug(f'{config_path} has no DEFAULT profile configured') + return + if not has_explicit_profile: + profile = "DEFAULT" + profiles = ini_file._sections + if ini_file.defaults(): + profiles['DEFAULT'] = ini_file.defaults() + if profile not in profiles: + raise ValueError(f'resolve: {config_path} has no {profile} profile configured') + raw_config = profiles[profile] + logger.info(f'loading {profile} profile from {config_file}: {", ".join(raw_config.keys())}') + for k, v in raw_config.items(): + if k in self._inner: + # don't overwrite a value previously set + continue + self.__setattr__(k, v) + + def _validate(self): + auths_used = set() + for attr in Config.attributes(): + if attr.name not in self._inner: + continue + if not attr.auth: + continue + auths_used.add(attr.auth) + if len(auths_used) <= 1: + return + if self.auth_type: + # client has auth preference set + return + names = " and ".join(sorted(auths_used)) + raise ValueError(f'validate: more than one authorization method configured: {names}') + + def _init_auth(self): + try: + self._header_factory = self._credentials_provider(self) + self.auth_type = self._credentials_provider.auth_type() + if not self._header_factory: + raise ValueError('not configured') + except ValueError as e: + raise ValueError(f'{self._credentials_provider.auth_type()} auth: {e}') from e + + def __repr__(self): + return f'<{self.debug_string()}>' + + def copy(self): + """Creates a copy of the config object. + All the copies share most of their internal state (ie, shared reference to fields such as credential_provider). + Copies have their own instances of the following fields + - `_user_agent_other_info` + """ + cpy: Config = copy.copy(self) + cpy._user_agent_other_info = copy.deepcopy(self._user_agent_other_info) + return cpy diff --git a/databricks/sdk/core.py b/databricks/sdk/core.py index fa7ede0f1..2b7442708 100644 --- a/databricks/sdk/core.py +++ b/databricks/sdk/core.py @@ -1,1040 +1,22 @@ -import abc -import base64 -import configparser -import copy -import functools -import io -import json -import logging -import os -import pathlib -import platform import re -import subprocess -import sys import urllib.parse -from datetime import datetime, timedelta +from datetime import timedelta from json import JSONDecodeError from types import TracebackType -from typing import (Any, BinaryIO, Callable, Dict, Iterable, Iterator, List, - Optional, Type, Union) +from typing import Any, BinaryIO, Iterator, Type -import google.auth -import requests -from google.auth import impersonated_credentials -from google.auth.transport.requests import Request -from google.oauth2 import service_account from requests.adapters import HTTPAdapter -from .azure import (ARM_DATABRICKS_RESOURCE_ID, ENVIRONMENTS, AzureEnvironment, - add_sp_management_token, add_workspace_id_header) +from .config import * +# To preserve backwards compatibility (as these definitions were previously in this module) +from .credentials_provider import * from .errors import DatabricksError, error_mapper -from .oauth import (ClientCredentials, OAuthClient, OidcEndpoints, Refreshable, - Token, TokenCache, TokenSource) from .retries import retried -from .version import __version__ __all__ = ['Config', 'DatabricksError'] logger = logging.getLogger('databricks.sdk') -HeaderFactory = Callable[[], Dict[str, str]] - -GcpScopes = ["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/compute"] - - -class CredentialsProvider(abc.ABC): - """ CredentialsProvider is the protocol (call-side interface) - for authenticating requests to Databricks REST APIs""" - - @abc.abstractmethod - def auth_type(self) -> str: - ... - - @abc.abstractmethod - def __call__(self, cfg: 'Config') -> HeaderFactory: - ... - - -def credentials_provider(name: str, require: List[str]): - """ Given the function that receives a Config and returns RequestVisitor, - create CredentialsProvider with a given name and required configuration - attribute names to be present for this function to be called. """ - - def inner(func: Callable[['Config'], HeaderFactory]) -> CredentialsProvider: - - @functools.wraps(func) - def wrapper(cfg: 'Config') -> Optional[HeaderFactory]: - for attr in require: - if not getattr(cfg, attr): - return None - return func(cfg) - - wrapper.auth_type = lambda: name - return wrapper - - return inner - - -@credentials_provider('basic', ['host', 'username', 'password']) -def basic_auth(cfg: 'Config') -> HeaderFactory: - """ Given username and password, add base64-encoded Basic credentials """ - encoded = base64.b64encode(f'{cfg.username}:{cfg.password}'.encode()).decode() - static_credentials = {'Authorization': f'Basic {encoded}'} - - def inner() -> Dict[str, str]: - return static_credentials - - return inner - - -@credentials_provider('pat', ['host', 'token']) -def pat_auth(cfg: 'Config') -> HeaderFactory: - """ Adds Databricks Personal Access Token to every request """ - static_credentials = {'Authorization': f'Bearer {cfg.token}'} - - def inner() -> Dict[str, str]: - return static_credentials - - return inner - - -@credentials_provider('runtime', []) -def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]: - if 'DATABRICKS_RUNTIME_VERSION' not in os.environ: - return None - - # This import MUST be after the "DATABRICKS_RUNTIME_VERSION" check - # above, so that we are not throwing import errors when not in - # runtime and no config variables are set. - from databricks.sdk.runtime import (init_runtime_legacy_auth, - init_runtime_native_auth, - init_runtime_repl_auth) - for init in [init_runtime_native_auth, init_runtime_repl_auth, init_runtime_legacy_auth]: - if init is None: - continue - host, inner = init() - if host is None: - logger.debug(f'[{init.__name__}] no host detected') - continue - cfg.host = host - logger.debug(f'[{init.__name__}] runtime native auth configured') - return inner - return None - - -@credentials_provider('oauth-m2m', ['host', 'client_id', 'client_secret']) -def oauth_service_principal(cfg: 'Config') -> Optional[HeaderFactory]: - """ Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request, - if /oidc/.well-known/oauth-authorization-server is available on the given host. """ - oidc = cfg.oidc_endpoints - if oidc is None: - return None - token_source = ClientCredentials(client_id=cfg.client_id, - client_secret=cfg.client_secret, - token_url=oidc.token_endpoint, - scopes=["all-apis"], - use_header=True) - - def inner() -> Dict[str, str]: - token = token_source.token() - return {'Authorization': f'{token.token_type} {token.access_token}'} - - return inner - - -@credentials_provider('external-browser', ['host', 'auth_type']) -def external_browser(cfg: 'Config') -> Optional[HeaderFactory]: - if cfg.auth_type != 'external-browser': - return None - if cfg.client_id: - client_id = cfg.client_id - elif cfg.is_aws: - client_id = 'databricks-cli' - elif cfg.is_azure: - # Use Azure AD app for cases when Azure CLI is not available on the machine. - # App has to be registered as Single-page multi-tenant to support PKCE - # TODO: temporary app ID, change it later. - client_id = '6128a518-99a9-425b-8333-4cc94f04cacd' - else: - raise ValueError(f'local browser SSO is not supported') - oauth_client = OAuthClient(host=cfg.host, - client_id=client_id, - redirect_url='http://localhost:8020', - client_secret=cfg.client_secret) - - # Load cached credentials from disk if they exist. - # Note that these are local to the Python SDK and not reused by other SDKs. - token_cache = TokenCache(oauth_client) - credentials = token_cache.load() - if credentials: - # Force a refresh in case the loaded credentials are expired. - credentials.token() - else: - consent = oauth_client.initiate_consent() - if not consent: - return None - credentials = consent.launch_external_browser() - token_cache.save(credentials) - return credentials(cfg) - - -def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenSource]): - """ Resolves Azure Databricks workspace URL from ARM Resource ID """ - if cfg.host: - return - if not cfg.azure_workspace_resource_id: - return - arm = cfg.arm_environment.resource_manager_endpoint - token = token_source_for(arm).token() - resp = requests.get(f"{arm}{cfg.azure_workspace_resource_id}?api-version=2018-04-01", - headers={"Authorization": f"Bearer {token.access_token}"}) - if not resp.ok: - raise ValueError(f"Cannot resolve Azure Databricks workspace: {resp.content}") - cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}" - - -@credentials_provider('azure-client-secret', - ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id']) -def azure_service_principal(cfg: 'Config') -> HeaderFactory: - """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens - to every request, while automatically resolving different Azure environment endpoints. """ - - def token_source_for(resource: str) -> TokenSource: - aad_endpoint = cfg.arm_environment.active_directory_endpoint - return ClientCredentials(client_id=cfg.azure_client_id, - client_secret=cfg.azure_client_secret, - token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token", - endpoint_params={"resource": resource}, - use_params=True) - - _ensure_host_present(cfg, token_source_for) - logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id) - inner = token_source_for(cfg.effective_azure_login_app_id) - cloud = token_source_for(cfg.arm_environment.service_management_endpoint) - - def refreshed_headers() -> Dict[str, str]: - headers = {'Authorization': f"Bearer {inner.token().access_token}", } - add_workspace_id_header(cfg, headers) - add_sp_management_token(cloud, headers) - return headers - - return refreshed_headers - - -@credentials_provider('github-oidc-azure', ['host', 'azure_client_id']) -def github_oidc_azure(cfg: 'Config') -> Optional[HeaderFactory]: - if 'ACTIONS_ID_TOKEN_REQUEST_TOKEN' not in os.environ: - # not in GitHub actions - return None - - # Client ID is the minimal thing we need, as otherwise we get AADSTS700016: Application with - # identifier 'https://token.actions.githubusercontent.com' was not found in the directory '...'. - if not cfg.is_azure: - return None - - # See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers - headers = {'Authorization': f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"} - endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience=api://AzureADTokenExchange" - response = requests.get(endpoint, headers=headers) - if not response.ok: - return None - - # get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name - response_json = response.json() - if 'value' not in response_json: - return None - - logger.info("Configured AAD token for GitHub Actions OIDC (%s)", cfg.azure_client_id) - params = { - 'client_assertion_type': 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer', - 'resource': cfg.effective_azure_login_app_id, - 'client_assertion': response_json['value'], - } - aad_endpoint = cfg.arm_environment.active_directory_endpoint - if not cfg.azure_tenant_id: - # detect Azure AD Tenant ID if it's not specified directly - token_endpoint = cfg.oidc_endpoints.token_endpoint - cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, '').split('/')[0] - inner = ClientCredentials(client_id=cfg.azure_client_id, - client_secret="", # we have no (rotatable) secrets in OIDC flow - token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token", - endpoint_params=params, - use_params=True) - - def refreshed_headers() -> Dict[str, str]: - token = inner.token() - return {'Authorization': f'{token.token_type} {token.access_token}'} - - return refreshed_headers - - -@credentials_provider('google-credentials', ['host', 'google_credentials']) -def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]: - if not cfg.is_gcp: - return None - # Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string. - # Obtain the id token by providing the json file path and target audience. - if (os.path.isfile(cfg.google_credentials)): - with io.open(cfg.google_credentials, "r", encoding="utf-8") as json_file: - account_info = json.load(json_file) - else: - # If the file doesn't exist, assume that the config is the actual JSON content. - account_info = json.loads(cfg.google_credentials) - - credentials = service_account.IDTokenCredentials.from_service_account_info(info=account_info, - target_audience=cfg.host) - - request = Request() - - gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info, - scopes=GcpScopes) - - def refreshed_headers() -> Dict[str, str]: - credentials.refresh(request) - headers = {'Authorization': f'Bearer {credentials.token}'} - if cfg.is_account_client: - gcp_credentials.refresh(request) - headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token - return headers - - return refreshed_headers - - -@credentials_provider('google-id', ['host', 'google_service_account']) -def google_id(cfg: 'Config') -> Optional[HeaderFactory]: - if not cfg.is_gcp: - return None - credentials, _project_id = google.auth.default() - - # Create the impersonated credential. - target_credentials = impersonated_credentials.Credentials(source_credentials=credentials, - target_principal=cfg.google_service_account, - target_scopes=[]) - - # Set the impersonated credential, target audience and token options. - id_creds = impersonated_credentials.IDTokenCredentials(target_credentials, - target_audience=cfg.host, - include_email=True) - - gcp_impersonated_credentials = impersonated_credentials.Credentials( - source_credentials=credentials, target_principal=cfg.google_service_account, target_scopes=GcpScopes) - - request = Request() - - def refreshed_headers() -> Dict[str, str]: - id_creds.refresh(request) - headers = {'Authorization': f'Bearer {id_creds.token}'} - if cfg.is_account_client: - gcp_impersonated_credentials.refresh(request) - headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token - return headers - - return refreshed_headers - - -class CliTokenSource(Refreshable): - - def __init__(self, cmd: List[str], token_type_field: str, access_token_field: str, expiry_field: str): - super().__init__() - self._cmd = cmd - self._token_type_field = token_type_field - self._access_token_field = access_token_field - self._expiry_field = expiry_field - - @staticmethod - def _parse_expiry(expiry: str) -> datetime: - expiry = expiry.rstrip("Z").split(".")[0] - for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"): - try: - return datetime.strptime(expiry, fmt) - except ValueError as e: - last_e = e - if last_e: - raise last_e - - def refresh(self) -> Token: - try: - is_windows = sys.platform.startswith('win') - # windows requires shell=True to be able to execute 'az login' or other commands - # cannot use shell=True all the time, as it breaks macOS - out = subprocess.run(self._cmd, capture_output=True, check=True, shell=is_windows) - it = json.loads(out.stdout.decode()) - expires_on = self._parse_expiry(it[self._expiry_field]) - return Token(access_token=it[self._access_token_field], - token_type=it[self._token_type_field], - expiry=expires_on) - except ValueError as e: - raise ValueError(f"cannot unmarshal CLI result: {e}") - except subprocess.CalledProcessError as e: - stdout = e.stdout.decode().strip() - stderr = e.stderr.decode().strip() - message = stdout or stderr - raise IOError(f'cannot get access token: {message}') from e - - -class AzureCliTokenSource(CliTokenSource): - """ Obtain the token granted by `az login` CLI command """ - - def __init__(self, resource: str, subscription: str = ""): - cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"] - if subscription != "": - cmd.append("--subscription") - cmd.append(subscription) - super().__init__(cmd=cmd, - token_type_field='tokenType', - access_token_field='accessToken', - expiry_field='expiresOn') - - def is_human_user(self) -> bool: - """The UPN claim is the username of the user, but not the Service Principal. - - Azure CLI can be authenticated by both human users (`az login`) and service principals. In case of service - principals, it can be either OIDC from GitHub or login with a password: - - ~ $ az login --service-principal --user $clientID --password $clientSecret --tenant $tenantID - - Human users get more claims: - - 'amr' - how the subject of the token was authenticated - - 'name', 'family_name', 'given_name' - human-readable values that identifies the subject of the token - - 'scp' with `user_impersonation` value, that shows the set of scopes exposed by your application for which - the client application has requested (and received) consent - - 'unique_name' - a human-readable value that identifies the subject of the token. This value is not - guaranteed to be unique within a tenant and should be used only for display purposes. - - 'upn' - The username of the user. - """ - return 'upn' in self.token().jwt_claims() - - @staticmethod - def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource': - subscription = AzureCliTokenSource.get_subscription(cfg) - if subscription != "": - token_source = AzureCliTokenSource(resource, subscription) - try: - # This will fail if the user has access to the workspace, but not to the subscription - # itself. - # In such case, we fall back to not using the subscription. - token_source.token() - return token_source - except OSError: - logger.warning("Failed to get token for subscription. Using resource only token.") - - token_source = AzureCliTokenSource(resource) - token_source.token() - return token_source - - @staticmethod - def get_subscription(cfg: 'Config') -> str: - resource = cfg.azure_workspace_resource_id - if resource is None or resource == "": - return "" - components = resource.split('/') - if len(components) < 3: - logger.warning("Invalid azure workspace resource ID") - return "" - return components[2] - - -@credentials_provider('azure-cli', ['is_azure']) -def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]: - """ Adds refreshed OAuth token granted by `az login` command to every request. """ - token_source = None - mgmt_token_source = None - try: - token_source = AzureCliTokenSource.for_resource(cfg, cfg.effective_azure_login_app_id) - except FileNotFoundError: - doc = 'https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest' - logger.debug(f'Most likely Azure CLI is not installed. See {doc} for details') - return None - if not token_source.is_human_user(): - try: - management_endpoint = cfg.arm_environment.service_management_endpoint - mgmt_token_source = AzureCliTokenSource.for_resource(cfg, management_endpoint) - except Exception as e: - logger.debug(f'Not including service management token in headers', exc_info=e) - mgmt_token_source = None - - _ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource)) - logger.info("Using Azure CLI authentication with AAD tokens") - if not cfg.is_account_client and AzureCliTokenSource.get_subscription(cfg) == "": - logger.warning( - "azure_workspace_resource_id field not provided. " - "It is recommended to specify this field in the Databricks configuration to avoid authentication errors." - ) - - def inner() -> Dict[str, str]: - token = token_source.token() - headers = {'Authorization': f'{token.token_type} {token.access_token}'} - add_workspace_id_header(cfg, headers) - if mgmt_token_source: - add_sp_management_token(mgmt_token_source, headers) - return headers - - return inner - - -class DatabricksCliTokenSource(CliTokenSource): - """ Obtain the token granted by `databricks auth login` CLI command """ - - def __init__(self, cfg: 'Config'): - args = ['auth', 'token', '--host', cfg.host] - if cfg.is_account_client: - args += ['--account-id', cfg.account_id] - - cli_path = cfg.databricks_cli_path - if not cli_path: - cli_path = 'databricks' - - # If the path is unqualified, look it up in PATH. - if cli_path.count("/") == 0: - cli_path = self.__class__._find_executable(cli_path) - - super().__init__(cmd=[cli_path, *args], - token_type_field='token_type', - access_token_field='access_token', - expiry_field='expiry') - - @staticmethod - def _find_executable(name) -> str: - err = FileNotFoundError("Most likely the Databricks CLI is not installed") - for dir in os.getenv("PATH", default="").split(os.path.pathsep): - path = pathlib.Path(dir).joinpath(name).resolve() - if not path.is_file(): - continue - - # The new Databricks CLI is a single binary with size > 1MB. - # We use the size as a signal to determine which Databricks CLI is installed. - stat = path.stat() - if stat.st_size < (1024 * 1024): - err = FileNotFoundError("Databricks CLI version <0.100.0 detected") - continue - - return str(path) - - raise err - - -@credentials_provider('databricks-cli', ['host', 'is_aws']) -def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]: - try: - token_source = DatabricksCliTokenSource(cfg) - except FileNotFoundError as e: - logger.debug(e) - return None - - try: - token_source.token() - except IOError as e: - if 'databricks OAuth is not' in str(e): - logger.debug(f'OAuth not configured or not available: {e}') - return None - raise e - - logger.info("Using Databricks CLI authentication") - - def inner() -> Dict[str, str]: - token = token_source.token() - return {'Authorization': f'{token.token_type} {token.access_token}'} - - return inner - - -class MetadataServiceTokenSource(Refreshable): - """ Obtain the token granted by Databricks Metadata Service """ - METADATA_SERVICE_VERSION = "1" - METADATA_SERVICE_VERSION_HEADER = "X-Databricks-Metadata-Version" - METADATA_SERVICE_HOST_HEADER = "X-Databricks-Host" - _metadata_service_timeout = 10 # seconds - - def __init__(self, cfg: 'Config'): - super().__init__() - self.url = cfg.metadata_service_url - self.host = cfg.host - - def refresh(self) -> Token: - resp = requests.get(self.url, - timeout=self._metadata_service_timeout, - headers={ - self.METADATA_SERVICE_VERSION_HEADER: self.METADATA_SERVICE_VERSION, - self.METADATA_SERVICE_HOST_HEADER: self.host - }) - json_resp: dict[str, Union[str, float]] = resp.json() - access_token = json_resp.get("access_token", None) - if access_token is None: - raise ValueError("Metadata Service returned empty token") - token_type = json_resp.get("token_type", None) - if token_type is None: - raise ValueError("Metadata Service returned empty token type") - if json_resp["expires_on"] in ["", None]: - raise ValueError("Metadata Service returned invalid expiry") - try: - expiry = datetime.fromtimestamp(json_resp["expires_on"]) - except: - raise ValueError("Metadata Service returned invalid expiry") - - return Token(access_token=access_token, token_type=token_type, expiry=expiry) - - -@credentials_provider('metadata-service', ['host', 'metadata_service_url']) -def metadata_service(cfg: 'Config') -> Optional[HeaderFactory]: - """ Adds refreshed token granted by Databricks Metadata Service to every request. """ - - token_source = MetadataServiceTokenSource(cfg) - token_source.token() - logger.info("Using Databricks Metadata Service authentication") - - def inner() -> Dict[str, str]: - token = token_source.token() - return {'Authorization': f'{token.token_type} {token.access_token}'} - - return inner - - -class DefaultCredentials: - """ Select the first applicable credential provider from the chain """ - - def __init__(self) -> None: - self._auth_type = 'default' - - def auth_type(self) -> str: - return self._auth_type - - def __call__(self, cfg: 'Config') -> HeaderFactory: - auth_providers = [ - pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal, - github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth, - google_credentials, google_id - ] - for provider in auth_providers: - auth_type = provider.auth_type() - if cfg.auth_type and auth_type != cfg.auth_type: - # ignore other auth types if one is explicitly enforced - logger.debug(f"Ignoring {auth_type} auth, because {cfg.auth_type} is preferred") - continue - logger.debug(f'Attempting to configure auth: {auth_type}') - try: - header_factory = provider(cfg) - if not header_factory: - continue - self._auth_type = auth_type - return header_factory - except Exception as e: - raise ValueError(f'{auth_type}: {e}') from e - auth_flow_url = "https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication" - raise ValueError( - f'cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method.' - ) - - -class ConfigAttribute: - """ Configuration attribute metadata and descriptor protocols. """ - - # name and transform are discovered from Config.__new__ - name: str = None - transform: type = str - - def __init__(self, env: str = None, auth: str = None, sensitive: bool = False): - self.env = env - self.auth = auth - self.sensitive = sensitive - - def __get__(self, cfg: 'Config', owner): - if not cfg: - return None - return cfg._inner.get(self.name, None) - - def __set__(self, cfg: 'Config', value: any): - cfg._inner[self.name] = self.transform(value) - - def __repr__(self) -> str: - return f"" - - -class Config: - host: str = ConfigAttribute(env='DATABRICKS_HOST') - account_id: str = ConfigAttribute(env='DATABRICKS_ACCOUNT_ID') - token: str = ConfigAttribute(env='DATABRICKS_TOKEN', auth='pat', sensitive=True) - username: str = ConfigAttribute(env='DATABRICKS_USERNAME', auth='basic') - password: str = ConfigAttribute(env='DATABRICKS_PASSWORD', auth='basic', sensitive=True) - client_id: str = ConfigAttribute(env='DATABRICKS_CLIENT_ID', auth='oauth') - client_secret: str = ConfigAttribute(env='DATABRICKS_CLIENT_SECRET', auth='oauth', sensitive=True) - profile: str = ConfigAttribute(env='DATABRICKS_CONFIG_PROFILE') - config_file: str = ConfigAttribute(env='DATABRICKS_CONFIG_FILE') - google_service_account: str = ConfigAttribute(env='DATABRICKS_GOOGLE_SERVICE_ACCOUNT', auth='google') - google_credentials: str = ConfigAttribute(env='GOOGLE_CREDENTIALS', auth='google', sensitive=True) - azure_workspace_resource_id: str = ConfigAttribute(env='DATABRICKS_AZURE_RESOURCE_ID', auth='azure') - azure_use_msi: bool = ConfigAttribute(env='ARM_USE_MSI', auth='azure') - azure_client_secret: str = ConfigAttribute(env='ARM_CLIENT_SECRET', auth='azure', sensitive=True) - azure_client_id: str = ConfigAttribute(env='ARM_CLIENT_ID', auth='azure') - azure_tenant_id: str = ConfigAttribute(env='ARM_TENANT_ID', auth='azure') - azure_environment: str = ConfigAttribute(env='ARM_ENVIRONMENT') - azure_login_app_id: str = ConfigAttribute(env='DATABRICKS_AZURE_LOGIN_APP_ID', auth='azure') - databricks_cli_path: str = ConfigAttribute(env='DATABRICKS_CLI_PATH') - auth_type: str = ConfigAttribute(env='DATABRICKS_AUTH_TYPE') - cluster_id: str = ConfigAttribute(env='DATABRICKS_CLUSTER_ID') - warehouse_id: str = ConfigAttribute(env='DATABRICKS_WAREHOUSE_ID') - skip_verify: bool = ConfigAttribute() - http_timeout_seconds: float = ConfigAttribute() - debug_truncate_bytes: int = ConfigAttribute(env='DATABRICKS_DEBUG_TRUNCATE_BYTES') - debug_headers: bool = ConfigAttribute(env='DATABRICKS_DEBUG_HEADERS') - rate_limit: int = ConfigAttribute(env='DATABRICKS_RATE_LIMIT') - retry_timeout_seconds: int = ConfigAttribute() - metadata_service_url = ConfigAttribute(env='DATABRICKS_METADATA_SERVICE_URL', - auth='metadata-service', - sensitive=True) - max_connection_pools: int = ConfigAttribute() - max_connections_per_pool: int = ConfigAttribute() - - def __init__(self, - *, - credentials_provider: CredentialsProvider = None, - product="unknown", - product_version="0.0.0", - **kwargs): - self._inner = {} - self._user_agent_other_info = [] - self._credentials_provider = credentials_provider if credentials_provider else DefaultCredentials() - try: - self._set_inner_config(kwargs) - self._load_from_env() - self._known_file_config_loader() - self._fix_host_if_needed() - self._validate() - self._init_auth() - self._product = product - self._product_version = product_version - except ValueError as e: - message = self.wrap_debug_info(str(e)) - raise ValueError(message) from e - - def wrap_debug_info(self, message: str) -> str: - debug_string = self.debug_string() - if debug_string: - message = f'{message.rstrip(".")}. {debug_string}' - return message - - @staticmethod - def parse_dsn(dsn: str) -> 'Config': - uri = urllib.parse.urlparse(dsn) - if uri.scheme != 'databricks': - raise ValueError(f'Expected databricks:// scheme, got {uri.scheme}://') - kwargs = {'host': f'https://{uri.hostname}'} - if uri.username: - kwargs['username'] = uri.username - if uri.password: - kwargs['password'] = uri.password - query = dict(urllib.parse.parse_qsl(uri.query)) - for attr in Config.attributes(): - if attr.name not in query: - continue - kwargs[attr.name] = query[attr.name] - return Config(**kwargs) - - def authenticate(self) -> Dict[str, str]: - """ Returns a list of fresh authentication headers """ - return self._header_factory() - - def as_dict(self) -> dict: - return self._inner - - @property - def is_azure(self) -> bool: - has_resource_id = self.azure_workspace_resource_id is not None - has_host = self.host is not None - is_public_cloud = has_host and ".azuredatabricks.net" in self.host - is_china_cloud = has_host and ".databricks.azure.cn" in self.host - is_gov_cloud = has_host and ".databricks.azure.us" in self.host - is_valid_cloud = is_public_cloud or is_china_cloud or is_gov_cloud - return has_resource_id or (has_host and is_valid_cloud) - - @property - def is_gcp(self) -> bool: - return self.host and ".gcp.databricks.com" in self.host - - @property - def is_aws(self) -> bool: - return not self.is_azure and not self.is_gcp - - @property - def is_account_client(self) -> bool: - if not self.host: - return False - return self.host.startswith("https://accounts.") or self.host.startswith("https://accounts-dod.") - - @property - def arm_environment(self) -> AzureEnvironment: - env = self.azure_environment if self.azure_environment else "PUBLIC" - try: - return ENVIRONMENTS[env] - except KeyError: - raise ValueError(f"Cannot find Azure {env} Environment") - - @property - def effective_azure_login_app_id(self): - app_id = self.azure_login_app_id - if app_id: - return app_id - return ARM_DATABRICKS_RESOURCE_ID - - @property - def hostname(self) -> str: - url = urllib.parse.urlparse(self.host) - return url.netloc - - @property - def is_any_auth_configured(self) -> bool: - for attr in Config.attributes(): - if not attr.auth: - continue - value = self._inner.get(attr.name, None) - if value: - return True - return False - - @property - def user_agent(self): - """ Returns User-Agent header used by this SDK """ - py_version = platform.python_version() - os_name = platform.uname().system.lower() - - ua = [ - f"{self._product}/{self._product_version}", f"databricks-sdk-py/{__version__}", - f"python/{py_version}", f"os/{os_name}", f"auth/{self.auth_type}", - ] - if len(self._user_agent_other_info) > 0: - ua.append(' '.join(self._user_agent_other_info)) - if len(self._upstream_user_agent) > 0: - ua.append(self._upstream_user_agent) - if 'DATABRICKS_RUNTIME_VERSION' in os.environ: - runtime_version = os.environ['DATABRICKS_RUNTIME_VERSION'] - if runtime_version != '': - runtime_version = self._sanitize_header_value(runtime_version) - ua.append(f'runtime/{runtime_version}') - - return ' '.join(ua) - - @staticmethod - def _sanitize_header_value(value: str) -> str: - value = value.replace(' ', '-') - value = value.replace('/', '-') - return value - - @property - def _upstream_user_agent(self) -> str: - product = os.environ.get('DATABRICKS_SDK_UPSTREAM', None) - product_version = os.environ.get('DATABRICKS_SDK_UPSTREAM_VERSION', None) - if product is not None and product_version is not None: - return f"upstream/{product} upstream-version/{product_version}" - return "" - - def with_user_agent_extra(self, key: str, value: str) -> 'Config': - self._user_agent_other_info.append(f"{key}/{value}") - return self - - @property - def oidc_endpoints(self) -> Optional[OidcEndpoints]: - self._fix_host_if_needed() - if not self.host: - return None - if self.is_azure and self.azure_client_id: - # Retrieve authorize endpoint to retrieve token endpoint after - res = requests.get(f'{self.host}/oidc/oauth2/v2.0/authorize', allow_redirects=False) - real_auth_url = res.headers.get('location') - if not real_auth_url: - return None - return OidcEndpoints(authorization_endpoint=real_auth_url, - token_endpoint=real_auth_url.replace('/authorize', '/token')) - if self.is_account_client and self.account_id: - prefix = f'{self.host}/oidc/accounts/{self.account_id}' - return OidcEndpoints(authorization_endpoint=f'{prefix}/v1/authorize', - token_endpoint=f'{prefix}/v1/token') - oidc = f'{self.host}/oidc/.well-known/oauth-authorization-server' - res = requests.get(oidc) - if res.status_code != 200: - return None - auth_metadata = res.json() - return OidcEndpoints(authorization_endpoint=auth_metadata.get('authorization_endpoint'), - token_endpoint=auth_metadata.get('token_endpoint')) - - def debug_string(self) -> str: - """ Returns log-friendly representation of configured attributes """ - buf = [] - attrs_used = [] - envs_used = [] - for attr in Config.attributes(): - if attr.env and os.environ.get(attr.env): - envs_used.append(attr.env) - value = getattr(self, attr.name) - if not value: - continue - safe = '***' if attr.sensitive else f'{value}' - attrs_used.append(f'{attr.name}={safe}') - if attrs_used: - buf.append(f'Config: {", ".join(attrs_used)}') - if envs_used: - buf.append(f'Env: {", ".join(envs_used)}') - return '. '.join(buf) - - def to_dict(self) -> Dict[str, any]: - return self._inner - - @property - def sql_http_path(self) -> Optional[str]: - """(Experimental) Return HTTP path for SQL Drivers. - - If `cluster_id` or `warehouse_id` are configured, return a valid HTTP Path argument - used in construction of JDBC/ODBC DSN string. - - See https://docs.databricks.com/integrations/jdbc-odbc-bi.html - """ - if (not self.cluster_id) and (not self.warehouse_id): - return None - if self.cluster_id and self.warehouse_id: - raise ValueError('cannot have both cluster_id and warehouse_id') - headers = self.authenticate() - headers['User-Agent'] = f'{self.user_agent} sdk-feature/sql-http-path' - if self.cluster_id: - response = requests.get(f"{self.host}/api/2.0/preview/scim/v2/Me", headers=headers) - # get workspace ID from the response header - workspace_id = response.headers.get('x-databricks-org-id') - return f'sql/protocolv1/o/{workspace_id}/{self.cluster_id}' - if self.warehouse_id: - return f'/sql/1.0/warehouses/{self.warehouse_id}' - - @classmethod - def attributes(cls) -> Iterable[ConfigAttribute]: - """ Returns a list of Databricks SDK configuration metadata """ - if hasattr(cls, '_attributes'): - return cls._attributes - if sys.version_info[1] >= 10: - import inspect - anno = inspect.get_annotations(cls) - else: - # Python 3.7 compatibility: getting type hints require extra hop, as described in - # "Accessing The Annotations Dict Of An Object In Python 3.9 And Older" section of - # https://docs.python.org/3/howto/annotations.html - anno = cls.__dict__['__annotations__'] - attrs = [] - for name, v in cls.__dict__.items(): - if type(v) != ConfigAttribute: - continue - v.name = name - v.transform = anno.get(name, str) - attrs.append(v) - cls._attributes = attrs - return cls._attributes - - def _fix_host_if_needed(self): - if not self.host: - return - # fix url to remove trailing slash - o = urllib.parse.urlparse(self.host) - if not o.hostname: - # only hostname is specified - self.host = f"https://{self.host}" - else: - self.host = f"{o.scheme}://{o.netloc}" - - def _set_inner_config(self, keyword_args: Dict[str, any]): - for attr in self.attributes(): - if attr.name not in keyword_args: - continue - if keyword_args.get(attr.name, None) is None: - continue - self.__setattr__(attr.name, keyword_args[attr.name]) - - def _load_from_env(self): - found = False - for attr in self.attributes(): - if not attr.env: - continue - if attr.name in self._inner: - continue - value = os.environ.get(attr.env) - if not value: - continue - self.__setattr__(attr.name, value) - found = True - if found: - logger.debug('Loaded from environment') - - def _known_file_config_loader(self): - if not self.profile and (self.is_any_auth_configured or self.host - or self.azure_workspace_resource_id): - # skip loading configuration file if there's any auth configured - # directly as part of the Config() constructor. - return - config_file = self.config_file - if not config_file: - config_file = "~/.databrickscfg" - config_path = pathlib.Path(config_file).expanduser() - if not config_path.exists(): - logger.debug("%s does not exist", config_path) - return - ini_file = configparser.ConfigParser() - ini_file.read(config_path) - profile = self.profile - has_explicit_profile = self.profile is not None - # In Go SDK, we skip merging the profile with DEFAULT section, though Python's ConfigParser.items() - # is returning profile key-value pairs _including those from DEFAULT_. This is not what we expect - # from Unified Auth test suite at the moment. Hence, the private variable access. - # See: https://docs.python.org/3/library/configparser.html#mapping-protocol-access - if not has_explicit_profile and not ini_file.defaults(): - logger.debug(f'{config_path} has no DEFAULT profile configured') - return - if not has_explicit_profile: - profile = "DEFAULT" - profiles = ini_file._sections - if ini_file.defaults(): - profiles['DEFAULT'] = ini_file.defaults() - if profile not in profiles: - raise ValueError(f'resolve: {config_path} has no {profile} profile configured') - raw_config = profiles[profile] - logger.info(f'loading {profile} profile from {config_file}: {", ".join(raw_config.keys())}') - for k, v in raw_config.items(): - if k in self._inner: - # don't overwrite a value previously set - continue - self.__setattr__(k, v) - - def _validate(self): - auths_used = set() - for attr in Config.attributes(): - if attr.name not in self._inner: - continue - if not attr.auth: - continue - auths_used.add(attr.auth) - if len(auths_used) <= 1: - return - if self.auth_type: - # client has auth preference set - return - names = " and ".join(sorted(auths_used)) - raise ValueError(f'validate: more than one authorization method configured: {names}') - - def _init_auth(self): - try: - self._header_factory = self._credentials_provider(self) - self.auth_type = self._credentials_provider.auth_type() - if not self._header_factory: - raise ValueError('not configured') - except ValueError as e: - raise ValueError(f'{self._credentials_provider.auth_type()} auth: {e}') from e - - def __repr__(self): - return f'<{self.debug_string()}>' - - def copy(self): - """Creates a copy of the config object. - All the copies share most of their internal state (ie, shared reference to fields such as credential_provider). - Copies have their own instances of the following fields - - `_user_agent_other_info` - """ - cpy: Config = copy.copy(self) - cpy._user_agent_other_info = copy.deepcopy(self._user_agent_other_info) - return cpy - class ApiClient: _cfg: Config diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py new file mode 100644 index 000000000..2c30ea143 --- /dev/null +++ b/databricks/sdk/credentials_provider.py @@ -0,0 +1,617 @@ +import abc +import base64 +import functools +import io +import json +import logging +import os +import pathlib +import subprocess +import sys +from datetime import datetime +from typing import Callable, Dict, List, Optional, Union + +import google.auth +import requests +from google.auth import impersonated_credentials +from google.auth.transport.requests import Request +from google.oauth2 import service_account + +from .azure import add_sp_management_token, add_workspace_id_header +from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token, + TokenCache, TokenSource) + +HeaderFactory = Callable[[], Dict[str, str]] + +logger = logging.getLogger('databricks.sdk') + + +class CredentialsProvider(abc.ABC): + """ CredentialsProvider is the protocol (call-side interface) + for authenticating requests to Databricks REST APIs""" + + @abc.abstractmethod + def auth_type(self) -> str: + ... + + @abc.abstractmethod + def __call__(self, cfg: 'Config') -> HeaderFactory: + ... + + +def credentials_provider(name: str, require: List[str]): + """ Given the function that receives a Config and returns RequestVisitor, + create CredentialsProvider with a given name and required configuration + attribute names to be present for this function to be called. """ + + def inner(func: Callable[['Config'], HeaderFactory]) -> CredentialsProvider: + + @functools.wraps(func) + def wrapper(cfg: 'Config') -> Optional[HeaderFactory]: + for attr in require: + if not getattr(cfg, attr): + return None + return func(cfg) + + wrapper.auth_type = lambda: name + return wrapper + + return inner + + +@credentials_provider('basic', ['host', 'username', 'password']) +def basic_auth(cfg: 'Config') -> HeaderFactory: + """ Given username and password, add base64-encoded Basic credentials """ + encoded = base64.b64encode(f'{cfg.username}:{cfg.password}'.encode()).decode() + static_credentials = {'Authorization': f'Basic {encoded}'} + + def inner() -> Dict[str, str]: + return static_credentials + + return inner + + +@credentials_provider('pat', ['host', 'token']) +def pat_auth(cfg: 'Config') -> HeaderFactory: + """ Adds Databricks Personal Access Token to every request """ + static_credentials = {'Authorization': f'Bearer {cfg.token}'} + + def inner() -> Dict[str, str]: + return static_credentials + + return inner + + +@credentials_provider('runtime', []) +def runtime_native_auth(cfg: 'Config') -> Optional[HeaderFactory]: + if 'DATABRICKS_RUNTIME_VERSION' not in os.environ: + return None + + # This import MUST be after the "DATABRICKS_RUNTIME_VERSION" check + # above, so that we are not throwing import errors when not in + # runtime and no config variables are set. + from databricks.sdk.runtime import (init_runtime_legacy_auth, + init_runtime_native_auth, + init_runtime_repl_auth) + for init in [init_runtime_native_auth, init_runtime_repl_auth, init_runtime_legacy_auth]: + if init is None: + continue + host, inner = init() + if host is None: + logger.debug(f'[{init.__name__}] no host detected') + continue + cfg.host = host + logger.debug(f'[{init.__name__}] runtime native auth configured') + return inner + return None + + +@credentials_provider('oauth-m2m', ['host', 'client_id', 'client_secret']) +def oauth_service_principal(cfg: 'Config') -> Optional[HeaderFactory]: + """ Adds refreshed Databricks machine-to-machine OAuth Bearer token to every request, + if /oidc/.well-known/oauth-authorization-server is available on the given host. """ + oidc = cfg.oidc_endpoints + if oidc is None: + return None + token_source = ClientCredentials(client_id=cfg.client_id, + client_secret=cfg.client_secret, + token_url=oidc.token_endpoint, + scopes=["all-apis"], + use_header=True) + + def inner() -> Dict[str, str]: + token = token_source.token() + return {'Authorization': f'{token.token_type} {token.access_token}'} + + return inner + + +@credentials_provider('external-browser', ['host', 'auth_type']) +def external_browser(cfg: 'Config') -> Optional[HeaderFactory]: + if cfg.auth_type != 'external-browser': + return None + if cfg.client_id: + client_id = cfg.client_id + elif cfg.is_aws: + client_id = 'databricks-cli' + elif cfg.is_azure: + # Use Azure AD app for cases when Azure CLI is not available on the machine. + # App has to be registered as Single-page multi-tenant to support PKCE + # TODO: temporary app ID, change it later. + client_id = '6128a518-99a9-425b-8333-4cc94f04cacd' + else: + raise ValueError(f'local browser SSO is not supported') + oauth_client = OAuthClient(host=cfg.host, + client_id=client_id, + redirect_url='http://localhost:8020', + client_secret=cfg.client_secret) + + # Load cached credentials from disk if they exist. + # Note that these are local to the Python SDK and not reused by other SDKs. + token_cache = TokenCache(oauth_client) + credentials = token_cache.load() + if credentials: + # Force a refresh in case the loaded credentials are expired. + credentials.token() + else: + consent = oauth_client.initiate_consent() + if not consent: + return None + credentials = consent.launch_external_browser() + token_cache.save(credentials) + return credentials(cfg) + + +def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenSource]): + """ Resolves Azure Databricks workspace URL from ARM Resource ID """ + if cfg.host: + return + if not cfg.azure_workspace_resource_id: + return + arm = cfg.arm_environment.resource_manager_endpoint + token = token_source_for(arm).token() + resp = requests.get(f"{arm}{cfg.azure_workspace_resource_id}?api-version=2018-04-01", + headers={"Authorization": f"Bearer {token.access_token}"}) + if not resp.ok: + raise ValueError(f"Cannot resolve Azure Databricks workspace: {resp.content}") + cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}" + + +@credentials_provider('azure-client-secret', + ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id']) +def azure_service_principal(cfg: 'Config') -> HeaderFactory: + """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens + to every request, while automatically resolving different Azure environment endpoints. """ + + def token_source_for(resource: str) -> TokenSource: + aad_endpoint = cfg.arm_environment.active_directory_endpoint + return ClientCredentials(client_id=cfg.azure_client_id, + client_secret=cfg.azure_client_secret, + token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token", + endpoint_params={"resource": resource}, + use_params=True) + + _ensure_host_present(cfg, token_source_for) + logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id) + inner = token_source_for(cfg.effective_azure_login_app_id) + cloud = token_source_for(cfg.arm_environment.service_management_endpoint) + + def refreshed_headers() -> Dict[str, str]: + headers = {'Authorization': f"Bearer {inner.token().access_token}", } + add_workspace_id_header(cfg, headers) + add_sp_management_token(cloud, headers) + return headers + + return refreshed_headers + + +@credentials_provider('github-oidc-azure', ['host', 'azure_client_id']) +def github_oidc_azure(cfg: 'Config') -> Optional[HeaderFactory]: + if 'ACTIONS_ID_TOKEN_REQUEST_TOKEN' not in os.environ: + # not in GitHub actions + return None + + # Client ID is the minimal thing we need, as otherwise we get AADSTS700016: Application with + # identifier 'https://token.actions.githubusercontent.com' was not found in the directory '...'. + if not cfg.is_azure: + return None + + # See https://docs.github.com/en/actions/deployment/security-hardening-your-deployments/configuring-openid-connect-in-cloud-providers + headers = {'Authorization': f"Bearer {os.environ['ACTIONS_ID_TOKEN_REQUEST_TOKEN']}"} + endpoint = f"{os.environ['ACTIONS_ID_TOKEN_REQUEST_URL']}&audience=api://AzureADTokenExchange" + response = requests.get(endpoint, headers=headers) + if not response.ok: + return None + + # get the ID Token with aud=api://AzureADTokenExchange sub=repo:org/repo:environment:name + response_json = response.json() + if 'value' not in response_json: + return None + + logger.info("Configured AAD token for GitHub Actions OIDC (%s)", cfg.azure_client_id) + params = { + 'client_assertion_type': 'urn:ietf:params:oauth:client-assertion-type:jwt-bearer', + 'resource': cfg.effective_azure_login_app_id, + 'client_assertion': response_json['value'], + } + aad_endpoint = cfg.arm_environment.active_directory_endpoint + if not cfg.azure_tenant_id: + # detect Azure AD Tenant ID if it's not specified directly + token_endpoint = cfg.oidc_endpoints.token_endpoint + cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, '').split('/')[0] + inner = ClientCredentials(client_id=cfg.azure_client_id, + client_secret="", # we have no (rotatable) secrets in OIDC flow + token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token", + endpoint_params=params, + use_params=True) + + def refreshed_headers() -> Dict[str, str]: + token = inner.token() + return {'Authorization': f'{token.token_type} {token.access_token}'} + + return refreshed_headers + + +GcpScopes = ["https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/compute"] + + +@credentials_provider('google-credentials', ['host', 'google_credentials']) +def google_credentials(cfg: 'Config') -> Optional[HeaderFactory]: + if not cfg.is_gcp: + return None + # Reads credentials as JSON. Credentials can be either a path to JSON file, or actual JSON string. + # Obtain the id token by providing the json file path and target audience. + if (os.path.isfile(cfg.google_credentials)): + with io.open(cfg.google_credentials, "r", encoding="utf-8") as json_file: + account_info = json.load(json_file) + else: + # If the file doesn't exist, assume that the config is the actual JSON content. + account_info = json.loads(cfg.google_credentials) + + credentials = service_account.IDTokenCredentials.from_service_account_info(info=account_info, + target_audience=cfg.host) + + request = Request() + + gcp_credentials = service_account.Credentials.from_service_account_info(info=account_info, + scopes=GcpScopes) + + def refreshed_headers() -> Dict[str, str]: + credentials.refresh(request) + headers = {'Authorization': f'Bearer {credentials.token}'} + if cfg.is_account_client: + gcp_credentials.refresh(request) + headers["X-Databricks-GCP-SA-Access-Token"] = gcp_credentials.token + return headers + + return refreshed_headers + + +@credentials_provider('google-id', ['host', 'google_service_account']) +def google_id(cfg: 'Config') -> Optional[HeaderFactory]: + if not cfg.is_gcp: + return None + credentials, _project_id = google.auth.default() + + # Create the impersonated credential. + target_credentials = impersonated_credentials.Credentials(source_credentials=credentials, + target_principal=cfg.google_service_account, + target_scopes=[]) + + # Set the impersonated credential, target audience and token options. + id_creds = impersonated_credentials.IDTokenCredentials(target_credentials, + target_audience=cfg.host, + include_email=True) + + gcp_impersonated_credentials = impersonated_credentials.Credentials( + source_credentials=credentials, target_principal=cfg.google_service_account, target_scopes=GcpScopes) + + request = Request() + + def refreshed_headers() -> Dict[str, str]: + id_creds.refresh(request) + headers = {'Authorization': f'Bearer {id_creds.token}'} + if cfg.is_account_client: + gcp_impersonated_credentials.refresh(request) + headers["X-Databricks-GCP-SA-Access-Token"] = gcp_impersonated_credentials.token + return headers + + return refreshed_headers + + +class CliTokenSource(Refreshable): + + def __init__(self, cmd: List[str], token_type_field: str, access_token_field: str, expiry_field: str): + super().__init__() + self._cmd = cmd + self._token_type_field = token_type_field + self._access_token_field = access_token_field + self._expiry_field = expiry_field + + @staticmethod + def _parse_expiry(expiry: str) -> datetime: + expiry = expiry.rstrip("Z").split(".")[0] + for fmt in ("%Y-%m-%d %H:%M:%S", "%Y-%m-%dT%H:%M:%S"): + try: + return datetime.strptime(expiry, fmt) + except ValueError as e: + last_e = e + if last_e: + raise last_e + + def refresh(self) -> Token: + try: + is_windows = sys.platform.startswith('win') + # windows requires shell=True to be able to execute 'az login' or other commands + # cannot use shell=True all the time, as it breaks macOS + out = subprocess.run(self._cmd, capture_output=True, check=True, shell=is_windows) + it = json.loads(out.stdout.decode()) + expires_on = self._parse_expiry(it[self._expiry_field]) + return Token(access_token=it[self._access_token_field], + token_type=it[self._token_type_field], + expiry=expires_on) + except ValueError as e: + raise ValueError(f"cannot unmarshal CLI result: {e}") + except subprocess.CalledProcessError as e: + stdout = e.stdout.decode().strip() + stderr = e.stderr.decode().strip() + message = stdout or stderr + raise IOError(f'cannot get access token: {message}') from e + + +class AzureCliTokenSource(CliTokenSource): + """ Obtain the token granted by `az login` CLI command """ + + def __init__(self, resource: str, subscription: str = ""): + cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"] + if subscription != "": + cmd.append("--subscription") + cmd.append(subscription) + super().__init__(cmd=cmd, + token_type_field='tokenType', + access_token_field='accessToken', + expiry_field='expiresOn') + + def is_human_user(self) -> bool: + """The UPN claim is the username of the user, but not the Service Principal. + + Azure CLI can be authenticated by both human users (`az login`) and service principals. In case of service + principals, it can be either OIDC from GitHub or login with a password: + + ~ $ az login --service-principal --user $clientID --password $clientSecret --tenant $tenantID + + Human users get more claims: + - 'amr' - how the subject of the token was authenticated + - 'name', 'family_name', 'given_name' - human-readable values that identifies the subject of the token + - 'scp' with `user_impersonation` value, that shows the set of scopes exposed by your application for which + the client application has requested (and received) consent + - 'unique_name' - a human-readable value that identifies the subject of the token. This value is not + guaranteed to be unique within a tenant and should be used only for display purposes. + - 'upn' - The username of the user. + """ + return 'upn' in self.token().jwt_claims() + + @staticmethod + def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource': + subscription = AzureCliTokenSource.get_subscription(cfg) + if subscription != "": + token_source = AzureCliTokenSource(resource, subscription) + try: + # This will fail if the user has access to the workspace, but not to the subscription + # itself. + # In such case, we fall back to not using the subscription. + token_source.token() + return token_source + except OSError: + logger.warning("Failed to get token for subscription. Using resource only token.") + + token_source = AzureCliTokenSource(resource) + token_source.token() + return token_source + + @staticmethod + def get_subscription(cfg: 'Config') -> str: + resource = cfg.azure_workspace_resource_id + if resource is None or resource == "": + return "" + components = resource.split('/') + if len(components) < 3: + logger.warning("Invalid azure workspace resource ID") + return "" + return components[2] + + +@credentials_provider('azure-cli', ['is_azure']) +def azure_cli(cfg: 'Config') -> Optional[HeaderFactory]: + """ Adds refreshed OAuth token granted by `az login` command to every request. """ + token_source = None + mgmt_token_source = None + try: + token_source = AzureCliTokenSource.for_resource(cfg, cfg.effective_azure_login_app_id) + except FileNotFoundError: + doc = 'https://docs.microsoft.com/en-us/cli/azure/?view=azure-cli-latest' + logger.debug(f'Most likely Azure CLI is not installed. See {doc} for details') + return None + except OSError as e: + logger.debug('skipping Azure CLI auth', exc_info=e) + logger.debug('This may happen if you are attempting to login to a dev or staging workspace') + return None + + if not token_source.is_human_user(): + try: + management_endpoint = cfg.arm_environment.service_management_endpoint + mgmt_token_source = AzureCliTokenSource.for_resource(cfg, management_endpoint) + except Exception as e: + logger.debug(f'Not including service management token in headers', exc_info=e) + mgmt_token_source = None + + _ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource)) + logger.info("Using Azure CLI authentication with AAD tokens") + if not cfg.is_account_client and AzureCliTokenSource.get_subscription(cfg) == "": + logger.warning( + "azure_workspace_resource_id field not provided. " + "It is recommended to specify this field in the Databricks configuration to avoid authentication errors." + ) + + def inner() -> Dict[str, str]: + token = token_source.token() + headers = {'Authorization': f'{token.token_type} {token.access_token}'} + add_workspace_id_header(cfg, headers) + if mgmt_token_source: + add_sp_management_token(mgmt_token_source, headers) + return headers + + return inner + + +class DatabricksCliTokenSource(CliTokenSource): + """ Obtain the token granted by `databricks auth login` CLI command """ + + def __init__(self, cfg: 'Config'): + args = ['auth', 'token', '--host', cfg.host] + if cfg.is_account_client: + args += ['--account-id', cfg.account_id] + + cli_path = cfg.databricks_cli_path + if not cli_path: + cli_path = 'databricks' + + # If the path is unqualified, look it up in PATH. + if cli_path.count("/") == 0: + cli_path = self.__class__._find_executable(cli_path) + + super().__init__(cmd=[cli_path, *args], + token_type_field='token_type', + access_token_field='access_token', + expiry_field='expiry') + + @staticmethod + def _find_executable(name) -> str: + err = FileNotFoundError("Most likely the Databricks CLI is not installed") + for dir in os.getenv("PATH", default="").split(os.path.pathsep): + path = pathlib.Path(dir).joinpath(name).resolve() + if not path.is_file(): + continue + + # The new Databricks CLI is a single binary with size > 1MB. + # We use the size as a signal to determine which Databricks CLI is installed. + stat = path.stat() + if stat.st_size < (1024 * 1024): + err = FileNotFoundError("Databricks CLI version <0.100.0 detected") + continue + + return str(path) + + raise err + + +@credentials_provider('databricks-cli', ['host', 'is_aws']) +def databricks_cli(cfg: 'Config') -> Optional[HeaderFactory]: + try: + token_source = DatabricksCliTokenSource(cfg) + except FileNotFoundError as e: + logger.debug(e) + return None + + try: + token_source.token() + except IOError as e: + if 'databricks OAuth is not' in str(e): + logger.debug(f'OAuth not configured or not available: {e}') + return None + raise e + + logger.info("Using Databricks CLI authentication") + + def inner() -> Dict[str, str]: + token = token_source.token() + return {'Authorization': f'{token.token_type} {token.access_token}'} + + return inner + + +class MetadataServiceTokenSource(Refreshable): + """ Obtain the token granted by Databricks Metadata Service """ + METADATA_SERVICE_VERSION = "1" + METADATA_SERVICE_VERSION_HEADER = "X-Databricks-Metadata-Version" + METADATA_SERVICE_HOST_HEADER = "X-Databricks-Host" + _metadata_service_timeout = 10 # seconds + + def __init__(self, cfg: 'Config'): + super().__init__() + self.url = cfg.metadata_service_url + self.host = cfg.host + + def refresh(self) -> Token: + resp = requests.get(self.url, + timeout=self._metadata_service_timeout, + headers={ + self.METADATA_SERVICE_VERSION_HEADER: self.METADATA_SERVICE_VERSION, + self.METADATA_SERVICE_HOST_HEADER: self.host + }) + json_resp: dict[str, Union[str, float]] = resp.json() + access_token = json_resp.get("access_token", None) + if access_token is None: + raise ValueError("Metadata Service returned empty token") + token_type = json_resp.get("token_type", None) + if token_type is None: + raise ValueError("Metadata Service returned empty token type") + if json_resp["expires_on"] in ["", None]: + raise ValueError("Metadata Service returned invalid expiry") + try: + expiry = datetime.fromtimestamp(json_resp["expires_on"]) + except: + raise ValueError("Metadata Service returned invalid expiry") + + return Token(access_token=access_token, token_type=token_type, expiry=expiry) + + +@credentials_provider('metadata-service', ['host', 'metadata_service_url']) +def metadata_service(cfg: 'Config') -> Optional[HeaderFactory]: + """ Adds refreshed token granted by Databricks Metadata Service to every request. """ + + token_source = MetadataServiceTokenSource(cfg) + token_source.token() + logger.info("Using Databricks Metadata Service authentication") + + def inner() -> Dict[str, str]: + token = token_source.token() + return {'Authorization': f'{token.token_type} {token.access_token}'} + + return inner + + +class DefaultCredentials: + """ Select the first applicable credential provider from the chain """ + + def __init__(self) -> None: + self._auth_type = 'default' + + def auth_type(self) -> str: + return self._auth_type + + def __call__(self, cfg: 'Config') -> HeaderFactory: + auth_providers = [ + pat_auth, basic_auth, metadata_service, oauth_service_principal, azure_service_principal, + github_oidc_azure, azure_cli, external_browser, databricks_cli, runtime_native_auth, + google_credentials, google_id + ] + for provider in auth_providers: + auth_type = provider.auth_type() + if cfg.auth_type and auth_type != cfg.auth_type: + # ignore other auth types if one is explicitly enforced + logger.debug(f"Ignoring {auth_type} auth, because {cfg.auth_type} is preferred") + continue + logger.debug(f'Attempting to configure auth: {auth_type}') + try: + header_factory = provider(cfg) + if not header_factory: + continue + self._auth_type = auth_type + return header_factory + except Exception as e: + raise ValueError(f'{auth_type}: {e}') from e + auth_flow_url = "https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication" + raise ValueError( + f'cannot configure default credentials, please check {auth_flow_url} to configure credentials for your preferred authentication method.' + ) diff --git a/databricks/sdk/environments.py b/databricks/sdk/environments.py new file mode 100644 index 000000000..ee41f0f96 --- /dev/null +++ b/databricks/sdk/environments.py @@ -0,0 +1,72 @@ +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +from .azure import ARM_DATABRICKS_RESOURCE_ID, ENVIRONMENTS, AzureEnvironment + + +class Cloud(Enum): + AWS = "AWS" + AZURE = "AZURE" + GCP = "GCP" + + +@dataclass +class DatabricksEnvironment: + cloud: Cloud + dns_zone: str + azure_application_id: Optional[str] = None + azure_environment: Optional[AzureEnvironment] = None + + def deployment_url(self, name: str) -> str: + return f"https://{name}.{self.dns_zone}" + + @property + def azure_service_management_endpoint(self) -> Optional[str]: + if self.azure_environment is None: + return None + return self.azure_environment.service_management_endpoint + + @property + def azure_resource_manager_endpoint(self) -> Optional[str]: + if self.azure_environment is None: + return None + return self.azure_environment.resource_manager_endpoint + + @property + def azure_active_directory_endpoint(self) -> Optional[str]: + if self.azure_environment is None: + return None + return self.azure_environment.active_directory_endpoint + + +DEFAULT_ENVIRONMENT = DatabricksEnvironment(Cloud.AWS, ".cloud.databricks.com") + +ALL_ENVS = [ + DatabricksEnvironment(Cloud.AWS, ".dev.databricks.com"), + DatabricksEnvironment(Cloud.AWS, ".staging.cloud.databricks.com"), + DatabricksEnvironment(Cloud.AWS, ".cloud.databricks.us"), DEFAULT_ENVIRONMENT, + DatabricksEnvironment(Cloud.AZURE, + ".dev.azuredatabricks.net", + azure_application_id="62a912ac-b58e-4c1d-89ea-b2dbfc7358fc", + azure_environment=ENVIRONMENTS["PUBLIC"]), + DatabricksEnvironment(Cloud.AZURE, + ".staging.azuredatabricks.net", + azure_application_id="4a67d088-db5c-48f1-9ff2-0aace800ae68", + azure_environment=ENVIRONMENTS["PUBLIC"]), + DatabricksEnvironment(Cloud.AZURE, + ".azuredatabricks.net", + azure_application_id=ARM_DATABRICKS_RESOURCE_ID, + azure_environment=ENVIRONMENTS["PUBLIC"]), + DatabricksEnvironment(Cloud.AZURE, + ".databricks.azure.us", + azure_application_id=ARM_DATABRICKS_RESOURCE_ID, + azure_environment=ENVIRONMENTS["USGOVERNMENT"]), + DatabricksEnvironment(Cloud.AZURE, + ".databricks.azure.cn", + azure_application_id=ARM_DATABRICKS_RESOURCE_ID, + azure_environment=ENVIRONMENTS["CHINA"]), + DatabricksEnvironment(Cloud.GCP, ".dev.gcp.databricks.com"), + DatabricksEnvironment(Cloud.GCP, ".staging.gcp.databricks.com"), + DatabricksEnvironment(Cloud.GCP, ".gcp.databricks.com") +] diff --git a/databricks/sdk/oauth.py b/databricks/sdk/oauth.py index 9a3061e1d..68b88f003 100644 --- a/databricks/sdk/oauth.py +++ b/databricks/sdk/oauth.py @@ -357,7 +357,8 @@ def __init__(self, scopes: List[str] = None, client_secret: str = None): # TODO: is it a circular dependency?.. - from .core import Config, credentials_provider + from .core import Config + from .credentials_provider import credentials_provider @credentials_provider('noop', []) def noop_credentials(_: any): diff --git a/tests/conftest.py b/tests/conftest.py index 748bd6794..80753ae95 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,8 @@ import pytest as pytest from pyfakefs.fake_filesystem_unittest import Patcher -from databricks.sdk.core import Config, credentials_provider +from databricks.sdk.core import Config +from databricks.sdk.credentials_provider import credentials_provider @credentials_provider('noop', []) diff --git a/tests/test_auth.py b/tests/test_auth.py index f52c66390..504e14439 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -204,7 +204,7 @@ def test_config_azure_cli_host(monkeypatch): @raises( - "default auth: azure-cli: cannot get access token: This is just a failing script. Config: azure_workspace_resource_id=/sub/rg/ws" + "default auth: cannot configure default credentials, please check https://docs.databricks.com/en/dev-tools/auth.html#databricks-client-unified-authentication to configure credentials for your preferred authentication method. Config: azure_workspace_resource_id=/sub/rg/ws" ) def test_config_azure_cli_host_fail(monkeypatch): monkeypatch.setenv('FAIL', 'yes') diff --git a/tests/test_core.py b/tests/test_core.py index d7e2c8f41..ca2eaac31 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -15,10 +15,13 @@ from databricks.sdk import WorkspaceClient from databricks.sdk.azure import ENVIRONMENTS, AzureEnvironment -from databricks.sdk.core import (ApiClient, CliTokenSource, Config, - CredentialsProvider, DatabricksCliTokenSource, - DatabricksError, HeaderFactory, - StreamingResponse, databricks_cli) +from databricks.sdk.core import (ApiClient, Config, DatabricksError, + StreamingResponse) +from databricks.sdk.credentials_provider import (CliTokenSource, + CredentialsProvider, + DatabricksCliTokenSource, + HeaderFactory, databricks_cli) +from databricks.sdk.environments import Cloud, DatabricksEnvironment from databricks.sdk.service.catalog import PermissionsChange from databricks.sdk.service.iam import AccessControlRequest from databricks.sdk.version import __version__ @@ -282,9 +285,10 @@ class DummyResponse(requests.Response): _closed: bool = False def __init__(self, content: List[bytes]) -> None: + super().__init__() self._content = iter(content) - def iter_content(self, chunk_size: int = 1) -> Iterator[bytes]: + def iter_content(self, chunk_size: int = 1, decode_unicode=False) -> Iterator[bytes]: return self._content def close(self): @@ -546,14 +550,33 @@ def inner(h: BaseHTTPRequestHandler): with http_fixture_server(inner) as host: monkeypatch.setenv('ACTIONS_ID_TOKEN_REQUEST_URL', f'{host}/oidc') monkeypatch.setenv('ACTIONS_ID_TOKEN_REQUEST_TOKEN', 'gh-actions-token') - ENVIRONMENTS[host] = AzureEnvironment(name=host, - service_management_endpoint=host + '/', - resource_manager_endpoint=host + '/', - active_directory_endpoint=host + '/') + azure_environment = AzureEnvironment(name=host, + service_management_endpoint=host + '/', + resource_manager_endpoint=host + '/', + active_directory_endpoint=host + '/') + databricks_environment = DatabricksEnvironment(Cloud.AZURE, + '...', + azure_environment=azure_environment) cfg = Config(host=host, azure_workspace_resource_id=..., azure_client_id='test', - azure_environment=host) + azure_environment=host, + databricks_environment=databricks_environment) headers = cfg.authenticate() assert {'Authorization': 'Taker this-is-it'} == headers + + +@pytest.mark.parametrize(['azure_environment', 'expected'], + [('PUBLIC', ENVIRONMENTS['PUBLIC']), ('USGOVERNMENT', ENVIRONMENTS['USGOVERNMENT']), + ('CHINA', ENVIRONMENTS['CHINA']), ('public', ENVIRONMENTS['PUBLIC']), + ('usgovernment', ENVIRONMENTS['USGOVERNMENT']), ('china', ENVIRONMENTS['CHINA']), + # Kept for historical compatibility + ('AzurePublicCloud', ENVIRONMENTS['PUBLIC']), + ('AzureUSGovernment', ENVIRONMENTS['USGOVERNMENT']), + ('AzureChinaCloud', ENVIRONMENTS['CHINA']), ]) +def test_azure_environment(azure_environment, expected): + c = Config(credentials_provider=noop_credentials, + azure_workspace_resource_id='...', + azure_environment=azure_environment) + assert c.arm_environment == expected diff --git a/tests/test_environments.py b/tests/test_environments.py new file mode 100644 index 000000000..c14426f0d --- /dev/null +++ b/tests/test_environments.py @@ -0,0 +1,19 @@ +from databricks.sdk.core import Config +from databricks.sdk.environments import ALL_ENVS, Cloud + + +def test_environment_aws(): + c = Config(host="https://test.cloud.databricks.com", token="token") + assert c.environment.cloud == Cloud.AWS + assert c.environment.dns_zone == ".cloud.databricks.com" + + +def test_environment_azure(): + c = Config(host="https://test.dev.azuredatabricks.net", token="token") + assert c.environment.cloud == Cloud.AZURE + assert c.environment.dns_zone == ".dev.azuredatabricks.net" + + +def test_default_environment_can_be_overridden(): + c = Config(host="https://test.cloud.databricks.com", token="token", databricks_environment=ALL_ENVS[1]) + assert c.environment == ALL_ENVS[1] diff --git a/tests/test_metadata_service_auth.py b/tests/test_metadata_service_auth.py index 753d96f0a..f2c052006 100644 --- a/tests/test_metadata_service_auth.py +++ b/tests/test_metadata_service_auth.py @@ -3,7 +3,8 @@ import requests -from databricks.sdk.core import Config, MetadataServiceTokenSource +from databricks.sdk.core import Config +from databricks.sdk.credentials_provider import MetadataServiceTokenSource def get_test_server(host: str, token: str, expires_after: int): diff --git a/tests/test_oauth.py b/tests/test_oauth.py index 49b194384..ce2d514ff 100644 --- a/tests/test_oauth.py +++ b/tests/test_oauth.py @@ -1,5 +1,5 @@ -from databricks.sdk.core import Config, OidcEndpoints -from databricks.sdk.oauth import OAuthClient, TokenCache +from databricks.sdk.core import Config +from databricks.sdk.oauth import OAuthClient, OidcEndpoints, TokenCache def test_token_cache_unique_filename_by_host(mocker):