Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix DatabricksConfig.copy when authenticated with OAuth #723

Merged
merged 6 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion databricks/sdk/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,10 @@ def inner() -> Dict[str, str]:
token = token_source.token()
return {'Authorization': f'{token.token_type} {token.access_token}'}

return OAuthCredentialsProvider(inner, token_source.token)
def token() -> Token:
return token_source.token()

return OAuthCredentialsProvider(inner, token)


class MetadataServiceTokenSource(Refreshable):
Expand Down
39 changes: 39 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import os
import pathlib
import platform
import random
import string
from datetime import datetime

import pytest

from databricks.sdk import useragent
from databricks.sdk.config import Config, with_product, with_user_agent_extra
from databricks.sdk.credentials_provider import Token
from databricks.sdk.version import __version__

from .conftest import noop_credentials, set_az_path
Expand Down Expand Up @@ -79,6 +84,40 @@ def test_config_copy_deep_copies_user_agent_other_info(config):
useragent._reset_extra(original_extra)


def test_config_deep_copy(monkeypatch, mocker, tmp_path):
mocker.patch('databricks.sdk.credentials_provider.CliTokenSource.refresh',
return_value=Token(access_token='token',
token_type='Bearer',
expiry=datetime(2023, 5, 22, 0, 0, 0)))

write_large_dummy_executable(tmp_path)
monkeypatch.setenv('PATH', tmp_path.as_posix())

config = Config(host="https://abc123.azuredatabricks.net", auth_type="databricks-cli")
config_copy = config.deep_copy()
assert config_copy.host == config.host


def write_large_dummy_executable(path: pathlib.Path):
cli = path.joinpath('databricks')

# Generate a long random string to inflate the file size.
random_string = ''.join(random.choice(string.ascii_letters) for i in range(1024 * 1024))
cli.write_text("""#!/bin/sh
cat <<EOF
{
"access_token": "...",
"token_type": "Bearer",
"expiry": "2023-05-22T00:00:00.000000+00:00"
}
EOF
exit 0
""" + random_string)
cli.chmod(0o755)
assert cli.stat().st_size >= (1024 * 1024)
return cli


def test_load_azure_tenant_id_404(requests_mock, monkeypatch):
set_az_path(monkeypatch)
mock = requests_mock.get('https://abc123.azuredatabricks.net/aad/auth', status_code=404)
Expand Down
Loading