Skip to content

Commit

Permalink
Move modules under "cohere.compass" + More refactoring (#56)
Browse files Browse the repository at this point in the history
- Move all modules under "cohere.compass".
- Remove `tqdm` from dependencies; we are not using it.
- Added `pyright` type checking to pre-commit.
- Changed pyright checking mode from `basic`, which, as the name
suggests, is basic and didn't catch some errors to `strict` to ensure we
catch as many errors as possible at coding time rather than run time.
- Updated README.md to have pyright certification at the top ;-)
- Changed Python support window to `[3.9, 4.0)`.
  • Loading branch information
corafid authored Dec 5, 2024
1 parent 46cd001 commit ec32e5b
Show file tree
Hide file tree
Showing 25 changed files with 250 additions and 183 deletions.
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
name: Formatting
name: Pre-commit Checks

on:
pull_request: {}
workflow_dispatch: {}

jobs:
build:
name: Formatting
name: Run pre-commit checks
runs-on: ubuntu-latest
strategy:
matrix:
python-version:
- 3.9
- "3.9"
- "3.10"
- "3.11"
- "3.12"
- "3.13"

steps:
- uses: actions/checkout@v4
Expand All @@ -21,10 +25,18 @@ jobs:
with:
python-version: ${{ matrix.python-version }}

- name: Upgrade pip & install requirements
- name: Install poetry
run: |
pip install poetry
- name: Install dependencies
run: |
poetry install
- name: Install pre-commit
run: |
pip install pre-commit
- name: Formatting
- name: Run pre-commit
run: |
pre-commit run --all-files
20 changes: 11 additions & 9 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ jobs:
strategy:
matrix:
python-version:
- 3.11
- "3.9"
- "3.10"
- "3.11"
- "3.12"
- "3.13"

steps:
- uses: actions/checkout@v4
Expand All @@ -47,17 +51,15 @@ jobs:
cache-dependency-path: |
poetry.lock
- name: Install dependencies (tests)
- name: Install poetry
run: |
pip install pytest pytest-asyncio pytest-mock requests-mock
pip install poetry
- name: Install dependencies
working-directory: .
- name: Install dependencies
run: |
pip install -e .
poetry install
- name: Run tests
- name: Run tests
working-directory: .
run: |
echo $COHERE_API_KEY
pytest -sv tests/test_compass_client.py
poetry run pytest -sv
35 changes: 0 additions & 35 deletions .github/workflows/typecheck.yml

This file was deleted.

4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ repos:
args: [--fix]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.390
hooks:
- id: pyright
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Cohere Compass SDK

[![Checked with pyright](https://microsoft.github.io/pyright/img/pyright_badge.svg)](https://microsoft.github.io/pyright/)

The Compass SDK is a Python library that allows you to parse documents and insert them
into a Compass index.

Expand Down
4 changes: 2 additions & 2 deletions compass_sdk/__init__.py → cohere/compass/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
from pydantic import BaseModel

# Local imports
from compass_sdk.models import (
from cohere.compass.models import (
MetadataConfig,
ParserConfig,
ValidatedModel,
)

__version__ = "0.7.0"
__version__ = "0.8.0"


class ProcessFileParameters(ValidatedModel):
Expand Down
3 changes: 3 additions & 0 deletions cohere/compass/clients/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from cohere.compass.clients.compass import * # noqa: F403
from cohere.compass.clients.parser import * # noqa: F403
from cohere.compass.clients.rbac import * # noqa: F403
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
import uuid

# 3rd party imports
from joblib import Parallel, delayed
# TODO find stubs for joblib and remove "type: ignore"
from joblib import Parallel, delayed # type: ignore
from pydantic import BaseModel
from requests.exceptions import InvalidSchema
from tenacity import (
Expand All @@ -23,22 +24,22 @@
import requests

# Local imports
from compass_sdk import (
from cohere.compass import (
GroupAuthorizationInput,
)
from compass_sdk.constants import (
from cohere.compass.constants import (
DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES,
DEFAULT_MAX_CHUNKS_PER_REQUEST,
DEFAULT_MAX_ERROR_RATE,
DEFAULT_MAX_RETRIES,
DEFAULT_SLEEP_RETRY_SECONDS,
)
from compass_sdk.exceptions import (
from cohere.compass.exceptions import (
CompassAuthError,
CompassClientError,
CompassMaxErrorRateExceeded,
)
from compass_sdk.models import (
from cohere.compass.models import (
Chunk,
CompassDocument,
CompassDocumentStatus,
Expand All @@ -60,7 +61,7 @@

@dataclass
class RetryResult:
result: Optional[dict] = None
result: Optional[dict[str, Any]] = None
error: Optional[str] = None


Expand All @@ -75,9 +76,9 @@ def __init__(self, timeout: int):
self._timeout = timeout
super().__init__()

def request(self, method, url, **kwargs):
def request(self, *args: Any, **kwargs: Any):
kwargs.setdefault("timeout", self._timeout)
return super().request(method, url, **kwargs)
return super().request(*args, **kwargs)


class CompassClient:
Expand Down Expand Up @@ -231,7 +232,7 @@ def add_context(
*,
index_name: str,
doc_id: str,
context: Dict,
context: dict[str, Any],
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
) -> Optional[RetryResult]:
Expand Down Expand Up @@ -291,7 +292,7 @@ def upload_document(
context: Dict[str, Any] = {},
max_retries: int = DEFAULT_MAX_RETRIES,
sleep_retry_seconds: int = DEFAULT_SLEEP_RETRY_SECONDS,
) -> Optional[str | Dict]:
) -> Optional[Union[str, Dict[str, Any]]]:
"""
Parse and insert a document into an index in Compass
:param index_name: the name of the index
Expand Down Expand Up @@ -362,8 +363,8 @@ def insert_docs(
"""

def put_request(
request_data: List[Tuple[CompassDocument, Document]],
previous_errors: List[CompassDocument],
request_data: list[Tuple[CompassDocument, Document]],
previous_errors: list[dict[str, str]],
num_doc: int,
) -> None:
nonlocal num_succeeded, errors
Expand Down Expand Up @@ -420,11 +421,11 @@ def put_request(
f"in the last {errors_sliding_window_size} API calls. Stopping the insertion process."
)

error_window = deque(
error_window: deque[Optional[str]] = deque(
maxlen=errors_sliding_window_size
) # Keep track of the results of the last N API calls
num_succeeded = 0
errors = []
errors: list[dict[str, str]] = []
requests_iter = self._get_request_blocks(docs, max_chunks_per_request)

try:
Expand Down Expand Up @@ -556,17 +557,18 @@ def list_datasources_objects_states(
def _get_request_blocks(
docs: Iterator[CompassDocument],
max_chunks_per_request: int,
) -> Iterator:
):
"""
Create request blocks to send to the Compass API
:param docs: the documents to send
:param max_chunks_per_request: the maximum number of chunks to send in a single API request
:return: an iterator over the request blocks
"""

request_block, errors = [], []
request_block: list[tuple[CompassDocument, Document]] = []
errors: list[dict[str, str]] = []
num_chunks = 0
for num_doc, doc in enumerate(docs, 1):
for _, doc in enumerate(docs, 1):
if doc.status != CompassDocumentStatus.Success:
logger.error(f"Document {doc.metadata.doc_id} has errors: {doc.errors}")
for error in doc.errors:
Expand Down Expand Up @@ -679,7 +681,7 @@ def _send_request(
api_name: str,
max_retries: int,
sleep_retry_seconds: int,
data: Optional[Union[Dict, BaseModel]] = None,
data: Optional[Union[Dict[str, Any], BaseModel]] = None,
**url_params: str,
) -> RetryResult:
"""
Expand Down Expand Up @@ -710,11 +712,13 @@ def _send_request_with_retry():
if data:
if isinstance(data, BaseModel):
data_dict = data.model_dump(mode="json")
elif isinstance(data, Dict):
else:
data_dict = data

headers = None
auth = (self.username, self.password)
auth = None
if self.username and self.password:
auth = (self.username, self.password)
if self.bearer_token:
headers = {"Authorization": f"Bearer {self.bearer_token}"}
auth = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@
import requests

# Local imports
from compass_sdk import (
from cohere.compass import (
ProcessFileParameters,
)
from compass_sdk.constants import DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES
from compass_sdk.models import (
from cohere.compass.constants import DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES
from cohere.compass.models import (
CompassDocument,
MetadataConfig,
ParserConfig,
)
from compass_sdk.utils import imap_queued, open_document, scan_folder
from cohere.compass.utils import imap_queued, open_document, scan_folder

Fn_or_Dict = Union[Dict[str, Any], Callable[[CompassDocument], Dict[str, Any]]]

Expand Down Expand Up @@ -254,7 +254,7 @@ def process_file(
)

if res.ok:
docs = []
docs: list[CompassDocument] = []
for doc in res.json()["docs"]:
if not doc.get("errors", []):
compass_doc = CompassDocument(**doc)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from pydantic import BaseModel
from requests import HTTPError

from compass_sdk.models import (
from cohere.compass.models import (
GroupCreateRequest,
GroupCreateResponse,
GroupDeleteResponse,
Expand Down
File renamed without changes.
6 changes: 3 additions & 3 deletions compass_sdk/exceptions.py → cohere/compass/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
class CompassClientError(Exception):
"""Exception raised for all 4xx client errors in the Compass client."""

def __init__(self, message="Client error occurred."):
def __init__(self, message: str = "Client error occurred."):
self.message = message
super().__init__(self.message)

Expand All @@ -11,7 +11,7 @@ class CompassAuthError(CompassClientError):

def __init__(
self,
message=(
message: str = (
"CompassAuthError - check your bearer token or username and password."
),
):
Expand All @@ -25,7 +25,7 @@ class CompassMaxErrorRateExceeded(Exception):

def __init__(
self,
message="The maximum error rate was exceeded. Stopping the insertion process.",
message: str = "The maximum error rate was exceeded. Stopping the insertion process.",
):
self.message = message
super().__init__(self.message)
29 changes: 29 additions & 0 deletions cohere/compass/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from typing import Any

# import models into model package
from pydantic import BaseModel


class ValidatedModel(BaseModel):
class Config:
arbitrary_types_allowed = True
use_enum_values = True

@classmethod
def attribute_in_model(cls, attr_name: str):
return attr_name in cls.model_fields

def __init__(self, **data: dict[str, Any]):
for name, _value in data.items():
if not self.attribute_in_model(name):
raise ValueError(
f"{name} is not a valid attribute for {self.__class__.__name__}"
)
super().__init__(**data)


from cohere.compass.models.config import * # noqa: E402, F403
from cohere.compass.models.datasources import * # noqa: E402, F403
from cohere.compass.models.documents import * # noqa: E402, F403
from cohere.compass.models.rbac import * # noqa: E402, F403
from cohere.compass.models.search import * # noqa: E402, F403
Loading

0 comments on commit ec32e5b

Please sign in to comment.