diff --git a/.codegen/__init__.py.tmpl b/.codegen/__init__.py.tmpl index 5ca160685..bc68f5654 100644 --- a/.codegen/__init__.py.tmpl +++ b/.codegen/__init__.py.tmpl @@ -5,6 +5,7 @@ from databricks.sdk.credentials_provider import CredentialsStrategy from databricks.sdk.mixins.files import DbfsExt from databricks.sdk.mixins.compute import ClustersExt from databricks.sdk.mixins.workspace import WorkspaceExt +from databricks.sdk.mixins.open_ai_client import ServingEndpointsExt {{- range .Services}} from databricks.sdk.service.{{.Package.Name}} import {{.PascalName}}API{{end}} from databricks.sdk.service.provisioning import Workspace @@ -17,7 +18,7 @@ from typing import Optional "google_credentials" "google_service_account" }} {{- define "api" -}} - {{- $mixins := dict "ClustersAPI" "ClustersExt" "DbfsAPI" "DbfsExt" "WorkspaceAPI" "WorkspaceExt" -}} + {{- $mixins := dict "ClustersAPI" "ClustersExt" "DbfsAPI" "DbfsExt" "WorkspaceAPI" "WorkspaceExt" "ServingEndpointsExt" "ServingEndpointsApi" -}} {{- $genApi := concat .PascalName "API" -}} {{- getOrDefault $mixins $genApi $genApi -}} {{- end -}} diff --git a/NOTICE b/NOTICE index 2a353a6c8..c05cdd318 100644 --- a/NOTICE +++ b/NOTICE @@ -12,8 +12,22 @@ googleapis/google-auth-library-python - https://github.com/googleapis/google-aut Copyright google-auth-library-python authors License - https://github.com/googleapis/google-auth-library-python/blob/main/LICENSE +openai/openai-python - https://github.com/openai/openai-python +Copyright 2024 OpenAI +License - https://github.com/openai/openai-python/blob/main/LICENSE + This software contains code from the following open source projects, licensed under the BSD (3-clause) license. x/oauth2 - https://cs.opensource.google/go/x/oauth2/+/master:oauth2.go Copyright 2014 The Go Authors. All rights reserved. License - https://cs.opensource.google/go/x/oauth2/+/master:LICENSE + +encode/httpx - https://github.com/encode/httpx +Copyright 2019, Encode OSS Ltd +License - https://github.com/encode/httpx/blob/master/LICENSE.md + +This software contains code from the following open source projects, licensed under the MIT license: + +langchain-ai/langchain - https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai +Copyright 2023 LangChain, Inc. +License - https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai/LICENSE diff --git a/databricks/sdk/__init__.py b/databricks/sdk/__init__.py index 848272198..a4058ec51 100755 --- a/databricks/sdk/__init__.py +++ b/databricks/sdk/__init__.py @@ -6,6 +6,7 @@ from databricks.sdk.credentials_provider import CredentialsStrategy from databricks.sdk.mixins.compute import ClustersExt from databricks.sdk.mixins.files import DbfsExt +from databricks.sdk.mixins.open_ai_client import ServingEndpointsExt from databricks.sdk.mixins.workspace import WorkspaceExt from databricks.sdk.service.apps import AppsAPI from databricks.sdk.service.billing import (BillableUsageAPI, BudgetsAPI, @@ -175,7 +176,7 @@ def __init__(self, self._config = config.copy() self._dbutils = _make_dbutils(self._config) self._api_client = client.ApiClient(self._config) - serving_endpoints = ServingEndpointsAPI(self._api_client) + serving_endpoints = ServingEndpointsExt(self._api_client) self._account_access_control_proxy = AccountAccessControlProxyAPI(self._api_client) self._alerts = AlertsAPI(self._api_client) self._alerts_legacy = AlertsLegacyAPI(self._api_client) diff --git a/databricks/sdk/mixins/open_ai_client.py b/databricks/sdk/mixins/open_ai_client.py new file mode 100644 index 000000000..f7a8af02d --- /dev/null +++ b/databricks/sdk/mixins/open_ai_client.py @@ -0,0 +1,52 @@ +from databricks.sdk.service.serving import ServingEndpointsAPI + + +class ServingEndpointsExt(ServingEndpointsAPI): + + # Using the HTTP Client to pass in the databricks authorization + # This method will be called on every invocation, so when using with model serving will always get the refreshed token + def _get_authorized_http_client(self): + import httpx + + class BearerAuth(httpx.Auth): + + def __init__(self, get_headers_func): + self.get_headers_func = get_headers_func + + def auth_flow(self, request: httpx.Request) -> httpx.Request: + auth_headers = self.get_headers_func() + request.headers["Authorization"] = auth_headers["Authorization"] + yield request + + databricks_token_auth = BearerAuth(self._api._cfg.authenticate) + + # Create an HTTP client with Bearer Token authentication + http_client = httpx.Client(auth=databricks_token_auth) + return http_client + + def get_open_ai_client(self): + try: + from openai import OpenAI + except Exception: + raise ImportError( + "Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]`" + ) + + return OpenAI( + base_url=self._api._cfg.host + "/serving-endpoints", + api_key="no-token", # Passing in a placeholder to pass validations, this will not be used + http_client=self._get_authorized_http_client()) + + def get_langchain_chat_open_ai_client(self, model): + try: + from langchain_openai import ChatOpenAI + except Exception: + raise ImportError( + "Langchain Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]` and ensure you are using python>3.7" + ) + + return ChatOpenAI( + model=model, + openai_api_base=self._api._cfg.host + "/serving-endpoints", + api_key="no-token", # Passing in a placeholder to pass validations, this will not be used + http_client=self._get_authorized_http_client()) diff --git a/setup.py b/setup.py index 9cfe38d09..b756e6d0d 100644 --- a/setup.py +++ b/setup.py @@ -17,8 +17,10 @@ extras_require={"dev": ["pytest", "pytest-cov", "pytest-xdist", "pytest-mock", "yapf", "pycodestyle", "autoflake", "isort", "wheel", "ipython", "ipywidgets", "requests-mock", "pyfakefs", - "databricks-connect", "pytest-rerunfailures"], - "notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"]}, + "databricks-connect", "pytest-rerunfailures", "openai", + 'langchain-openai; python_version > "3.7"', "httpx"], + "notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"], + "openai": ["openai", 'langchain-openai; python_version > "3.7"', "httpx"]}, author="Serge Smertin", author_email="serge.smertin@databricks.com", description="Databricks SDK for Python (Beta)", diff --git a/tests/test_open_ai_mixin.py b/tests/test_open_ai_mixin.py new file mode 100644 index 000000000..1858c66cb --- /dev/null +++ b/tests/test_open_ai_mixin.py @@ -0,0 +1,30 @@ +import sys + +import pytest + +from databricks.sdk.core import Config + + +def test_open_ai_client(monkeypatch): + from databricks.sdk import WorkspaceClient + + monkeypatch.setenv('DATABRICKS_HOST', 'test_host') + monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token') + w = WorkspaceClient(config=Config()) + client = w.serving_endpoints.get_open_ai_client() + + assert client.base_url == "https://test_host/serving-endpoints/" + assert client.api_key == "no-token" + + +@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python > 3.7") +def test_langchain_open_ai_client(monkeypatch): + from databricks.sdk import WorkspaceClient + + monkeypatch.setenv('DATABRICKS_HOST', 'test_host') + monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token') + w = WorkspaceClient(config=Config()) + client = w.serving_endpoints.get_langchain_chat_open_ai_client("databricks-meta-llama-3-1-70b-instruct") + + assert client.openai_api_base == "https://test_host/serving-endpoints" + assert client.model_name == "databricks-meta-llama-3-1-70b-instruct"