diff --git a/dspy/retrieve/databricks_rm.py b/dspy/retrieve/databricks_rm.py index 115aa79eb..d092a10b1 100644 --- a/dspy/retrieve/databricks_rm.py +++ b/dspy/retrieve/databricks_rm.py @@ -9,8 +9,12 @@ import dspy from dspy.primitives.prediction import Prediction -_databricks_sdk_installed = find_spec("databricks.sdk") is not None +_databricks_sdk_installed = False +try: + _databricks_sdk_installed = find_spec("databricks.sdk") is not None +except ModuleNotFoundError: + _databricks_sdk_installed = False @dataclass class Document: @@ -137,11 +141,10 @@ def __init__( databricks_client_id if databricks_client_id is not None else os.environ.get("DATABRICKS_CLIENT_ID") ) self.databricks_client_secret = ( - databricks_client_secret - if databricks_client_secret is not None - else os.environ.get("DATABRICKS_CLIENT_SECRET") + databricks_client_secret if databricks_client_secret is not None else os.environ.get("DATABRICKS_CLIENT_SECRET") ) - if not _databricks_sdk_installed and (self.databricks_token, self.databricks_endpoint).count(None) > 0: + if not _databricks_sdk_installed and ((self.databricks_token, self.databricks_endpoint).count(None) > 0 + and (self.databricks_client_id, self.databricks_client_secret).count(None) > 0): raise ValueError( "To retrieve documents with Databricks Vector Search, you must install the" " databricks-sdk Python library, supply the databricks_token and" @@ -183,6 +186,7 @@ def _extract_doc_ids(self, item: Dict[str, Any]) -> str: if self.docs_id_column_name == "metadata": docs_dict = json.loads(item["metadata"]) return docs_dict["document_id"] + return item[self.docs_id_column_name] def _get_extra_columns(self, item: Dict[str, Any]) -> Dict[str, Any]: @@ -198,11 +202,13 @@ def _get_extra_columns(self, item: Dict[str, Any]) -> Dict[str, Any]: for k, v in item.items() if k not in [self.docs_id_column_name, self.text_column_name, self.docs_uri_column_name] } + if self.docs_id_column_name == "metadata": extra_columns = { **extra_columns, **{"metadata": {k: v for k, v in json.loads(item["metadata"]).items() if k != "document_id"}}, } + return extra_columns def forward( @@ -251,6 +257,7 @@ def forward( raise ValueError("Query must be a string or a list of floats.") if _databricks_sdk_installed: + print("Using the Databricks SDK to query the Vector Search Index.") results = self._query_via_databricks_sdk( index_name=self.databricks_index_name, k=self.k, @@ -265,12 +272,15 @@ def forward( filters_json=filters_json or self.filters_json, ) else: + print("Using the REST API to query the Vector Search Index.") results = self._query_via_requests( index_name=self.databricks_index_name, k=self.k, columns=self.columns, databricks_token=self.databricks_token, databricks_endpoint=self.databricks_endpoint, + databricks_client_id=self.databricks_client_id, + databricks_client_secret=self.databricks_client_secret, query_type=query_type, query_text=query_text, query_vector=query_vector, @@ -313,6 +323,7 @@ def forward( ).to_dict() for doc in sorted_docs ] + else: # Returning the prediction return Prediction( @@ -351,8 +362,10 @@ def _query_via_databricks_sdk( filters_json (Optional[str]): JSON string representing additional query filters. databricks_token (str): Databricks authentication token. If not specified, the token is resolved from the current environment. - databricks_endpoint (str): Databricks index endpoint url. If not specified, - the endpoint is resolved from the current environment. + databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing + the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST`` + environment variable. If unspecified, the Databricks SDK is used to identify the + endpoint based on the current environment. databricks_client_id (str): Databricks service principal id. If not specified, the token is resolved from the current environment (DATABRICKS_CLIENT_ID). databricks_client_secret (str): Databricks service principal secret. If not specified, @@ -400,6 +413,8 @@ def _query_via_requests( columns: List[str], databricks_token: str, databricks_endpoint: str, + databricks_client_id: Optional[str], + databricks_client_secret: Optional[str], query_type: str, query_text: Optional[str], query_vector: Optional[List[float]], @@ -413,7 +428,14 @@ def _query_via_requests( k (int): Number of relevant documents to retrieve. columns (List[str]): Column names to include in response. databricks_token (str): Databricks authentication token. - databricks_endpoint (str): Databricks index endpoint url. + databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing + the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST`` + environment variable. If unspecified, the Databricks SDK is used to identify the + endpoint based on the current environment. + databricks_client_id (str): Databricks service principal id. If not specified, + the token is resolved from the current environment (DATABRICKS_CLIENT_ID). + databricks_client_secret (str): Databricks service principal secret. If not specified, + the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET). query_text (Optional[str]): Text query for which to find relevant documents. Exactly one of query_text or query_vector must be specified. query_vector (Optional[List[float]]): Numeric query vector for which to find relevant @@ -423,30 +445,100 @@ def _query_via_requests( Returns: Dict[str, Any]: Parsed JSON response from the Databricks Vector Search Index query. """ + if (query_text, query_vector).count(None) != 1: raise ValueError("Exactly one of query_text or query_vector must be specified.") + if databricks_client_id and databricks_client_secret: + try: + print("Retrieving OAuth token using service principal authentication.") + databricks_token = _get_oauth_token( + index_name, databricks_endpoint, databricks_client_id, databricks_client_secret + ) + except Exception as e: + raise ValueError( + f"Failed to retrieve OAuth token. Please check your Databricks client ID and secret. \n" + f"Error: {e} \n \n" + f"NOTE: If you are experiencing a 401 error, be sure to check the permissions on the index that " + f"you are trying to query. The service principal must have the select permission on the index." + ) + headers = { "Authorization": f"Bearer {databricks_token}", "Content-Type": "application/json", } + payload = { "columns": columns, "num_results": k, "query_type": query_type, } + if filters_json is not None: payload["filters_json"] = filters_json if query_text is not None: payload["query_text"] = query_text elif query_vector is not None: payload["query_vector"] = query_vector + response = requests.post( f"{databricks_endpoint}/api/2.0/vector-search/indexes/{index_name}/query", json=payload, headers=headers, ) + results = response.json() if "error_code" in results: raise Exception(f"ERROR: {results['error_code']} -- {results['message']}") return results + +def _get_oauth_token( + index_name: str, + databricks_endpoint: str, + databricks_client_id: str, + databricks_client_secret: str, +) -> str: + """ + Get OAuth token for Databricks service principal authentication. + + Args: + index_name (str): Name of the Databricks vector search index to query + databricks_endpoint (Optional[str]): The URL of the Databricks Workspace containing + the Vector Search Index. Defaults to the value of the ``DATABRICKS_HOST`` + environment variable. If unspecified, the Databricks SDK is used to identify the + endpoint based on the current environment. + databricks_client_id (str): Databricks service principal id. If not specified, + the token is resolved from the current environment (DATABRICKS_CLIENT_ID). + databricks_client_secret (str): Databricks service principal secret. If not specified, + the endpoint is resolved from the current environment (DATABRICKS_CLIENT_SECRET). + + Returns: + str: OAuth token. + """ + + authorization_details = { + "type": "unity_catalog_permission", + "securable_type": "table", + "securable_object_name": index_name, + "operation": "ReadVectorIndex" + } + + authorization_details_list = [authorization_details] + + token_url = f"{databricks_endpoint}/oidc/v1/token" + + data = { + 'grant_type': 'client_credentials', + 'scope': 'all-apis', + 'authorization_details': json.dumps(authorization_details_list) + } + + response = requests.post( + token_url, + auth=(databricks_client_id, databricks_client_secret), + data=data, + headers={'Content-Type': 'application/x-www-form-urlencoded'} + ) + + response.raise_for_status() + return response.json()['access_token'] \ No newline at end of file