From 8a59c084055fdf3bc623802c70033f9f80322438 Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 18 Jul 2024 09:47:54 +0200 Subject: [PATCH] take two --- databricks/sdk/azure.py | 17 ------------- databricks/sdk/config.py | 22 ++++++++++++++++ databricks/sdk/credentials_provider.py | 20 +++++++-------- tests/test_azure.py | 20 --------------- tests/test_config.py | 35 ++++++++++++++++++++++++++ 5 files changed, 66 insertions(+), 48 deletions(-) delete mode 100644 tests/test_azure.py diff --git a/databricks/sdk/azure.py b/databricks/sdk/azure.py index f06477f88..372669d61 100644 --- a/databricks/sdk/azure.py +++ b/databricks/sdk/azure.py @@ -1,7 +1,4 @@ from typing import Dict -from urllib import parse - -import requests from .oauth import TokenSource from .service.provisioning import Workspace @@ -28,17 +25,3 @@ def get_azure_resource_id(workspace: Workspace): return (f'/subscriptions/{workspace.azure_workspace_info.subscription_id}' f'/resourceGroups/{workspace.azure_workspace_info.resource_group}' f'/providers/Microsoft.Databricks/workspaces/{workspace.workspace_name}') - - -def _load_azure_tenant_id(cfg: 'Config'): - if not cfg.is_azure or cfg.azure_tenant_id is not None or cfg.host is None: - return - logging.debug(f'Loading tenant ID from {cfg.host}/aad/auth') - resp = requests.get(f'{cfg.host}/aad/auth', allow_redirects=False) - entra_id_endpoint = resp.headers.get('Location') - if entra_id_endpoint is None: - logging.debug(f'No Location header in response from {cfg.host}/aad/auth') - return - url = parse.urlparse(entra_id_endpoint) - cfg.azure_tenant_id = url.path.split('/')[1] - logging.debug(f'Loaded tenant ID: {cfg.azure_tenant_id}') diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 47d0ecc44..0d9823231 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -119,6 +119,7 @@ def __init__(self, self._load_from_env() self._known_file_config_loader() self._fix_host_if_needed() + self._load_azure_tenant_id() self._validate() self.init_auth() self._init_product(product, product_version) @@ -363,6 +364,27 @@ def _fix_host_if_needed(self): self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) + def _load_azure_tenant_id(self): + if not self.is_azure or self.azure_tenant_id is not None or self.host is None: + return + login_url = f'{self.host}/aad/auth' + logger.debug(f'Loading tenant ID from {login_url}') + resp = requests.get(login_url, allow_redirects=False) + if resp.status_code // 100 != 3: + logger.debug(f'Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}') + return + entra_id_endpoint = resp.headers.get('Location') + if entra_id_endpoint is None: + logger.debug(f'No Location header in response from {login_url}') + return + url = urllib.parse.urlparse(entra_id_endpoint) + path_segments = url.path.split('/') + if len(path_segments) < 2: + logger.debug(f'Invalid path in Location header: {url.path}') + return + self.azure_tenant_id = path_segments[1] + logger.debug(f'Loaded tenant ID: {self.azure_tenant_id}') + def _set_inner_config(self, keyword_args: Dict[str, any]): for attr in self.attributes(): if attr.name not in keyword_args: diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 8738d5116..27237376a 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -18,7 +18,7 @@ from google.auth.transport.requests import Request from google.oauth2 import service_account -from .azure import add_sp_management_token, add_workspace_id_header, _load_azure_tenant_id +from .azure import add_sp_management_token, add_workspace_id_header from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token, TokenCache, TokenSource) @@ -246,7 +246,6 @@ def token_source_for(resource: str) -> TokenSource: endpoint_params={"resource": resource}, use_params=True) - _load_azure_tenant_id(cfg) _ensure_host_present(cfg, token_source_for) logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id) inner = token_source_for(cfg.effective_azure_login_app_id) @@ -432,9 +431,9 @@ def refresh(self) -> Token: class AzureCliTokenSource(CliTokenSource): """ Obtain the token granted by `az login` CLI command """ - def __init__(self, resource: str, subscription: str = "", tenant: str = None): + def __init__(self, resource: str, subscription: Optional[str] = None, tenant: Optional[str] = None): cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"] - if subscription != "": + if subscription is not None: cmd.append("--subscription") cmd.append(subscription) if tenant: @@ -466,8 +465,8 @@ def is_human_user(self) -> bool: @staticmethod def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource': subscription = AzureCliTokenSource.get_subscription(cfg) - if cfg.azure_tenant_id == "" and subscription != "": - token_source = AzureCliTokenSource(resource, subscription) + if subscription is not None: + token_source = AzureCliTokenSource(resource, subscription=subscription, tenant=cfg.azure_tenant_id) try: # This will fail if the user has access to the workspace, but not to the subscription # itself. @@ -477,26 +476,25 @@ def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource': except OSError: logger.warning("Failed to get token for subscription. Using resource only token.") - token_source = AzureCliTokenSource(resource, cfg.azure_tenant_id) + token_source = AzureCliTokenSource(resource, subscription=None, tenant=cfg.azure_tenant_id) token_source.token() return token_source @staticmethod - def get_subscription(cfg: 'Config') -> str: + def get_subscription(cfg: 'Config') -> Optional[str]: resource = cfg.azure_workspace_resource_id if resource is None or resource == "": - return "" + return None components = resource.split('/') if len(components) < 3: logger.warning("Invalid azure workspace resource ID") - return "" + return None return components[2] @credentials_strategy('azure-cli', ['is_azure']) def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]: """ Adds refreshed OAuth token granted by `az login` command to every request. """ - _load_azure_tenant_id(cfg) token_source = None mgmt_token_source = None try: diff --git a/tests/test_azure.py b/tests/test_azure.py deleted file mode 100644 index 9d1b1d2fb..000000000 --- a/tests/test_azure.py +++ /dev/null @@ -1,20 +0,0 @@ -from databricks.sdk.config import Config -import os - -__tests__ = os.path.dirname(__file__) - - -def test_load_azure_tenant_id(requests_mock, monkeypatch): - monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') - mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://login.microsoftonline.com/abc123xyz/oauth2/authorize'}) - cfg = Config(host="https://abc123.azuredatabricks.net") - assert cfg.azure_tenant_id == 'abc123xyz' - assert mock.called_once - - -def test_load_azure_tenant_id_tenant_id_set(requests_mock, monkeypatch): - monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') - mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://login.microsoftonline.com/abc123xyz/oauth2/authorize'}) - cfg = Config(host="https://abc123.azuredatabricks.net", azure_tenant_id="123456789") - assert cfg.azure_tenant_id == '123456789' - assert mock.call_count == 0 diff --git a/tests/test_config.py b/tests/test_config.py index 4d3a0ebef..701333e38 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -8,6 +8,9 @@ from .conftest import noop_credentials +import os + +__tests__ = os.path.dirname(__file__) def test_config_supports_legacy_credentials_provider(): c = Config(credentials_provider=noop_credentials, product='foo', product_version='1.2.3') @@ -74,3 +77,35 @@ def test_config_copy_deep_copies_user_agent_other_info(config): assert "blueprint/0.4.6" in config.user_agent assert "blueprint/0.4.6" in config_copy.user_agent useragent._reset_extra(original_extra) + + +def test_load_azure_tenant_id_404(requests_mock, monkeypatch): + monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=404) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id is None + assert mock.called_once + + +def test_load_azure_tenant_id_no_location_header(requests_mock, monkeypatch): + monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id is None + assert mock.called_once + + +def test_load_azure_tenant_id_unparsable_location_header(requests_mock, monkeypatch): + monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://unexpected-location'}) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id is None + assert mock.called_once + + +def test_load_azure_tenant_id_happy_path(requests_mock, monkeypatch): + monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin') + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://login.microsoftonline.com/tenant-id/oauth2/authorize'}) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id == 'tenant-id' + assert mock.called_once