From b0750ebb718e41e9e14cfbf76ddb4569642f75ee Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Thu, 18 Jul 2024 13:39:12 +0200 Subject: [PATCH] [Fix] Infer Azure tenant ID if not set (#638) ## Changes Port of https://github.com/databricks/databricks-sdk-go/pull/910 to the Python SDK. In order to use Azure U2M or M2M authentication with the Databricks SDK, users must request a token from the Entra ID instance that the underlying workspace or account belongs to, as Databricks rejects requests to workspaces with a token from a different Entra ID tenant. However, with Azure CLI auth, it is possible that a user is logged into multiple tenants at the same time. Currently, the SDK uses the subscription ID from the configured Azure Resource ID for the workspace when issuing the `az account get-access-token` command. However, when users don't specify the resource ID, the SDK simply fetches a token for the active subscription for the user. If the active subscription is in a different tenant than the workspace, users will see an error such as: ``` io.jsonwebtoken.IncorrectClaimException: Expected iss claim to be: https://sts.windows.net/72f988bf-86f1-41af-91ab-2d7cd011db47/, but was: https://sts.windows.net/e3fe3f22-4b98-4c04-82cc-d8817d1b17da/ ``` This PR modifies Azure CLI and Azure SP credential providers to attempt to load the tenant ID of the workspace if not provided before authenticating. Currently, there are no unauthenticated endpoints that the tenant ID can be directly fetched from. However, the tenant ID is indirectly exposed via the redirect URL used when logging into a workspace. In this PR, we fetch the tenant ID from this endpoint and configure it if not already set. Here, we lazily fetch the tenant ID only in the auth methods that need it. This prevents us from making any unnecessary requests if these Azure credential providers are not needed. ## Tests Unit tests check that the tenant ID is fetched automatically if not specified for an azure workspace when authenticating with client ID/secret or with the CLI. - [x] `make test` run locally - [x] `make fmt` applied - [x] relevant integration tests applied --- databricks/sdk/config.py | 27 +++++++++++++++++ databricks/sdk/credentials_provider.py | 30 +++++++++--------- tests/conftest.py | 13 ++++++++ tests/test_auth.py | 9 ++++-- tests/test_auth_manual_tests.py | 15 ++++++--- tests/test_config.py | 42 +++++++++++++++++++++++++- 6 files changed, 112 insertions(+), 24 deletions(-) diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index 47d0ecc44..28d57ad42 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -363,6 +363,33 @@ def _fix_host_if_needed(self): self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment)) + def load_azure_tenant_id(self): + """[Internal] Load the Azure tenant ID from the Azure Databricks login page. + + If the tenant ID is already set, this method does nothing.""" + if not self.is_azure or self.azure_tenant_id is not None or self.host is None: + return + login_url = f'{self.host}/aad/auth' + logger.debug(f'Loading tenant ID from {login_url}') + resp = requests.get(login_url, allow_redirects=False) + if resp.status_code // 100 != 3: + logger.debug( + f'Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}') + return + entra_id_endpoint = resp.headers.get('Location') + if entra_id_endpoint is None: + logger.debug(f'No Location header in response from {login_url}') + return + # The Location header has the following form: https://login.microsoftonline.com//oauth2/authorize?... + # The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud). + url = urllib.parse.urlparse(entra_id_endpoint) + path_segments = url.path.split('/') + if len(path_segments) < 2: + logger.debug(f'Invalid path in Location header: {url.path}') + return + self.azure_tenant_id = path_segments[1] + logger.debug(f'Loaded tenant ID: {self.azure_tenant_id}') + def _set_inner_config(self, keyword_args: Dict[str, any]): for attr in self.attributes(): if attr.name not in keyword_args: diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index 50c2eee89..cfdf80e0d 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -233,8 +233,7 @@ def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenS cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}" -@oauth_credentials_strategy('azure-client-secret', - ['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id']) +@oauth_credentials_strategy('azure-client-secret', ['is_azure', 'azure_client_id', 'azure_client_secret']) def azure_service_principal(cfg: 'Config') -> CredentialsProvider: """ Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens to every request, while automatically resolving different Azure environment endpoints. """ @@ -248,6 +247,7 @@ def token_source_for(resource: str) -> TokenSource: use_params=True) _ensure_host_present(cfg, token_source_for) + cfg.load_azure_tenant_id() logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id) inner = token_source_for(cfg.effective_azure_login_app_id) cloud = token_source_for(cfg.arm_environment.service_management_endpoint) @@ -432,11 +432,13 @@ def refresh(self) -> Token: class AzureCliTokenSource(CliTokenSource): """ Obtain the token granted by `az login` CLI command """ - def __init__(self, resource: str, subscription: str = ""): + def __init__(self, resource: str, subscription: Optional[str] = None, tenant: Optional[str] = None): cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"] - if subscription != "": + if subscription is not None: cmd.append("--subscription") cmd.append(subscription) + if tenant: + cmd.extend(["--tenant", tenant]) super().__init__(cmd=cmd, token_type_field='tokenType', access_token_field='accessToken', @@ -464,8 +466,10 @@ def is_human_user(self) -> bool: @staticmethod def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource': subscription = AzureCliTokenSource.get_subscription(cfg) - if subscription != "": - token_source = AzureCliTokenSource(resource, subscription) + if subscription is not None: + token_source = AzureCliTokenSource(resource, + subscription=subscription, + tenant=cfg.azure_tenant_id) try: # This will fail if the user has access to the workspace, but not to the subscription # itself. @@ -475,25 +479,26 @@ def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource': except OSError: logger.warning("Failed to get token for subscription. Using resource only token.") - token_source = AzureCliTokenSource(resource) + token_source = AzureCliTokenSource(resource, subscription=None, tenant=cfg.azure_tenant_id) token_source.token() return token_source @staticmethod - def get_subscription(cfg: 'Config') -> str: + def get_subscription(cfg: 'Config') -> Optional[str]: resource = cfg.azure_workspace_resource_id if resource is None or resource == "": - return "" + return None components = resource.split('/') if len(components) < 3: logger.warning("Invalid azure workspace resource ID") - return "" + return None return components[2] @credentials_strategy('azure-cli', ['is_azure']) def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]: """ Adds refreshed OAuth token granted by `az login` command to every request. """ + cfg.load_azure_tenant_id() token_source = None mgmt_token_source = None try: @@ -517,11 +522,6 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]: _ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource)) logger.info("Using Azure CLI authentication with AAD tokens") - if not cfg.is_account_client and AzureCliTokenSource.get_subscription(cfg) == "": - logger.warning( - "azure_workspace_resource_id field not provided. " - "It is recommended to specify this field in the Databricks configuration to avoid authentication errors." - ) def inner() -> Dict[str, str]: token = token_source.token() diff --git a/tests/conftest.py b/tests/conftest.py index a7e520dc9..0f415ecf1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -77,3 +77,16 @@ def set_az_path(monkeypatch): monkeypatch.setenv('COMSPEC', 'C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe') else: monkeypatch.setenv('PATH', __tests__ + "/testdata:/bin") + + +@pytest.fixture +def mock_tenant(requests_mock): + + def stub_tenant_request(host, tenant_id="test-tenant-id"): + mock = requests_mock.get( + f'https://{host}/aad/auth', + status_code=302, + headers={'Location': f'https://login.microsoftonline.com/{tenant_id}/oauth2/authorize'}) + return mock + + return stub_tenant_request diff --git a/tests/test_auth.py b/tests/test_auth.py index fd73378b2..cd8f3cfc1 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -193,9 +193,10 @@ def test_config_azure_pat(): assert cfg.is_azure -def test_config_azure_cli_host(monkeypatch): +def test_config_azure_cli_host(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws') assert cfg.auth_type == 'azure-cli' @@ -229,9 +230,10 @@ def test_config_azure_cli_host_pat_conflict_with_config_file_present_without_def cfg = Config(token='x', azure_workspace_resource_id='/sub/rg/ws') -def test_config_azure_cli_host_and_resource_id(monkeypatch): +def test_config_azure_cli_host_and_resource_id(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws') assert cfg.auth_type == 'azure-cli' @@ -239,10 +241,11 @@ def test_config_azure_cli_host_and_resource_id(monkeypatch): assert cfg.is_azure -def test_config_azure_cli_host_and_resource_i_d_configuration_precedence(monkeypatch): +def test_config_azure_cli_host_and_resource_i_d_configuration_precedence(monkeypatch, mock_tenant): monkeypatch.setenv('DATABRICKS_CONFIG_PROFILE', 'justhost') set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws') assert cfg.auth_type == 'azure-cli' diff --git a/tests/test_auth_manual_tests.py b/tests/test_auth_manual_tests.py index e2874c427..34aa3a9c2 100644 --- a/tests/test_auth_manual_tests.py +++ b/tests/test_auth_manual_tests.py @@ -3,9 +3,10 @@ from .conftest import set_az_path, set_home -def test_azure_cli_workspace_header_present(monkeypatch): +def test_azure_cli_workspace_header_present(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', host='https://adb-123.4.azuredatabricks.net', @@ -14,9 +15,10 @@ def test_azure_cli_workspace_header_present(monkeypatch): assert cfg.authenticate()['X-Databricks-Azure-Workspace-Resource-Id'] == resource_id -def test_azure_cli_user_with_management_access(monkeypatch): +def test_azure_cli_user_with_management_access(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', host='https://adb-123.4.azuredatabricks.net', @@ -24,9 +26,10 @@ def test_azure_cli_user_with_management_access(monkeypatch): assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate() -def test_azure_cli_user_no_management_access(monkeypatch): +def test_azure_cli_user_no_management_access(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') monkeypatch.setenv('FAIL_IF', 'https://management.core.windows.net/') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', @@ -35,9 +38,10 @@ def test_azure_cli_user_no_management_access(monkeypatch): assert 'X-Databricks-Azure-SP-Management-Token' not in cfg.authenticate() -def test_azure_cli_fallback(monkeypatch): +def test_azure_cli_fallback(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') monkeypatch.setenv('FAIL_IF', 'subscription') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', @@ -46,9 +50,10 @@ def test_azure_cli_fallback(monkeypatch): assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate() -def test_azure_cli_with_warning_on_stderr(monkeypatch): +def test_azure_cli_with_warning_on_stderr(monkeypatch, mock_tenant): set_home(monkeypatch, '/testdata/azure') set_az_path(monkeypatch) + mock_tenant('adb-123.4.azuredatabricks.net') monkeypatch.setenv('WARN', 'this is a warning') resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123' cfg = Config(auth_type='azure-cli', diff --git a/tests/test_config.py b/tests/test_config.py index 4d3a0ebef..4bab85cf1 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,3 +1,4 @@ +import os import platform import pytest @@ -6,7 +7,9 @@ from databricks.sdk.config import Config, with_product, with_user_agent_extra from databricks.sdk.version import __version__ -from .conftest import noop_credentials +from .conftest import noop_credentials, set_az_path + +__tests__ = os.path.dirname(__file__) def test_config_supports_legacy_credentials_provider(): @@ -74,3 +77,40 @@ def test_config_copy_deep_copies_user_agent_other_info(config): assert "blueprint/0.4.6" in config.user_agent assert "blueprint/0.4.6" in config_copy.user_agent useragent._reset_extra(original_extra) + + +def test_load_azure_tenant_id_404(requests_mock, monkeypatch): + set_az_path(monkeypatch) + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=404) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id is None + assert mock.called_once + + +def test_load_azure_tenant_id_no_location_header(requests_mock, monkeypatch): + set_az_path(monkeypatch) + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id is None + assert mock.called_once + + +def test_load_azure_tenant_id_unparsable_location_header(requests_mock, monkeypatch): + set_az_path(monkeypatch) + mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', + status_code=302, + headers={'Location': 'https://unexpected-location'}) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id is None + assert mock.called_once + + +def test_load_azure_tenant_id_happy_path(requests_mock, monkeypatch): + set_az_path(monkeypatch) + mock = requests_mock.get( + 'https://abc123.azuredatabricks.net/aad/auth', + status_code=302, + headers={'Location': 'https://login.microsoftonline.com/tenant-id/oauth2/authorize'}) + cfg = Config(host="https://abc123.azuredatabricks.net") + assert cfg.azure_tenant_id == 'tenant-id' + assert mock.called_once