diff --git a/tests/test_open_ai_mixin.py b/tests/test_open_ai_mixin.py index 2d184267..287947b0 100644 --- a/tests/test_open_ai_mixin.py +++ b/tests/test_open_ai_mixin.py @@ -1,7 +1,4 @@ import sys - -import pytest - from databricks.sdk.core import Config @@ -17,14 +14,19 @@ def test_open_ai_client(monkeypatch): assert client.api_key == "test_token" -@pytest.mark.skipif(sys.version_info <= (3, 7), reason="Requires Python > 3.7") def test_langchain_open_ai_client(monkeypatch): from databricks.sdk import WorkspaceClient - - monkeypatch.setenv('DATABRICKS_HOST', 'test_host') - monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token') - w = WorkspaceClient(config=Config()) - client = w.serving_endpoints.get_langchain_chat_open_ai_client("databricks-meta-llama-3-1-70b-instruct") - - assert client.openai_api_base == "https://test_host/serving-endpoints" - assert client.model_name == "databricks-meta-llama-3-1-70b-instruct" + print(sys.version_info) + print(sys.version_info <= (3,7)) + if sys.version_info <= (3, 7): + with pytest.raises(ImportError): + w = WorkspaceClient(config=Config()) + client = w.serving_endpoints.get_langchain_chat_open_ai_client("databricks-meta-llama-3-1-70b-instruct") + else: + monkeypatch.setenv('DATABRICKS_HOST', 'test_host') + monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token') + w = WorkspaceClient(config=Config()) + client = w.serving_endpoints.get_langchain_chat_open_ai_client("databricks-meta-llama-3-1-70b-instruct") + + assert client.openai_api_base == "https://test_host/serving-endpoints" + assert client.model_name == "databricks-meta-llama-3-1-70b-instruct"