-
Notifications
You must be signed in to change notification settings - Fork 157
[Feature] Open AI Client Mixin #779
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
16 commits
Select commit
Hold shift + click to select a range
675c45b
[ML-45784]Add open ai client mixin
aravind-segu 2822f29
[ML-45784]Add setup file
aravind-segu 2e7d6d4
[ML-45784]Format fixes
aravind-segu 8029bba
Add open ai to dev
aravind-segu deda35d
Add Langchain Open AI Client
aravind-segu 2d3ec7d
Skip langchain test for less than 3.7
aravind-segu 295a5c3
Update setup.py for dev
aravind-segu 59269c1
remove unneccessary files
aravind-segu c92a0e9
handle python versions in tests
aravind-segu b11d127
Skip if less than 3.8
aravind-segu 64631ae
Use https client for authorization on request
aravind-segu dec6de5
add print statement to verify
aravind-segu 83618ec
Update tests
aravind-segu f7b4a18
Add httpx to setup
aravind-segu 41f197b
Update Notice to include the new packages
aravind-segu 5582d58
Undo file deletions
aravind-segu File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
from databricks.sdk.service.serving import ServingEndpointsAPI | ||
|
||
|
||
class ServingEndpointsExt(ServingEndpointsAPI): | ||
|
||
# Using the HTTP Client to pass in the databricks authorization | ||
# This method will be called on every invocation, so when using with model serving will always get the refreshed token | ||
def _get_authorized_http_client(self): | ||
import httpx | ||
|
||
class BearerAuth(httpx.Auth): | ||
|
||
def __init__(self, get_headers_func): | ||
self.get_headers_func = get_headers_func | ||
|
||
def auth_flow(self, request: httpx.Request) -> httpx.Request: | ||
auth_headers = self.get_headers_func() | ||
request.headers["Authorization"] = auth_headers["Authorization"] | ||
yield request | ||
|
||
databricks_token_auth = BearerAuth(self._api._cfg.authenticate) | ||
|
||
# Create an HTTP client with Bearer Token authentication | ||
http_client = httpx.Client(auth=databricks_token_auth) | ||
return http_client | ||
|
||
def get_open_ai_client(self): | ||
try: | ||
from openai import OpenAI | ||
except Exception: | ||
raise ImportError( | ||
"Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]`" | ||
) | ||
|
||
return OpenAI( | ||
base_url=self._api._cfg.host + "/serving-endpoints", | ||
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used | ||
http_client=self._get_authorized_http_client()) | ||
|
||
def get_langchain_chat_open_ai_client(self, model): | ||
try: | ||
from langchain_openai import ChatOpenAI | ||
except Exception: | ||
raise ImportError( | ||
"Langchain Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]` and ensure you are using python>3.7" | ||
) | ||
|
||
return ChatOpenAI( | ||
model=model, | ||
openai_api_base=self._api._cfg.host + "/serving-endpoints", | ||
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used | ||
http_client=self._get_authorized_http_client()) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,8 +17,10 @@ | |
extras_require={"dev": ["pytest", "pytest-cov", "pytest-xdist", "pytest-mock", | ||
"yapf", "pycodestyle", "autoflake", "isort", "wheel", | ||
"ipython", "ipywidgets", "requests-mock", "pyfakefs", | ||
"databricks-connect", "pytest-rerunfailures"], | ||
"notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"]}, | ||
"databricks-connect", "pytest-rerunfailures", "openai", | ||
'langchain-openai; python_version > "3.7"', "httpx"], | ||
"notebook": ["ipython>=8,<9", "ipywidgets>=8,<9"], | ||
"openai": ["openai", 'langchain-openai; python_version > "3.7"', "httpx"]}, | ||
author="Serge Smertin", | ||
author_email="[email protected]", | ||
description="Databricks SDK for Python (Beta)", | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
import sys | ||
|
||
import pytest | ||
|
||
from databricks.sdk.core import Config | ||
|
||
|
||
def test_open_ai_client(monkeypatch): | ||
from databricks.sdk import WorkspaceClient | ||
|
||
monkeypatch.setenv('DATABRICKS_HOST', 'test_host') | ||
monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token') | ||
w = WorkspaceClient(config=Config()) | ||
client = w.serving_endpoints.get_open_ai_client() | ||
|
||
assert client.base_url == "https://test_host/serving-endpoints/" | ||
assert client.api_key == "no-token" | ||
|
||
|
||
@pytest.mark.skipif(sys.version_info < (3, 8), reason="Requires Python > 3.7") | ||
def test_langchain_open_ai_client(monkeypatch): | ||
from databricks.sdk import WorkspaceClient | ||
|
||
monkeypatch.setenv('DATABRICKS_HOST', 'test_host') | ||
monkeypatch.setenv('DATABRICKS_TOKEN', 'test_token') | ||
w = WorkspaceClient(config=Config()) | ||
client = w.serving_endpoints.get_langchain_chat_open_ai_client("databricks-meta-llama-3-1-70b-instruct") | ||
|
||
assert client.openai_api_base == "https://test_host/serving-endpoints" | ||
assert client.model_name == "databricks-meta-llama-3-1-70b-instruct" |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.