-
Notifications
You must be signed in to change notification settings - Fork 131
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Use https client for authorization on request
- Loading branch information
1 parent
ed97494
commit fc39ed7
Showing
2 changed files
with
36 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,40 +1,53 @@ | ||
import httpx | ||
|
||
from databricks.sdk.service.serving import ServingEndpointsAPI | ||
|
||
|
||
class ServingEndpointsExt(ServingEndpointsAPI): | ||
|
||
def get_open_ai_client(self): | ||
auth_headers = self._api._cfg.authenticate() | ||
# Using the HTTP Client to pass in the databricks authorization | ||
# This method will be called on every invocation, so when using with model serving will always get the refreshed token | ||
def _get_authorized_http_client(self): | ||
|
||
try: | ||
token = auth_headers["Authorization"][len("Bearer "):] | ||
except Exception: | ||
raise ValueError("Unable to extract authorization token for OpenAI Client") | ||
class BearerAuth(httpx.Auth): | ||
|
||
def __init__(self, get_headers_func): | ||
self.get_headers_func = get_headers_func | ||
|
||
def auth_flow(self, request: httpx.Request) -> httpx.Request: | ||
auth_headers = self.get_headers_func() | ||
request.headers["Authorization"] = auth_headers["Authorization"] | ||
yield request | ||
|
||
databricks_token_auth = BearerAuth(self._api._cfg.authenticate) | ||
|
||
# Create an HTTP client with Bearer Token authentication | ||
http_client = httpx.Client(auth=databricks_token_auth) | ||
return http_client | ||
|
||
def get_open_ai_client(self): | ||
try: | ||
from openai import OpenAI | ||
except Exception: | ||
raise ImportError( | ||
"Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]`" | ||
) | ||
|
||
return OpenAI(base_url=self._api._cfg.host + "/serving-endpoints", api_key=token) | ||
return OpenAI( | ||
base_url=self._api._cfg.host + "/serving-endpoints", | ||
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used | ||
http_client=self._get_authorized_http_client()) | ||
|
||
def get_langchain_chat_open_ai_client(self, model): | ||
auth_headers = self._api._cfg.authenticate() | ||
|
||
try: | ||
from langchain_openai import ChatOpenAI | ||
except Exception: | ||
raise ImportError( | ||
"Langchain Open AI is not installed. Please install the Databricks SDK with the following command `pip isntall databricks-sdk[openai]` and ensure you are using python>3.7" | ||
) | ||
|
||
try: | ||
token = auth_headers["Authorization"][len("Bearer "):] | ||
except Exception: | ||
raise ValueError("Unable to extract authorization token for Langchain OpenAI Client") | ||
|
||
return ChatOpenAI(model=model, | ||
openai_api_base=self._api._cfg.host + "/serving-endpoints", | ||
openai_api_key=token) | ||
return ChatOpenAI( | ||
model=model, | ||
openai_api_base=self._api._cfg.host + "/serving-endpoints", | ||
api_key="no-token", # Passing in a placeholder to pass validations, this will not be used | ||
http_client=self._get_authorized_http_client()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters