Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type support to update_group_authorization #70

Merged
merged 3 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cohere/compass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ValidatedModel,
)

__version__ = "0.10.2"
__version__ = "0.11.0"


class ProcessFileParameters(ValidatedModel):
Expand All @@ -33,14 +33,14 @@ class ProcessFilesParameters(ValidatedModel):


class GroupAuthorizationActions(str, Enum):
"""Enum for use with the edit_group_authorization API to specify the edit type."""
"""Enum for use with the update_group_authorization API to specify the edit type."""

ADD = "add"
REMOVE = "remove"


class GroupAuthorizationInput(BaseModel):
"""Model for use with the edit_group_authorization API."""
"""Model for use with the update_group_authorization API."""

document_ids: list[str]
authorized_groups: list[str]
Expand Down
37 changes: 18 additions & 19 deletions cohere/compass/clients/compass.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from cohere.compass.exceptions import (
CompassAuthError,
CompassClientError,
CompassError,
CompassMaxErrorRateExceeded,
)
from cohere.compass.models import (
Expand All @@ -59,6 +60,7 @@
UploadDocumentsInput,
)
from cohere.compass.models.datasources import PaginatedList
from cohere.compass.models.documents import DocumentAttributes, PutDocumentsResponse


@dataclass
Expand Down Expand Up @@ -117,7 +119,7 @@ def __init__(
"add_attributes": self.session.post,
"refresh": self.session.post,
"upload_documents": self.session.post,
"edit_group_authorization": self.session.post,
"update_group_authorization": self.session.post,
# Data Sources APIs
"create_datasource": self.session.post,
"list_datasources": self.session.get,
Expand All @@ -138,7 +140,7 @@ def __init__(
"add_attributes": "/api/v1/indexes/{index_name}/documents/{document_id}/_add_attributes", # noqa: E501
"refresh": "/api/v1/indexes/{index_name}/_refresh",
"upload_documents": "/api/v1/indexes/{index_name}/documents/_upload",
"edit_group_authorization": "/api/v1/indexes/{index_name}/group_authorization", # noqa: E501
"update_group_authorization": "/api/v1/indexes/{index_name}/group_authorization", # noqa: E501
# Data Sources APIs
"create_datasource": "/api/v1/datasources",
"list_datasources": "/api/v1/datasources",
Expand All @@ -162,7 +164,7 @@ def create_index(self, *, index_name: str):
index_name=index_name,
)

def refresh(self, *, index_name: str):
def refresh_index(self, *, index_name: str):
"""
Refresh index.

Expand Down Expand Up @@ -242,7 +244,7 @@ def add_attributes(
*,
index_name: str,
document_id: str,
context: dict[str, Any],
attributes: DocumentAttributes,
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
) -> Optional[RetryResult]:
Expand All @@ -251,16 +253,15 @@ def add_attributes(

:param index_name: the name of the index
:param document_id: the document to modify
:param context: A dictionary of key-value pairs to insert into the content field
of a document
:param attributes: the attributes to add to the document
:param max_retries: the maximum number of times to retry a doc insertion
:param sleep_retry_seconds: number of seconds to go to sleep before retrying a
doc insertion
"""
return self._send_request(
api_name="add_attributes",
document_id=document_id,
data=context,
data=attributes,
max_retries=max_retries,
sleep_retry_seconds=sleep_retry_seconds,
index_name=index_name,
Expand Down Expand Up @@ -301,7 +302,7 @@ def upload_document(
filebytes: bytes,
content_type: str,
document_id: uuid.UUID,
attributes: dict[str, Any] = {},
attributes: DocumentAttributes = DocumentAttributes(),
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
) -> Optional[Union[str, dict[str, Any]]]:
Expand Down Expand Up @@ -744,29 +745,32 @@ def search_chunks(

return SearchChunksResponse.model_validate(result.result)

def edit_group_authorization(
def update_group_authorization(
self, *, index_name: str, group_auth_input: GroupAuthorizationInput
):
) -> PutDocumentsResponse:
"""
Edit group authorization for an index.

:param index_name: the name of the index
:param group_auth_input: the group authorization input
"""
return self._send_request(
api_name="edit_group_authorization",
result = self._send_request(
api_name="update_group_authorization",
index_name=index_name,
data=group_auth_input,
max_retries=DEFAULT_MAX_RETRIES,
sleep_retry_seconds=DEFAULT_SLEEP_RETRY_SECONDS,
)
if result.error:
raise CompassError(result.error)
return PutDocumentsResponse.model_validate(result.result)

def _send_request(
self,
api_name: str,
max_retries: int,
sleep_retry_seconds: int,
data: Optional[Union[dict[str, Any], BaseModel]] = None,
data: Optional[BaseModel] = None,
**url_params: str,
) -> RetryResult:
"""
Expand Down Expand Up @@ -794,12 +798,7 @@ def _send_request_with_retry():
nonlocal error

try:
data_dict = None
if data:
if isinstance(data, BaseModel):
data_dict = data.model_dump(mode="json")
else:
data_dict = data
data_dict = data.model_dump(mode="json") if data else None

headers = None
auth = None
Expand Down
8 changes: 7 additions & 1 deletion cohere/compass/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
class CompassClientError(Exception):
class CompassError(Exception):
"""Base class for all exceptions raised by the Compass client."""

pass


class CompassClientError(CompassError):
"""Exception raised for all 4xx client errors in the Compass client."""

def __init__( # noqa: D107
Expand Down
34 changes: 32 additions & 2 deletions cohere/compass/models/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Annotated, Any, Optional

# 3rd party imports
from pydantic import BaseModel, Field, PositiveInt, StringConstraints
from pydantic import BaseModel, ConfigDict, PositiveInt, StringConstraints

# Local imports
from cohere.compass.models import ValidatedModel
Expand Down Expand Up @@ -195,6 +195,19 @@ class Document(BaseModel):
authorized_groups: Optional[list[str]] = None


class DocumentAttributes(BaseModel):
"""Model class for document attributes."""

model_config = ConfigDict(extra="allow")

# Had to add this to please the linter, because BaseModel only defines __setattr__
# if TYPE_CHECKING is not set, i.e. at runtime, resulting in the type checking pass
# done by the linter failing to find the __setattr__ method. See:
# https://github.com/pydantic/pydantic/blob/main/pydantic/main.py#L878-L920
def __setattr__(self, name: str, value: Any): # noqa: D105
return super().__setattr__(name, value)


class ParseableDocument(BaseModel):
"""A document to be sent to Compass for parsing."""

Expand All @@ -205,7 +218,7 @@ class ParseableDocument(BaseModel):
content_type: str
content_length_bytes: PositiveInt # File size must be a non-negative integer
content_encoded_bytes: str # Base64-encoded file contents
attributes: dict[str, Any] = Field(default_factory=dict)
attributes: DocumentAttributes


class UploadDocumentsInput(BaseModel):
Expand All @@ -220,3 +233,20 @@ class PutDocumentsInput(BaseModel):
documents: list[Document]
authorized_groups: Optional[list[str]] = None
merge_groups_on_conflict: bool = False


class PutDocumentResult(BaseModel):
"""
A model for the response of put_document.

This model is also used by the put_documents and edit_group_authorization APIs.
"""

document_id: str
error: Optional[str]


class PutDocumentsResponse(BaseModel):
"""A model for the response of put_documents and edit_group_authorization APIs."""

results: list[PutDocumentResult]
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "compass-sdk"
version = "0.10.2"
version = "0.11.0"
authors = []
description = "Compass SDK"
readme = "README.md"
Expand Down
9 changes: 7 additions & 2 deletions tests/test_compass_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from cohere.compass.clients import CompassClient
from cohere.compass.models import CompassDocument
from cohere.compass.models.documents import DocumentAttributes


def test_delete_url_formatted_with_doc_and_index(requests_mock: Mocker):
Expand Down Expand Up @@ -65,7 +66,7 @@ def test_get_documents_is_valid(requests_mock: Mocker):

def test_refresh_is_valid(requests_mock: Mocker):
compass = CompassClient(index_url="http://test.com")
compass.refresh(index_name="test_index")
compass.refresh_index(index_name="test_index")
assert requests_mock.request_history[0].method == "POST"
assert (
requests_mock.request_history[0].url
Expand All @@ -74,9 +75,13 @@ def test_refresh_is_valid(requests_mock: Mocker):


def test_add_attributes_is_valid(requests_mock: Mocker):
attrs = DocumentAttributes()
attrs.fake = "context"
compass = CompassClient(index_url="http://test.com")
compass.add_attributes(
index_name="test_index", document_id="test_id", context={"fake": "context"}
index_name="test_index",
document_id="test_id",
attributes=attrs,
)
assert requests_mock.request_history[0].method == "POST"
assert (
Expand Down
Loading