Skip to content

Commit

Permalink
take two
Browse files Browse the repository at this point in the history
  • Loading branch information
mgyucht committed Jul 18, 2024
1 parent 7cee99c commit 8a59c08
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 48 deletions.
17 changes: 0 additions & 17 deletions databricks/sdk/azure.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
from typing import Dict
from urllib import parse

import requests

from .oauth import TokenSource
from .service.provisioning import Workspace
Expand All @@ -28,17 +25,3 @@ def get_azure_resource_id(workspace: Workspace):
return (f'/subscriptions/{workspace.azure_workspace_info.subscription_id}'
f'/resourceGroups/{workspace.azure_workspace_info.resource_group}'
f'/providers/Microsoft.Databricks/workspaces/{workspace.workspace_name}')


def _load_azure_tenant_id(cfg: 'Config'):
if not cfg.is_azure or cfg.azure_tenant_id is not None or cfg.host is None:
return
logging.debug(f'Loading tenant ID from {cfg.host}/aad/auth')
resp = requests.get(f'{cfg.host}/aad/auth', allow_redirects=False)
entra_id_endpoint = resp.headers.get('Location')
if entra_id_endpoint is None:
logging.debug(f'No Location header in response from {cfg.host}/aad/auth')
return
url = parse.urlparse(entra_id_endpoint)
cfg.azure_tenant_id = url.path.split('/')[1]
logging.debug(f'Loaded tenant ID: {cfg.azure_tenant_id}')
22 changes: 22 additions & 0 deletions databricks/sdk/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(self,
self._load_from_env()
self._known_file_config_loader()
self._fix_host_if_needed()
self._load_azure_tenant_id()
self._validate()
self.init_auth()
self._init_product(product, product_version)
Expand Down Expand Up @@ -363,6 +364,27 @@ def _fix_host_if_needed(self):

self.host = urllib.parse.urlunparse((o.scheme, netloc, path, o.params, o.query, o.fragment))

def _load_azure_tenant_id(self):
if not self.is_azure or self.azure_tenant_id is not None or self.host is None:
return
login_url = f'{self.host}/aad/auth'
logger.debug(f'Loading tenant ID from {login_url}')
resp = requests.get(login_url, allow_redirects=False)
if resp.status_code // 100 != 3:
logger.debug(f'Failed to get tenant ID from {login_url}: expected status code 3xx, got {resp.status_code}')
return
entra_id_endpoint = resp.headers.get('Location')
if entra_id_endpoint is None:
logger.debug(f'No Location header in response from {login_url}')
return
url = urllib.parse.urlparse(entra_id_endpoint)
path_segments = url.path.split('/')
if len(path_segments) < 2:
logger.debug(f'Invalid path in Location header: {url.path}')
return
self.azure_tenant_id = path_segments[1]
logger.debug(f'Loaded tenant ID: {self.azure_tenant_id}')

def _set_inner_config(self, keyword_args: Dict[str, any]):
for attr in self.attributes():
if attr.name not in keyword_args:
Expand Down
20 changes: 9 additions & 11 deletions databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from google.auth.transport.requests import Request
from google.oauth2 import service_account

from .azure import add_sp_management_token, add_workspace_id_header, _load_azure_tenant_id
from .azure import add_sp_management_token, add_workspace_id_header
from .oauth import (ClientCredentials, OAuthClient, Refreshable, Token,
TokenCache, TokenSource)

Expand Down Expand Up @@ -246,7 +246,6 @@ def token_source_for(resource: str) -> TokenSource:
endpoint_params={"resource": resource},
use_params=True)

_load_azure_tenant_id(cfg)
_ensure_host_present(cfg, token_source_for)
logger.info("Configured AAD token for Service Principal (%s)", cfg.azure_client_id)
inner = token_source_for(cfg.effective_azure_login_app_id)
Expand Down Expand Up @@ -432,9 +431,9 @@ def refresh(self) -> Token:
class AzureCliTokenSource(CliTokenSource):
""" Obtain the token granted by `az login` CLI command """

def __init__(self, resource: str, subscription: str = "", tenant: str = None):
def __init__(self, resource: str, subscription: Optional[str] = None, tenant: Optional[str] = None):
cmd = ["az", "account", "get-access-token", "--resource", resource, "--output", "json"]
if subscription != "":
if subscription is not None:
cmd.append("--subscription")
cmd.append(subscription)
if tenant:
Expand Down Expand Up @@ -466,8 +465,8 @@ def is_human_user(self) -> bool:
@staticmethod
def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
subscription = AzureCliTokenSource.get_subscription(cfg)
if cfg.azure_tenant_id == "" and subscription != "":
token_source = AzureCliTokenSource(resource, subscription)
if subscription is not None:
token_source = AzureCliTokenSource(resource, subscription=subscription, tenant=cfg.azure_tenant_id)
try:
# This will fail if the user has access to the workspace, but not to the subscription
# itself.
Expand All @@ -477,26 +476,25 @@ def for_resource(cfg: 'Config', resource: str) -> 'AzureCliTokenSource':
except OSError:
logger.warning("Failed to get token for subscription. Using resource only token.")

token_source = AzureCliTokenSource(resource, cfg.azure_tenant_id)
token_source = AzureCliTokenSource(resource, subscription=None, tenant=cfg.azure_tenant_id)
token_source.token()
return token_source

@staticmethod
def get_subscription(cfg: 'Config') -> str:
def get_subscription(cfg: 'Config') -> Optional[str]:
resource = cfg.azure_workspace_resource_id
if resource is None or resource == "":
return ""
return None
components = resource.split('/')
if len(components) < 3:
logger.warning("Invalid azure workspace resource ID")
return ""
return None
return components[2]


@credentials_strategy('azure-cli', ['is_azure'])
def azure_cli(cfg: 'Config') -> Optional[CredentialsProvider]:
""" Adds refreshed OAuth token granted by `az login` command to every request. """
_load_azure_tenant_id(cfg)
token_source = None
mgmt_token_source = None
try:
Expand Down
20 changes: 0 additions & 20 deletions tests/test_azure.py

This file was deleted.

35 changes: 35 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@

from .conftest import noop_credentials

import os

__tests__ = os.path.dirname(__file__)

def test_config_supports_legacy_credentials_provider():
c = Config(credentials_provider=noop_credentials, product='foo', product_version='1.2.3')
Expand Down Expand Up @@ -74,3 +77,35 @@ def test_config_copy_deep_copies_user_agent_other_info(config):
assert "blueprint/0.4.6" in config.user_agent
assert "blueprint/0.4.6" in config_copy.user_agent
useragent._reset_extra(original_extra)


def test_load_azure_tenant_id_404(requests_mock, monkeypatch):
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=404)
cfg = Config(host="https://abc123.azuredatabricks.net")
assert cfg.azure_tenant_id is None
assert mock.called_once


def test_load_azure_tenant_id_no_location_header(requests_mock, monkeypatch):
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302)
cfg = Config(host="https://abc123.azuredatabricks.net")
assert cfg.azure_tenant_id is None
assert mock.called_once


def test_load_azure_tenant_id_unparsable_location_header(requests_mock, monkeypatch):
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://unexpected-location'})
cfg = Config(host="https://abc123.azuredatabricks.net")
assert cfg.azure_tenant_id is None
assert mock.called_once


def test_load_azure_tenant_id_happy_path(requests_mock, monkeypatch):
monkeypatch.setenv('PATH', __tests__ + '/testdata:/bin')
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=302, headers={'Location': 'https://login.microsoftonline.com/tenant-id/oauth2/authorize'})
cfg = Config(host="https://abc123.azuredatabricks.net")
assert cfg.azure_tenant_id == 'tenant-id'
assert mock.called_once

0 comments on commit 8a59c08

Please sign in to comment.