Skip to content

Updated the DatabricksRM class to use Databricks service principals with REST API path #8327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
afcbc52
Updated the DatabricksRM class to use Databricks service principals.
willsmithDB May 27, 2025
784354a
Updated comments for the usage of service principals for the Databric…
willsmithDB May 29, 2025
f06c2ed
Updated print statements to acknowledge auth method.
willsmithDB May 29, 2025
6c3512d
Merge branch 'stanfordnlp:main' into dbx_service_principal_functionality
willsmithDB May 29, 2025
4b00022
Updated print statements to acknowledge auth method.
willsmithDB May 29, 2025
f9b2b9d
Merge remote-tracking branch 'origin/dbx_service_principal_functional…
willsmithDB May 29, 2025
b47fe02
format
chenmoneygithub May 29, 2025
9cf562a
Updated for Oauth support via REST API
willsmithDB Jun 3, 2025
b3a57fa
Merge remote-tracking branch 'origin/dbx_service_principal_functional…
willsmithDB Jun 3, 2025
34fa87c
Updated for Oauth support via REST API
willsmithDB Jun 3, 2025
aedce91
Updated for Oauth support via REST API
willsmithDB Jun 3, 2025
594fb04
Updated for Oauth support via REST API
willsmithDB Jun 3, 2025
ff8a2e1
Merge remote-tracking branch 'origin/dbx_service_principal_functional…
willsmithDB Jun 3, 2025
554f368
Merge branch 'stanfordnlp:main' into dbx_service_principal_functionality
willsmithDB Jun 3, 2025
0e55841
Updated for Databricks Oauth support via REST API
willsmithDB Jun 3, 2025
49c299f
Merge remote-tracking branch 'origin/dbx_service_principal_functional…
willsmithDB Jun 3, 2025
7faec6e
Merge branch 'stanfordnlp:main' into dbx_service_principal_functionality
willsmithDB Jun 4, 2025
e9a68ca
Merge pull request #1 from willsmithDB/dbx_service_principal_function…
willsmithDB Jun 4, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 100 additions & 8 deletions dspy/retrieve/databricks_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,12 @@
import dspy
from dspy.primitives.prediction import Prediction

_databricks_sdk_installed = find_spec("databricks.sdk") is not None
_databricks_sdk_installed = False

try:
_databricks_sdk_installed = find_spec("databricks.sdk") is not None
except ModuleNotFoundError:
_databricks_sdk_installed = False

@dataclass
class Document:
Expand Down Expand Up @@ -137,11 +141,10 @@ def __init__(
databricks_client_id if databricks_client_id is not None else os.environ.get("DATABRICKS_CLIENT_ID")
)
self.databricks_client_secret = (
databricks_client_secret
if databricks_client_secret is not None
else os.environ.get("DATABRICKS_CLIENT_SECRET")
databricks_client_secret if databricks_client_secret is not None else os.environ.get("DATABRICKS_CLIENT_SECRET")
)
if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0:
if not _databricks_sdk_installed and ((self.databricks_token, self.databricks_endpoint).count(None) > 0
and (self.databricks_client_id, self.databricks_client_secret).count(None) > 0):
raise ValueError(
"To retrieve documents with Databricks Vector Search, you must install the"
" databricks-sdk Python library, supply the databricks_token and"
Expand Down Expand Up @@ -183,6 +186,7 @@ def _extract_doc_ids(self, item: Dict[str, Any]) -> str:
if self.docs_id_column_name == "metadata":
docs_dict = json.loads(item["metadata"])
return docs_dict["document_id"]

return item[self.docs_id_column_name]

def _get_extra_columns(self, item: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -198,11 +202,13 @@ def _get_extra_columns(self, item: Dict[str, Any]) -> Dict[str, Any]:
for k, v in item.items()
if k not in [self.docs_id_column_name, self.text_column_name, self.docs_uri_column_name]
}

if self.docs_id_column_name == "metadata":
extra_columns = {
**extra_columns,
**{"metadata": {k: v for k, v in json.loads(item["metadata"]).items() if k != "document_id"}},
}

return extra_columns

def forward(
Expand Down Expand Up @@ -251,6 +257,7 @@ def forward(
raise ValueError("Query must be a string or a list of floats.")

if _databricks_sdk_installed:
print("Using the Databricks SDK to query the Vector Search Index.")
results = self._query_via_databricks_sdk(
index_name=self.databricks_index_name,
k=self.k,
Expand All @@ -265,12 +272,15 @@ def forward(
filters_json=filters_json or self.filters_json,
)
else:
print("Using the REST API to query the Vector Search Index.")
results = self._query_via_requests(
index_name=self.databricks_index_name,
k=self.k,
columns=self.columns,
databricks_token=self.databricks_token,
databricks_endpoint=self.databricks_endpoint,
databricks_client_id=self.databricks_client_id,
databricks_client_secret=self.databricks_client_secret,
query_type=query_type,
query_text=query_text,
query_vector=query_vector,
Expand Down Expand Up @@ -313,6 +323,7 @@ def forward(
).to_dict()
for doc in sorted_docs
]

else:
# Returning the prediction
return Prediction(
Expand Down Expand Up @@ -351,8 +362,10 @@ def _query_via_databricks_sdk(
filters_json (Optional[str]): JSON string representing additional query filters.
databricks_token (str): Databricks authentication token. If not specified,
the token is resolved from the current environment.
databricks_endpoint (str): Databricks index endpoint url. If not specified,
the endpoint is resolved from the current environment.
databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing
the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST``
environment variable. If unspecified, the Databricks SDK is used to identify the
endpoint based on the current environment.
databricks_client_id (str): Databricks service principal id. If not specified,
the token is resolved from the current environment (DATABRICKS_CLIENT_ID).
databricks_client_secret (str): Databricks service principal secret. If not specified,
Expand Down Expand Up @@ -400,6 +413,8 @@ def _query_via_requests(
columns: List[str],
databricks_token: str,
databricks_endpoint: str,
databricks_client_id: Optional[str],
databricks_client_secret: Optional[str],
query_type: str,
query_text: Optional[str],
query_vector: Optional[List[float]],
Expand All @@ -413,7 +428,14 @@ def _query_via_requests(
k (int): Number of relevant documents to retrieve.
columns (List[str]): Column names to include in response.
databricks_token (str): Databricks authentication token.
databricks_endpoint (str): Databricks index endpoint url.
databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing
the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST``
environment variable. If unspecified, the Databricks SDK is used to identify the
endpoint based on the current environment.
databricks_client_id (str): Databricks service principal id. If not specified,
the token is resolved from the current environment (DATABRICKS_CLIENT_ID).
databricks_client_secret (str): Databricks service principal secret. If not specified,
the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET).
query_text (Optional[str]): Text query for which to find relevant documents. Exactly
one of query_text or query_vector must be specified.
query_vector (Optional[List[float]]): Numeric query vector for which to find relevant
Expand All @@ -423,30 +445,100 @@ def _query_via_requests(
Returns:
Dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query.
"""

if (query_text, query_vector).count(None) != 1:
raise ValueError("Exactly one of query_text or query_vector must be specified.")

if databricks_client_id and databricks_client_secret:
try:
print("Retrieving OAuth token using service principal authentication.")
databricks_token = _get_oauth_token(
index_name, databricks_endpoint, databricks_client_id, databricks_client_secret
)
except Exception as e:
raise ValueError(
f"Failed to retrieve OAuth token. Please check your Databricks client ID and secret. \n"
f"Error: {e} \n \n"
f"NOTE: If you are experiencing a 401 error, be sure to check the permissions on the index that "
f"you are trying to query. The service principal must have the select permission on the index."
)

headers = {
"Authorization": f"Bearer {databricks_token}",
"Content-Type": "application/json",
}

payload = {
"columns": columns,
"num_results": k,
"query_type": query_type,
}

if filters_json is not None:
payload["filters_json"] = filters_json
if query_text is not None:
payload["query_text"] = query_text
elif query_vector is not None:
payload["query_vector"] = query_vector

response = requests.post(
f"{databricks_endpoint}/api/2.0/vector-search/indexes/{index_name}/query",
json=payload,
headers=headers,
)

results = response.json()
if "error_code" in results:
raise Exception(f"ERROR: {results['error_code']} -- {results['message']}")
return results

def _get_oauth_token(
index_name: str,
databricks_endpoint: str,
databricks_client_id: str,
databricks_client_secret: str,
) -> str:
"""
Get OAuth token for Databricks service principal authentication.

Args:
index_name (str): Name of the Databricks vector search index to query
databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing
the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST``
environment variable. If unspecified, the Databricks SDK is used to identify the
endpoint based on the current environment.
databricks_client_id (str): Databricks service principal id. If not specified,
the token is resolved from the current environment (DATABRICKS_CLIENT_ID).
databricks_client_secret (str): Databricks service principal secret. If not specified,
the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET).

Returns:
str: OAuth token.
"""

authorization_details = {
"type": "unity_catalog_permission",
"securable_type": "table",
"securable_object_name": index_name,
"operation": "ReadVectorIndex"
}

authorization_details_list = [authorization_details]

token_url = f"{databricks_endpoint}/oidc/v1/token"

data = {
'grant_type': 'client_credentials',
'scope': 'all-apis',
'authorization_details': json.dumps(authorization_details_list)
}

response = requests.post(
token_url,
auth=(databricks_client_id, databricks_client_secret),
data=data,
headers={'Content-Type': 'application/x-www-form-urlencoded'}
)

response.raise_for_status()
return response.json()['access_token']