diff --git a/cohere/compass/__init__.py b/cohere/compass/__init__.py index 9f1de0e..70175e6 100644 --- a/cohere/compass/__init__.py +++ b/cohere/compass/__init__.py @@ -12,7 +12,7 @@ ValidatedModel, ) -__version__ = "0.10.2" +__version__ = "0.11.0" class ProcessFileParameters(ValidatedModel): @@ -33,14 +33,14 @@ class ProcessFilesParameters(ValidatedModel): class GroupAuthorizationActions(str, Enum): - """Enum for use with the edit_group_authorization API to specify the edit type.""" + """Enum for use with the update_group_authorization API to specify the edit type.""" ADD = "add" REMOVE = "remove" class GroupAuthorizationInput(BaseModel): - """Model for use with the edit_group_authorization API.""" + """Model for use with the update_group_authorization API.""" document_ids: list[str] authorized_groups: list[str] diff --git a/cohere/compass/clients/compass.py b/cohere/compass/clients/compass.py index b0891e7..df8428b 100644 --- a/cohere/compass/clients/compass.py +++ b/cohere/compass/clients/compass.py @@ -39,6 +39,7 @@ from cohere.compass.exceptions import ( CompassAuthError, CompassClientError, + CompassError, CompassMaxErrorRateExceeded, ) from cohere.compass.models import ( @@ -59,6 +60,7 @@ UploadDocumentsInput, ) from cohere.compass.models.datasources import PaginatedList +from cohere.compass.models.documents import DocumentAttributes, PutDocumentsResponse @dataclass @@ -117,7 +119,7 @@ def __init__( "add_attributes": self.session.post, "refresh": self.session.post, "upload_documents": self.session.post, - "edit_group_authorization": self.session.post, + "update_group_authorization": self.session.post, # Data Sources APIs "create_datasource": self.session.post, "list_datasources": self.session.get, @@ -138,7 +140,7 @@ def __init__( "add_attributes": "/api/v1/indexes/{index_name}/documents/{document_id}/_add_attributes", # noqa: E501 "refresh": "/api/v1/indexes/{index_name}/_refresh", "upload_documents": "/api/v1/indexes/{index_name}/documents/_upload", - "edit_group_authorization": "/api/v1/indexes/{index_name}/group_authorization", # noqa: E501 + "update_group_authorization": "/api/v1/indexes/{index_name}/group_authorization", # noqa: E501 # Data Sources APIs "create_datasource": "/api/v1/datasources", "list_datasources": "/api/v1/datasources", @@ -162,7 +164,7 @@ def create_index(self, *, index_name: str): index_name=index_name, ) - def refresh(self, *, index_name: str): + def refresh_index(self, *, index_name: str): """ Refresh index. @@ -242,7 +244,7 @@ def add_attributes( *, index_name: str, document_id: str, - context: dict[str, Any], + attributes: DocumentAttributes, max_retries: int = DEFAULT_MAX_RETRIES, sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, ) -> Optional[RetryResult]: @@ -251,8 +253,7 @@ def add_attributes( :param index_name: the name of the index :param document_id: the document to modify - :param context: A dictionary of key-value pairs to insert into the content field - of a document + :param attributes: the attributes to add to the document :param max_retries: the maximum number of times to retry a doc insertion :param sleep_retry_seconds: number of seconds to go to sleep before retrying a doc insertion @@ -260,7 +261,7 @@ def add_attributes( return self._send_request( api_name="add_attributes", document_id=document_id, - data=context, + data=attributes, max_retries=max_retries, sleep_retry_seconds=sleep_retry_seconds, index_name=index_name, @@ -301,7 +302,7 @@ def upload_document( filebytes: bytes, content_type: str, document_id: uuid.UUID, - attributes: dict[str, Any] = {}, + attributes: DocumentAttributes = DocumentAttributes(), max_retries: int = DEFAULT_MAX_RETRIES, sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, ) -> Optional[Union[str, dict[str, Any]]]: @@ -744,29 +745,32 @@ def search_chunks( return SearchChunksResponse.model_validate(result.result) - def edit_group_authorization( + def update_group_authorization( self, *, index_name: str, group_auth_input: GroupAuthorizationInput - ): + ) -> PutDocumentsResponse: """ Edit group authorization for an index. :param index_name: the name of the index :param group_auth_input: the group authorization input """ - return self._send_request( - api_name="edit_group_authorization", + result = self._send_request( + api_name="update_group_authorization", index_name=index_name, data=group_auth_input, max_retries=DEFAULT_MAX_RETRIES, sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS, ) + if result.error: + raise CompassError(result.error) + return PutDocumentsResponse.model_validate(result.result) def _send_request( self, api_name: str, max_retries: int, sleep_retry_seconds: int, - data: Optional[Union[dict[str, Any], BaseModel]] = None, + data: Optional[BaseModel] = None, **url_params: str, ) -> RetryResult: """ @@ -794,12 +798,7 @@ def _send_request_with_retry(): nonlocal error try: - data_dict = None - if data: - if isinstance(data, BaseModel): - data_dict = data.model_dump(mode="json") - else: - data_dict = data + data_dict = data.model_dump(mode="json") if data else None headers = None auth = None diff --git a/cohere/compass/exceptions.py b/cohere/compass/exceptions.py index c074f1b..47c1440 100644 --- a/cohere/compass/exceptions.py +++ b/cohere/compass/exceptions.py @@ -1,4 +1,10 @@ -class CompassClientError(Exception): +class CompassError(Exception): + """Base class for all exceptions raised by the Compass client.""" + + pass + + +class CompassClientError(CompassError): """Exception raised for all 4xx client errors in the Compass client.""" def __init__( # noqa: D107 diff --git a/cohere/compass/models/documents.py b/cohere/compass/models/documents.py index e9e7ace..ff9b5f7 100644 --- a/cohere/compass/models/documents.py +++ b/cohere/compass/models/documents.py @@ -5,7 +5,7 @@ from typing import Annotated, Any, Optional # 3rd party imports -from pydantic import BaseModel, Field, PositiveInt, StringConstraints +from pydantic import BaseModel, ConfigDict, PositiveInt, StringConstraints # Local imports from cohere.compass.models import ValidatedModel @@ -195,6 +195,19 @@ class Document(BaseModel): authorized_groups: Optional[list[str]] = None +class DocumentAttributes(BaseModel): + """Model class for document attributes.""" + + model_config = ConfigDict(extra="allow") + + # Had to add this to please the linter, because BaseModel only defines __setattr__ + # if TYPE_CHECKING is not set, i.e. at runtime, resulting in the type checking pass + # done by the linter failing to find the __setattr__ method. See: + # https://github.com/pydantic/pydantic/blob/main/pydantic/main.py#L878-L920 + def __setattr__(self, name: str, value: Any): # noqa: D105 + return super().__setattr__(name, value) + + class ParseableDocument(BaseModel): """A document to be sent to Compass for parsing.""" @@ -205,7 +218,7 @@ class ParseableDocument(BaseModel): content_type: str content_length_bytes: PositiveInt # File size must be a non-negative integer content_encoded_bytes: str # Base64-encoded file contents - attributes: dict[str, Any] = Field(default_factory=dict) + attributes: DocumentAttributes class UploadDocumentsInput(BaseModel): @@ -220,3 +233,20 @@ class PutDocumentsInput(BaseModel): documents: list[Document] authorized_groups: Optional[list[str]] = None merge_groups_on_conflict: bool = False + + +class PutDocumentResult(BaseModel): + """ + A model for the response of put_document. + + This model is also used by the put_documents and edit_group_authorization APIs. + """ + + document_id: str + error: Optional[str] + + +class PutDocumentsResponse(BaseModel): + """A model for the response of put_documents and edit_group_authorization APIs.""" + + results: list[PutDocumentResult] diff --git a/pyproject.toml b/pyproject.toml index 26b5158..4a606b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "compass-sdk" -version = "0.10.2" +version = "0.11.0" authors = [] description = "Compass SDK" readme = "README.md" diff --git a/tests/test_compass_client.py b/tests/test_compass_client.py index ae99dbd..584d85e 100644 --- a/tests/test_compass_client.py +++ b/tests/test_compass_client.py @@ -2,6 +2,7 @@ from cohere.compass.clients import CompassClient from cohere.compass.models import CompassDocument +from cohere.compass.models.documents import DocumentAttributes def test_delete_url_formatted_with_doc_and_index(requests_mock: Mocker): @@ -65,7 +66,7 @@ def test_get_documents_is_valid(requests_mock: Mocker): def test_refresh_is_valid(requests_mock: Mocker): compass = CompassClient(index_url="http://test.com") - compass.refresh(index_name="test_index") + compass.refresh_index(index_name="test_index") assert requests_mock.request_history[0].method == "POST" assert ( requests_mock.request_history[0].url @@ -74,9 +75,13 @@ def test_refresh_is_valid(requests_mock: Mocker): def test_add_attributes_is_valid(requests_mock: Mocker): + attrs = DocumentAttributes() + attrs.fake = "context" compass = CompassClient(index_url="http://test.com") compass.add_attributes( - index_name="test_index", document_id="test_id", context={"fake": "context"} + index_name="test_index", + document_id="test_id", + attributes=attrs, ) assert requests_mock.request_history[0].method == "POST" assert (