Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Open AI Client Mixin #779

Merged
merged 16 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .codegen/__init__.py.tmpl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ from databricks.sdk.credentials_provider import CredentialsStrategy
from databricks.sdk.mixins.files import DbfsExt
from databricks.sdk.mixins.compute import ClustersExt
from databricks.sdk.mixins.workspace import WorkspaceExt
from databricks.sdk.mixins.open_ai_client import ServingEndpointsExt
{{- range .Services}}
from databricks.sdk.service.{{.Package.Name}} import {{.PascalName}}API{{end}}
from databricks.sdk.service.provisioning import Workspace
Expand All @@ -17,7 +18,7 @@ from typing import Optional
"google_credentials" "google_service_account" }}

{{- define "api" -}}
{{- $mixins := dict "ClustersAPI" "ClustersExt" "DbfsAPI" "DbfsExt" "WorkspaceAPI" "WorkspaceExt" -}}
{{- $mixins := dict "ClustersAPI" "ClustersExt" "DbfsAPI" "DbfsExt" "WorkspaceAPI" "WorkspaceExt" "ServingEndpointsExt" "ServingEndpointsApi" -}}
{{- $genApi := concat .PascalName "API" -}}
{{- getOrDefault $mixins $genApi $genApi -}}
{{- end -}}
Expand Down
14 changes: 14 additions & 0 deletions NOTICE
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,22 @@ googleapis/google-auth-library-python - https://github.com/googleapis/google-aut
Copyright google-auth-library-python authors
License - https://github.com/googleapis/google-auth-library-python/blob/main/LICENSE

openai/openai-python - https://github.com/openai/openai-python
Copyright 2024 OpenAI
License - https://github.com/openai/openai-python/blob/main/LICENSE

This software contains code from the following open source projects, licensed under the BSD (3-clause) license.

x/oauth2 - https://cs.opensource.google/go/x/oauth2/+/master:oauth2.go
Copyright 2014 The Go Authors. All rights reserved.
License - https://cs.opensource.google/go/x/oauth2/+/master:LICENSE

encode/httpx - https://github.com/encode/httpx
Copyright 2019, Encode OSS Ltd
License - https://github.com/encode/httpx/blob/master/LICENSE.md

This software contains code from the following open source projects, licensed under the MIT license:

langchain-ai/langchain - https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai
Copyright 2023 LangChain, Inc.
License - https://github.com/langchain-ai/langchain/blob/master/libs/partners/openai/LICENSE
3 changes: 2 additions & 1 deletion databricks/sdk/__init__.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

52 changes: 52 additions & 0 deletions databricks/sdk/mixins/open_ai_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from databricks.sdk.service.serving import ServingEndpointsAPI


class ServingEndpointsExt(ServingEndpointsAPI):

# Using the HTTP Client to pass in the databricks authorization
# This method will be called on every invocation, so when using with model serving will always get the refreshed token
def _get_authorized_http_client(self):
import httpx

class BearerAuth(httpx.Auth):

def __init__(self, get_headers_func):
self.get_headers_func = get_headers_func

def auth_flow(self, request: httpx.Request) -> httpx.Request:
auth_headers = self.get_headers_func()
request.headers["Authorization"] = auth_headers["Authorization"]
yield request

databricks_token_auth = BearerAuth(self._api._cfg.authenticate)

# Create an HTTP client with Bearer Token authentication
http_client = httpx.Client(auth=databricks_token_auth)
return http_client

def get_open_ai_client(self):
try:
from openai import OpenAI
except Exception:
raise ImportError(
"Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]`"
)

return OpenAI(
aravind-segu marked this conversation as resolved.
Show resolved Hide resolved
base_url=self._api._cfg.host + "/serving-endpoints",
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used
http_client=self._get_authorized_http_client())

def get_langchain_chat_open_ai_client(self, model):
try:
from langchain_openai import ChatOpenAI
except Exception:
raise ImportError(
"Langchain Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]` and ensure you are using python>3.7"
)

return ChatOpenAI(
model=model,
openai_api_base=self._api._cfg.host + "/serving-endpoints",
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used
http_client=self._get_authorized_http_client())
6 changes: 4 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
extras_require={"dev": ["pytest", "pytest-cov", "pytest-xdist", "pytest-mock",
"yapf", "pycodestyle", "autoflake", "isort", "wheel",
"ipython", "ipywidgets", "requests-mock", "pyfakefs",
"databricks-connect", "pytest-rerunfailures"],
"notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"]},
"databricks-connect", "pytest-rerunfailures", "openai",
'langchain-openai; python_version > "3.7"', "httpx"],
"notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"],
"openai": ["openai", 'langchain-openai; python_version > "3.7"', "httpx"]},
author="Serge Smertin",
author_email="[email protected]",
description="Databricks SDK for Python (Beta)",
Expand Down
30 changes: 30 additions & 0 deletions tests/test_open_ai_mixin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import sys

import pytest

from databricks.sdk.core import Config


def test_open_ai_client(monkeypatch):
from databricks.sdk import WorkspaceClient

monkeypatch.setenv('DATABRICKS_HOST', 'test_host')
monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token')
w = WorkspaceClient(config=Config())
client = w.serving_endpoints.get_open_ai_client()

assert client.base_url == "https://test_host/serving-endpoints/"
assert client.api_key == "no-token"


@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python > 3.7")
def test_langchain_open_ai_client(monkeypatch):
from databricks.sdk import WorkspaceClient

monkeypatch.setenv('DATABRICKS_HOST', 'test_host')
monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token')
w = WorkspaceClient(config=Config())
client = w.serving_endpoints.get_langchain_chat_open_ai_client("databricks-meta-llama-3-1-70b-instruct")

assert client.openai_api_base == "https://test_host/serving-endpoints"
assert client.model_name == "databricks-meta-llama-3-1-70b-instruct"
Loading