diff --git a/databricks/sdk/config.py b/databricks/sdk/config.py index b4efdf603..387fa65c5 100644 --- a/databricks/sdk/config.py +++ b/databricks/sdk/config.py @@ -92,15 +92,16 @@ class Config: max_connections_per_pool: int = ConfigAttribute() databricks_environment: Optional[DatabricksEnvironment] = None - def __init__(self, - *, - # Deprecated. Use credentials_strategy instead. - credentials_provider: Optional[CredentialsStrategy] = None, - credentials_strategy: Optional[CredentialsStrategy] = None, - product=None, - product_version=None, - clock: Optional[Clock] = None, - **kwargs): + def __init__( + self, + *, + # Deprecated. Use credentials_strategy instead. + credentials_provider: Optional[CredentialsStrategy] = None, + credentials_strategy: Optional[CredentialsStrategy] = None, + product=None, + product_version=None, + clock: Optional[Clock] = None, + **kwargs): self._header_factory = None self._inner = {} self._user_agent_other_info = [] diff --git a/databricks/sdk/credentials_provider.py b/databricks/sdk/credentials_provider.py index a79151b5a..e91e37af4 100644 --- a/databricks/sdk/credentials_provider.py +++ b/databricks/sdk/credentials_provider.py @@ -304,11 +304,12 @@ def github_oidc_azure(cfg: 'Config') -> Optional[CredentialsProvider]: # detect Azure AD Tenant ID if it's not specified directly token_endpoint = cfg.oidc_endpoints.token_endpoint cfg.azure_tenant_id = token_endpoint.replace(aad_endpoint, '').split('/')[0] - inner = ClientCredentials(client_id=cfg.azure_client_id, - client_secret="", # we have no (rotatable) secrets in OIDC flow - token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token", - endpoint_params=params, - use_params=True) + inner = ClientCredentials( + client_id=cfg.azure_client_id, + client_secret="", # we have no (rotatable) secrets in OIDC flow + token_url=f"{aad_endpoint}{cfg.azure_tenant_id}/oauth2/token", + endpoint_params=params, + use_params=True) def refreshed_headers() -> Dict[str, str]: token = inner.token() diff --git a/tests/integration/test_auth.py b/tests/integration/test_auth.py index 0bf7f951d..3ee271778 100644 --- a/tests/integration/test_auth.py +++ b/tests/integration/test_auth.py @@ -133,15 +133,16 @@ def _test_runtime_auth_from_jobs_inner(w, env_or_skip, random, dbr_versions, lib tasks = [] for v in dbr_versions: - t = Task(task_key=f'test_{v.key.replace(".", "_")}', - notebook_task=NotebookTask(notebook_path=notebook_path), - new_cluster=ClusterSpec( - spark_version=v.key, - num_workers=1, - instance_pool_id=instance_pool_id, - # GCP uses "custom" data security mode by default, which does not support UC. - data_security_mode=DataSecurityMode.SINGLE_USER), - libraries=[library]) + t = Task( + task_key=f'test_{v.key.replace(".", "_")}', + notebook_task=NotebookTask(notebook_path=notebook_path), + new_cluster=ClusterSpec( + spark_version=v.key, + num_workers=1, + instance_pool_id=instance_pool_id, + # GCP uses "custom" data security mode by default, which does not support UC. + data_security_mode=DataSecurityMode.SINGLE_USER), + libraries=[library]) tasks.append(t) waiter = w.jobs.submit(run_name=f'Runtime Native Auth {random(10)}', tasks=tasks) diff --git a/tests/integration/test_jobs.py b/tests/integration/test_jobs.py index 8fd5f8820..768752a75 100644 --- a/tests/integration/test_jobs.py +++ b/tests/integration/test_jobs.py @@ -17,18 +17,19 @@ def test_submitting_jobs(w, random, env_or_skip): with w.dbfs.open(py_on_dbfs, write=True, overwrite=True) as f: f.write(b'import time; time.sleep(10); print("Hello, World!")') - waiter = w.jobs.submit(run_name=f'py-sdk-{random(8)}', - tasks=[ - jobs.SubmitTask( - task_key='pi', - new_cluster=compute.ClusterSpec( - spark_version=w.clusters.select_spark_version(long_term_support=True), - # node_type_id=w.clusters.select_node_type(local_disk=True), - instance_pool_id=env_or_skip('TEST_INSTANCE_POOL_ID'), - num_workers=1), - spark_python_task=jobs.SparkPythonTask(python_file=f'dbfs:{py_on_dbfs}'), - ) - ]) + waiter = w.jobs.submit( + run_name=f'py-sdk-{random(8)}', + tasks=[ + jobs.SubmitTask( + task_key='pi', + new_cluster=compute.ClusterSpec( + spark_version=w.clusters.select_spark_version(long_term_support=True), + # node_type_id=w.clusters.select_node_type(local_disk=True), + instance_pool_id=env_or_skip('TEST_INSTANCE_POOL_ID'), + num_workers=1), + spark_python_task=jobs.SparkPythonTask(python_file=f'dbfs:{py_on_dbfs}'), + ) + ]) logging.info(f'starting to poll: {waiter.run_id}') diff --git a/tests/test_base_client.py b/tests/test_base_client.py index b55f4e7f8..1e133b8fc 100644 --- a/tests/test_base_client.py +++ b/tests/test_base_client.py @@ -280,11 +280,13 @@ def inner(h: BaseHTTPRequestHandler): assert len(requests) == 2 -@pytest.mark.parametrize('chunk_size,expected_chunks,data_size', - [(5, 20, 100), # 100 / 5 bytes per chunk = 20 chunks - (10, 10, 100), # 100 / 10 bytes per chunk = 10 chunks - (200, 1, 100), # 100 / 200 bytes per chunk = 1 chunk - ]) +@pytest.mark.parametrize( + 'chunk_size,expected_chunks,data_size', + [ + (5, 20, 100), # 100 / 5 bytes per chunk = 20 chunks + (10, 10, 100), # 100 / 10 bytes per chunk = 10 chunks + (200, 1, 100), # 100 / 200 bytes per chunk = 1 chunk + ]) def test_streaming_response_chunk_size(chunk_size, expected_chunks, data_size): rng = random.Random(42) test_data = bytes(rng.getrandbits(8) for _ in range(data_size)) diff --git a/tests/test_core.py b/tests/test_core.py index 16a4c2ad6..1cca428cb 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -370,14 +370,20 @@ def inner(h: BaseHTTPRequestHandler): assert {'Authorization': 'Taker this-is-it'} == headers -@pytest.mark.parametrize(['azure_environment', 'expected'], - [('PUBLIC', ENVIRONMENTS['PUBLIC']), ('USGOVERNMENT', ENVIRONMENTS['USGOVERNMENT']), - ('CHINA', ENVIRONMENTS['CHINA']), ('public', ENVIRONMENTS['PUBLIC']), - ('usgovernment', ENVIRONMENTS['USGOVERNMENT']), ('china', ENVIRONMENTS['CHINA']), - # Kept for historical compatibility - ('AzurePublicCloud', ENVIRONMENTS['PUBLIC']), - ('AzureUSGovernment', ENVIRONMENTS['USGOVERNMENT']), - ('AzureChinaCloud', ENVIRONMENTS['CHINA']), ]) +@pytest.mark.parametrize( + ['azure_environment', 'expected'], + [ + ('PUBLIC', ENVIRONMENTS['PUBLIC']), + ('USGOVERNMENT', ENVIRONMENTS['USGOVERNMENT']), + ('CHINA', ENVIRONMENTS['CHINA']), + ('public', ENVIRONMENTS['PUBLIC']), + ('usgovernment', ENVIRONMENTS['USGOVERNMENT']), + ('china', ENVIRONMENTS['CHINA']), + # Kept for historical compatibility + ('AzurePublicCloud', ENVIRONMENTS['PUBLIC']), + ('AzureUSGovernment', ENVIRONMENTS['USGOVERNMENT']), + ('AzureChinaCloud', ENVIRONMENTS['CHINA']), + ]) def test_azure_environment(azure_environment, expected): c = Config(credentials_strategy=noop_credentials, azure_workspace_resource_id='...', diff --git a/tests/test_model_serving_auth.py b/tests/test_model_serving_auth.py index e0e368fae..13f55668c 100644 --- a/tests/test_model_serving_auth.py +++ b/tests/test_model_serving_auth.py @@ -47,13 +47,16 @@ def test_model_serving_auth(env_values, del_env_values, oauth_file_name, monkeyp assert headers.get("Authorization") == 'Bearer databricks_sdk_unit_test_token' -@pytest.mark.parametrize("env_values, oauth_file_name", [ - ([], "invalid_file_name"), # Not in Model Serving and Invalid File Name - ([('IS_IN_DB_MODEL_SERVING_ENV', 'true')], "invalid_file_name"), # In Model Serving and Invalid File Name - ([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true') - ], "invalid_file_name"), # In Model Serving and Invalid File Name - ([], "tests/testdata/model-serving-test-token") # Not in Model Serving and Valid File Name -]) +@pytest.mark.parametrize( + "env_values, oauth_file_name", + [ + ([], "invalid_file_name"), # Not in Model Serving and Invalid File Name + ([('IS_IN_DB_MODEL_SERVING_ENV', 'true') + ], "invalid_file_name"), # In Model Serving and Invalid File Name + ([('IS_IN_DATABRICKS_MODEL_SERVING_ENV', 'true') + ], "invalid_file_name"), # In Model Serving and Invalid File Name + ([], "tests/testdata/model-serving-test-token") # Not in Model Serving and Valid File Name + ]) @raises(default_auth_base_error_message) def test_model_serving_auth_errors(env_values, oauth_file_name, monkeypatch): # Guarantee that the tests defaults to env variables rather than config file.