Skip to content

Commit

Permalink
[Fix] Infer Azure tenant ID if not set (#638)
Browse files Browse the repository at this point in the history
## Changes
Port of databricks/databricks-sdk-go#910 to the
Python SDK.

In order to use Azure U2M or M2M authentication with the Databricks SDK,
users must request a token from the Entra ID instance that the
underlying workspace or account belongs to, as Databricks rejects
requests to workspaces with a token from a different Entra ID tenant.
However, with Azure CLI auth, it is possible that a user is logged into
multiple tenants at the same time. Currently, the SDK uses the
subscription ID from the configured Azure Resource ID for the workspace
when issuing the `az account get-access-token` command. However, when
users don't specify the resource ID, the SDK simply fetches a token for
the active subscription for the user. If the active subscription is in a
different tenant than the workspace, users will see an error such as:

```
io.jsonwebtoken.IncorrectClaimException: Expected iss claim to be: https://sts.windows.net/72f988bf-86f1-41af-91ab-2d7cd011db47/, but was: https://sts.windows.net/e3fe3f22-4b98-4c04-82cc-d8817d1b17da/
```

This PR modifies Azure CLI and Azure SP credential providers to attempt
to load the tenant ID of the workspace if not provided before
authenticating. Currently, there are no unauthenticated endpoints that
the tenant ID can be directly fetched from. However, the tenant ID is
indirectly exposed via the redirect URL used when logging into a
workspace. In this PR, we fetch the tenant ID from this endpoint and
configure it if not already set.

Here, we lazily fetch the tenant ID only in the auth methods that need
it. This prevents us from making any unnecessary requests if these Azure
credential providers are not needed.

## Tests
Unit tests check that the tenant ID is fetched automatically if not
specified for an azure workspace when authenticating with client
ID/secret or with the CLI.

- [x] `make test` run locally
- [x] `make fmt` applied
- [x] relevant integration tests applied
  • Loading branch information
mgyucht authored Jul 18, 2024
1 parent f5c5f48 commit b0750eb
Show file tree
Hide file tree
Showing 6 changed files with 112 additions and 24 deletions.
27 changes: 27 additions & 0 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,33 @@ def _fix_host_if_needed(self):

self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))

def load_azure_tenant_id(self):
"""[Internal] Load the Azure tenant ID from the Azure Databricks login page.
If the tenant ID is already set, this method does nothing."""
if not self.is_azure or self.azure_tenant_id is not None or self.host is None:
return
login_url = f'{self.host}/aad/auth'
logger.debug(f'Loading tenant ID from {login_url}')
resp = requests.get(login_url, allow_redirects=False)
if resp.status_code // 100 != 3:
logger.debug(
f'Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}')
return
entra_id_endpoint = resp.headers.get('Location')
if entra_id_endpoint is None:
logger.debug(f'No Location header in response from {login_url}')
return
# The Location header has the following form: https://login.microsoftonline.com/<tenant-id>/oauth2/authorize?...
# The domain may change depending on the Azure cloud (e.g. login.microsoftonline.us for US Government cloud).
url = urllib.parse.urlparse(entra_id_endpoint)
path_segments = url.path.split('/')
if len(path_segments) < 2:
logger.debug(f'Invalid path in Location header: {url.path}')
return
self.azure_tenant_id = path_segments[1]
logger.debug(f'Loaded tenant ID: {self.azure_tenant_id}')

def _set_inner_config(self, keyword_args: Dict[str, any]):
for attr in self.attributes():
if attr.name not in keyword_args:
Expand Down
30 changes: 15 additions & 15 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,7 @@ def _ensure_host_present(cfg: 'Config', token_source_for: Callable[[str], TokenS
cfg.host = f"https://{resp.json()['properties']['workspaceUrl']}"


@oauth_credentials_strategy('azure-client-secret',
['is_azure', 'azure_client_id', 'azure_client_secret', 'azure_tenant_id'])
@oauth_credentials_strategy('azure-client-secret', ['is_azure', 'azure_client_id', 'azure_client_secret'])
def azure_service_principal(cfg: 'Config') -> CredentialsProvider:
""" Adds refreshed Azure Active Directory (AAD) Service Principal OAuth tokens
to every request, while automatically resolving different Azure environment endpoints. """
Expand All @@ -248,6 +247,7 @@ def token_source_for(resource: str) -> TokenSource:
use_params=True)

_ensure_host_present(cfg, token_source_for)
cfg.load_azure_tenant_id()
logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id)
inner = token_source_for(cfg.effective_azure_login_app_id)
cloud = token_source_for(cfg.arm_environment.service_management_endpoint)
Expand Down Expand Up @@ -432,11 +432,13 @@ def refresh(self) -> Token:
class AzureCliTokenSource(CliTokenSource):
""" Obtain the token granted by `az login` CLI command """

def __init__(self, resource: str, subscription: str = ""):
def __init__(self, resource: str, subscription: Optional[str] = None, tenant: Optional[str] = None):
cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"]
if subscription != "":
if subscription is not None:
cmd.append("--subscription")
cmd.append(subscription)
if tenant:
cmd.extend(["--tenant", tenant])
super().__init__(cmd=cmd,
token_type_field='tokenType',
access_token_field='accessToken',
Expand Down Expand Up @@ -464,8 +466,10 @@ def is_human_user(self) -> bool:
@staticmethod
def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
subscription = AzureCliTokenSource.get_subscription(cfg)
if subscription != "":
token_source = AzureCliTokenSource(resource, subscription)
if subscription is not None:
token_source = AzureCliTokenSource(resource,
subscription=subscription,
tenant=cfg.azure_tenant_id)
try:
# This will fail if the user has access to the workspace, but not to the subscription
# itself.
Expand All @@ -475,25 +479,26 @@ def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
except OSError:
logger.warning("Failed to get token for subscription. Using resource only token.")

token_source = AzureCliTokenSource(resource)
token_source = AzureCliTokenSource(resource, subscription=None, tenant=cfg.azure_tenant_id)
token_source.token()
return token_source

@staticmethod
def get_subscription(cfg: 'Config') -> str:
def get_subscription(cfg: 'Config') -> Optional[str]:
resource = cfg.azure_workspace_resource_id
if resource is None or resource == "":
return ""
return None
components = resource.split('/')
if len(components) < 3:
logger.warning("Invalid azure workspace resource ID")
return ""
return None
return components[2]


@credentials_strategy('azure-cli', ['is_azure'])
def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
""" Adds refreshed OAuth token granted by `az login` command to every request. """
cfg.load_azure_tenant_id()
token_source = None
mgmt_token_source = None
try:
Expand All @@ -517,11 +522,6 @@ def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:

_ensure_host_present(cfg, lambda resource: AzureCliTokenSource.for_resource(cfg, resource))
logger.info("Using Azure CLI authentication with AAD tokens")
if not cfg.is_account_client and AzureCliTokenSource.get_subscription(cfg) == "":
logger.warning(
"azure_workspace_resource_id field not provided. "
"It is recommended to specify this field in the Databricks configuration to avoid authentication errors."
)

def inner() -> Dict[str, str]:
token = token_source.token()
Expand Down
13 changes: 13 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,16 @@ def set_az_path(monkeypatch):
monkeypatch.setenv('COMSPEC', 'C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\powershell.exe')
else:
monkeypatch.setenv('PATH', __tests__ + "/testdata:/bin")


@pytest.fixture
def mock_tenant(requests_mock):

def stub_tenant_request(host, tenant_id="test-tenant-id"):
mock = requests_mock.get(
f'https://{host}/aad/auth',
status_code=302,
headers={'Location': f'https://login.microsoftonline.com/{tenant_id}/oauth2/authorize'})
return mock

return stub_tenant_request
9 changes: 6 additions & 3 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,10 @@ def test_config_azure_pat():
assert cfg.is_azure


def test_config_azure_cli_host(monkeypatch):
def test_config_azure_cli_host(monkeypatch, mock_tenant):
set_home(monkeypatch, '/testdata/azure')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws')

assert cfg.auth_type == 'azure-cli'
Expand Down Expand Up @@ -229,20 +230,22 @@ def test_config_azure_cli_host_pat_conflict_with_config_file_present_without_def
cfg = Config(token='x', azure_workspace_resource_id='/sub/rg/ws')


def test_config_azure_cli_host_and_resource_id(monkeypatch):
def test_config_azure_cli_host_and_resource_id(monkeypatch, mock_tenant):
set_home(monkeypatch, '/testdata')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws')

assert cfg.auth_type == 'azure-cli'
assert cfg.host == 'https://adb-123.4.azuredatabricks.net'
assert cfg.is_azure


def test_config_azure_cli_host_and_resource_i_d_configuration_precedence(monkeypatch):
def test_config_azure_cli_host_and_resource_i_d_configuration_precedence(monkeypatch, mock_tenant):
monkeypatch.setenv('DATABRICKS_CONFIG_PROFILE', 'justhost')
set_home(monkeypatch, '/testdata/azure')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
cfg = Config(host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id='/sub/rg/ws')

assert cfg.auth_type == 'azure-cli'
Expand Down
15 changes: 10 additions & 5 deletions tests/test_auth_manual_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from .conftest import set_az_path, set_home


def test_azure_cli_workspace_header_present(monkeypatch):
def test_azure_cli_workspace_header_present(monkeypatch, mock_tenant):
set_home(monkeypatch, '/testdata/azure')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli',
host='https://adb-123.4.azuredatabricks.net',
Expand All @@ -14,19 +15,21 @@ def test_azure_cli_workspace_header_present(monkeypatch):
assert cfg.authenticate()['X-Databricks-Azure-Workspace-Resource-Id'] == resource_id


def test_azure_cli_user_with_management_access(monkeypatch):
def test_azure_cli_user_with_management_access(monkeypatch, mock_tenant):
set_home(monkeypatch, '/testdata/azure')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli',
host='https://adb-123.4.azuredatabricks.net',
azure_workspace_resource_id=resource_id)
assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate()


def test_azure_cli_user_no_management_access(monkeypatch):
def test_azure_cli_user_no_management_access(monkeypatch, mock_tenant):
set_home(monkeypatch, '/testdata/azure')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
monkeypatch.setenv('FAIL_IF', 'https://management.core.windows.net/')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli',
Expand All @@ -35,9 +38,10 @@ def test_azure_cli_user_no_management_access(monkeypatch):
assert 'X-Databricks-Azure-SP-Management-Token' not in cfg.authenticate()


def test_azure_cli_fallback(monkeypatch):
def test_azure_cli_fallback(monkeypatch, mock_tenant):
set_home(monkeypatch, '/testdata/azure')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
monkeypatch.setenv('FAIL_IF', 'subscription')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli',
Expand All @@ -46,9 +50,10 @@ def test_azure_cli_fallback(monkeypatch):
assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate()


def test_azure_cli_with_warning_on_stderr(monkeypatch):
def test_azure_cli_with_warning_on_stderr(monkeypatch, mock_tenant):
set_home(monkeypatch, '/testdata/azure')
set_az_path(monkeypatch)
mock_tenant('adb-123.4.azuredatabricks.net')
monkeypatch.setenv('WARN', 'this is a warning')
resource_id = '/subscriptions/123/resourceGroups/abc/providers/Microsoft.Databricks/workspaces/abc123'
cfg = Config(auth_type='azure-cli',
Expand Down
42 changes: 41 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import platform

import pytest
Expand All @@ -6,7 +7,9 @@
from databricks.sdk.config import Config, with_product, with_user_agent_extra
from databricks.sdk.version import __version__

from .conftest import noop_credentials
from .conftest import noop_credentials, set_az_path

__tests__ = os.path.dirname(__file__)


def test_config_supports_legacy_credentials_provider():
Expand Down Expand Up @@ -74,3 +77,40 @@ def test_config_copy_deep_copies_user_agent_other_info(config):
assert "blueprint/0.4.6" in config.user_agent
assert "blueprint/0.4.6" in config_copy.user_agent
useragent._reset_extra(original_extra)


def test_load_azure_tenant_id_404(requests_mock, monkeypatch):
set_az_path(monkeypatch)
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=404)
cfg = Config(host="https://abc123.azuredatabricks.net")
assert cfg.azure_tenant_id is None
assert mock.called_once


def test_load_azure_tenant_id_no_location_header(requests_mock, monkeypatch):
set_az_path(monkeypatch)
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302)
cfg = Config(host="https://abc123.azuredatabricks.net")
assert cfg.azure_tenant_id is None
assert mock.called_once


def test_load_azure_tenant_id_unparsable_location_header(requests_mock, monkeypatch):
set_az_path(monkeypatch)
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth',
status_code=302,
headers={'Location': 'https://unexpected-location'})
cfg = Config(host="https://abc123.azuredatabricks.net")
assert cfg.azure_tenant_id is None
assert mock.called_once


def test_load_azure_tenant_id_happy_path(requests_mock, monkeypatch):
set_az_path(monkeypatch)
mock = requests_mock.get(
'https://abc123.azuredatabricks.net/aad/auth',
status_code=302,
headers={'Location': 'https://login.microsoftonline.com/tenant-id/oauth2/authorize'})
cfg = Config(host="https://abc123.azuredatabricks.net")
assert cfg.azure_tenant_id == 'tenant-id'
assert mock.called_once

0 comments on commit b0750eb

Please sign in to comment.