From ca4438e92a072a36120fe8c3810c91d0f41035c8 Mon Sep 17 00:00:00 2001 From: Rafid Date: Tue, 3 Dec 2024 16:00:25 -0800 Subject: [PATCH] General refactoring and code quality improvement (#54) - Introduced a new sub-module called "clients" and moved all clients to it, namely CompassClient, CompassParserClient, and CompassRootClient. - Moved RBAC-related models (previously under `types.py`) under the `models` module and renamed it accordingly. - Introduced ruff for code formatting, instead of black. Ruff is getting increased traction from the community to its features and efficiency. - Reformatted the files to ensure code and docstrings are wrapped at 88 chars as suggested by Ruff. - Removed the Logger and LoggerLevel classes and used `getLogger` instead to rely on the customer to configure their logging. - Move models from `__init__.py` to `models` module - Updated README.md file. --- .pre-commit-config.yaml | 21 +- README.md | 105 +++++- compass_sdk/__init__.py | 428 +---------------------- compass_sdk/clients/__init__.py | 3 + compass_sdk/{ => clients}/compass.py | 103 ++++-- compass_sdk/{ => clients}/parser.py | 145 +++++--- compass_sdk/{ => clients}/rbac.py | 98 ++++-- compass_sdk/exceptions.py | 4 +- compass_sdk/models/__init__.py | 28 +- compass_sdk/models/config.py | 170 +++++++++ compass_sdk/models/documents.py | 188 ++++++++++ compass_sdk/{types.py => models/rbac.py} | 2 + compass_sdk/models/search.py | 37 +- compass_sdk/utils.py | 28 +- poetry.lock | 29 +- poetry.toml | 2 + pyproject.toml | 6 +- tests/test_compass_client.py | 37 +- 18 files changed, 855 insertions(+), 579 deletions(-) create mode 100644 compass_sdk/clients/__init__.py rename compass_sdk/{ => clients}/compass.py (92%) rename compass_sdk/{ => clients}/parser.py (62%) rename compass_sdk/{ => clients}/rbac.py (60%) create mode 100644 compass_sdk/models/config.py create mode 100644 compass_sdk/models/documents.py rename compass_sdk/{types.py => models/rbac.py} (97%) create mode 100644 poetry.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index baeb835..91e42e8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,16 +1,9 @@ repos: - - repo: https://github.com/pycqa/isort - rev: "5.13.2" + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.8.1 hooks: - - id: isort - args: ["--profile", "black", "--filter-files", "--line-length", "120", "--py", "39"] - - repo: https://github.com/psf/black - rev: "24.4.0" - hooks: - - id: black - args: ["--line-length=120", "--target-version=py39"] - - repo: https://github.com/pycqa/autoflake - rev: "v2.3.1" - hooks: - - id: autoflake - args: ["--in-place", "--remove-all-unused-imports", "--expand-star-imports", "--ignore-init-module-imports", "-r"] + # Run the linter. + - id: ruff + args: [--fix] + # Run the formatter. + - id: ruff-format diff --git a/README.md b/README.md index 5428845..3afca57 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,62 @@ -# Cohere Compass SDK +# Cohere Compass SDK -The Compass SDK is a Python library that allows you to parse documents and insert them into a Compass index. +The Compass SDK is a Python library that allows you to parse documents and insert them +into a Compass index. -In order to parse documents, the Compass SDK relies on the Compass Parser API, which is a RESTful API that -receives files and returns parsed documents. This requires a hosted Compass server. +In order to parse documents, the Compass SDK relies on the Compass Parser API, which is +a RESTful API that receives files and returns parsed documents. This requires a hosted +Compass server. -The Compass SDK provides a `CompassParserClient` that allows to interact with the parser API from your -Python code in a convenient manner. The `CompassParserClient` provides methods to parse single and multiple -files, as well as entire folders, and supports multiple file types (e.g., `pdf`, `docx`, `json`, `csv`, etc.) as well -as different file systems (e.g., local, S3, GCS, etc.). +The Compass SDK provides a `CompassParserClient` that allows to interact with the parser +API from your Python code in a convenient manner. The `CompassParserClient` provides +methods to parse single and multiple files, as well as entire folders, and supports +multiple file types (e.g., `pdf`, `docx`, `json`, `csv`, etc.) as well as different file +systems (e.g., local, S3, GCS, etc.). -To insert parsed documents into a `Compass` index, the Compass SDK provides a `CompassClient` class that -allows to interact with a Compass API server. The Compass API is also a RESTful API that allows to create, -delete and search documents in a Compass index. To install a Compass API service, please refer to the -[Compass documentation](https://github.com/cohere-ai/compass) +To insert parsed documents into a `Compass` index, the Compass SDK provides a +`CompassClient` class that allows to interact with a Compass API server. The Compass API +is also a RESTful API that allows to create, delete and search documents in a Compass +index. To install a Compass API service, please refer to the [Compass +documentation](https://github.com/cohere-ai/compass) -## Quickstart Snippet +## Table of Contents -Fill in your URL, username, password, and path to test data below for an end to end run of parsing and searching. + + +- [Getting Started](#getting-started) +- [Local Development](#local-development) + - [Create Python Virtual Environment](#create-python-virtual-environment) + - [Running Tests Locally](#running-tests-locally) + - [VSCode Users](#vscode-users) + - [Pre-commit](#pre-commit) + + + +## Getting Started + +Fill in your URL, username, password, and path to test data below for an end to end run +of parsing and searching. + +```Python +from compass_sdk.clients import CompassClient, CompassParserClient from compass_sdk import MetadataStrategy, MetadataConfig # Using cohere_web_test folder for data url = "" -username = "" +username = "" password = "" index = "test-index" @@ -49,5 +79,42 @@ compass_client.create_index(index_name=index) results = compass_client.insert_docs(index_name=index, docs=docs_to_index) results = compass_client.search(index_name=index, query="test", top_k=1) -print(f"Results preview: \n {results.result['hits'][-1]} ... \n \n ") +print(f"Results preview: \n {results.result['hits'][-1]} ... \n \n ") +``` + +## Local Development + +### Create Python Virtual Environment + +We use Poetry to manage our Python environment. To create the virtual environment use +the following command: + +``` +poetry install +``` + +### Running Tests Locally + +We use `pytest` for testing. So, you can simply run tests using the following command: + +``` +poetry run python -m pytest +``` + +#### VSCode Users + +We provide `.vscode` folder for those developers who prefer to use VSCode. You just need +to open the folder in VSCode and VSCode should pick our settings. + +### Pre-commit + +We love and appreciate Coding Standards and so we enforce them in our code base. +However, without automation, enforcing Coding Standards usually result in a lot of +frustration for developers when they publish Pull Requests and our linters complain. So, +we automate our formatting and linting with [pre-commit](https://pre-commit.com/). All +you need to do is install our `pre-commit` hook so the code gets formatted automatically +when you commit your changes locally: + +```bash +pip install pre-commit ``` diff --git a/compass_sdk/__init__.py b/compass_sdk/__init__.py index b9999b3..063012b 100644 --- a/compass_sdk/__init__.py +++ b/compass_sdk/__init__.py @@ -1,420 +1,18 @@ -import logging -import math -import uuid -from dataclasses import field -from enum import Enum, StrEnum -from os import getenv -from typing import Annotated, Any, Dict, List, Optional, Union - -from pydantic import BaseModel, ConfigDict, Field, PositiveInt, StringConstraints - -from compass_sdk.constants import ( - COHERE_API_ENV_VAR, - DEFAULT_COMMANDR_EXTRACTABLE_ATTRIBUTES, - DEFAULT_COMMANDR_PROMPT, - DEFAULT_MIN_CHARS_PER_ELEMENT, - DEFAULT_MIN_NUM_CHUNKS_IN_TITLE, - DEFAULT_MIN_NUM_TOKENS_CHUNK, - DEFAULT_NUM_TOKENS_CHUNK_OVERLAP, - DEFAULT_NUM_TOKENS_PER_CHUNK, - METADATA_HEURISTICS_ATTRIBUTES, - SKIP_INFER_TABLE_TYPES, +# Python imports +from enum import Enum +from typing import List, Optional + +# 3rd party imports +from pydantic import BaseModel + +# Local imports +from compass_sdk.models import ( + MetadataConfig, + ParserConfig, + ValidatedModel, ) - -class Logger: - def __init__(self, name: str, log_level: int = logging.INFO): - self._logger = logging.getLogger(name) - self._logger.setLevel(log_level) - - formatter = logging.Formatter(f"%(asctime)s-{name}-PID:%(process)d: %(message)s", "%d-%m-%y:%H:%M:%S") - stream_handler = logging.StreamHandler() - stream_handler.setFormatter(formatter) - self._logger.addHandler(stream_handler) - - def info(self, msg: str): - self._logger.info(msg) - - def debug(self, msg: str): - self._logger.debug(msg) - - def error(self, msg: str): - self._logger.error(msg) - - def critical(self, msg: str): - self._logger.critical(msg) - - def warning(self, msg: str): - self._logger.warning(msg) - - def flush(self): - for handler in self._logger.handlers: - handler.flush() - - def setLevel(self, level: Union[int, str]): - self._logger.setLevel(level) - - -logger = Logger(name="compass-sdk", log_level=logging.INFO) - - -class ValidatedModel(BaseModel): - class Config: - arbitrary_types_allowed = True - use_enum_values = True - - @classmethod - def attribute_in_model(cls, attr_name): - return attr_name in cls.__fields__ - - def __init__(self, **data): - for name, value in data.items(): - if not self.attribute_in_model(name): - raise ValueError(f"{name} is not a valid attribute for {self.__class__.__name__}") - super().__init__(**data) - - -class CompassDocumentMetadata(ValidatedModel): - """ - Compass document metadata - """ - - doc_id: str = "" - filename: str = "" - meta: List = field(default_factory=list) - parent_doc_id: str = "" - - -class CompassDocumentStatus(str, Enum): - """ - Compass document status - """ - - Success = "success" - ParsingErrors = "parsing-errors" - MetadataErrors = "metadata-errors" - IndexingErrors = "indexing-errors" - - -class CompassSdkStage(str, Enum): - """ - Compass SDK stages - """ - - Parsing = "parsing" - Metadata = "metadata" - Chunking = "chunking" - Indexing = "indexing" - - -class CompassDocumentChunkAsset(BaseModel): - content_type: str - asset_data: str - - -class CompassDocumentChunk(BaseModel): - chunk_id: str - sort_id: str - doc_id: str - parent_doc_id: str - content: Dict[str, Any] - origin: Optional[Dict[str, Any]] = None - assets: Optional[list[CompassDocumentChunkAsset]] = None - - def parent_doc_is_split(self): - return self.doc_id != self.parent_doc_id - - -class CompassDocument(ValidatedModel): - """ - A Compass document contains all the information required to process a document and insert it into the index - It includes: - - metadata: the document metadata (e.g., filename, title, authors, date) - - content: the document content in string format - - elements: the document's Unstructured elements (e.g., tables, images, text). Used for chunking - - chunks: the document's chunks (e.g., paragraphs, tables, images). Used for indexing - - index_fields: the fields to be indexed. Used by the indexer - """ - - filebytes: bytes = b"" - metadata: CompassDocumentMetadata = CompassDocumentMetadata() - content: Dict[str, str] = field(default_factory=dict) - content_type: Optional[str] = None - elements: List[Any] = field(default_factory=list) - chunks: List[CompassDocumentChunk] = field(default_factory=list) - index_fields: List[str] = field(default_factory=list) - errors: List[Dict[CompassSdkStage, str]] = field(default_factory=list) - ignore_metadata_errors: bool = True - markdown: Optional[str] = None - - def has_data(self) -> bool: - return len(self.filebytes) > 0 - - def has_markdown(self) -> bool: - return self.markdown is not None - - def has_filename(self) -> bool: - return len(self.metadata.filename) > 0 - - def has_metadata(self) -> bool: - return len(self.metadata.meta) > 0 - - def has_parsing_errors(self) -> bool: - return any(stage == CompassSdkStage.Parsing for error in self.errors for stage, _ in error.items()) - - def has_metadata_errors(self) -> bool: - return any(stage == CompassSdkStage.Metadata for error in self.errors for stage, _ in error.items()) - - def has_indexing_errors(self) -> bool: - return any(stage == CompassSdkStage.Indexing for error in self.errors for stage, _ in error.items()) - - @property - def status(self) -> CompassDocumentStatus: - if self.has_parsing_errors(): - return CompassDocumentStatus.ParsingErrors - - if not self.ignore_metadata_errors and self.has_metadata_errors(): - return CompassDocumentStatus.MetadataErrors - - if self.has_indexing_errors(): - return CompassDocumentStatus.IndexingErrors - - return CompassDocumentStatus.Success - - -class MetadataStrategy(str, Enum): - No_Metadata = "no_metadata" - Naive_Title = "naive_title" - KeywordSearch = "keyword_search" - Bart = "bart" - Command_R = "command_r" - Custom = "custom" - - @classmethod - def _missing_(cls, value): - return cls.No_Metadata - - -class LoggerLevel(str, Enum): - DEBUG = "DEBUG" - INFO = "INFO" - WARNING = "WARNING" - ERROR = "ERROR" - CRITICAL = "CRITICAL" - - @classmethod - def _missing_(cls, value): - return cls.INFO - - -class MetadataConfig(ValidatedModel): - """ - Configuration class for metadata detection. - :param metadata_strategy: the metadata detection strategy to use. One of: - - No_Metadata: no metadata is inferred - - Heuristics: metadata is inferred using heuristics - - Bart: metadata is inferred using the BART summarization model - - Command_R: metadata is inferred using the Command-R summarization model - :param cohere_api_key: the Cohere API key to use for metadata detection - :param commandr_model_name: the name of the Command-R model to use for metadata detection - :param commandr_prompt: the prompt to use for the Command-R model - :param commandr_extractable_attributes: the extractable attributes for the Command-R model - :param commandr_max_tokens: the maximum number of tokens to use for the Command-R model - :param keyword_search_attributes: the attributes to search for in the document when using keyword search - :param keyword_search_separator: the separator to use for nested attributes when using keyword search - :param ignore_errors: if set to True, metadata detection errors will not be raised or stop the parsing process - - """ - - metadata_strategy: MetadataStrategy = MetadataStrategy.No_Metadata - cohere_api_key: Optional[str] = getenv(COHERE_API_ENV_VAR, None) - commandr_model_name: str = "command-r" - commandr_prompt: str = DEFAULT_COMMANDR_PROMPT - commandr_max_tokens: int = 500 - commandr_extractable_attributes: List[str] = DEFAULT_COMMANDR_EXTRACTABLE_ATTRIBUTES - keyword_search_attributes: List[str] = METADATA_HEURISTICS_ATTRIBUTES - keyword_search_separator: str = "." - ignore_errors: bool = True - - -class ParsingStrategy(str, Enum): - Fast = "fast" - Hi_Res = "hi_res" - - @classmethod - def _missing_(cls, value): - return cls.Fast - - -class ParsingModel(str, Enum): - Marker = "marker" # Default model, it is actually a combination of models used by the Marker PDF parser - YoloX_Quantized = "yolox_quantized" # Only PDF parsing working option from Unstructured - - @classmethod - def _missing_(cls, value): - return cls.Marker - - -class DocumentFormat(str, Enum): - Markdown = "markdown" - Text = "text" - - @classmethod - def _missing_(cls, value): - return cls.Markdown - - -class PDFParsingStrategy(StrEnum): - QuickText = "QuickText" - ImageToMarkdown = "ImageToMarkdown" - - @classmethod - def _missing_(cls, value): - return cls.QuickText - - -class PresentationParsingStrategy(StrEnum): - Unstructured = "Unstructured" - ImageToMarkdown = "ImageToMarkdown" - - @classmethod - def _missing_(cls, value): - return cls.Unstructured - - -class ParserConfig(BaseModel): - """ - CompassParser configuration. Important parameters: - :param parsing_strategy: the parsing strategy to use: - - 'auto' (default): automatically determine the best strategy - - 'fast': leverage traditional NLP extraction techniques to quickly pull all the - text elements. “Fast” strategy is not good for image based file types. - - 'hi_res': identifies the layout of the document using detectron2. The advantage of “hi_res” - is that it uses the document layout to gain additional information about document elements. - We recommend using this strategy if your use case is highly sensitive to correct - classifications for document elements. - - 'ocr_only': leverage Optical Character Recognition to extract text from the image based files. - :param parsing_model: the parsing model to use. One of: - - yolox_quantized (default): single-stage object detection model, quantized. Runs faster than YoloX - See https://unstructured-io.github.io/unstructured/best_practices/models.html for more details. - We have temporarily removed the option to use other models because - of ongoing stability issues. - - """ - - model_config = ConfigDict( - arbitrary_types_allowed=True, - extra="ignore", - ) - - # CompassParser configuration - logger_level: LoggerLevel = LoggerLevel.INFO - parse_tables: bool = True - parse_images: bool = True - parsed_images_output_dir: Optional[str] = None - allowed_image_types: Optional[List[str]] = None - min_chars_per_element: int = DEFAULT_MIN_CHARS_PER_ELEMENT - skip_infer_table_types: List[str] = SKIP_INFER_TABLE_TYPES - parsing_strategy: ParsingStrategy = ParsingStrategy.Fast - parsing_model: ParsingModel = ParsingModel.YoloX_Quantized - - # CompassChunker configuration - num_tokens_per_chunk: int = DEFAULT_NUM_TOKENS_PER_CHUNK - num_tokens_overlap: int = DEFAULT_NUM_TOKENS_CHUNK_OVERLAP - min_chunk_tokens: int = DEFAULT_MIN_NUM_TOKENS_CHUNK - num_chunks_in_title: int = DEFAULT_MIN_NUM_CHUNKS_IN_TITLE - max_tokens_metadata: int = math.floor(num_tokens_per_chunk * 0.1) - include_tables: bool = True - - # Formatting configuration - output_format: DocumentFormat = DocumentFormat.Markdown - - # Visual elements extraction configuration - extract_visual_elements: bool = False - vertical_table_crop_margin: int = 100 - horizontal_table_crop_margin: int = 100 - - pdf_parsing_strategy: PDFParsingStrategy = PDFParsingStrategy.QuickText - presentation_parsing_strategy: PresentationParsingStrategy = PresentationParsingStrategy.Unstructured - - -### Document indexing - - -class DocumentChunkAsset(BaseModel): - content_type: str - asset_data: str - - -class Chunk(BaseModel): - chunk_id: str - sort_id: int - content: Dict[str, Any] - origin: Optional[Dict[str, Any]] = None - assets: Optional[list[DocumentChunkAsset]] = None - parent_doc_id: str - - -class Document(BaseModel): - """ - A document that can be indexed in Compass (i.e., a list of indexable chunks) - """ - - doc_id: str - path: str - parent_doc_id: str - content: Dict[str, Any] - chunks: List[Chunk] - index_fields: List[str] = field(default_factory=list) - - -class ParseableDocument(BaseModel): - """ - A document to be sent to Compass in bytes format for parsing on the Compass side - """ - - id: uuid.UUID - filename: Annotated[str, StringConstraints(min_length=1)] # Ensures the filename is a non-empty string - content_type: str - content_length_bytes: PositiveInt # File size must be a non-negative integer - content_encoded_bytes: str # Base64-encoded file contents - context: Dict[str, Any] = Field(default_factory=dict) - - -class PushDocumentsInput(BaseModel): - documents: List[ParseableDocument] - - -class SearchFilter(BaseModel): - class FilterType(str, Enum): - EQ = "$eq" - LT_EQ = "$lte" - GT_EQ = "$gte" - WORD_MATCH = "$wordMatch" - - field: str - type: FilterType - value: Any - - -class SearchInput(BaseModel): - """ - Search query input - """ - - query: str - top_k: int - filters: Optional[List[SearchFilter]] = None - - -class PutDocumentsInput(BaseModel): - """ - A Compass request to put a list of Document - """ - - docs: List[Document] - authorized_groups: Optional[List[str]] = None - merge_groups_on_conflict: bool = False +__version__ = "0.6.0" class ProcessFileParameters(ValidatedModel): diff --git a/compass_sdk/clients/__init__.py b/compass_sdk/clients/__init__.py new file mode 100644 index 0000000..bdefd9c --- /dev/null +++ b/compass_sdk/clients/__init__.py @@ -0,0 +1,3 @@ +from compass_sdk.clients.compass import * # noqa: F403 +from compass_sdk.clients.parser import * # noqa: F403 +from compass_sdk.clients.rbac import * # noqa: F403 diff --git a/compass_sdk/compass.py b/compass_sdk/clients/compass.py similarity index 92% rename from compass_sdk/compass.py rename to compass_sdk/clients/compass.py index f7f5853..7012510 100644 --- a/compass_sdk/compass.py +++ b/compass_sdk/clients/compass.py @@ -1,32 +1,30 @@ -import base64 -import os -import threading -import uuid +# Python imports from collections import deque from dataclasses import dataclass from statistics import mean from typing import Any, Dict, Iterator, List, Literal, Optional, Tuple, Union +import base64 +import logging +import os +import threading +import uuid -import requests +# 3rd party imports from joblib import Parallel, delayed from pydantic import BaseModel from requests.exceptions import InvalidSchema -from tenacity import RetryError, retry, retry_if_not_exception_type, stop_after_attempt, wait_fixed +from tenacity import ( + RetryError, + retry, + retry_if_not_exception_type, + stop_after_attempt, + wait_fixed, +) +import requests +# Local imports from compass_sdk import ( - Chunk, - CompassDocument, - CompassDocumentStatus, - CompassSdkStage, - Document, GroupAuthorizationInput, - LoggerLevel, - ParseableDocument, - PushDocumentsInput, - PutDocumentsInput, - SearchFilter, - SearchInput, - logger, ) from compass_sdk.constants import ( DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES, @@ -35,13 +33,27 @@ DEFAULT_MAX_RETRIES, DEFAULT_SLEEP_RETRY_SECONDS, ) -from compass_sdk.exceptions import CompassAuthError, CompassClientError, CompassMaxErrorRateExceeded +from compass_sdk.exceptions import ( + CompassAuthError, + CompassClientError, + CompassMaxErrorRateExceeded, +) from compass_sdk.models import ( + Chunk, + CompassDocument, + CompassDocumentStatus, + CompassSdkStage, CreateDataSource, DataSource, + Document, PaginatedList, + ParseableDocument, + PushDocumentsInput, + PutDocumentsInput, SearchChunksResponse, SearchDocumentsResponse, + SearchFilter, + SearchInput, ) @@ -54,6 +66,9 @@ class RetryResult: _DEFAULT_TIMEOUT = 30 +logger = logging.getLogger(__name__) + + class SessionWithDefaultTimeout(requests.Session): def __init__(self, timeout: int): self._timeout = timeout @@ -72,7 +87,6 @@ def __init__( username: Optional[str] = None, password: Optional[str] = None, bearer_token: Optional[str] = None, - logger_level: LoggerLevel = LoggerLevel.INFO, default_timeout: int = _DEFAULT_TIMEOUT, ): """ @@ -125,7 +139,6 @@ def __init__( "delete_datasources": "/api/v1/datasources/{datasource_id}", "get_datasource": "/api/v1/datasources/{datasource_id}", } - logger.setLevel(logger_level.value) def create_index(self, *, index_name: str): """ @@ -350,7 +363,9 @@ def put_request( ) -> None: nonlocal num_succeeded, errors errors.extend(previous_errors) - compass_docs: List[CompassDocument] = [compass_doc for compass_doc, _ in request_data] + compass_docs: List[CompassDocument] = [ + compass_doc for compass_doc, _ in request_data + ] put_docs_input = PutDocumentsInput( docs=[input_doc for _, input_doc in request_data], authorized_groups=authorized_groups, @@ -373,22 +388,36 @@ def put_request( if results.error: for doc in compass_docs: - doc.errors.append({CompassSdkStage.Indexing: f"{doc.metadata.filename}: {results.error}"}) - errors.append({doc.metadata.doc_id: f"{doc.metadata.filename}: {results.error}"}) + doc.errors.append( + { + CompassSdkStage.Indexing: f"{doc.metadata.filename}: {results.error}" + } + ) + errors.append( + { + doc.metadata.doc_id: f"{doc.metadata.filename}: {results.error}" + } + ) else: num_succeeded += len(compass_docs) # Keep track of the results of the last N API calls to calculate the error rate # If the error rate is higher than the threshold, stop the insertion process error_window.append(results.error) - error_rate = mean([1 if x else 0 for x in error_window]) if len(error_window) == error_window.maxlen else 0 + error_rate = ( + mean([1 if x else 0 for x in error_window]) + if len(error_window) == error_window.maxlen + else 0 + ) if error_rate > max_error_rate: raise CompassMaxErrorRateExceeded( f"[Thread {threading.get_native_id()}]{error_rate * 100}% of insertions failed " f"in the last {errors_sliding_window_size} API calls. Stopping the insertion process." ) - error_window = deque(maxlen=errors_sliding_window_size) # Keep track of the results of the last N API calls + error_window = deque( + maxlen=errors_sliding_window_size + ) # Keep track of the results of the last N API calls num_succeeded = 0 errors = [] requests_iter = self._get_request_blocks(docs, max_chunks_per_request) @@ -498,7 +527,11 @@ def _get_request_blocks( for error in doc.errors: errors.append({doc.metadata.doc_id: list(error.values())[0]}) else: - num_chunks += len(doc.chunks) if doc.status == CompassDocumentStatus.Success else 0 + num_chunks += ( + len(doc.chunks) + if doc.status == CompassDocumentStatus.Success + else 0 + ) if num_chunks > max_chunks_per_request: yield request_block, errors request_block, errors = [], [] @@ -580,7 +613,9 @@ def search_chunks( return SearchChunksResponse.model_validate(result.result) - def edit_group_authorization(self, *, index_name: str, group_auth_input: GroupAuthorizationInput): + def edit_group_authorization( + self, *, index_name: str, group_auth_input: GroupAuthorizationInput + ): """ Edit group authorization for an index :param index_name: the name of the index @@ -639,7 +674,9 @@ def _send_request_with_retry(): headers = {"Authorization": f"Bearer {self.bearer_token}"} auth = None - response = self.api_method[api_name](target_path, json=data_dict, auth=auth, headers=headers) + response = self.api_method[api_name]( + target_path, json=data_dict, auth=auth, headers=headers + ) if response.ok: error = None @@ -672,12 +709,16 @@ def _send_request_with_retry(): error = None try: - target_path = self.index_url + self.api_endpoint[api_name].format(**url_params) + target_path = self.index_url + self.api_endpoint[api_name].format( + **url_params + ) res = _send_request_with_retry() if res: return res else: return RetryResult(result=None, error=error) except RetryError: - logger.error(f"Failed to send request after {max_retries} attempts. Aborting.") + logger.error( + f"Failed to send request after {max_retries} attempts. Aborting." + ) return RetryResult(result=None, error=error) diff --git a/compass_sdk/parser.py b/compass_sdk/clients/parser.py similarity index 62% rename from compass_sdk/parser.py rename to compass_sdk/clients/parser.py index 76a4802..65cc8c9 100644 --- a/compass_sdk/parser.py +++ b/compass_sdk/clients/parser.py @@ -1,26 +1,42 @@ -import json -import os +# Python imports from concurrent.futures import ThreadPoolExecutor from typing import Any, Callable, Dict, Iterable, List, Optional, Union +import json +import logging +import os +# 3rd party imports import requests -from compass_sdk import CompassDocument, MetadataConfig, ParserConfig, ProcessFileParameters, logger +# Local imports +from compass_sdk import ( + ProcessFileParameters, +) from compass_sdk.constants import DEFAULT_MAX_ACCEPTED_FILE_SIZE_BYTES +from compass_sdk.models import ( + CompassDocument, + MetadataConfig, + ParserConfig, +) from compass_sdk.utils import imap_queued, open_document, scan_folder Fn_or_Dict = Union[Dict[str, Any], Callable[[CompassDocument], Dict[str, Any]]] +logger = logging.getLogger(__name__) + + class CompassParserClient: """ - Client to interact with the CompassParser API. It allows to process files using the parser and metadata - configurations specified in the parameters. The client is stateful, that is, it can be initialized with - parser and metadata configurations that will be used for all subsequent files processed by the client. - Also, independently of the default configurations, the client allows to pass specific configurations for each file - when calling the process_file or process_files methods. The client is responsible for opening the files and - sending them to the CompassParser API for processing. The resulting documents are returned as CompassDocument - objects. + Client to interact with the CompassParser API. It allows to process files using the + parser and metadata configurations specified in the parameters. The client is + stateful, that is, it can be initialized with parser and metadata configurations + that will be used for all subsequent files processed by the client. Also, + independently of the default configurations, the client allows to pass specific + configurations for each file when calling the process_file or process_files methods. + The client is responsible for opening the files and sending them to the + CompassParser API for processing. The resulting documents are returned as + CompassDocument objects. :param parser_url: URL of the CompassParser API :param parser_config: Default parser configuration to use when processing files @@ -39,18 +55,24 @@ def __init__( num_workers: int = 4, ): """ - Initializes the CompassParserClient with the specified parser_url, parser_config, and metadata_config. - The parser_config and metadata_config are optional, and if not provided, the default configurations will be used. - If the parser/metadata configs are provided, they will be used for all subsequent files processed by the client - unless specific configs are passed when calling the process_file or process_files methods. + Initializes the CompassParserClient with the specified parser_url, + parser_config, and metadata_config. The parser_config and metadata_config are + optional, and if not provided, the default configurations will be used. If the + parser/metadata configs are provided, they will be used for all subsequent files + processed by the client unless specific configs are passed when calling the + process_file or process_files methods. :param parser_url: the URL of the CompassParser API - :param parser_config: the parser configuration to use when processing files if no parser configuration - is specified in the method calls (process_file or process_files) - :param metadata_config: the metadata configuration to use when processing files if no metadata configuration - is specified in the method calls (process_file or process_files) + :param parser_config: the parser configuration to use when processing files if + no parser configuration is specified in the method calls (process_file or + process_files) + :param metadata_config: the metadata configuration to use when processing files + if no metadata configuration is specified in the method calls (process_file + or process_files) """ - self.parser_url = parser_url if not parser_url.endswith("/") else parser_url[:-1] + self.parser_url = ( + parser_url if not parser_url.endswith("/") else parser_url[:-1] + ) self.parser_config = parser_config self.username = username or os.getenv("COHERE_COMPASS_USERNAME") self.password = password or os.getenv("COHERE_COMPASS_PASSWORD") @@ -59,7 +81,9 @@ def __init__( self.num_workers = num_workers self.metadata_config = metadata_config - logger.info(f"CompassParserClient initialized with parser_url: {self.parser_url}") + logger.info( + f"CompassParserClient initialized with parser_url: {self.parser_url}" + ) def process_folder( self, @@ -72,23 +96,32 @@ def process_folder( custom_context: Optional[Fn_or_Dict] = None, ): """ - Processes all the files in the specified folder using the default parser and metadata configurations - passed when creating the client. The method iterates over all the files in the folder and processes them - using the process_file method. The resulting documents are returned as a list of CompassDocument objects. + Processes all the files in the specified folder using the default parser and + metadata configurations passed when creating the client. The method iterates + over all the files in the folder and processes them using the process_file + method. The resulting documents are returned as a list of CompassDocument + objects. :param folder_path: the folder to process :param allowed_extensions: the list of allowed extensions to process :param recursive: whether to process the folder recursively - :param parser_config: the parser configuration to use when processing files if no parser configuration - is specified in the method calls (process_file or process_files) - :param metadata_config: the metadata configuration to use when processing files if no metadata configuration - is specified in the method calls (process_file or process_files) - :param custom_context: Additional data to add to compass document. Fields will be filterable but not semantically searchable. - Can either be a dictionary or a callable that takes a CompassDocument and returns a dictionary. + :param parser_config: the parser configuration to use when processing files if + no parser configuration is specified in the method calls (process_file or + process_files) + :param metadata_config: the metadata configuration to use when processing files + if no metadata configuration is specified in the method calls (process_file + or process_files) + :param custom_context: Additional data to add to compass document. Fields will + be filterable but not semantically searchable. Can either be a dictionary + or a callable that takes a CompassDocument and returns a dictionary. :return: the list of processed documents """ - filenames = scan_folder(folder_path=folder_path, allowed_extensions=allowed_extensions, recursive=recursive) + filenames = scan_folder( + folder_path=folder_path, + allowed_extensions=allowed_extensions, + recursive=recursive, + ) return self.process_files( filenames=filenames, parser_config=parser_config, @@ -106,29 +139,33 @@ def process_files( custom_context: Optional[Fn_or_Dict] = None, ) -> Iterable[CompassDocument]: """ - Processes a list of files provided as filenames, using the specified parser and metadata configurations. + Processes a list of files provided as filenames, using the specified parser and + metadata configurations. - If the parser/metadata configs are not provided, then the default configs passed by parameter when - creating the client will be used. This makes the CompassParserClient stateful. That is, we can set the - parser/metadata configs only once when creating the parser client, and process all subsequent files + If the parser/metadata configs are not provided, then the default configs passed + by parameter when creating the client will be used. This makes the + CompassParserClient stateful. That is, we can set the parser/metadata configs + only once when creating the parser client, and process all subsequent files without having to pass the config every time. - All the documents passed as filenames and opened to obtain their bytes. Then, they are packed into a - ProcessFilesParameters object that contains a list of ProcessFileParameters, each contain a file, - its id, and the parser/metadata config + All the documents passed as filenames and opened to obtain their bytes. Then, + they are packed into a ProcessFilesParameters object that contains a list of + ProcessFileParameters, each contain a file, its id, and the parser/metadata + config :param filenames: List of filenames to process :param file_ids: List of ids for the files :param parser_config: ParserConfig object (applies the same config to all docs) - :param metadata_config: MetadataConfig object (applies the same config to all docs) - :param custom_context: Additional data to add to compass document. Fields will be filterable but not semantically searchable. - Can either be a dictionary or a callable that takes a CompassDocument and returns a dictionary. + :param metadata_config: MetadataConfig object (applies the same config to all + docs) + :param custom_context: Additional data to add to compass document. Fields will + be filterable but not semantically searchable. Can either be a dictionary + or a callable that takes a CompassDocument and returns a dictionary. :return: List of processed documents """ def process_file(i: int) -> List[CompassDocument]: - return self.process_file( filename=filenames[i], file_id=file_ids[i] if file_ids else None, @@ -146,7 +183,9 @@ def process_file(i: int) -> List[CompassDocument]: yield from results @staticmethod - def _get_metadata(doc: CompassDocument, custom_context: Optional[Fn_or_Dict] = None) -> Dict[str, Any]: + def _get_metadata( + doc: CompassDocument, custom_context: Optional[Fn_or_Dict] = None + ) -> Dict[str, Any]: if custom_context is None: return {} elif callable(custom_context): @@ -165,18 +204,22 @@ def process_file( custom_context: Optional[Fn_or_Dict] = None, ) -> List[CompassDocument]: """ - Takes in a file, its id, and the parser/metadata config. If the config is None, then it uses the - default configs passed by parameter when creating the client. This makes the CompassParserClient - stateful for convenience, that is, one can pass in the parser/metadata config only once when creating the - CompassParserClient, and process files without having to pass the config every time + Takes in a file, its id, and the parser/metadata config. If the config is None, + then it uses the default configs passed by parameter when creating the client. + This makes the CompassParserClient stateful for convenience, that is, one can + pass in the parser/metadata config only once when creating the + CompassParserClient, and process files without having to pass the config every + time :param filename: Filename to process :param file_id: Id for the file :param content_type: Content type of the file :param parser_config: ParserConfig object with the config to use for parsing the file - :param metadata_config: MetadataConfig object with the config to use for extracting metadata for each document - :param custom_context: Additional data to add to compass document. Fields will be filterable but not semantically searchable. - Can either be a dictionary or a callable that takes a CompassDocument and returns a dictionary. + :param metadata_config: MetadataConfig object with the config to use for + extracting metadata for each document + :param custom_context: Additional data to add to compass document. Fields will + be filterable but not semantically searchable. Can either be a dictionary + or a callable that takes a CompassDocument and returns a dictionary. :return: List of resulting documents """ @@ -200,7 +243,9 @@ def process_file( doc_id=file_id, content_type=content_type, ) - auth = (self.username, self.password) if self.username and self.password else None + auth = ( + (self.username, self.password) if self.username and self.password else None + ) res = self.session.post( url=f"{self.parser_url}/v1/process_file", data={"data": json.dumps(params.model_dump())}, diff --git a/compass_sdk/rbac.py b/compass_sdk/clients/rbac.py similarity index 60% rename from compass_sdk/rbac.py rename to compass_sdk/clients/rbac.py index c7529c1..2953561 100644 --- a/compass_sdk/rbac.py +++ b/compass_sdk/clients/rbac.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from requests import HTTPError -from compass_sdk.types import ( +from compass_sdk.models import ( GroupCreateRequest, GroupCreateResponse, GroupDeleteResponse, @@ -29,7 +29,10 @@ class CompassRootClient: def __init__(self, compass_url: str, root_user_token: str): self.base_url = compass_url + "/api/security/admin/rbac" - self.headers = {"Authorization": f"Bearer {root_user_token}", "Content-Type": "application/json"} + self.headers = { + "Authorization": f"Bearer {root_user_token}", + "Content-Type": "application/json", + } T = TypeVar("T", bound=BaseModel) U = TypeVar("U", bound=BaseModel) @@ -42,35 +45,51 @@ def _fetch_entities(url: str, headers: Headers, entity_type: Type[T]) -> List[T] return [entity_type.model_validate(entity) for entity in response.json()] @staticmethod - def _create_entities(url: str, headers: Headers, entity_request: List[T], entity_response: Type[U]) -> List[U]: + def _create_entities( + url: str, headers: Headers, entity_request: List[T], entity_response: Type[U] + ) -> List[U]: response = requests.post( url, json=[json.loads(entity.model_dump_json()) for entity in entity_request], headers=headers, ) CompassRootClient.raise_for_status(response) - return [entity_response.model_validate(response) for response in response.json()] + return [ + entity_response.model_validate(response) for response in response.json() + ] @staticmethod - def _delete_entities(url: str, headers: Headers, names: List[str], entity_response: Type[U]) -> List[U]: + def _delete_entities( + url: str, headers: Headers, names: List[str], entity_response: Type[U] + ) -> List[U]: entities = ",".join(names) response = requests.delete(f"{url}/{entities}", headers=headers) CompassRootClient.raise_for_status(response) return [entity_response.model_validate(entity) for entity in response.json()] def fetch_users(self) -> List[UserFetchResponse]: - return self._fetch_entities(f"{self.base_url}/v1/users", self.headers, UserFetchResponse) + return self._fetch_entities( + f"{self.base_url}/v1/users", self.headers, UserFetchResponse + ) def fetch_groups(self) -> List[GroupFetchResponse]: - return self._fetch_entities(f"{self.base_url}/v1/groups", self.headers, GroupFetchResponse) + return self._fetch_entities( + f"{self.base_url}/v1/groups", self.headers, GroupFetchResponse + ) def fetch_roles(self) -> List[RoleFetchResponse]: - return self._fetch_entities(f"{self.base_url}/v1/roles", self.headers, RoleFetchResponse) + return self._fetch_entities( + f"{self.base_url}/v1/roles", self.headers, RoleFetchResponse + ) def fetch_role_mappings(self) -> List[RoleMappingResponse]: - return self._fetch_entities(f"{self.base_url}/v1/role-mappings", self.headers, RoleMappingResponse) + return self._fetch_entities( + f"{self.base_url}/v1/role-mappings", self.headers, RoleMappingResponse + ) - def create_users(self, *, users: List[UserCreateRequest]) -> List[UserCreateResponse]: + def create_users( + self, *, users: List[UserCreateRequest] + ) -> List[UserCreateResponse]: return self._create_entities( url=f"{self.base_url}/v1/users", headers=self.headers, @@ -78,7 +97,9 @@ def create_users(self, *, users: List[UserCreateRequest]) -> List[UserCreateResp entity_response=UserCreateResponse, ) - def create_groups(self, *, groups: List[GroupCreateRequest]) -> List[GroupCreateResponse]: + def create_groups( + self, *, groups: List[GroupCreateRequest] + ) -> List[GroupCreateResponse]: return self._create_entities( url=f"{self.base_url}/v1/groups", headers=self.headers, @@ -86,7 +107,9 @@ def create_groups(self, *, groups: List[GroupCreateRequest]) -> List[GroupCreate entity_response=GroupCreateResponse, ) - def create_roles(self, *, roles: List[RoleCreateRequest]) -> List[RoleCreateResponse]: + def create_roles( + self, *, roles: List[RoleCreateRequest] + ) -> List[RoleCreateResponse]: return self._create_entities( url=f"{self.base_url}/v1/roles", headers=self.headers, @@ -94,7 +117,9 @@ def create_roles(self, *, roles: List[RoleCreateRequest]) -> List[RoleCreateResp entity_response=RoleCreateResponse, ) - def create_role_mappings(self, *, role_mappings: List[RoleMappingRequest]) -> List[RoleMappingResponse]: + def create_role_mappings( + self, *, role_mappings: List[RoleMappingRequest] + ) -> List[RoleMappingResponse]: return self._create_entities( url=f"{self.base_url}/v1/role-mappings", headers=self.headers, @@ -103,27 +128,46 @@ def create_role_mappings(self, *, role_mappings: List[RoleMappingRequest]) -> Li ) def delete_users(self, *, user_names: List[str]) -> List[UserDeleteResponse]: - return self._delete_entities(f"{self.base_url}/v1/users", self.headers, user_names, UserDeleteResponse) + return self._delete_entities( + f"{self.base_url}/v1/users", self.headers, user_names, UserDeleteResponse + ) def delete_groups(self, *, group_names: List[str]) -> List[GroupDeleteResponse]: - return self._delete_entities(f"{self.base_url}/v1/groups", self.headers, group_names, GroupDeleteResponse) + return self._delete_entities( + f"{self.base_url}/v1/groups", self.headers, group_names, GroupDeleteResponse + ) def delete_roles(self, *, role_ids: List[str]) -> List[RoleDeleteResponse]: - return self._delete_entities(f"{self.base_url}/v1/roles", self.headers, role_ids, RoleDeleteResponse) + return self._delete_entities( + f"{self.base_url}/v1/roles", self.headers, role_ids, RoleDeleteResponse + ) - def delete_role_mappings(self, *, role_name: str, group_name: str) -> List[RoleMappingDeleteResponse]: + def delete_role_mappings( + self, *, role_name: str, group_name: str + ) -> List[RoleMappingDeleteResponse]: response = requests.delete( - f"{self.base_url}/v1/role-mappings/role/{role_name}/group/{group_name}", headers=self.headers + f"{self.base_url}/v1/role-mappings/role/{role_name}/group/{group_name}", + headers=self.headers, ) self.raise_for_status(response) - return [RoleMappingDeleteResponse.model_validate(role_mapping) for role_mapping in response.json()] - - def delete_user_group(self, *, group_name: str, user_name: str) -> GroupUserDeleteResponse: - response = requests.delete(f"{self.base_url}/v1/group/{group_name}/user/{user_name}", headers=self.headers) + return [ + RoleMappingDeleteResponse.model_validate(role_mapping) + for role_mapping in response.json() + ] + + def delete_user_group( + self, *, group_name: str, user_name: str + ) -> GroupUserDeleteResponse: + response = requests.delete( + f"{self.base_url}/v1/group/{group_name}/user/{user_name}", + headers=self.headers, + ) self.raise_for_status(response) return GroupUserDeleteResponse.model_validate(response.json()) - def update_role(self, *, role_name: str, policies: List[PolicyRequest]) -> RoleCreateResponse: + def update_role( + self, *, role_name: str, policies: List[PolicyRequest] + ) -> RoleCreateResponse: response = requests.put( f"{self.base_url}/v1/roles/{role_name}", json=[json.loads(policy.model_dump_json()) for policy in policies], @@ -150,10 +194,14 @@ def raise_for_status(response: requests.Response): reason = response.content if 400 <= response.status_code < 500: - http_error_msg = f"{response.status_code} Client Error: {reason} for url: {response.url}" + http_error_msg = ( + f"{response.status_code} Client Error: {reason} for url: {response.url}" + ) elif 500 <= response.status_code < 600: - http_error_msg = f"{response.status_code} Server Error: {reason} for url: {response.url}" + http_error_msg = ( + f"{response.status_code} Server Error: {reason} for url: {response.url}" + ) if http_error_msg: raise HTTPError(http_error_msg, response=response) diff --git a/compass_sdk/exceptions.py b/compass_sdk/exceptions.py index 01af1d6..ea99adb 100644 --- a/compass_sdk/exceptions.py +++ b/compass_sdk/exceptions.py @@ -11,7 +11,9 @@ class CompassAuthError(CompassClientError): def __init__( self, - message=("CompassAuthError - check your bearer token or username and password."), + message=( + "CompassAuthError - check your bearer token or username and password." + ), ): self.message = message super().__init__(self.message) diff --git a/compass_sdk/models/__init__.py b/compass_sdk/models/__init__.py index bb3ac98..ca3d589 100644 --- a/compass_sdk/models/__init__.py +++ b/compass_sdk/models/__init__.py @@ -1,3 +1,27 @@ # import models into model package -from compass_sdk.models.datasources import * # noqa: F403 -from compass_sdk.models.search import * # noqa: F403 +from pydantic import BaseModel + + +class ValidatedModel(BaseModel): + class Config: + arbitrary_types_allowed = True + use_enum_values = True + + @classmethod + def attribute_in_model(cls, attr_name): + return attr_name in cls.__fields__ + + def __init__(self, **data): + for name, value in data.items(): + if not self.attribute_in_model(name): + raise ValueError( + f"{name} is not a valid attribute for {self.__class__.__name__}" + ) + super().__init__(**data) + + +from compass_sdk.models.config import * # noqa: E402, F403 +from compass_sdk.models.datasources import * # noqa: E402, F403 +from compass_sdk.models.documents import * # noqa: E402, F403 +from compass_sdk.models.rbac import * # noqa: E402, F403 +from compass_sdk.models.search import * # noqa: E402, F403 diff --git a/compass_sdk/models/config.py b/compass_sdk/models/config.py new file mode 100644 index 0000000..2d0b2d4 --- /dev/null +++ b/compass_sdk/models/config.py @@ -0,0 +1,170 @@ +# Python imports +from enum import Enum, StrEnum +from os import getenv +from typing import List, Optional +import math + +# 3rd party imports +from pydantic import BaseModel, ConfigDict + +# Local imports +from compass_sdk.constants import ( + COHERE_API_ENV_VAR, + DEFAULT_COMMANDR_EXTRACTABLE_ATTRIBUTES, + DEFAULT_COMMANDR_PROMPT, + DEFAULT_MIN_CHARS_PER_ELEMENT, + DEFAULT_MIN_NUM_CHUNKS_IN_TITLE, + DEFAULT_MIN_NUM_TOKENS_CHUNK, + DEFAULT_NUM_TOKENS_CHUNK_OVERLAP, + DEFAULT_NUM_TOKENS_PER_CHUNK, + METADATA_HEURISTICS_ATTRIBUTES, + SKIP_INFER_TABLE_TYPES, +) +from compass_sdk.models import ValidatedModel + + +class DocumentFormat(str, Enum): + Markdown = "markdown" + Text = "text" + + @classmethod + def _missing_(cls, value): + return cls.Markdown + + +class PDFParsingStrategy(StrEnum): + QuickText = "QuickText" + ImageToMarkdown = "ImageToMarkdown" + + @classmethod + def _missing_(cls, value): + return cls.QuickText + + +class PresentationParsingStrategy(StrEnum): + Unstructured = "Unstructured" + ImageToMarkdown = "ImageToMarkdown" + + @classmethod + def _missing_(cls, value): + return cls.Unstructured + + +class ParsingStrategy(str, Enum): + Fast = "fast" + Hi_Res = "hi_res" + + @classmethod + def _missing_(cls, value): + return cls.Fast + + +class ParsingModel(str, Enum): + Marker = "marker" # Default model, it is actually a combination of models used by the Marker PDF parser + YoloX_Quantized = ( + "yolox_quantized" # Only PDF parsing working option from Unstructured + ) + + @classmethod + def _missing_(cls, value): + return cls.Marker + + +class ParserConfig(BaseModel): + """ + CompassParser configuration. Important parameters: + :param parsing_strategy: the parsing strategy to use: + - 'auto' (default): automatically determine the best strategy + - 'fast': leverage traditional NLP extraction techniques to quickly pull all the + text elements. “Fast” strategy is not good for image based file types. + - 'hi_res': identifies the layout of the document using detectron2. The advantage of “hi_res” + is that it uses the document layout to gain additional information about document elements. + We recommend using this strategy if your use case is highly sensitive to correct + classifications for document elements. + - 'ocr_only': leverage Optical Character Recognition to extract text from the image based files. + :param parsing_model: the parsing model to use. One of: + - yolox_quantized (default): single-stage object detection model, quantized. Runs faster than YoloX + See https://unstructured-io.github.io/unstructured/best_practices/models.html for more details. + We have temporarily removed the option to use other models because + of ongoing stability issues. + + """ + + model_config = ConfigDict( + arbitrary_types_allowed=True, + extra="ignore", + ) + + # CompassParser configuration + parse_tables: bool = True + parse_images: bool = True + parsed_images_output_dir: Optional[str] = None + allowed_image_types: Optional[List[str]] = None + min_chars_per_element: int = DEFAULT_MIN_CHARS_PER_ELEMENT + skip_infer_table_types: List[str] = SKIP_INFER_TABLE_TYPES + parsing_strategy: ParsingStrategy = ParsingStrategy.Fast + parsing_model: ParsingModel = ParsingModel.YoloX_Quantized + + # CompassChunker configuration + num_tokens_per_chunk: int = DEFAULT_NUM_TOKENS_PER_CHUNK + num_tokens_overlap: int = DEFAULT_NUM_TOKENS_CHUNK_OVERLAP + min_chunk_tokens: int = DEFAULT_MIN_NUM_TOKENS_CHUNK + num_chunks_in_title: int = DEFAULT_MIN_NUM_CHUNKS_IN_TITLE + max_tokens_metadata: int = math.floor(num_tokens_per_chunk * 0.1) + include_tables: bool = True + + # Formatting configuration + output_format: DocumentFormat = DocumentFormat.Markdown + + # Visual elements extraction configuration + extract_visual_elements: bool = False + vertical_table_crop_margin: int = 100 + horizontal_table_crop_margin: int = 100 + + pdf_parsing_strategy: PDFParsingStrategy = PDFParsingStrategy.QuickText + presentation_parsing_strategy: PresentationParsingStrategy = ( + PresentationParsingStrategy.Unstructured + ) + + +class MetadataStrategy(str, Enum): + No_Metadata = "no_metadata" + Naive_Title = "naive_title" + KeywordSearch = "keyword_search" + Bart = "bart" + Command_R = "command_r" + Custom = "custom" + + @classmethod + def _missing_(cls, value): + return cls.No_Metadata + + +class MetadataConfig(ValidatedModel): + """ + Configuration class for metadata detection. + :param metadata_strategy: the metadata detection strategy to use. One of: + - No_Metadata: no metadata is inferred + - Heuristics: metadata is inferred using heuristics + - Bart: metadata is inferred using the BART summarization model + - Command_R: metadata is inferred using the Command-R summarization model + :param cohere_api_key: the Cohere API key to use for metadata detection + :param commandr_model_name: the name of the Command-R model to use for metadata detection + :param commandr_prompt: the prompt to use for the Command-R model + :param commandr_extractable_attributes: the extractable attributes for the Command-R model + :param commandr_max_tokens: the maximum number of tokens to use for the Command-R model + :param keyword_search_attributes: the attributes to search for in the document when using keyword search + :param keyword_search_separator: the separator to use for nested attributes when using keyword search + :param ignore_errors: if set to True, metadata detection errors will not be raised or stop the parsing process + + """ + + metadata_strategy: MetadataStrategy = MetadataStrategy.No_Metadata + cohere_api_key: Optional[str] = getenv(COHERE_API_ENV_VAR, None) + commandr_model_name: str = "command-r" + commandr_prompt: str = DEFAULT_COMMANDR_PROMPT + commandr_max_tokens: int = 500 + commandr_extractable_attributes: List[str] = DEFAULT_COMMANDR_EXTRACTABLE_ATTRIBUTES + keyword_search_attributes: List[str] = METADATA_HEURISTICS_ATTRIBUTES + keyword_search_separator: str = "." + ignore_errors: bool = True diff --git a/compass_sdk/models/documents.py b/compass_sdk/models/documents.py new file mode 100644 index 0000000..9502266 --- /dev/null +++ b/compass_sdk/models/documents.py @@ -0,0 +1,188 @@ +# Python imports +from dataclasses import field +from enum import Enum +from typing import Annotated, Any, Dict, List, Optional +import uuid + +# 3rd party imports +from pydantic import BaseModel, Field, PositiveInt, StringConstraints + +# Local imports +from compass_sdk.models import ValidatedModel + + +class CompassDocumentMetadata(ValidatedModel): + """ + Compass document metadata + """ + + doc_id: str = "" + filename: str = "" + meta: List = field(default_factory=list) + parent_doc_id: str = "" + + +class CompassDocumentChunkAsset(BaseModel): + content_type: str + asset_data: str + + +class CompassDocumentChunk(BaseModel): + chunk_id: str + sort_id: str + doc_id: str + parent_doc_id: str + content: Dict[str, Any] + origin: Optional[Dict[str, Any]] = None + assets: Optional[list[CompassDocumentChunkAsset]] = None + + def parent_doc_is_split(self): + return self.doc_id != self.parent_doc_id + + +class CompassDocumentStatus(str, Enum): + """ + Compass document status + """ + + Success = "success" + ParsingErrors = "parsing-errors" + MetadataErrors = "metadata-errors" + IndexingErrors = "indexing-errors" + + +class CompassSdkStage(str, Enum): + """ + Compass SDK stages + """ + + Parsing = "parsing" + Metadata = "metadata" + Chunking = "chunking" + Indexing = "indexing" + + +class CompassDocument(ValidatedModel): + """ + A Compass document contains all the information required to process a document and + insert it into the index. It includes: + - metadata: the document metadata (e.g., filename, title, authors, date) + - content: the document content in string format + - elements: the document's Unstructured elements (e.g., tables, images, text). Used + for chunking + - chunks: the document's chunks (e.g., paragraphs, tables, images). Used for indexing + - index_fields: the fields to be indexed. Used by the indexer + """ + + filebytes: bytes = b"" + metadata: CompassDocumentMetadata = CompassDocumentMetadata() + content: Dict[str, str] = field(default_factory=dict) + content_type: Optional[str] = None + elements: List[Any] = field(default_factory=list) + chunks: List[CompassDocumentChunk] = field(default_factory=list) + index_fields: List[str] = field(default_factory=list) + errors: List[Dict[CompassSdkStage, str]] = field(default_factory=list) + ignore_metadata_errors: bool = True + markdown: Optional[str] = None + + def has_data(self) -> bool: + return len(self.filebytes) > 0 + + def has_markdown(self) -> bool: + return self.markdown is not None + + def has_filename(self) -> bool: + return len(self.metadata.filename) > 0 + + def has_metadata(self) -> bool: + return len(self.metadata.meta) > 0 + + def has_parsing_errors(self) -> bool: + return any( + stage == CompassSdkStage.Parsing + for error in self.errors + for stage, _ in error.items() + ) + + def has_metadata_errors(self) -> bool: + return any( + stage == CompassSdkStage.Metadata + for error in self.errors + for stage, _ in error.items() + ) + + def has_indexing_errors(self) -> bool: + return any( + stage == CompassSdkStage.Indexing + for error in self.errors + for stage, _ in error.items() + ) + + @property + def status(self) -> CompassDocumentStatus: + if self.has_parsing_errors(): + return CompassDocumentStatus.ParsingErrors + + if not self.ignore_metadata_errors and self.has_metadata_errors(): + return CompassDocumentStatus.MetadataErrors + + if self.has_indexing_errors(): + return CompassDocumentStatus.IndexingErrors + + return CompassDocumentStatus.Success + + +class DocumentChunkAsset(BaseModel): + content_type: str + asset_data: str + + +class Chunk(BaseModel): + chunk_id: str + sort_id: int + content: Dict[str, Any] + origin: Optional[Dict[str, Any]] = None + assets: Optional[list[DocumentChunkAsset]] = None + parent_doc_id: str + + +class Document(BaseModel): + """ + A document that can be indexed in Compass (i.e., a list of indexable chunks) + """ + + doc_id: str + path: str + parent_doc_id: str + content: Dict[str, Any] + chunks: List[Chunk] + index_fields: List[str] = field(default_factory=list) + + +class ParseableDocument(BaseModel): + """ + A document to be sent to Compass in bytes format for parsing on the Compass side + """ + + id: uuid.UUID + filename: Annotated[ + str, StringConstraints(min_length=1) + ] # Ensures the filename is a non-empty string + content_type: str + content_length_bytes: PositiveInt # File size must be a non-negative integer + content_encoded_bytes: str # Base64-encoded file contents + context: Dict[str, Any] = Field(default_factory=dict) + + +class PushDocumentsInput(BaseModel): + documents: List[ParseableDocument] + + +class PutDocumentsInput(BaseModel): + """ + A Compass request to put a list of Document + """ + + docs: List[Document] + authorized_groups: Optional[List[str]] = None + merge_groups_on_conflict: bool = False diff --git a/compass_sdk/types.py b/compass_sdk/models/rbac.py similarity index 97% rename from compass_sdk/types.py rename to compass_sdk/models/rbac.py index d5b1212..4375f06 100644 --- a/compass_sdk/types.py +++ b/compass_sdk/models/rbac.py @@ -1,6 +1,8 @@ +# Python imports from enum import Enum from typing import List +# 3rd party imports from pydantic import BaseModel diff --git a/compass_sdk/models/search.py b/compass_sdk/models/search.py index 7e12753..7453914 100644 --- a/compass_sdk/models/search.py +++ b/compass_sdk/models/search.py @@ -1,5 +1,8 @@ +# Python imports +from enum import Enum from typing import Any, Dict, List, Optional, TypeAlias +# 3rd party imports from pydantic import BaseModel Content: TypeAlias = Dict[str, Any] @@ -10,7 +13,7 @@ class AssetInfo(BaseModel): presigned_url: str -class Chunk(BaseModel): +class RetrievedChunk(BaseModel): chunk_id: str sort_id: int parent_doc_id: str @@ -20,26 +23,48 @@ class Chunk(BaseModel): score: float -class Document(BaseModel): +class RetrievedDocument(BaseModel): doc_id: str path: str parent_doc_id: str content: Content index_fields: Optional[List[str]] = None authorized_groups: Optional[List[str]] = None - chunks: List[Chunk] + chunks: List[RetrievedChunk] score: float -class ChunkExtended(Chunk): +class RetrieveChunkExtended(RetrievedChunk): doc_id: str path: str index_fields: Optional[List[str]] = None class SearchDocumentsResponse(BaseModel): - hits: List[Document] + hits: List[RetrievedDocument] class SearchChunksResponse(BaseModel): - hits: List[ChunkExtended] + hits: List[RetrieveChunkExtended] + + +class SearchFilter(BaseModel): + class FilterType(str, Enum): + EQ = "$eq" + LT_EQ = "$lte" + GT_EQ = "$gte" + WORD_MATCH = "$wordMatch" + + field: str + type: FilterType + value: Any + + +class SearchInput(BaseModel): + """ + Search query input + """ + + query: str + top_k: int + filters: Optional[List[SearchFilter]] = None diff --git a/compass_sdk/utils.py b/compass_sdk/utils.py index ab2e150..b11d1e2 100644 --- a/compass_sdk/utils.py +++ b/compass_sdk/utils.py @@ -9,21 +9,29 @@ import fsspec from fsspec import AbstractFileSystem -from compass_sdk import CompassDocument, CompassDocumentMetadata, CompassSdkStage from compass_sdk.constants import UUID_NAMESPACE +from compass_sdk.models import ( + CompassDocument, + CompassDocumentMetadata, + CompassSdkStage, +) T = TypeVar("T") U = TypeVar("U") -def imap_queued(executor: Executor, f: Callable[[T], U], it: Iterable[T], max_queued: int) -> Iterator[U]: +def imap_queued( + executor: Executor, f: Callable[[T], U], it: Iterable[T], max_queued: int +) -> Iterator[U]: assert max_queued >= 1 futures_set = set() for x in it: futures_set.add(executor.submit(f, x)) while len(futures_set) > max_queued: - done, futures_set = futures.wait(futures_set, return_when=futures.FIRST_COMPLETED) + done, futures_set = futures.wait( + futures_set, return_when=futures.FIRST_COMPLETED + ) for future in done: yield future.result() @@ -65,7 +73,11 @@ def open_document(document_path) -> CompassDocument: return doc -def scan_folder(folder_path: str, allowed_extensions: Optional[List[str]] = None, recursive: bool = False) -> List[str]: +def scan_folder( + folder_path: str, + allowed_extensions: Optional[List[str]] = None, + recursive: bool = False, +) -> List[str]: """ Scans a folder for files with the given extensions :param folder_path: the path to the folder @@ -75,12 +87,16 @@ def scan_folder(folder_path: str, allowed_extensions: Optional[List[str]] = None """ fs = get_fs(folder_path) all_files = [] - path_prepend = f"{folder_path.split('://')[0]}://" if folder_path.find("://") >= 0 else "" + path_prepend = ( + f"{folder_path.split('://')[0]}://" if folder_path.find("://") >= 0 else "" + ) if allowed_extensions is None: allowed_extensions = [""] else: - allowed_extensions = [f".{ext}" if not ext.startswith(".") else ext for ext in allowed_extensions] + allowed_extensions = [ + f".{ext}" if not ext.startswith(".") else ext for ext in allowed_extensions + ] for ext in allowed_extensions: rec_glob = "**/" if recursive else "" diff --git a/poetry.lock b/poetry.lock index a41aedd..6d2913a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -458,6 +458,33 @@ requests = ">=2.22,<3" [package.extras] fixture = ["fixtures"] +[[package]] +name = "ruff" +version = "0.8.1" +description = "An extremely fast Python linter and code formatter, written in Rust." +optional = false +python-versions = ">=3.7" +files = [ + {file = "ruff-0.8.1-py3-none-linux_armv6l.whl", hash = "sha256:fae0805bd514066f20309f6742f6ee7904a773eb9e6c17c45d6b1600ca65c9b5"}, + {file = "ruff-0.8.1-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:b8a4f7385c2285c30f34b200ca5511fcc865f17578383db154e098150ce0a087"}, + {file = "ruff-0.8.1-py3-none-macosx_11_0_arm64.whl", hash = "sha256:cd054486da0c53e41e0086e1730eb77d1f698154f910e0cd9e0d64274979a209"}, + {file = "ruff-0.8.1-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2029b8c22da147c50ae577e621a5bfbc5d1fed75d86af53643d7a7aee1d23871"}, + {file = "ruff-0.8.1-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:2666520828dee7dfc7e47ee4ea0d928f40de72056d929a7c5292d95071d881d1"}, + {file = "ruff-0.8.1-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:333c57013ef8c97a53892aa56042831c372e0bb1785ab7026187b7abd0135ad5"}, + {file = "ruff-0.8.1-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:288326162804f34088ac007139488dcb43de590a5ccfec3166396530b58fb89d"}, + {file = "ruff-0.8.1-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b12c39b9448632284561cbf4191aa1b005882acbc81900ffa9f9f471c8ff7e26"}, + {file = "ruff-0.8.1-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:364e6674450cbac8e998f7b30639040c99d81dfb5bbc6dfad69bc7a8f916b3d1"}, + {file = "ruff-0.8.1-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b22346f845fec132aa39cd29acb94451d030c10874408dbf776af3aaeb53284c"}, + {file = "ruff-0.8.1-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:b2f2f7a7e7648a2bfe6ead4e0a16745db956da0e3a231ad443d2a66a105c04fa"}, + {file = "ruff-0.8.1-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:adf314fc458374c25c5c4a4a9270c3e8a6a807b1bec018cfa2813d6546215540"}, + {file = "ruff-0.8.1-py3-none-musllinux_1_2_i686.whl", hash = "sha256:a885d68342a231b5ba4d30b8c6e1b1ee3a65cf37e3d29b3c74069cdf1ee1e3c9"}, + {file = "ruff-0.8.1-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:d2c16e3508c8cc73e96aa5127d0df8913d2290098f776416a4b157657bee44c5"}, + {file = "ruff-0.8.1-py3-none-win32.whl", hash = "sha256:93335cd7c0eaedb44882d75a7acb7df4b77cd7cd0d2255c93b28791716e81790"}, + {file = "ruff-0.8.1-py3-none-win_amd64.whl", hash = "sha256:2954cdbe8dfd8ab359d4a30cd971b589d335a44d444b6ca2cb3d1da21b75e4b6"}, + {file = "ruff-0.8.1-py3-none-win_arm64.whl", hash = "sha256:55873cc1a473e5ac129d15eccb3c008c096b94809d693fc7053f588b67822737"}, + {file = "ruff-0.8.1.tar.gz", hash = "sha256:3583db9a6450364ed5ca3f3b4225958b24f78178908d5c4bc0f46251ccca898f"}, +] + [[package]] name = "tenacity" version = "8.2.3" @@ -524,4 +551,4 @@ zstd = ["zstandard (>=0.18.0)"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<3.12" -content-hash = "1aaa46b9193b35e99db7cfac1a63d38ed42dd46102ebbc2f6b1bd9da28d7c23a" +content-hash = "4d6d1d1a40b93e95d0019f6ee73a913f0e6cdee6df2c41bf23d48f8000a8a515" diff --git a/poetry.toml b/poetry.toml new file mode 100644 index 0000000..ab1033b --- /dev/null +++ b/poetry.toml @@ -0,0 +1,2 @@ +[virtualenvs] +in-project = true diff --git a/pyproject.toml b/pyproject.toml index 15964da..cf0d696 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "compass-sdk" -version = "0.5.1" +version = "0.6.0" authors = [] description = "Compass SDK" @@ -18,6 +18,7 @@ pytest = "^8.3.3" pytest-asyncio = "^0.24.0" pytest-mock = "^3.14.0" requests-mock = "^1.12.1" +ruff = "^0.8.1" [tool.pyright] typeCheckingMode = 'basic' @@ -26,3 +27,6 @@ reportMissingImports = false [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" + +[tool.ruff] +line-length = 88 diff --git a/tests/test_compass_client.py b/tests/test_compass_client.py index 0aef69b..22e6cc3 100644 --- a/tests/test_compass_client.py +++ b/tests/test_compass_client.py @@ -1,24 +1,34 @@ -from compass_sdk.compass import CompassClient, CompassDocument +from compass_sdk.clients import CompassClient +from compass_sdk.models import CompassDocument def test_delete_url_formatted_with_doc_and_index(requests_mock): compass = CompassClient(index_url="http://test.com") compass.delete_document(index_name="test_index", doc_id="test_id") - assert requests_mock.request_history[0].url == "http://test.com/api/v1/indexes/test_index/documents/test_id" + assert ( + requests_mock.request_history[0].url + == "http://test.com/api/v1/indexes/test_index/documents/test_id" + ) assert requests_mock.request_history[0].method == "DELETE" def test_create_index_formatted_with_index(requests_mock): compass = CompassClient(index_url="http://test.com") compass.create_index(index_name="test_index") - assert requests_mock.request_history[0].url == "http://test.com/api/v1/indexes/test_index" + assert ( + requests_mock.request_history[0].url + == "http://test.com/api/v1/indexes/test_index" + ) assert requests_mock.request_history[0].method == "PUT" def test_put_documents_payload_and_url_exist(requests_mock): compass = CompassClient(index_url="http://test.com") compass.insert_docs(index_name="test_index", docs=iter([CompassDocument()])) - assert requests_mock.request_history[0].url == "http://test.com/api/v1/indexes/test_index/documents" + assert ( + requests_mock.request_history[0].url + == "http://test.com/api/v1/indexes/test_index/documents" + ) assert requests_mock.request_history[0].method == "PUT" assert "docs" in requests_mock.request_history[0].json() @@ -26,7 +36,10 @@ def test_put_documents_payload_and_url_exist(requests_mock): def test_put_document_payload_and_url_exist(requests_mock): compass = CompassClient(index_url="http://test.com") compass.insert_doc(index_name="test_index", doc=CompassDocument()) - assert requests_mock.request_history[0].url == "http://test.com/api/v1/indexes/test_index/documents" + assert ( + requests_mock.request_history[0].url + == "http://test.com/api/v1/indexes/test_index/documents" + ) assert requests_mock.request_history[0].method == "PUT" assert "docs" in requests_mock.request_history[0].json() @@ -42,19 +55,27 @@ def test_get_documents_is_valid(requests_mock): compass = CompassClient(index_url="http://test.com") compass.get_document(index_name="test_index", doc_id="test_id") assert requests_mock.request_history[0].method == "GET" - assert requests_mock.request_history[0].url == "http://test.com/api/v1/indexes/test_index/documents/test_id" + assert ( + requests_mock.request_history[0].url + == "http://test.com/api/v1/indexes/test_index/documents/test_id" + ) def test_refresh_is_valid(requests_mock): compass = CompassClient(index_url="http://test.com") compass.refresh(index_name="test_index") assert requests_mock.request_history[0].method == "POST" - assert requests_mock.request_history[0].url == "http://test.com/api/v1/indexes/test_index/_refresh" + assert ( + requests_mock.request_history[0].url + == "http://test.com/api/v1/indexes/test_index/_refresh" + ) def test_add_context_is_valid(requests_mock): compass = CompassClient(index_url="http://test.com") - compass.add_context(index_name="test_index", doc_id="test_id", context={"fake": "context"}) + compass.add_context( + index_name="test_index", doc_id="test_id", context={"fake": "context"} + ) assert requests_mock.request_history[0].method == "POST" assert ( requests_mock.request_history[0].url