From c0b9b9239d0ee2f04e0e3078b8200d11da90d9ef Mon Sep 17 00:00:00 2001 From: Miles Yucht Date: Mon, 16 Sep 2024 08:52:51 -0400 Subject: [PATCH] [Fix] Do not specify --tenant flag when fetching managed identity access token from the CLI (#748) ## Changes Ports https://github.com/databricks/databricks-sdk-go/pull/1021 to the Python SDK. The Azure CLI's az account get-access-token command does not allow specifying --tenant flag if it is authenticated via the CLI. Fixes #742. ## Tests Unit tests ensure that all expected cases are treated as managed identities. - [ ] `make test` run locally - [ ] `make fmt` applied - [ ] relevant integration tests applied --- databricks/sdk/credentials_provider.py | 44 +++++++++++++++++++++++--- tests/test_auth_manual_tests.py | 12 +++++++ tests/testdata/az | 32 +++++++++++++++++-- tests/testdata/windows/az.ps1 | 28 ++++++++++++++++ 4 files changed, 109 insertions(+), 7 deletions(-) diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index df928020b..b64a66e08 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -412,10 +412,7 @@ def _parse_expiry(expiry: str) -> datetime: def refresh(self) -> Token: try: - is_windows = sys.platform.startswith('win') - # windows requires shell=True to be able to execute 'az login' or other commands - # cannot use shell=True all the time, as it breaks macOS - out = subprocess.run(self._cmd, capture_output=True, check=True, shell=is_windows) + out = _run_subprocess(self._cmd, capture_output=True, check=True) it = json.loads(out.stdout.decode()) expires_on = self._parse_expiry(it[self._expiry_field]) return Token(access_token=it[self._access_token_field], @@ -430,6 +427,26 @@ def refresh(self) -> Token: raise IOError(f'cannot get access token: {message}') from e +def _run_subprocess(popenargs, + input=None, + capture_output=True, + timeout=None, + check=False, + **kwargs) -> subprocess.CompletedProcess: + """Runs subprocess with given arguments. + This handles OS-specific modifications that need to be made to the invocation of subprocess.run.""" + kwargs['shell'] = sys.platform.startswith('win') + # windows requires shell=True to be able to execute 'az login' or other commands + # cannot use shell=True all the time, as it breaks macOS + logging.debug(f'Running command: {" ".join(popenargs)}') + return subprocess.run(popenargs, + input=input, + capture_output=capture_output, + timeout=timeout, + check=check, + **kwargs) + + class AzureCliTokenSource(CliTokenSource): """ Obtain the token granted by `az login` CLI command """ @@ -438,13 +455,30 @@ def __init__(self, resource: str, subscription: Optional[str] = None, tenant: Op if subscription is not None: cmd.append("--subscription") cmd.append(subscription) - if tenant: + if tenant and not self.__is_cli_using_managed_identity(): cmd.extend(["--tenant", tenant]) super().__init__(cmd=cmd, token_type_field='tokenType', access_token_field='accessToken', expiry_field='expiresOn') + @staticmethod + def __is_cli_using_managed_identity() -> bool: + """Checks whether the current CLI session is authenticated using managed identity.""" + try: + cmd = ["az", "account", "show", "--output", "json"] + out = _run_subprocess(cmd, capture_output=True, check=True) + account = json.loads(out.stdout.decode()) + user = account.get("user") + if user is None: + return False + return user.get("type") == "servicePrincipal" and user.get("name") in [ + 'systemAssignedIdentity', 'userAssignedIdentity' + ] + except subprocess.CalledProcessError as e: + logger.debug("Failed to get account information from Azure CLI", exc_info=e) + return False + def is_human_user(self) -> bool: """The UPN claim is the username of the user, but not the Service Principal. diff --git a/tests/test_auth_manual_tests.py b/tests/test_auth_manual_tests.py index 34aa3a9c2..8c58dd6bf 100644 --- a/tests/test_auth_manual_tests.py +++ b/tests/test_auth_manual_tests.py @@ -1,3 +1,5 @@ +import pytest + from databricks.sdk.core import Config from .conftest import set_az_path, set_home @@ -60,3 +62,13 @@ def test_azure_cli_with_warning_on_stderr(monkeypatch, mock_tenant): host='https://adb-123.4.azuredatabricks.net', azure_workspace_resource_id=resource_id) assert 'X-Databricks-Azure-SP-Management-Token' in cfg.authenticate() + + +@pytest.mark.parametrize('username', ['systemAssignedIdentity', 'userAssignedIdentity']) +def test_azure_cli_does_not_specify_tenant_id_with_msi(monkeypatch, username): + set_home(monkeypatch, '/testdata/azure') + set_az_path(monkeypatch) + monkeypatch.setenv('FAIL_IF_TENANT_ID_SET', 'true') + monkeypatch.setenv('AZ_USER_NAME', username) + monkeypatch.setenv('AZ_USER_TYPE', 'servicePrincipal') + cfg = Config(auth_type='azure-cli', host='https://adb-123.4.azuredatabricks.net', azure_tenant_id='abc') diff --git a/tests/testdata/az b/tests/testdata/az index 5bf43a663..7437babce 100755 --- a/tests/testdata/az +++ b/tests/testdata/az @@ -1,7 +1,20 @@ #!/bin/bash -if [ -n "$WARN" ]; then - >&2 /bin/echo "WARNING: ${WARN}" +# If the arguments are "account show", return the account details. +if [ "$1" == "account" ] && [ "$2" == "show" ]; then + /bin/echo "{ + \"environmentName\": \"AzureCloud\", + \"id\": \"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee\", + \"isDefault\": true, + \"name\": \"Pay-As-You-Go\", + \"state\": \"Enabled\", + \"tenantId\": \"aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee\", + \"user\": { + \"name\": \"${AZ_USER_NAME:-testuser@databricks.com}\", + \"type\": \"${AZ_USER_TYPE:-user}\" + } +}" + exit 0 fi if [ "yes" == "$FAIL" ]; then @@ -26,6 +39,21 @@ for arg in "$@"; do fi done +# Add character to file at $COUNT if it is defined. +if [ -n "$COUNT" ]; then + echo -n x >> "$COUNT" +fi + +# If FAIL_IF_TENANT_ID_SET is set & --tenant-id is passed, fail. +if [ -n "$FAIL_IF_TENANT_ID_SET" ]; then + for arg in "$@"; do + if [[ "$arg" == "--tenant" ]]; then + echo 1>&2 "ERROR: Tenant shouldn't be specified for managed identity account" + exit 1 + fi + done +fi + # Macos EXP="$(/bin/date -v+${EXPIRE:=10S} +'%F %T' 2>/dev/null)" if [ -z "${EXP}" ]; then diff --git a/tests/testdata/windows/az.ps1 b/tests/testdata/windows/az.ps1 index 4aa96adf5..97ecbca7c 100644 --- a/tests/testdata/windows/az.ps1 +++ b/tests/testdata/windows/az.ps1 @@ -1,5 +1,23 @@ #!/usr/bin/env pwsh +# If the arguments are "account show", return the account details. +if ($args[0] -eq "account" -and $args[1] -eq "show") { + $output = @{ + environmentName = "AzureCloud" + id = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + isDefault = $true + name = "Pay-As-You-Go" + state = "Enabled" + tenantId = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" + user = @{ + name = if ($env:AZ_USER_NAME) { $env:AZ_USER_NAME } else { "testuser@databricks.com" } + type = if ($env:AZ_USER_TYPE) { $env:AZ_USER_TYPE } else { "user" } + } + } + $output | ConvertTo-Json + exit 0 +} + if ($env:WARN) { Write-Error "WARNING: $env:WARN" } @@ -30,6 +48,16 @@ foreach ($arg in $Args) { } } +# If FAIL_IF_TENANT_ID_SET is set & --tenant-id is passed, fail. +if ($env:FAIL_IF_TENANT_ID_SET) { + foreach ($arg in $args) { + if ($arg -eq "--tenant-id" -or $arg -like "--tenant*") { + Write-Error "ERROR: Tenant shouldn't be specified for managed identity account" + exit 1 + } + } +} + try { $EXP = (Get-Date).AddSeconds($env:EXPIRE -as [int]) } catch {