From ec32e5b5c855b63a97aefc9bedefedbe79a9805d Mon Sep 17 00:00:00 2001 From: Rafid Date: Wed, 4 Dec 2024 16:47:01 -0800 Subject: [PATCH] Move modules under "cohere.compass" + More refactoring (#56) - Move all modules under "cohere.compass". - Remove `tqdm` from dependencies; we are not using it. - Added `pyright` type checking to pre-commit. - Changed pyright checking mode from `basic`, which, as the name suggests, is basic and didn't catch some errors to `strict` to ensure we catch as many errors as possible at coding time rather than run time. - Updated README.md to have pyright certification at the top ;-) - Changed Python support window to `[3.9, 4.0)`. --- .../{formatting.yml => pre-commit-checks.yml} | 22 +++- .github/workflows/test.yml | 20 +-- .github/workflows/typecheck.yml | 35 ----- .pre-commit-config.yaml | 4 + README.md | 2 + {compass_sdk => cohere/compass}/__init__.py | 4 +- cohere/compass/clients/__init__.py | 3 + .../compass}/clients/compass.py | 44 ++++--- .../compass}/clients/parser.py | 10 +- .../compass}/clients/rbac.py | 2 +- {compass_sdk => cohere/compass}/constants.py | 0 {compass_sdk => cohere/compass}/exceptions.py | 6 +- cohere/compass/models/__init__.py | 29 +++++ .../compass}/models/config.py | 24 ++-- .../compass}/models/datasources.py | 2 - .../compass}/models/documents.py | 4 +- .../compass}/models/rbac.py | 0 .../compass}/models/search.py | 8 +- {compass_sdk => cohere/compass}/utils.py | 26 ++-- compass_sdk/clients/__init__.py | 3 - compass_sdk/models/__init__.py | 27 ---- poetry.lock | 120 ++++++++++++++---- pyproject.toml | 14 +- tests/test_compass_client.py | 22 ++-- tests/test_utils.py | 2 +- 25 files changed, 250 insertions(+), 183 deletions(-) rename .github/workflows/{formatting.yml => pre-commit-checks.yml} (53%) delete mode 100644 .github/workflows/typecheck.yml rename {compass_sdk => cohere/compass}/__init__.py (92%) create mode 100644 cohere/compass/clients/__init__.py rename {compass_sdk => cohere/compass}/clients/compass.py (96%) rename {compass_sdk => cohere/compass}/clients/parser.py (97%) rename {compass_sdk => cohere/compass}/clients/rbac.py (99%) rename {compass_sdk => cohere/compass}/constants.py (100%) rename {compass_sdk => cohere/compass}/exceptions.py (80%) create mode 100644 cohere/compass/models/__init__.py rename {compass_sdk => cohere/compass}/models/config.py (92%) rename {compass_sdk => cohere/compass}/models/datasources.py (96%) rename {compass_sdk => cohere/compass}/models/documents.py (98%) rename {compass_sdk => cohere/compass}/models/rbac.py (100%) rename {compass_sdk => cohere/compass}/models/search.py (90%) rename {compass_sdk => cohere/compass}/utils.py (82%) delete mode 100644 compass_sdk/clients/__init__.py delete mode 100644 compass_sdk/models/__init__.py diff --git a/.github/workflows/formatting.yml b/.github/workflows/pre-commit-checks.yml similarity index 53% rename from .github/workflows/formatting.yml rename to .github/workflows/pre-commit-checks.yml index e751cd6..150ffea 100644 --- a/.github/workflows/formatting.yml +++ b/.github/workflows/pre-commit-checks.yml @@ -1,4 +1,4 @@ -name: Formatting +name: Pre-commit Checks on: pull_request: {} @@ -6,12 +6,16 @@ on: jobs: build: - name: Formatting + name: Run pre-commit checks runs-on: ubuntu-latest strategy: matrix: python-version: - - 3.9 + - "3.9" + - "3.10" + - "3.11" + - "3.12" + - "3.13" steps: - uses: actions/checkout@v4 @@ -21,10 +25,18 @@ jobs: with: python-version: ${{ matrix.python-version }} - - name: Upgrade pip & install requirements + - name: Install poetry + run: | + pip install poetry + + - name: Install dependencies + run: | + poetry install + + - name: Install pre-commit run: | pip install pre-commit - - name: Formatting + - name: Run pre-commit run: | pre-commit run --all-files diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 517ea08..e83f90b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -34,7 +34,11 @@ jobs: strategy: matrix: python-version: - - 3.11 + - "3.9" + - "3.10" + - "3.11" + - "3.12" + - "3.13" steps: - uses: actions/checkout@v4 @@ -47,17 +51,15 @@ jobs: cache-dependency-path: | poetry.lock - - name: Install dependencies (tests) + - name: Install poetry run: | - pip install pytest pytest-asyncio pytest-mock requests-mock + pip install poetry - - name: Install dependencies - working-directory: . + - name: Install dependencies run: | - pip install -e . + poetry install - - name: Run tests + - name: Run tests working-directory: . run: | - echo $COHERE_API_KEY - pytest -sv tests/test_compass_client.py \ No newline at end of file + poetry run pytest -sv diff --git a/.github/workflows/typecheck.yml b/.github/workflows/typecheck.yml deleted file mode 100644 index 7b2643a..0000000 --- a/.github/workflows/typecheck.yml +++ /dev/null @@ -1,35 +0,0 @@ -name: Typecheck - -on: - pull_request: {} - workflow_dispatch: {} - -jobs: - typecheck: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: - - 3.11 - package: - - . - fail-fast: false - - steps: - - uses: actions/checkout@v4 - - - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - cache: "pip" - cache-dependency-path: | - ${{ matrix.package }}/poetry.lock - - - name: Install dependencies - working-directory: ${{ matrix.package }} - run: | - pip install -e . - - - uses: jakebailey/pyright-action@v2 - with: - working-directory: ${{ matrix.package }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 91e42e8..9ef149b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,3 +7,7 @@ repos: args: [--fix] # Run the formatter. - id: ruff-format + - repo: https://github.com/RobertCraigie/pyright-python + rev: v1.1.390 + hooks: + - id: pyright diff --git a/README.md b/README.md index 3afca57..897ca5e 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # Cohere Compass SDK +[![Checked with pyright](https://microsoft.github.io/pyright/img/pyright_badge.svg)](https://microsoft.github.io/pyright/) + The Compass SDK is a Python library that allows you to parse documents and insert them into a Compass index. diff --git a/compass_sdk/__init__.py b/cohere/compass/__init__.py similarity index 92% rename from compass_sdk/__init__.py rename to cohere/compass/__init__.py index 819720a..c048f5d 100644 --- a/compass_sdk/__init__.py +++ b/cohere/compass/__init__.py @@ -6,13 +6,13 @@ from pydantic import BaseModel # Local imports -from compass_sdk.models import ( +from cohere.compass.models import ( MetadataConfig, ParserConfig, ValidatedModel, ) -__version__ = "0.7.0" +__version__ = "0.8.0" class ProcessFileParameters(ValidatedModel): diff --git a/cohere/compass/clients/__init__.py b/cohere/compass/clients/__init__.py new file mode 100644 index 0000000..d7b6681 --- /dev/null +++ b/cohere/compass/clients/__init__.py @@ -0,0 +1,3 @@ +from cohere.compass.clients.compass import * # noqa: F403 +from cohere.compass.clients.parser import * # noqa: F403 +from cohere.compass.clients.rbac import * # noqa: F403 diff --git a/compass_sdk/clients/compass.py b/cohere/compass/clients/compass.py similarity index 96% rename from compass_sdk/clients/compass.py rename to cohere/compass/clients/compass.py index e340d23..43afaf3 100644 --- a/compass_sdk/clients/compass.py +++ b/cohere/compass/clients/compass.py @@ -10,7 +10,8 @@ import uuid # 3rd party imports -from joblib import Parallel, delayed +# TODO find stubs for joblib and remove "type: ignore" +from joblib import Parallel, delayed # type: ignore from pydantic import BaseModel from requests.exceptions import InvalidSchema from tenacity import ( @@ -23,22 +24,22 @@ import requests # Local imports -from compass_sdk import ( +from cohere.compass import ( GroupAuthorizationInput, ) -from compass_sdk.constants import ( +from cohere.compass.constants import ( DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES, DEFAULT_MAX_CHUNKS_PER_REQUEST, DEFAULT_MAX_ERROR_RATE, DEFAULT_MAX_RETRIES, DEFAULT_SLEEP_RETRY_SECONDS, ) -from compass_sdk.exceptions import ( +from cohere.compass.exceptions import ( CompassAuthError, CompassClientError, CompassMaxErrorRateExceeded, ) -from compass_sdk.models import ( +from cohere.compass.models import ( Chunk, CompassDocument, CompassDocumentStatus, @@ -60,7 +61,7 @@ @dataclass class RetryResult: - result: Optional[dict] = None + result: Optional[dict[str, Any]] = None error: Optional[str] = None @@ -75,9 +76,9 @@ def __init__(self, timeout: int): self._timeout = timeout super().__init__() - def request(self, method, url, **kwargs): + def request(self, *args: Any, **kwargs: Any): kwargs.setdefault("timeout", self._timeout) - return super().request(method, url, **kwargs) + return super().request(*args, **kwargs) class CompassClient: @@ -231,7 +232,7 @@ def add_context( *, index_name: str, doc_id: str, - context: Dict, + context: dict[str, Any], max_retries: int = DEFAULT_MAX_RETRIES, sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, ) -> Optional[RetryResult]: @@ -291,7 +292,7 @@ def upload_document( context: Dict[str, Any] = {}, max_retries: int = DEFAULT_MAX_RETRIES, sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS, - ) -> Optional[str | Dict]: + ) -> Optional[Union[str, Dict[str, Any]]]: """ Parse and insert a document into an index in Compass :param index_name: the name of the index @@ -362,8 +363,8 @@ def insert_docs( """ def put_request( - request_data: List[Tuple[CompassDocument, Document]], - previous_errors: List[CompassDocument], + request_data: list[Tuple[CompassDocument, Document]], + previous_errors: list[dict[str, str]], num_doc: int, ) -> None: nonlocal num_succeeded, errors @@ -420,11 +421,11 @@ def put_request( f"in the last {errors_sliding_window_size} API calls. Stopping the insertion process." ) - error_window = deque( + error_window: deque[Optional[str]] = deque( maxlen=errors_sliding_window_size ) # Keep track of the results of the last N API calls num_succeeded = 0 - errors = [] + errors: list[dict[str, str]] = [] requests_iter = self._get_request_blocks(docs, max_chunks_per_request) try: @@ -556,7 +557,7 @@ def list_datasources_objects_states( def _get_request_blocks( docs: Iterator[CompassDocument], max_chunks_per_request: int, - ) -> Iterator: + ): """ Create request blocks to send to the Compass API :param docs: the documents to send @@ -564,9 +565,10 @@ def _get_request_blocks( :return: an iterator over the request blocks """ - request_block, errors = [], [] + request_block: list[tuple[CompassDocument, Document]] = [] + errors: list[dict[str, str]] = [] num_chunks = 0 - for num_doc, doc in enumerate(docs, 1): + for _, doc in enumerate(docs, 1): if doc.status != CompassDocumentStatus.Success: logger.error(f"Document {doc.metadata.doc_id} has errors: {doc.errors}") for error in doc.errors: @@ -679,7 +681,7 @@ def _send_request( api_name: str, max_retries: int, sleep_retry_seconds: int, - data: Optional[Union[Dict, BaseModel]] = None, + data: Optional[Union[Dict[str, Any], BaseModel]] = None, **url_params: str, ) -> RetryResult: """ @@ -710,11 +712,13 @@ def _send_request_with_retry(): if data: if isinstance(data, BaseModel): data_dict = data.model_dump(mode="json") - elif isinstance(data, Dict): + else: data_dict = data headers = None - auth = (self.username, self.password) + auth = None + if self.username and self.password: + auth = (self.username, self.password) if self.bearer_token: headers = {"Authorization": f"Bearer {self.bearer_token}"} auth = None diff --git a/compass_sdk/clients/parser.py b/cohere/compass/clients/parser.py similarity index 97% rename from compass_sdk/clients/parser.py rename to cohere/compass/clients/parser.py index 65cc8c9..97770b3 100644 --- a/compass_sdk/clients/parser.py +++ b/cohere/compass/clients/parser.py @@ -9,16 +9,16 @@ import requests # Local imports -from compass_sdk import ( +from cohere.compass import ( ProcessFileParameters, ) -from compass_sdk.constants import DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES -from compass_sdk.models import ( +from cohere.compass.constants import DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES +from cohere.compass.models import ( CompassDocument, MetadataConfig, ParserConfig, ) -from compass_sdk.utils import imap_queued, open_document, scan_folder +from cohere.compass.utils import imap_queued, open_document, scan_folder Fn_or_Dict = Union[Dict[str, Any], Callable[[CompassDocument], Dict[str, Any]]] @@ -254,7 +254,7 @@ def process_file( ) if res.ok: - docs = [] + docs: list[CompassDocument] = [] for doc in res.json()["docs"]: if not doc.get("errors", []): compass_doc = CompassDocument(**doc) diff --git a/compass_sdk/clients/rbac.py b/cohere/compass/clients/rbac.py similarity index 99% rename from compass_sdk/clients/rbac.py rename to cohere/compass/clients/rbac.py index 2953561..1d7cfad 100644 --- a/compass_sdk/clients/rbac.py +++ b/cohere/compass/clients/rbac.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from requests import HTTPError -from compass_sdk.models import ( +from cohere.compass.models import ( GroupCreateRequest, GroupCreateResponse, GroupDeleteResponse, diff --git a/compass_sdk/constants.py b/cohere/compass/constants.py similarity index 100% rename from compass_sdk/constants.py rename to cohere/compass/constants.py diff --git a/compass_sdk/exceptions.py b/cohere/compass/exceptions.py similarity index 80% rename from compass_sdk/exceptions.py rename to cohere/compass/exceptions.py index ea99adb..d5401a2 100644 --- a/compass_sdk/exceptions.py +++ b/cohere/compass/exceptions.py @@ -1,7 +1,7 @@ class CompassClientError(Exception): """Exception raised for all 4xx client errors in the Compass client.""" - def __init__(self, message="Client error occurred."): + def __init__(self, message: str = "Client error occurred."): self.message = message super().__init__(self.message) @@ -11,7 +11,7 @@ class CompassAuthError(CompassClientError): def __init__( self, - message=( + message: str = ( "CompassAuthError - check your bearer token or username and password." ), ): @@ -25,7 +25,7 @@ class CompassMaxErrorRateExceeded(Exception): def __init__( self, - message="The maximum error rate was exceeded. Stopping the insertion process.", + message: str = "The maximum error rate was exceeded. Stopping the insertion process.", ): self.message = message super().__init__(self.message) diff --git a/cohere/compass/models/__init__.py b/cohere/compass/models/__init__.py new file mode 100644 index 0000000..9ec4f09 --- /dev/null +++ b/cohere/compass/models/__init__.py @@ -0,0 +1,29 @@ +from typing import Any + +# import models into model package +from pydantic import BaseModel + + +class ValidatedModel(BaseModel): + class Config: + arbitrary_types_allowed = True + use_enum_values = True + + @classmethod + def attribute_in_model(cls, attr_name: str): + return attr_name in cls.model_fields + + def __init__(self, **data: dict[str, Any]): + for name, _value in data.items(): + if not self.attribute_in_model(name): + raise ValueError( + f"{name} is not a valid attribute for {self.__class__.__name__}" + ) + super().__init__(**data) + + +from cohere.compass.models.config import * # noqa: E402, F403 +from cohere.compass.models.datasources import * # noqa: E402, F403 +from cohere.compass.models.documents import * # noqa: E402, F403 +from cohere.compass.models.rbac import * # noqa: E402, F403 +from cohere.compass.models.search import * # noqa: E402, F403 diff --git a/compass_sdk/models/config.py b/cohere/compass/models/config.py similarity index 92% rename from compass_sdk/models/config.py rename to cohere/compass/models/config.py index 2d0b2d4..4dbd427 100644 --- a/compass_sdk/models/config.py +++ b/cohere/compass/models/config.py @@ -1,14 +1,14 @@ # Python imports -from enum import Enum, StrEnum +from enum import Enum from os import getenv -from typing import List, Optional +from typing import Any, List, Optional import math # 3rd party imports from pydantic import BaseModel, ConfigDict # Local imports -from compass_sdk.constants import ( +from cohere.compass.constants import ( COHERE_API_ENV_VAR, DEFAULT_COMMANDR_EXTRACTABLE_ATTRIBUTES, DEFAULT_COMMANDR_PROMPT, @@ -20,7 +20,7 @@ METADATA_HEURISTICS_ATTRIBUTES, SKIP_INFER_TABLE_TYPES, ) -from compass_sdk.models import ValidatedModel +from cohere.compass.models import ValidatedModel class DocumentFormat(str, Enum): @@ -28,25 +28,25 @@ class DocumentFormat(str, Enum): Text = "text" @classmethod - def _missing_(cls, value): + def _missing_(cls, value: Any): return cls.Markdown -class PDFParsingStrategy(StrEnum): +class PDFParsingStrategy(str, Enum): QuickText = "QuickText" ImageToMarkdown = "ImageToMarkdown" @classmethod - def _missing_(cls, value): + def _missing_(cls, value: Any): return cls.QuickText -class PresentationParsingStrategy(StrEnum): +class PresentationParsingStrategy(str, Enum): Unstructured = "Unstructured" ImageToMarkdown = "ImageToMarkdown" @classmethod - def _missing_(cls, value): + def _missing_(cls, value: Any): return cls.Unstructured @@ -55,7 +55,7 @@ class ParsingStrategy(str, Enum): Hi_Res = "hi_res" @classmethod - def _missing_(cls, value): + def _missing_(cls, value: Any): return cls.Fast @@ -66,7 +66,7 @@ class ParsingModel(str, Enum): ) @classmethod - def _missing_(cls, value): + def _missing_(cls, value: Any): return cls.Marker @@ -136,7 +136,7 @@ class MetadataStrategy(str, Enum): Custom = "custom" @classmethod - def _missing_(cls, value): + def _missing_(cls, value: Any): return cls.No_Metadata diff --git a/compass_sdk/models/datasources.py b/cohere/compass/models/datasources.py similarity index 96% rename from compass_sdk/models/datasources.py rename to cohere/compass/models/datasources.py index b0ae053..9df8c19 100644 --- a/compass_sdk/models/datasources.py +++ b/cohere/compass/models/datasources.py @@ -9,8 +9,6 @@ T = typing.TypeVar("T") -Content: typing.TypeAlias = typing.Dict[str, typing.Any] - class PaginatedList(pydantic.BaseModel, typing.Generic[T]): value: typing.List[T] diff --git a/compass_sdk/models/documents.py b/cohere/compass/models/documents.py similarity index 98% rename from compass_sdk/models/documents.py rename to cohere/compass/models/documents.py index 9502266..70f3aed 100644 --- a/compass_sdk/models/documents.py +++ b/cohere/compass/models/documents.py @@ -8,7 +8,7 @@ from pydantic import BaseModel, Field, PositiveInt, StringConstraints # Local imports -from compass_sdk.models import ValidatedModel +from cohere.compass.models import ValidatedModel class CompassDocumentMetadata(ValidatedModel): @@ -18,7 +18,7 @@ class CompassDocumentMetadata(ValidatedModel): doc_id: str = "" filename: str = "" - meta: List = field(default_factory=list) + meta: list[Any] = field(default_factory=list) parent_doc_id: str = "" diff --git a/compass_sdk/models/rbac.py b/cohere/compass/models/rbac.py similarity index 100% rename from compass_sdk/models/rbac.py rename to cohere/compass/models/rbac.py diff --git a/compass_sdk/models/search.py b/cohere/compass/models/search.py similarity index 90% rename from compass_sdk/models/search.py rename to cohere/compass/models/search.py index 7453914..7ee42c8 100644 --- a/compass_sdk/models/search.py +++ b/cohere/compass/models/search.py @@ -1,12 +1,10 @@ # Python imports from enum import Enum -from typing import Any, Dict, List, Optional, TypeAlias +from typing import Any, Dict, List, Optional # 3rd party imports from pydantic import BaseModel -Content: TypeAlias = Dict[str, Any] - class AssetInfo(BaseModel): content_type: str @@ -17,7 +15,7 @@ class RetrievedChunk(BaseModel): chunk_id: str sort_id: int parent_doc_id: str - content: Content + content: Dict[str, Any] origin: Optional[Dict[str, Any]] = None assets_info: Optional[list[AssetInfo]] = None score: float @@ -27,7 +25,7 @@ class RetrievedDocument(BaseModel): doc_id: str path: str parent_doc_id: str - content: Content + content: Dict[str, Any] index_fields: Optional[List[str]] = None authorized_groups: Optional[List[str]] = None chunks: List[RetrievedChunk] diff --git a/compass_sdk/utils.py b/cohere/compass/utils.py similarity index 82% rename from compass_sdk/utils.py rename to cohere/compass/utils.py index b11d1e2..9d8d9d0 100644 --- a/compass_sdk/utils.py +++ b/cohere/compass/utils.py @@ -6,11 +6,11 @@ from concurrent.futures import Executor from typing import Callable, Iterable, Iterator, List, Optional, TypeVar -import fsspec -from fsspec import AbstractFileSystem +import fsspec # type: ignore +from fsspec import AbstractFileSystem # type: ignore -from compass_sdk.constants import UUID_NAMESPACE -from compass_sdk.models import ( +from cohere.compass.constants import UUID_NAMESPACE +from cohere.compass.models import ( CompassDocument, CompassDocumentMetadata, CompassSdkStage, @@ -24,7 +24,7 @@ def imap_queued( executor: Executor, f: Callable[[T], U], it: Iterable[T], max_queued: int ) -> Iterator[U]: assert max_queued >= 1 - futures_set = set() + futures_set: set[futures.Future[U]] = set() for x in it: futures_set.add(executor.submit(f, x)) @@ -47,13 +47,13 @@ def get_fs(document_path: str) -> AbstractFileSystem: """ if document_path.find("://") >= 0: file_system = document_path.split("://")[0] - fs = fsspec.filesystem(file_system) + fs = fsspec.filesystem(file_system) # type: ignore else: - fs = fsspec.filesystem("local") + fs = fsspec.filesystem("local") # type: ignore return fs -def open_document(document_path) -> CompassDocument: +def open_document(document_path: str) -> CompassDocument: """ Opens a document regardless of the file system (local, GCS, S3, etc.) and returns a file-like object :param document_path: the path to the document @@ -62,9 +62,9 @@ def open_document(document_path) -> CompassDocument: doc = CompassDocument(metadata=CompassDocumentMetadata(filename=document_path)) try: fs = get_fs(document_path) - with fs.open(document_path, "rb") as f: + with fs.open(document_path, "rb") as f: # type: ignore val = f.read() - if val is not None and isinstance(val, bytes): + if isinstance(val, bytes): doc.filebytes = val else: raise Exception(f"Expected bytes, got {type(val)}") @@ -86,7 +86,7 @@ def scan_folder( :return: a list of file paths """ fs = get_fs(folder_path) - all_files = [] + all_files: list[str] = [] path_prepend = ( f"{folder_path.split('://')[0]}://" if folder_path.find("://") >= 0 else "" ) @@ -101,8 +101,8 @@ def scan_folder( for ext in allowed_extensions: rec_glob = "**/" if recursive else "" pattern = os.path.join(glob.escape(folder_path), f"{rec_glob}*{ext}") - scanned_files = fs.glob(pattern, recursive=recursive) - all_files.extend([f"{path_prepend}{f}" for f in scanned_files]) + scanned_files = fs.glob(pattern, recursive=recursive) # type: ignore + all_files.extend([f"{path_prepend}{f}" for f in scanned_files]) # type: ignore return all_files diff --git a/compass_sdk/clients/__init__.py b/compass_sdk/clients/__init__.py deleted file mode 100644 index bdefd9c..0000000 --- a/compass_sdk/clients/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from compass_sdk.clients.compass import * # noqa: F403 -from compass_sdk.clients.parser import * # noqa: F403 -from compass_sdk.clients.rbac import * # noqa: F403 diff --git a/compass_sdk/models/__init__.py b/compass_sdk/models/__init__.py deleted file mode 100644 index ca3d589..0000000 --- a/compass_sdk/models/__init__.py +++ /dev/null @@ -1,27 +0,0 @@ -# import models into model package -from pydantic import BaseModel - - -class ValidatedModel(BaseModel): - class Config: - arbitrary_types_allowed = True - use_enum_values = True - - @classmethod - def attribute_in_model(cls, attr_name): - return attr_name in cls.__fields__ - - def __init__(self, **data): - for name, value in data.items(): - if not self.attribute_in_model(name): - raise ValueError( - f"{name} is not a valid attribute for {self.__class__.__name__}" - ) - super().__init__(**data) - - -from compass_sdk.models.config import * # noqa: E402, F403 -from compass_sdk.models.datasources import * # noqa: E402, F403 -from compass_sdk.models.documents import * # noqa: E402, F403 -from compass_sdk.models.rbac import * # noqa: E402, F403 -from compass_sdk.models.search import * # noqa: E402, F403 diff --git a/poetry.lock b/poetry.lock index 6d2913a..ebe7cd0 100644 --- a/poetry.lock +++ b/poetry.lock @@ -147,15 +147,29 @@ files = [ {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, ] +[[package]] +name = "exceptiongroup" +version = "1.2.2" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, + {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, +] + +[package.extras] +test = ["pytest (>=6)"] + [[package]] name = "fsspec" -version = "2024.2.0" +version = "2024.10.0" description = "File-system specification" optional = false python-versions = ">=3.8" files = [ - {file = "fsspec-2024.2.0-py3-none-any.whl", hash = "sha256:817f969556fa5916bc682e02ca2045f96ff7f586d45110fcb76022063ad2c7d8"}, - {file = "fsspec-2024.2.0.tar.gz", hash = "sha256:b6ad1a679f760dda52b1168c859d01b7b80648ea6f7f7c7f5a8a91dc3f3ecb84"}, + {file = "fsspec-2024.10.0-py3-none-any.whl", hash = "sha256:03b9a6785766a4de40368b88906366755e2819e758b83705c88cd7cb5fe81871"}, + {file = "fsspec-2024.10.0.tar.gz", hash = "sha256:eda2d8a4116d4f2429db8550f2457da57279247dd930bb12f821b58391359493"}, ] [package.extras] @@ -163,7 +177,8 @@ abfs = ["adlfs"] adl = ["adlfs"] arrow = ["pyarrow (>=1)"] dask = ["dask", "distributed"] -devel = ["pytest", "pytest-cov"] +dev = ["pre-commit", "ruff"] +doc = ["numpydoc", "sphinx", "sphinx-design", "sphinx-rtd-theme", "yarl"] dropbox = ["dropbox", "dropboxdrivefs", "requests"] full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "dask", "distributed", "dropbox", "dropboxdrivefs", "fusepy", "gcsfs", "libarchive-c", "ocifs", "panel", "paramiko", "pyarrow (>=1)", "pygit2", "requests", "s3fs", "smbprotocol", "tqdm"] fuse = ["fusepy"] @@ -180,6 +195,9 @@ s3 = ["s3fs"] sftp = ["paramiko"] smb = ["smbprotocol"] ssh = ["paramiko"] +test = ["aiohttp (!=4.0.0a0,!=4.0.0a1)", "numpy", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "requests"] +test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe,test]", "moto[server] (>4,<5)", "pytest-timeout", "xarray"] +test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] tqdm = ["tqdm"] [[package]] @@ -218,6 +236,17 @@ files = [ {file = "joblib-1.4.2.tar.gz", hash = "sha256:2382c5816b2636fbd20a09e0f4e9dad4736765fdfb7dca582943b9c1366b3f0e"}, ] +[[package]] +name = "nodeenv" +version = "1.9.1" +description = "Node.js virtual environment builder" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" +files = [ + {file = "nodeenv-1.9.1-py2.py3-none-any.whl", hash = "sha256:ba11c9782d29c27c70ffbdda2d7415098754709be8a7056d79a737cd901155c9"}, + {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, +] + [[package]] name = "packaging" version = "24.2" @@ -258,7 +287,10 @@ files = [ [package.dependencies] annotated-types = ">=0.6.0" pydantic-core = "2.23.4" -typing-extensions = {version = ">=4.6.1", markers = "python_version < \"3.13\""} +typing-extensions = [ + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, + {version = ">=4.6.1", markers = "python_version < \"3.13\""}, +] [package.extras] email = ["email-validator (>=2.0.0)"] @@ -365,6 +397,26 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pyright" +version = "1.1.390" +description = "Command line wrapper for pyright" +optional = false +python-versions = ">=3.7" +files = [ + {file = "pyright-1.1.390-py3-none-any.whl", hash = "sha256:ecebfba5b6b50af7c1a44c2ba144ba2ab542c227eb49bc1f16984ff714e0e110"}, + {file = "pyright-1.1.390.tar.gz", hash = "sha256:aad7f160c49e0fbf8209507a15e17b781f63a86a1facb69ca877c71ef2e9538d"}, +] + +[package.dependencies] +nodeenv = ">=1.6.0" +typing-extensions = ">=4.1" + +[package.extras] +all = ["nodejs-wheel-binaries", "twine (>=3.4.1)"] +dev = ["twine (>=3.4.1)"] +nodejs = ["nodejs-wheel-binaries"] + [[package]] name = "pytest" version = "8.3.3" @@ -378,9 +430,11 @@ files = [ [package.dependencies] colorama = {version = "*", markers = "sys_platform == \"win32\""} +exceptiongroup = {version = ">=1.0.0rc8", markers = "python_version < \"3.11\""} iniconfig = "*" packaging = "*" pluggy = ">=1.5,<2" +tomli = {version = ">=1", markers = "python_version < \"3.11\""} [package.extras] dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments (>=2.7.2)", "requests", "setuptools", "xmlschema"] @@ -500,26 +554,46 @@ files = [ doc = ["reno", "sphinx", "tornado (>=4.5)"] [[package]] -name = "tqdm" -version = "4.67.0" -description = "Fast, Extensible Progress Meter" +name = "tomli" +version = "2.2.1" +description = "A lil' TOML parser" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "tqdm-4.67.0-py3-none-any.whl", hash = "sha256:0cd8af9d56911acab92182e88d763100d4788bdf421d251616040cc4d44863be"}, - {file = "tqdm-4.67.0.tar.gz", hash = "sha256:fe5a6f95e6fe0b9755e9469b77b9c3cf850048224ecaa8293d7d2d31f97d869a"}, + {file = "tomli-2.2.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:678e4fa69e4575eb77d103de3df8a895e1591b48e740211bd1067378c69e8249"}, + {file = "tomli-2.2.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:023aa114dd824ade0100497eb2318602af309e5a55595f76b626d6d9f3b7b0a6"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ece47d672db52ac607a3d9599a9d48dcb2f2f735c6c2d1f34130085bb12b112a"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6972ca9c9cc9f0acaa56a8ca1ff51e7af152a9f87fb64623e31d5c83700080ee"}, + {file = "tomli-2.2.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c954d2250168d28797dd4e3ac5cf812a406cd5a92674ee4c8f123c889786aa8e"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:8dd28b3e155b80f4d54beb40a441d366adcfe740969820caf156c019fb5c7ec4"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:e59e304978767a54663af13c07b3d1af22ddee3bb2fb0618ca1593e4f593a106"}, + {file = "tomli-2.2.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:33580bccab0338d00994d7f16f4c4ec25b776af3ffaac1ed74e0b3fc95e885a8"}, + {file = "tomli-2.2.1-cp311-cp311-win32.whl", hash = "sha256:465af0e0875402f1d226519c9904f37254b3045fc5084697cefb9bdde1ff99ff"}, + {file = "tomli-2.2.1-cp311-cp311-win_amd64.whl", hash = "sha256:2d0f2fdd22b02c6d81637a3c95f8cd77f995846af7414c5c4b8d0545afa1bc4b"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:4a8f6e44de52d5e6c657c9fe83b562f5f4256d8ebbfe4ff922c495620a7f6cea"}, + {file = "tomli-2.2.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:8d57ca8095a641b8237d5b079147646153d22552f1c637fd3ba7f4b0b29167a8"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4e340144ad7ae1533cb897d406382b4b6fede8890a03738ff1683af800d54192"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:db2b95f9de79181805df90bedc5a5ab4c165e6ec3fe99f970d0e302f384ad222"}, + {file = "tomli-2.2.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:40741994320b232529c802f8bc86da4e1aa9f413db394617b9a256ae0f9a7f77"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:400e720fe168c0f8521520190686ef8ef033fb19fc493da09779e592861b78c6"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:02abe224de6ae62c19f090f68da4e27b10af2b93213d36cf44e6e1c5abd19fdd"}, + {file = "tomli-2.2.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:b82ebccc8c8a36f2094e969560a1b836758481f3dc360ce9a3277c65f374285e"}, + {file = "tomli-2.2.1-cp312-cp312-win32.whl", hash = "sha256:889f80ef92701b9dbb224e49ec87c645ce5df3fa2cc548664eb8a25e03127a98"}, + {file = "tomli-2.2.1-cp312-cp312-win_amd64.whl", hash = "sha256:7fc04e92e1d624a4a63c76474610238576942d6b8950a2d7f908a340494e67e4"}, + {file = "tomli-2.2.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f4039b9cbc3048b2416cc57ab3bda989a6fcf9b36cf8937f01a6e731b64f80d7"}, + {file = "tomli-2.2.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:286f0ca2ffeeb5b9bd4fcc8d6c330534323ec51b2f52da063b11c502da16f30c"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a92ef1a44547e894e2a17d24e7557a5e85a9e1d0048b0b5e7541f76c5032cb13"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9316dc65bed1684c9a98ee68759ceaed29d229e985297003e494aa825ebb0281"}, + {file = "tomli-2.2.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e85e99945e688e32d5a35c1ff38ed0b3f41f43fad8df0bdf79f72b2ba7bc5272"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ac065718db92ca818f8d6141b5f66369833d4a80a9d74435a268c52bdfa73140"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:d920f33822747519673ee656a4b6ac33e382eca9d331c87770faa3eef562aeb2"}, + {file = "tomli-2.2.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:a198f10c4d1b1375d7687bc25294306e551bf1abfa4eace6650070a5c1ae2744"}, + {file = "tomli-2.2.1-cp313-cp313-win32.whl", hash = "sha256:d3f5614314d758649ab2ab3a62d4f2004c825922f9e370b29416484086b264ec"}, + {file = "tomli-2.2.1-cp313-cp313-win_amd64.whl", hash = "sha256:a38aa0308e754b0e3c67e344754dff64999ff9b513e691d0e786265c93583c69"}, + {file = "tomli-2.2.1-py3-none-any.whl", hash = "sha256:cb55c73c5f4408779d0cf3eef9f762b9c9f147a77de7b258bef0a5628adc85cc"}, + {file = "tomli-2.2.1.tar.gz", hash = "sha256:cd45e1dc79c835ce60f7404ec8119f2eb06d38b1deba146f07ced3bbc44505ff"}, ] -[package.dependencies] -colorama = {version = "*", markers = "platform_system == \"Windows\""} - -[package.extras] -dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] -discord = ["requests"] -notebook = ["ipywidgets (>=6)"] -slack = ["slack-sdk"] -telegram = ["requests"] - [[package]] name = "typing-extensions" version = "4.12.2" @@ -550,5 +624,5 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" -python-versions = ">=3.11,<3.12" -content-hash = "4d6d1d1a40b93e95d0019f6ee73a913f0e6cdee6df2c41bf23d48f8000a8a515" +python-versions = ">=3.9,<4.0" +content-hash = "1e01950e92aed31912006f9f27462f6f5f4810254789d4bef83b851466cb7b33" diff --git a/pyproject.toml b/pyproject.toml index ee62619..01aae03 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,19 +1,21 @@ [tool.poetry] name = "compass-sdk" -version = "0.7.0" +version = "0.8.0" authors = [] description = "Compass SDK" +readme = "README.md" +packages = [{include = "cohere"}] [tool.poetry.dependencies] -fsspec = "2024.2.0" +fsspec = "^2024.10.0" joblib = "1.4.2" pydantic = ">=2.6.3" -python = ">=3.11,<3.12" +python = ">=3.9,<4.0" requests = ">=2.25.0,<3.0.0" tenacity = "8.2.3" -tqdm = ">=4.42.1" [tool.poetry.group.dev.dependencies] +pyright = "^1.1.390" pytest = "^8.3.3" pytest-asyncio = "^0.24.0" pytest-mock = "^3.14.0" @@ -21,8 +23,10 @@ requests-mock = "^1.12.1" ruff = "^0.8.1" [tool.pyright] -typeCheckingMode = 'basic' reportMissingImports = false +typeCheckingMode = "strict" +venv = ".venv" +venvPath = "." [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/tests/test_compass_client.py b/tests/test_compass_client.py index 22e6cc3..c39941a 100644 --- a/tests/test_compass_client.py +++ b/tests/test_compass_client.py @@ -1,8 +1,10 @@ -from compass_sdk.clients import CompassClient -from compass_sdk.models import CompassDocument +from requests_mock import Mocker +from cohere.compass.clients import CompassClient +from cohere.compass.models import CompassDocument -def test_delete_url_formatted_with_doc_and_index(requests_mock): + +def test_delete_url_formatted_with_doc_and_index(requests_mock: Mocker): compass = CompassClient(index_url="http://test.com") compass.delete_document(index_name="test_index", doc_id="test_id") assert ( @@ -12,7 +14,7 @@ def test_delete_url_formatted_with_doc_and_index(requests_mock): assert requests_mock.request_history[0].method == "DELETE" -def test_create_index_formatted_with_index(requests_mock): +def test_create_index_formatted_with_index(requests_mock: Mocker): compass = CompassClient(index_url="http://test.com") compass.create_index(index_name="test_index") assert ( @@ -22,7 +24,7 @@ def test_create_index_formatted_with_index(requests_mock): assert requests_mock.request_history[0].method == "PUT" -def test_put_documents_payload_and_url_exist(requests_mock): +def test_put_documents_payload_and_url_exist(requests_mock: Mocker): compass = CompassClient(index_url="http://test.com") compass.insert_docs(index_name="test_index", docs=iter([CompassDocument()])) assert ( @@ -33,7 +35,7 @@ def test_put_documents_payload_and_url_exist(requests_mock): assert "docs" in requests_mock.request_history[0].json() -def test_put_document_payload_and_url_exist(requests_mock): +def test_put_document_payload_and_url_exist(requests_mock: Mocker): compass = CompassClient(index_url="http://test.com") compass.insert_doc(index_name="test_index", doc=CompassDocument()) assert ( @@ -44,14 +46,14 @@ def test_put_document_payload_and_url_exist(requests_mock): assert "docs" in requests_mock.request_history[0].json() -def test_list_indices_is_valid(requests_mock): +def test_list_indices_is_valid(requests_mock: Mocker): compass = CompassClient(index_url="http://test.com") compass.list_indexes() assert requests_mock.request_history[0].method == "GET" assert requests_mock.request_history[0].url == "http://test.com/api/v1/indexes" -def test_get_documents_is_valid(requests_mock): +def test_get_documents_is_valid(requests_mock: Mocker): compass = CompassClient(index_url="http://test.com") compass.get_document(index_name="test_index", doc_id="test_id") assert requests_mock.request_history[0].method == "GET" @@ -61,7 +63,7 @@ def test_get_documents_is_valid(requests_mock): ) -def test_refresh_is_valid(requests_mock): +def test_refresh_is_valid(requests_mock: Mocker): compass = CompassClient(index_url="http://test.com") compass.refresh(index_name="test_index") assert requests_mock.request_history[0].method == "POST" @@ -71,7 +73,7 @@ def test_refresh_is_valid(requests_mock): ) -def test_add_context_is_valid(requests_mock): +def test_add_context_is_valid(requests_mock: Mocker): compass = CompassClient(index_url="http://test.com") compass.add_context( index_name="test_index", doc_id="test_id", context={"fake": "context"} diff --git a/tests/test_utils.py b/tests/test_utils.py index bf723a2..4fa9d20 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,6 +1,6 @@ from concurrent.futures import ThreadPoolExecutor -from compass_sdk.utils import imap_queued +from cohere.compass.utils import imap_queued def test_imap_queued():