From 1ad8faa01dd088183be74d724084b3ebff2e9c75 Mon Sep 17 00:00:00 2001 From: Abhinav Kumar <32016468+abhinavk454@users.noreply.github.com> Date: Mon, 17 Jun 2024 09:41:33 +0000 Subject: [PATCH 01/16] Port pydantic v1 models to pydantic v2 --- .../query_controllers/example/types.py | 8 +++-- .../query_controllers/multimodal/types.py | 8 +++-- backend/requirements.txt | 2 +- backend/settings.py | 2 +- backend/types.py | 29 +++++++++---------- 5 files changed, 26 insertions(+), 23 deletions(-) diff --git a/backend/modules/query_controllers/example/types.py b/backend/modules/query_controllers/example/types.py index f922c471..9ba4051a 100644 --- a/backend/modules/query_controllers/example/types.py +++ b/backend/modules/query_controllers/example/types.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar, Collection, Dict, Literal, Optional -from pydantic import BaseModel, Field, root_validator, validator +from pydantic import field_validator, model_validator, BaseModel, Field, root_validator from qdrant_client.models import Filter as QdrantFilter from backend.types import LLMConfig @@ -87,7 +87,8 @@ class ContextualCompressionRetrieverConfig(VectorStoreRetrieverConfig): allowed_compressor_model_providers: ClassVar[Collection[str]] = ("mixbread-ai",) - @validator("compressor_model_provider") + @field_validator("compressor_model_provider") + @classmethod def validate_retriever_type(cls, value) -> Dict: assert ( value in cls.allowed_compressor_model_providers @@ -142,7 +143,8 @@ class ExampleQueryInput(BaseModel): stream: Optional[bool] = Field(title="Stream the results", default=False) - @root_validator() + @model_validator() + @classmethod def validate_retriever_type(cls, values: Dict) -> Dict: retriever_name = values.get("retriever_name") diff --git a/backend/modules/query_controllers/multimodal/types.py b/backend/modules/query_controllers/multimodal/types.py index f922c471..9ba4051a 100644 --- a/backend/modules/query_controllers/multimodal/types.py +++ b/backend/modules/query_controllers/multimodal/types.py @@ -1,6 +1,6 @@ from typing import Any, ClassVar, Collection, Dict, Literal, Optional -from pydantic import BaseModel, Field, root_validator, validator +from pydantic import field_validator, model_validator, BaseModel, Field, root_validator from qdrant_client.models import Filter as QdrantFilter from backend.types import LLMConfig @@ -87,7 +87,8 @@ class ContextualCompressionRetrieverConfig(VectorStoreRetrieverConfig): allowed_compressor_model_providers: ClassVar[Collection[str]] = ("mixbread-ai",) - @validator("compressor_model_provider") + @field_validator("compressor_model_provider") + @classmethod def validate_retriever_type(cls, value) -> Dict: assert ( value in cls.allowed_compressor_model_providers @@ -142,7 +143,8 @@ class ExampleQueryInput(BaseModel): stream: Optional[bool] = Field(title="Stream the results", default=False) - @root_validator() + @model_validator() + @classmethod def validate_retriever_type(cls, values: Dict) -> Dict: retriever_name = values.get("retriever_name") diff --git a/backend/requirements.txt b/backend/requirements.txt index 4c62a9ea..6a3b35c3 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -9,7 +9,7 @@ uvicorn==0.23.2 fastapi==0.109.1 qdrant-client==1.9.0 python-dotenv==1.0.1 -pydantic==1.10.13 +pydantic==2.7.4 orjson==3.9.15 PyMuPDF==1.23.6 redis==5.0.1 diff --git a/backend/settings.py b/backend/settings.py index 2d4ec227..8046b74c 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -2,9 +2,9 @@ from typing import Optional import orjson -from pydantic import BaseSettings from backend.types import EmbeddingCacheConfig, MetadataStoreConfig, VectorDBConfig +from pydantic_settings import BaseSettings class Settings(BaseSettings): diff --git a/backend/types.py b/backend/types.py index a14ddf32..3a810b10 100644 --- a/backend/types.py +++ b/backend/types.py @@ -3,9 +3,10 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional -from pydantic import BaseModel, Field, constr, root_validator +from pydantic import StringConstraints, ConfigDict, BaseModel, Field, root_validator from backend.constants import FQN_SEPARATOR +from typing_extensions import Annotated class DataIngestionMode(str, Enum): @@ -42,7 +43,7 @@ class DataPoint(BaseModel): ) metadata: Optional[Dict[str, str]] = Field( - title="Additional metadata for the data point", + None, title="Additional metadata for the data point", ) @property @@ -83,10 +84,10 @@ class LoadedDataPoint(DataPoint): title="Local file path of the loaded data point", ) file_extension: Optional[str] = Field( - title="File extension of the loaded data point", + None, title="File extension of the loaded data point", ) local_metadata_file_path: Optional[str] = Field( - title="Local file path of the metadata file", + None, title="Local file path of the metadata file", ) @@ -142,9 +143,7 @@ class QdrantClientConfig(BaseModel): """ Qdrant extra configuration """ - - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") port: Optional[int] = None grpc_port: int = 6334 @@ -276,7 +275,7 @@ class DataIngestionRun(BaseDataIngestionRun): title="Name of the data ingestion run", ) status: Optional[DataIngestionRunStatus] = Field( - title="Status of the data ingestion run", + None, title="Status of the data ingestion run", ) @@ -292,7 +291,7 @@ class BaseDataSource(BaseModel): title="A unique identifier for the data source", ) metadata: Optional[Dict[str, Any]] = Field( - title="Additional config for your data source" + None, title="Additional config for your data source" ) @property @@ -325,7 +324,7 @@ class AssociatedDataSources(BaseModel): title="Parser configuration for the data transformation", default_factory=dict ) data_source: Optional[DataSource] = Field( - title="Data source associated with the collection" + None, title="Data source associated with the collection" ) @@ -339,7 +338,7 @@ class IngestDataToCollectionDto(BaseModel): ) data_source_fqn: Optional[str] = Field( - title="Fully qualified name of the data source", + None, title="Fully qualified name of the data source", ) data_ingestion_mode: DataIngestionMode = Field( @@ -410,12 +409,12 @@ class BaseCollection(BaseModel): Base collection configuration """ - name: constr(regex=r"^[a-z][a-z0-9]*$") = Field( # type: ignore + name: Annotated[str, StringConstraints(pattern=r"^[a-z][a-z0-9]*$")] = Field( # type: ignore title="a unique name to your collection", description="Should only contain lowercase alphanumeric character", ) description: Optional[str] = Field( - title="a description for your collection", + None, title="a description for your collection", ) embedder_config: EmbedderConfig = Field( title="Embedder configuration", default_factory=dict @@ -434,7 +433,7 @@ class Collection(BaseCollection): class CreateCollectionDto(CreateCollection): associated_data_sources: Optional[List[AssociateDataSourceWithCollection]] = Field( - title="Data sources associated with the collection" + None, title="Data sources associated with the collection" ) @@ -443,7 +442,7 @@ class UploadToDataDirectoryDto(BaseModel): # allow only small case alphanumeric and hyphen, should contain atleast one alphabet and begin with alphabet upload_name: str = Field( title="Name of the upload", - regex=r"^[a-z][a-z0-9-]*$", + pattern=r"^[a-z][a-z0-9-]*$", default=str(uuid.uuid4()), ) From d2398a5b16a68d65ef92a0dd98548bd1dfa60636 Mon Sep 17 00:00:00 2001 From: Abhinav Kumar <32016468+abhinavk454@users.noreply.github.com> Date: Mon, 17 Jun 2024 11:08:45 +0000 Subject: [PATCH 02/16] fad --- backend/modules/query_controllers/example/types.py | 6 +++--- backend/modules/query_controllers/multimodal/types.py | 6 +++--- backend/requirements.txt | 1 + backend/settings.py | 4 ++-- backend/types.py | 6 +++--- 5 files changed, 12 insertions(+), 11 deletions(-) diff --git a/backend/modules/query_controllers/example/types.py b/backend/modules/query_controllers/example/types.py index 9ba4051a..7f242ebc 100644 --- a/backend/modules/query_controllers/example/types.py +++ b/backend/modules/query_controllers/example/types.py @@ -1,6 +1,5 @@ from typing import Any, ClassVar, Collection, Dict, Literal, Optional - -from pydantic import field_validator, model_validator, BaseModel, Field, root_validator +from pydantic import field_validator, model_validator, BaseModel, Field from qdrant_client.models import Filter as QdrantFilter from backend.types import LLMConfig @@ -35,7 +34,8 @@ class VectorStoreRetrieverConfig(BaseModel): "mmr", ) - @root_validator + @model_validator(mode="before") + @classmethod def validate_search_type(cls, values: Dict) -> Dict: """Validate search type.""" search_type = values.get("search_type") diff --git a/backend/modules/query_controllers/multimodal/types.py b/backend/modules/query_controllers/multimodal/types.py index 9ba4051a..7a456eaa 100644 --- a/backend/modules/query_controllers/multimodal/types.py +++ b/backend/modules/query_controllers/multimodal/types.py @@ -1,6 +1,5 @@ from typing import Any, ClassVar, Collection, Dict, Literal, Optional - -from pydantic import field_validator, model_validator, BaseModel, Field, root_validator +from pydantic import field_validator, model_validator, BaseModel, Field from qdrant_client.models import Filter as QdrantFilter from backend.types import LLMConfig @@ -35,7 +34,8 @@ class VectorStoreRetrieverConfig(BaseModel): "mmr", ) - @root_validator + @model_validator() + @classmethod def validate_search_type(cls, values: Dict) -> Dict: """Validate search type.""" search_type = values.get("search_type") diff --git a/backend/requirements.txt b/backend/requirements.txt index 6a3b35c3..64126f4f 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -10,6 +10,7 @@ fastapi==0.109.1 qdrant-client==1.9.0 python-dotenv==1.0.1 pydantic==2.7.4 +pydantic-settings==2.3.3 orjson==3.9.15 PyMuPDF==1.23.6 redis==5.0.1 diff --git a/backend/settings.py b/backend/settings.py index 8046b74c..2965e291 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -55,11 +55,11 @@ class Settings(BaseSettings): TFY_LLM_GATEWAY_URL = f"{TFY_HOST}/api/llm" try: - VECTOR_DB_CONFIG = VectorDBConfig.parse_obj(orjson.loads(VECTOR_DB_CONFIG)) + VECTOR_DB_CONFIG = VectorDBConfig.model_validate(orjson.loads(VECTOR_DB_CONFIG)) except Exception as e: raise ValueError(f"VECTOR_DB_CONFIG is invalid: {e}") try: - METADATA_STORE_CONFIG = MetadataStoreConfig.parse_obj( + METADATA_STORE_CONFIG = MetadataStoreConfig.model_validate( orjson.loads(METADATA_STORE_CONFIG) ) except Exception as e: diff --git a/backend/types.py b/backend/types.py index 3a810b10..978593be 100644 --- a/backend/types.py +++ b/backend/types.py @@ -2,8 +2,7 @@ import uuid from enum import Enum from typing import Any, Dict, List, Literal, Optional - -from pydantic import StringConstraints, ConfigDict, BaseModel, Field, root_validator +from pydantic import StringConstraints, ConfigDict, BaseModel, Field, model_validator from backend.constants import FQN_SEPARATOR from typing_extensions import Annotated @@ -298,7 +297,8 @@ class BaseDataSource(BaseModel): def fqn(self): return f"{FQN_SEPARATOR}".join([self.type, self.uri]) - @root_validator + @model_validator(mode="before") + @classmethod def validate_fqn(cls, values: Dict) -> Dict: values["fqn"] = f"{FQN_SEPARATOR}".join([values["type"], values["uri"]]) return values From 23b6924fa7ad9e98c3e2f0a78bfbfa219c4751c4 Mon Sep 17 00:00:00 2001 From: Abhinav Kumar <32016468+abhinavk454@users.noreply.github.com> Date: Mon, 17 Jun 2024 11:25:08 +0000 Subject: [PATCH 03/16] Port pydantic v1 models to pydantic v2 --- backend/settings.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/backend/settings.py b/backend/settings.py index 2965e291..6c2e91f9 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -1,5 +1,5 @@ import os -from typing import Optional +from typing import Optional,ClassVar import orjson @@ -26,14 +26,14 @@ class Settings(BaseSettings): VECTOR_DB_CONFIG = os.getenv("VECTOR_DB_CONFIG", "") METADATA_STORE_CONFIG = os.getenv("METADATA_STORE_CONFIG", "") TFY_SERVICE_ROOT_PATH = os.getenv("TFY_SERVICE_ROOT_PATH", "") - JOB_FQN = os.getenv("JOB_FQN", "") - JOB_COMPONENT_NAME = os.getenv("JOB_COMPONENT_NAME", "") + JOB_FQN:ClassVar[str] = os.getenv("JOB_FQN", "") + JOB_COMPONENT_NAME:ClassVar[str] = os.getenv("JOB_COMPONENT_NAME", "") TFY_API_KEY = os.getenv("TFY_API_KEY", "") OPENAI_API_KEY = os.getenv("OPENAI_API_KEY", "") TFY_HOST = os.getenv("TFY_HOST", "") TFY_LLM_GATEWAY_URL = os.getenv("TFY_LLM_GATEWAY_URL", "") EMBEDDING_CACHE_CONFIG = ( - EmbeddingCacheConfig.parse_obj( + EmbeddingCacheConfig.model_validate( orjson.loads(os.getenv("EMBEDDING_CACHE_CONFIG")) ) if os.getenv("EMBEDDING_CACHE_CONFIG", None) From d7af9837df311535c3c56ebf9026de1655a45003 Mon Sep 17 00:00:00 2001 From: Abhinav Kumar <32016468+abhinavk454@users.noreply.github.com> Date: Mon, 17 Jun 2024 17:15:10 +0000 Subject: [PATCH 04/16] Port pydantic v1 models to pydantic v2 --- backend/modules/query_controllers/example/types.py | 2 +- backend/modules/query_controllers/multimodal/types.py | 4 ++-- backend/server/decorators.py | 5 ++--- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/backend/modules/query_controllers/example/types.py b/backend/modules/query_controllers/example/types.py index 7f242ebc..8be5b234 100644 --- a/backend/modules/query_controllers/example/types.py +++ b/backend/modules/query_controllers/example/types.py @@ -143,7 +143,7 @@ class ExampleQueryInput(BaseModel): stream: Optional[bool] = Field(title="Stream the results", default=False) - @model_validator() + @model_validator(mode="before") @classmethod def validate_retriever_type(cls, values: Dict) -> Dict: retriever_name = values.get("retriever_name") diff --git a/backend/modules/query_controllers/multimodal/types.py b/backend/modules/query_controllers/multimodal/types.py index 7a456eaa..8be5b234 100644 --- a/backend/modules/query_controllers/multimodal/types.py +++ b/backend/modules/query_controllers/multimodal/types.py @@ -34,7 +34,7 @@ class VectorStoreRetrieverConfig(BaseModel): "mmr", ) - @model_validator() + @model_validator(mode="before") @classmethod def validate_search_type(cls, values: Dict) -> Dict: """Validate search type.""" @@ -143,7 +143,7 @@ class ExampleQueryInput(BaseModel): stream: Optional[bool] = Field(title="Stream the results", default=False) - @model_validator() + @model_validator(mode="before") @classmethod def validate_retriever_type(cls, values: Dict) -> Dict: retriever_name = values.get("retriever_name") diff --git a/backend/server/decorators.py b/backend/server/decorators.py index 2d8ca129..6b162930 100644 --- a/backend/server/decorators.py +++ b/backend/server/decorators.py @@ -4,10 +4,9 @@ """ import inspect -from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints +from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints, ClassVar from fastapi import APIRouter, Depends -from pydantic.typing import is_classvar from starlette.routing import Route, WebSocketRoute T = TypeVar("T") @@ -59,7 +58,7 @@ def _init_cbv(cls: Type[Any]) -> None: ] dependency_names: List[str] = [] for name, hint in get_type_hints(cls).items(): - if is_classvar(hint): + if getattr(hint, "__origin__", None) is ClassVar: continue parameter_kwargs = {"default": getattr(cls, name, Ellipsis)} dependency_names.append(name) From 5c2406c2aa91206dfc888c5a7bf65b8764aff228 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Tue, 25 Jun 2024 04:15:50 +0530 Subject: [PATCH 05/16] update default values, type annotations, validators --- backend/indexer/indexer.py | 2 +- backend/migration/qdrant_migration.py | 2 +- backend/modules/metadata_store/prismastore.py | 16 +- backend/modules/metadata_store/truefoundry.py | 32 ++-- .../modules/model_gateway/model_gateway.py | 2 +- .../query_controllers/example/controller.py | 4 +- .../query_controllers/example/types.py | 140 ++++++++++-------- .../multimodal/controller.py | 2 +- .../query_controllers/multimodal/types.py | 140 ++++++++++-------- backend/modules/vector_db/qdrant.py | 6 +- backend/server/decorators.py | 9 +- backend/server/routers/collection.py | 14 +- backend/server/routers/data_source.py | 4 +- backend/server/routers/internal.py | 4 +- backend/types.py | 48 +++--- 15 files changed, 233 insertions(+), 192 deletions(-) diff --git a/backend/indexer/indexer.py b/backend/indexer/indexer.py index 0af0c357..12b9563c 100644 --- a/backend/indexer/indexer.py +++ b/backend/indexer/indexer.py @@ -310,7 +310,7 @@ async def ingest_data(request: IngestDataToCollectionDto): # convert to pydantic model if not already -> For prisma models if not isinstance(collection, Collection): - collection = Collection(**collection.dict()) + collection = Collection(**collection.model_dump()) if not collection: logger.error( diff --git a/backend/migration/qdrant_migration.py b/backend/migration/qdrant_migration.py index 1f535ea2..0b653a83 100644 --- a/backend/migration/qdrant_migration.py +++ b/backend/migration/qdrant_migration.py @@ -90,7 +90,7 @@ def migrate_collection( "associated_data_sources" ).items() ], - ).dict() + ).model_dump() logger.debug( f"Creating '{dest_collection.get('name')}' collection at destination" diff --git a/backend/modules/metadata_store/prismastore.py b/backend/modules/metadata_store/prismastore.py index f5305eb2..cb5b039d 100644 --- a/backend/modules/metadata_store/prismastore.py +++ b/backend/modules/metadata_store/prismastore.py @@ -61,8 +61,8 @@ async def acreate_collection(self, collection: CreateCollection) -> Collection: ) try: - logger.info(f"Creating collection: {collection.dict()}") - collection_data = collection.dict() + logger.info(f"Creating collection: {collection.model_dump()}") + collection_data = collection.model_dump() collection_data["embedder_config"] = json.dumps( collection_data["embedder_config"] ) @@ -143,7 +143,7 @@ async def acreate_data_source(self, data_source: CreateDataSource) -> DataSource ) try: - data = data_source.dict() + data = data_source.model_dump() data["metadata"] = json.dumps(data["metadata"]) data_source = await self.db.datasource.create(data) logger.info(f"Created data source: {data_source}") @@ -235,7 +235,7 @@ async def aassociate_data_source_with_collection( data_source_fqn, data_source, ) in existing_collection_associated_data_sources.items(): - associated_data_sources[data_source_fqn] = data_source.dict() + associated_data_sources[data_source_fqn] = data_source.model_dump() updated_collection = await self.db.collection.update( where={"name": collection_name}, @@ -316,7 +316,7 @@ async def alist_data_sources( ) -> List[dict[str, str]]: try: data_sources = await self.aget_data_sources() - return [data_source.dict() for data_source in data_sources] + return [data_source.model_dump() for data_source in data_sources] except Exception as e: logger.error(f"Failed to list data sources: {e}") raise HTTPException(status_code=500, detail="Failed to list data sources") @@ -409,10 +409,10 @@ async def acreate_data_ingestion_run( ) try: - run_data = created_data_ingestion_run.dict() + run_data = created_data_ingestion_run.model_dump() run_data["parser_config"] = json.dumps(run_data["parser_config"]) data_ingestion_run = await self.db.ingestionruns.create(data=run_data) - return DataIngestionRun(**data_ingestion_run.dict()) + return DataIngestionRun(**data_ingestion_run.model_dump()) except Exception as e: logger.error(f"Failed to create data ingestion run: {e}") raise HTTPException( @@ -428,7 +428,7 @@ async def aget_data_ingestion_run( ) logger.info(f"Data ingestion run: {data_ingestion_run}") if data_ingestion_run: - return DataIngestionRun(**data_ingestion_run.dict()) + return DataIngestionRun(**data_ingestion_run.model_dump()) return None except Exception as e: logger.error(f"Failed to get data ingestion run: {e}") diff --git a/backend/modules/metadata_store/truefoundry.py b/backend/modules/metadata_store/truefoundry.py index e08acd40..a57eda88 100644 --- a/backend/modules/metadata_store/truefoundry.py +++ b/backend/modules/metadata_store/truefoundry.py @@ -110,7 +110,7 @@ def create_collection(self, collection: CreateCollection) -> Collection: embedder_config=collection.embedder_config, ) self._save_entity_to_run( - run=run, metadata=created_collection.dict(), params=params + run=run, metadata=created_collection.model_dump(), params=params ) run.end() logger.debug(f"[Metadata Store] Collection Saved") @@ -184,7 +184,7 @@ def get_collection_by_name( ) return None collection = self._populate_collection( - Collection.parse_obj(self._get_entity_from_run(run=ml_run)) + Collection.model_validate(self._get_entity_from_run(run=ml_run)) ) logger.debug(f"[Metadata Store] Fetched collection with name {collection_name}") return collection @@ -198,7 +198,9 @@ def get_collections(self) -> List[Collection]: ) collections = [] for ml_run in ml_runs: - collection = Collection.parse_obj(self._get_entity_from_run(run=ml_run)) + collection = Collection.model_validate( + self._get_entity_from_run(run=ml_run) + ) collections.append(self._populate_collection(collection)) logger.debug(f"[Metadata Store] Listed {len(collections)} collections") return collections @@ -244,7 +246,9 @@ def associate_data_source_with_collection( f"data source with fqn {data_source_association.data_source_fqn} not found", ) # Always do this to avoid race conditions - collection = Collection.parse_obj(self._get_entity_from_run(run=collection_run)) + collection = Collection.model_validate( + self._get_entity_from_run(run=collection_run) + ) associated_data_source = AssociatedDataSources( data_source_fqn=data_source_association.data_source_fqn, parser_config=data_source_association.parser_config, @@ -253,7 +257,7 @@ def associate_data_source_with_collection( data_source_association.data_source_fqn ] = associated_data_source - self._update_entity_in_run(run=collection_run, metadata=collection.dict()) + self._update_entity_in_run(run=collection_run, metadata=collection.model_dump()) logger.debug( f"[Metadata Store] Associated data_source {data_source_association.data_source_fqn} " f"to collection {collection_name}" @@ -278,9 +282,11 @@ def unassociate_data_source_with_collection( f"Collection {collection_name} not found.", ) # Always do this to avoid run conditions - collection = Collection.parse_obj(self._get_entity_from_run(run=collection_run)) + collection = Collection.model_validate( + self._get_entity_from_run(run=collection_run) + ) collection.associated_data_sources.pop(data_source_fqn) - self._update_entity_in_run(run=collection_run, metadata=collection.dict()) + self._update_entity_in_run(run=collection_run, metadata=collection.model_dump()) logger.debug( f"[Metadata Store] Unassociated data_source {data_source_fqn} to collection {collection_name}" ) @@ -313,7 +319,7 @@ def create_data_source(self, data_source: CreateDataSource) -> DataSource: metadata=data_source.metadata, ) self._save_entity_to_run( - run=run, metadata=created_data_source.dict(), params=params + run=run, metadata=created_data_source.model_dump(), params=params ) run.end() logger.debug( @@ -328,7 +334,7 @@ def get_data_source_from_fqn(self, fqn: str) -> DataSource | None: filter_string=f"params.entity_type = '{MLRunTypes.DATA_SOURCE.value}' and params.data_source_fqn = '{fqn}'", ) for run in runs: - data_source = DataSource.parse_obj(self._get_entity_from_run(run=run)) + data_source = DataSource.model_validate(self._get_entity_from_run(run=run)) logger.debug(f"[Metadata Store] Fetched Data Source with fqn {fqn}") return data_source logger.debug(f"[Metadata Store] Data Source with fqn {fqn} not found") @@ -342,7 +348,7 @@ def get_data_sources(self) -> List[DataSource]: ) data_sources: List[DataSource] = [] for run in runs: - data_source = DataSource.parse_obj(self._get_entity_from_run(run=run)) + data_source = DataSource.model_validate(self._get_entity_from_run(run=run)) data_sources.append(data_source) logger.debug(f"[Metadata Store] Listed {len(data_sources)} data sources") return data_sources @@ -377,7 +383,7 @@ def create_data_ingestion_run( status=DataIngestionRunStatus.INITIALIZED, ) self._save_entity_to_run( - run=run, metadata=created_data_ingestion_run.dict(), params=params + run=run, metadata=created_data_ingestion_run.model_dump(), params=params ) run.end() logger.debug( @@ -398,7 +404,7 @@ def get_data_ingestion_run( f"[Metadata Store] Ingestion run with name {data_ingestion_run_name} not found" ) return None - data_ingestion_run = DataIngestionRun.parse_obj( + data_ingestion_run = DataIngestionRun.model_validate( self._get_entity_from_run(run=run) ) run_tags = run.get_tags() @@ -429,7 +435,7 @@ def get_data_ingestion_runs( ) data_ingestion_runs: List[DataIngestionRun] = [] for run in runs: - data_ingestion_run = DataIngestionRun.parse_obj( + data_ingestion_run = DataIngestionRun.model_validate( self._get_entity_from_run(run=run) ) run_tags = run.get_tags() diff --git a/backend/modules/model_gateway/model_gateway.py b/backend/modules/model_gateway/model_gateway.py index 65ec9717..b26b2b5b 100644 --- a/backend/modules/model_gateway/model_gateway.py +++ b/backend/modules/model_gateway/model_gateway.py @@ -24,7 +24,7 @@ def __init__(self): _providers = data.get("model_providers") or [] # parse the json data into a list of ModelProviderConfig objects self.provider_configs = [ - ModelProviderConfig.parse_obj(item) for item in _providers + ModelProviderConfig.model_validate(item) for item in _providers ] # load llm models diff --git a/backend/modules/query_controllers/example/controller.py b/backend/modules/query_controllers/example/controller.py index 5ef2b1df..7077bfed 100644 --- a/backend/modules/query_controllers/example/controller.py +++ b/backend/modules/query_controllers/example/controller.py @@ -87,7 +87,7 @@ async def _get_vector_store(self, collection_name: str): raise HTTPException(status_code=404, detail="Collection not found") if not isinstance(collection, Collection): - collection = Collection(**collection.dict()) + collection = Collection(**collection.model_dump()) return VECTOR_STORE_CLIENT.get_vector_store( collection_name=collection.name, @@ -324,7 +324,7 @@ async def answer( # "stream": True # } -# data = ExampleQueryInput(**payload).dict() +# data = ExampleQueryInput(**payload).model_dump() # ENDPOINT_URL = 'http://localhost:8000/retrievers/example-app/answer' diff --git a/backend/modules/query_controllers/example/types.py b/backend/modules/query_controllers/example/types.py index 2d6e4fc8..d73c8b03 100644 --- a/backend/modules/query_controllers/example/types.py +++ b/backend/modules/query_controllers/example/types.py @@ -1,8 +1,9 @@ -from typing import Any, ClassVar, Collection, Dict, Optional +from typing import Any, ClassVar, Optional, Sequence from pydantic import BaseModel, Field, field_validator, model_validator from qdrant_client.models import Filter as QdrantFilter +from backend.logger import logger from backend.types import ModelConfig GENERATION_TIMEOUT_SEC = 60.0 * 10 @@ -29,7 +30,7 @@ class VectorStoreRetrieverConfig(BaseModel): title="""Filter by document metadata""", ) - allowed_search_types: ClassVar[Collection[str]] = ( + allowed_search_types: ClassVar[Sequence[str]] = ( "similarity", "similarity_score_threshold", "mmr", @@ -37,33 +38,42 @@ class VectorStoreRetrieverConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_search_type(cls, values: Dict) -> Dict: + def validate_search_type(cls, values: Any) -> Any: """Validate search type.""" - search_type = values.get("search_type") + if isinstance(values, dict): + # TODO (chiragjn): Convert all asserts + search_type = values.get("search_type") - assert ( - search_type in cls.allowed_search_types - ), f"search_type of {search_type} not allowed. Valid values are: {cls.allowed_search_types}" - - search_kwargs = values.get("search_kwargs") - - if search_type == "similarity": - assert "k" in search_kwargs, "k is required for similarity search" - - elif search_type == "mmr": - assert "k" in search_kwargs, "k is required in search_kwargs for mmr search" - assert ( - "fetch_k" in search_kwargs - ), "fetch_k is required in search_kwargs for mmr search" - - elif search_type == "similarity_score_threshold": assert ( - "score_threshold" in search_kwargs - ), "score_threshold with a float value(0~1) is required in search_kwargs for similarity_score_threshold search" - - filters = values.get("filter") - if filters: - search_kwargs["filter"] = QdrantFilter.parse_obj(filters) + search_type in cls.allowed_search_types + ), f"search_type of {search_type} not allowed. Valid values are: {cls.allowed_search_types}" + + search_kwargs = values.get("search_kwargs") + + if search_type == "similarity": + assert "k" in search_kwargs, "k is required for similarity search" + + elif search_type == "mmr": + assert ( + "k" in search_kwargs + ), "k is required in search_kwargs for mmr search" + assert ( + "fetch_k" in search_kwargs + ), "fetch_k is required in search_kwargs for mmr search" + + elif search_type == "similarity_score_threshold": + assert ( + "score_threshold" in search_kwargs + ), "score_threshold with a float value(0~1) is required in search_kwargs for similarity_score_threshold search" + + filters = values.get("filter") + if filters: + search_kwargs["filter"] = QdrantFilter.model_validate(filters) + else: + logger.warning( + f"[Validation Skipped] Pydantic v2 validator received " + f"non dict values of type {type(values)}" + ) return values @@ -86,11 +96,11 @@ class ContextualCompressionRetrieverConfig(VectorStoreRetrieverConfig): title="Top K docs to collect post compression", ) - allowed_compressor_model_providers: ClassVar[Collection[str]] = ("mixedbread-ai",) + allowed_compressor_model_providers: ClassVar[Sequence[str]] = ("mixedbread-ai",) @field_validator("compressor_model_provider") @classmethod - def validate_retriever_type(cls, value) -> Dict: + def validate_retriever_type(cls, value: str) -> str: assert ( value in cls.allowed_compressor_model_providers ), f"Compressor model of {value} not allowed. Valid values are: {cls.allowed_compressor_model_providers}" @@ -120,6 +130,7 @@ class ExampleQueryInput(BaseModel): query: str = Field(title="Question to search for") + # TODO (chiragjn): Pydantic v2 does not like fields that begin with model_* model_configuration: ModelConfig prompt_template: str = Field( @@ -130,11 +141,11 @@ class ExampleQueryInput(BaseModel): title="Retriever name", ) - retriever_config: Dict[str, Any] = Field( + retriever_config: dict[str, Any] = Field( title="Retriever configuration", ) - allowed_retriever_types: ClassVar[Collection[str]] = ( + allowed_retriever_types: ClassVar[Sequence[str]] = ( "vectorstore", "multi-query", "contextual-compression", @@ -142,40 +153,47 @@ class ExampleQueryInput(BaseModel): "lord-of-the-retrievers", ) - stream: Optional[bool] = Field(title="Stream the results", default=False) + stream: bool = Field(default=False, title="Stream the results") @model_validator(mode="before") @classmethod - def validate_retriever_type(cls, values: Dict) -> Dict: - retriever_name = values.get("retriever_name") - - assert ( - retriever_name in cls.allowed_retriever_types - ), f"retriever of {retriever_name} not allowed. Valid values are: {cls.allowed_retriever_types}" + def validate_retriever_type(cls, values: Any) -> Any: + if isinstance(values, dict): + retriever_name = values.get("retriever_name") - if retriever_name == "vectorstore": - values["retriever_config"] = VectorStoreRetrieverConfig( - **values.get("retriever_config") - ) - - elif retriever_name == "multi-query": - values["retriever_config"] = MultiQueryRetrieverConfig( - **values.get("retriever_config") - ) - - elif retriever_name == "contextual-compression": - values["retriever_config"] = ContextualCompressionRetrieverConfig( - **values.get("retriever_config") - ) - - elif retriever_name == "contextual-compression-multi-query": - values["retriever_config"] = ContextualCompressionMultiQueryRetrieverConfig( - **values.get("retriever_config") - ) - - elif retriever_name == "lord-of-the-retrievers": - values["retriever_config"] = LordOfRetrievers( - **values.get("retriever_config") + assert ( + retriever_name in cls.allowed_retriever_types + ), f"retriever of {retriever_name} not allowed. Valid values are: {cls.allowed_retriever_types}" + + if retriever_name == "vectorstore": + values["retriever_config"] = VectorStoreRetrieverConfig( + **values.get("retriever_config") + ) + + elif retriever_name == "multi-query": + values["retriever_config"] = MultiQueryRetrieverConfig( + **values.get("retriever_config") + ) + + elif retriever_name == "contextual-compression": + values["retriever_config"] = ContextualCompressionRetrieverConfig( + **values.get("retriever_config") + ) + + elif retriever_name == "contextual-compression-multi-query": + values[ + "retriever_config" + ] = ContextualCompressionMultiQueryRetrieverConfig( + **values.get("retriever_config") + ) + + elif retriever_name == "lord-of-the-retrievers": + values["retriever_config"] = LordOfRetrievers( + **values.get("retriever_config") + ) + else: + logger.warning( + f"[Validation Skipped] Pydantic v2 validator received " + f"non dict values of type {type(values)}" ) - return values diff --git a/backend/modules/query_controllers/multimodal/controller.py b/backend/modules/query_controllers/multimodal/controller.py index b57bff65..b403c32e 100644 --- a/backend/modules/query_controllers/multimodal/controller.py +++ b/backend/modules/query_controllers/multimodal/controller.py @@ -100,7 +100,7 @@ async def _get_vector_store(self, collection_name: str): raise HTTPException(status_code=404, detail="Collection not found") if not isinstance(collection, Collection): - collection = Collection(**collection.dict()) + collection = Collection(**collection.model_dump()) return VECTOR_STORE_CLIENT.get_vector_store( collection_name=collection.name, diff --git a/backend/modules/query_controllers/multimodal/types.py b/backend/modules/query_controllers/multimodal/types.py index 2d6e4fc8..d73c8b03 100644 --- a/backend/modules/query_controllers/multimodal/types.py +++ b/backend/modules/query_controllers/multimodal/types.py @@ -1,8 +1,9 @@ -from typing import Any, ClassVar, Collection, Dict, Optional +from typing import Any, ClassVar, Optional, Sequence from pydantic import BaseModel, Field, field_validator, model_validator from qdrant_client.models import Filter as QdrantFilter +from backend.logger import logger from backend.types import ModelConfig GENERATION_TIMEOUT_SEC = 60.0 * 10 @@ -29,7 +30,7 @@ class VectorStoreRetrieverConfig(BaseModel): title="""Filter by document metadata""", ) - allowed_search_types: ClassVar[Collection[str]] = ( + allowed_search_types: ClassVar[Sequence[str]] = ( "similarity", "similarity_score_threshold", "mmr", @@ -37,33 +38,42 @@ class VectorStoreRetrieverConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_search_type(cls, values: Dict) -> Dict: + def validate_search_type(cls, values: Any) -> Any: """Validate search type.""" - search_type = values.get("search_type") + if isinstance(values, dict): + # TODO (chiragjn): Convert all asserts + search_type = values.get("search_type") - assert ( - search_type in cls.allowed_search_types - ), f"search_type of {search_type} not allowed. Valid values are: {cls.allowed_search_types}" - - search_kwargs = values.get("search_kwargs") - - if search_type == "similarity": - assert "k" in search_kwargs, "k is required for similarity search" - - elif search_type == "mmr": - assert "k" in search_kwargs, "k is required in search_kwargs for mmr search" - assert ( - "fetch_k" in search_kwargs - ), "fetch_k is required in search_kwargs for mmr search" - - elif search_type == "similarity_score_threshold": assert ( - "score_threshold" in search_kwargs - ), "score_threshold with a float value(0~1) is required in search_kwargs for similarity_score_threshold search" - - filters = values.get("filter") - if filters: - search_kwargs["filter"] = QdrantFilter.parse_obj(filters) + search_type in cls.allowed_search_types + ), f"search_type of {search_type} not allowed. Valid values are: {cls.allowed_search_types}" + + search_kwargs = values.get("search_kwargs") + + if search_type == "similarity": + assert "k" in search_kwargs, "k is required for similarity search" + + elif search_type == "mmr": + assert ( + "k" in search_kwargs + ), "k is required in search_kwargs for mmr search" + assert ( + "fetch_k" in search_kwargs + ), "fetch_k is required in search_kwargs for mmr search" + + elif search_type == "similarity_score_threshold": + assert ( + "score_threshold" in search_kwargs + ), "score_threshold with a float value(0~1) is required in search_kwargs for similarity_score_threshold search" + + filters = values.get("filter") + if filters: + search_kwargs["filter"] = QdrantFilter.model_validate(filters) + else: + logger.warning( + f"[Validation Skipped] Pydantic v2 validator received " + f"non dict values of type {type(values)}" + ) return values @@ -86,11 +96,11 @@ class ContextualCompressionRetrieverConfig(VectorStoreRetrieverConfig): title="Top K docs to collect post compression", ) - allowed_compressor_model_providers: ClassVar[Collection[str]] = ("mixedbread-ai",) + allowed_compressor_model_providers: ClassVar[Sequence[str]] = ("mixedbread-ai",) @field_validator("compressor_model_provider") @classmethod - def validate_retriever_type(cls, value) -> Dict: + def validate_retriever_type(cls, value: str) -> str: assert ( value in cls.allowed_compressor_model_providers ), f"Compressor model of {value} not allowed. Valid values are: {cls.allowed_compressor_model_providers}" @@ -120,6 +130,7 @@ class ExampleQueryInput(BaseModel): query: str = Field(title="Question to search for") + # TODO (chiragjn): Pydantic v2 does not like fields that begin with model_* model_configuration: ModelConfig prompt_template: str = Field( @@ -130,11 +141,11 @@ class ExampleQueryInput(BaseModel): title="Retriever name", ) - retriever_config: Dict[str, Any] = Field( + retriever_config: dict[str, Any] = Field( title="Retriever configuration", ) - allowed_retriever_types: ClassVar[Collection[str]] = ( + allowed_retriever_types: ClassVar[Sequence[str]] = ( "vectorstore", "multi-query", "contextual-compression", @@ -142,40 +153,47 @@ class ExampleQueryInput(BaseModel): "lord-of-the-retrievers", ) - stream: Optional[bool] = Field(title="Stream the results", default=False) + stream: bool = Field(default=False, title="Stream the results") @model_validator(mode="before") @classmethod - def validate_retriever_type(cls, values: Dict) -> Dict: - retriever_name = values.get("retriever_name") - - assert ( - retriever_name in cls.allowed_retriever_types - ), f"retriever of {retriever_name} not allowed. Valid values are: {cls.allowed_retriever_types}" + def validate_retriever_type(cls, values: Any) -> Any: + if isinstance(values, dict): + retriever_name = values.get("retriever_name") - if retriever_name == "vectorstore": - values["retriever_config"] = VectorStoreRetrieverConfig( - **values.get("retriever_config") - ) - - elif retriever_name == "multi-query": - values["retriever_config"] = MultiQueryRetrieverConfig( - **values.get("retriever_config") - ) - - elif retriever_name == "contextual-compression": - values["retriever_config"] = ContextualCompressionRetrieverConfig( - **values.get("retriever_config") - ) - - elif retriever_name == "contextual-compression-multi-query": - values["retriever_config"] = ContextualCompressionMultiQueryRetrieverConfig( - **values.get("retriever_config") - ) - - elif retriever_name == "lord-of-the-retrievers": - values["retriever_config"] = LordOfRetrievers( - **values.get("retriever_config") + assert ( + retriever_name in cls.allowed_retriever_types + ), f"retriever of {retriever_name} not allowed. Valid values are: {cls.allowed_retriever_types}" + + if retriever_name == "vectorstore": + values["retriever_config"] = VectorStoreRetrieverConfig( + **values.get("retriever_config") + ) + + elif retriever_name == "multi-query": + values["retriever_config"] = MultiQueryRetrieverConfig( + **values.get("retriever_config") + ) + + elif retriever_name == "contextual-compression": + values["retriever_config"] = ContextualCompressionRetrieverConfig( + **values.get("retriever_config") + ) + + elif retriever_name == "contextual-compression-multi-query": + values[ + "retriever_config" + ] = ContextualCompressionMultiQueryRetrieverConfig( + **values.get("retriever_config") + ) + + elif retriever_name == "lord-of-the-retrievers": + values["retriever_config"] = LordOfRetrievers( + **values.get("retriever_config") + ) + else: + logger.warning( + f"[Validation Skipped] Pydantic v2 validator received " + f"non dict values of type {type(values)}" ) - return values diff --git a/backend/modules/vector_db/qdrant.py b/backend/modules/vector_db/qdrant.py index 5ac47d04..341493e3 100644 --- a/backend/modules/vector_db/qdrant.py +++ b/backend/modules/vector_db/qdrant.py @@ -17,7 +17,7 @@ class QdrantVectorDB(BaseVectorDB): def __init__(self, config: VectorDBConfig): - logger.debug(f"Connecting to qdrant using config: {config.dict()}") + logger.debug(f"Connecting to qdrant using config: {config.model_dump()}") if config.local is True: # TODO: make this path customizable self.qdrant_client = QdrantClient( @@ -28,7 +28,7 @@ def __init__(self, config: VectorDBConfig): api_key = config.api_key if not api_key: api_key = None - qdrant_kwargs = QdrantClientConfig.parse_obj(config.config or {}) + qdrant_kwargs = QdrantClientConfig.model_validate(config.config or {}) if url.startswith("http://") or url.startswith("https://"): if qdrant_kwargs.port is None: parsed_port = urlparse(url).port @@ -37,7 +37,7 @@ def __init__(self, config: VectorDBConfig): else: qdrant_kwargs.port = 443 if url.startswith("https://") else 6333 self.qdrant_client = QdrantClient( - url=url, api_key=api_key, **qdrant_kwargs.dict() + url=url, api_key=api_key, **qdrant_kwargs.model_dump() ) def create_collection(self, collection_name: str, embeddings: Embeddings): diff --git a/backend/server/decorators.py b/backend/server/decorators.py index cef2eeaf..8f12dd81 100644 --- a/backend/server/decorators.py +++ b/backend/server/decorators.py @@ -58,6 +58,7 @@ def _init_cbv(cls: Type[Any]) -> None: ] dependency_names: List[str] = [] for name, hint in get_type_hints(cls).items(): + # TODO (chiragjn): Verify this if getattr(hint, "__origin__", None) is ClassVar: continue parameter_kwargs = {"default": getattr(cls, name, Ellipsis)} @@ -126,13 +127,13 @@ def wrapper(cls) -> ClassBasedView: for name, method in cls.__dict__.items(): if callable(method) and hasattr(method, "method"): # Check if method is decorated with an HTTP method decorator - assert ( - hasattr(method, "__path__") and method.__path__ - ), f"Missing path for method {name}" + if not hasattr(method, "__path__") or not method.__path__: + raise ValueError(f"Missing path for method {name}") http_method = method.method # Ensure that the method is a valid HTTP method - assert http_method in http_method_names, f"Invalid method {http_method}" + if http_method not in http_method_names: + raise ValueError(f"Invalid method {http_method}") if prefix: method.__path__ = prefix + method.__path__ if not method.__path__.startswith("/"): diff --git a/backend/server/routers/collection.py b/backend/server/routers/collection.py index 3d278755..3cd8684a 100644 --- a/backend/server/routers/collection.py +++ b/backend/server/routers/collection.py @@ -29,7 +29,7 @@ async def get_collections(): if collections is None: return JSONResponse(content={"collections": []}) return JSONResponse( - content={"collections": [obj.dict() for obj in collections]} + content={"collections": [obj.model_dump() for obj in collections]} ) except Exception as exp: logger.exception("Failed to get collection") @@ -55,7 +55,7 @@ async def get_collection_by_name(collection_name: str = Path(title="Collection n collection = await client.aget_collection_by_name(collection_name) if collection is None: return JSONResponse(content={"collection": []}) - return JSONResponse(content={"collection": collection.dict()}) + return JSONResponse(content={"collection": collection.model_dump()}) except HTTPException as exp: raise exp except Exception as exp: @@ -98,7 +98,7 @@ async def create_collection(collection: CreateCollectionDto): collection_name=created_collection.name ) return JSONResponse( - content={"collection": created_collection.dict()}, status_code=201 + content={"collection": created_collection.model_dump()}, status_code=201 ) except HTTPException as exp: raise exp @@ -121,7 +121,7 @@ async def associate_data_source_to_collection( parser_config=request.parser_config, ), ) - return JSONResponse(content={"collection": collection.dict()}) + return JSONResponse(content={"collection": collection.model_dump()}) except HTTPException as exp: raise exp except Exception as exp: @@ -140,7 +140,7 @@ async def unassociate_data_source_from_collection( collection_name=request.collection_name, data_source_fqn=request.data_source_fqn, ) - return JSONResponse(content={"collection": collection.dict()}) + return JSONResponse(content={"collection": collection.model_dump()}) except HTTPException as exp: raise exp except Exception as exp: @@ -182,7 +182,9 @@ async def list_data_ingestion_runs(request: ListDataIngestionRunsDto): request.collection_name, request.data_source_fqn ) return JSONResponse( - content={"data_ingestion_runs": [obj.dict() for obj in data_ingestion_runs]} + content={ + "data_ingestion_runs": [obj.model_dump() for obj in data_ingestion_runs] + } ) diff --git a/backend/server/routers/data_source.py b/backend/server/routers/data_source.py index ff40d686..ef821b5d 100644 --- a/backend/server/routers/data_source.py +++ b/backend/server/routers/data_source.py @@ -17,7 +17,7 @@ async def get_data_source(): client = await get_client() data_sources = await client.aget_data_sources() return JSONResponse( - content={"data_sources": [obj.dict() for obj in data_sources]} + content={"data_sources": [obj.model_dump() for obj in data_sources]} ) except Exception as exp: logger.exception("Failed to get data source") @@ -45,7 +45,7 @@ async def add_data_source( client = await get_client() created_data_source = await client.acreate_data_source(data_source=data_source) return JSONResponse( - content={"data_source": created_data_source.dict()}, status_code=201 + content={"data_source": created_data_source.model_dump()}, status_code=201 ) except HTTPException as exp: raise exp diff --git a/backend/server/routers/internal.py b/backend/server/routers/internal.py index 1a9534e5..e6d0dda0 100644 --- a/backend/server/routers/internal.py +++ b/backend/server/routers/internal.py @@ -88,7 +88,7 @@ async def upload_to_data_directory(req: UploadToDataDirectoryDto): paths=req.filepaths, ) - data = [url.dict() for url in urls] + data = [url.model_dump() for url in urls] return JSONResponse( content={"data": data, "data_directory_fqn": dataset.fqn}, ) @@ -111,7 +111,7 @@ def get_enabled_models( ) # Serialized models - serialized_models = [model.dict() for model in enabled_models] + serialized_models = [model.model_dump() for model in enabled_models] return JSONResponse( content={"models": serialized_models}, ) diff --git a/backend/types.py b/backend/types.py index 342b5e4d..97a228f4 100644 --- a/backend/types.py +++ b/backend/types.py @@ -1,9 +1,9 @@ import enum import uuid from enum import Enum -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Literal, Optional -from pydantic import BaseModel, ConfigDict, Field, StringConstraints, model_validator +from pydantic import BaseModel, ConfigDict, Field, StringConstraints from typing_extensions import Annotated from backend.constants import FQN_SEPARATOR @@ -27,7 +27,7 @@ class DataPoint(BaseModel): data_point_fqn (str): Fully qualified name for the data point with respect to the data source data_point_uri (str): URI for the data point for given data source. It could be url, file path or any other identifier data_point_hash (str): Hash of the data point for the given data source that is guaranteed to be updated for any update in data point at source - metadata (Optional[Dict[str, str]]): Additional metadata for the data point + metadata (Optional[dict[str, str]]): Additional metadata for the data point """ data_source_fqn: str = Field( @@ -42,7 +42,7 @@ class DataPoint(BaseModel): title="Hash of the data point for the given data source that is guaranteed to be updated for any update in data point at source", ) - metadata: Optional[Dict[str, str]] = Field( + metadata: Optional[dict[str, Any]] = Field( None, title="Additional metadata for the data point", ) @@ -105,17 +105,17 @@ class ModelType(str, Enum): class ModelConfig(BaseModel): name: str - type: Optional[ModelType] - parameters: Optional[Dict[str, Any]] = None + type: ModelType + parameters: dict[str, Any] = Field(default_factory=dict) class ModelProviderConfig(BaseModel): provider_name: str api_format: str - llm_model_ids: List[str] - embedding_model_ids: List[str] - api_key_env_var: str base_url: Optional[str] = None + api_key_env_var: str + llm_model_ids: list[str] = Field(default_factory=list) + embedding_model_ids: list[str] = Field(default_factory=list) class EmbedderConfig(BaseModel): @@ -123,8 +123,9 @@ class EmbedderConfig(BaseModel): Embedder configuration """ + # TODO (chiragjn): Pydantic v2 does not like fields that begin with model_* model_config: ModelConfig - config: Optional[Dict[str, Any]] = Field( + config: Optional[dict[str, Any]] = Field( title="Configuration for the embedder", default_factory=dict ) @@ -136,14 +137,14 @@ class ParserConfig(BaseModel): chunk_size: int = Field(title="Chunk Size for data parsing", ge=1, default=1000) chunk_overlap: int = Field(title="Chunk Overlap for indexing", ge=0, default=20) - parser_map: Dict[str, str] = Field( + parser_map: dict[str, str] = Field( title="Mapping of file extensions to parsers", default={ ".md": "MarkdownParser", ".pdf": "PdfParserFast", }, ) - additional_config: Optional[Dict[str, Any]] = Field( + additional_config: Optional[dict[str, Any]] = Field( title="Additional optional configuration for the parser", default_factory=dict, ) @@ -158,7 +159,7 @@ class VectorDBConfig(BaseModel): local: bool = False url: Optional[str] = None api_key: Optional[str] = None - config: Optional[dict] = None + config: Optional[dict] = Field(default_factory=dict) class QdrantClientConfig(BaseModel): @@ -191,7 +192,8 @@ class RetrieverConfig(BaseModel): search_type: Literal["mmr", "similarity"] = Field( default="similarity", - title="""Defines the type of search that the Retriever should perform. Can be "similarity" (default), "mmr", or "similarity_score_threshold".""", + title="""Defines the type of search that the Retriever should perform. \ + Can be "similarity" (default), "mmr", or "similarity_score_threshold".""", ) k: int = Field( default=4, @@ -260,7 +262,7 @@ class BaseDataIngestionRun(BaseModel): title="Data ingestion mode for the data ingestion", ) - raise_error_on_failure: Optional[bool] = Field( + raise_error_on_failure: bool = Field( title="Flag to configure weather to raise error on failure or not. Default is True", default=True, ) @@ -291,7 +293,7 @@ class BaseDataSource(BaseModel): uri: str = Field( title="A unique identifier for the data source", ) - metadata: Optional[Dict[str, Any]] = Field( + metadata: Optional[dict[str, Any]] = Field( None, title="Additional config for your data source" ) @@ -299,12 +301,6 @@ class BaseDataSource(BaseModel): def fqn(self): return f"{FQN_SEPARATOR}".join([self.type, self.uri]) - @model_validator(mode="before") - @classmethod - def validate_fqn(cls, values: Dict) -> Dict: - values["fqn"] = f"{FQN_SEPARATOR}".join([values["type"], values["uri"]]) - return values - class CreateDataSource(BaseDataSource): pass @@ -349,7 +345,7 @@ class IngestDataToCollectionDto(BaseModel): title="Data ingestion mode for the data ingestion", ) - raise_error_on_failure: Optional[bool] = Field( + raise_error_on_failure: bool = Field( title="Flag to configure weather to raise error on failure or not. Default is True", default=True, ) @@ -430,19 +426,19 @@ class CreateCollection(BaseCollection): class Collection(BaseCollection): - associated_data_sources: Dict[str, AssociatedDataSources] = Field( + associated_data_sources: dict[str, AssociatedDataSources] = Field( title="Data sources associated with the collection", default_factory=dict ) class CreateCollectionDto(CreateCollection): - associated_data_sources: Optional[List[AssociateDataSourceWithCollection]] = Field( + associated_data_sources: Optional[list[AssociateDataSourceWithCollection]] = Field( None, title="Data sources associated with the collection" ) class UploadToDataDirectoryDto(BaseModel): - filepaths: List[str] + filepaths: list[str] # allow only small case alphanumeric and hyphen, should contain at least one alphabet and begin with alphabet upload_name: str = Field( title="Name of the upload", From 563b20ade1d26335178cfe19b95a0b8a90caeeb9 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Wed, 31 Jul 2024 02:41:55 +0530 Subject: [PATCH 06/16] Fixes on top of merging main --- backend/migration/utils.py | 7 +++---- backend/modules/metadata_store/base.py | 2 +- backend/modules/metadata_store/prismastore.py | 3 +-- backend/modules/model_gateway/model_gateway.py | 1 - backend/modules/model_gateway/reranker_svc.py | 3 +-- backend/modules/parsers/parser.py | 2 +- backend/modules/parsers/unstructured_io.py | 1 + backend/modules/query_controllers/example/types.py | 2 ++ backend/modules/query_controllers/multimodal/types.py | 2 ++ backend/requirements.txt | 4 ++-- backend/server/routers/rag_apps.py | 2 +- 11 files changed, 15 insertions(+), 14 deletions(-) diff --git a/backend/migration/utils.py b/backend/migration/utils.py index d0198915..afc15732 100644 --- a/backend/migration/utils.py +++ b/backend/migration/utils.py @@ -2,7 +2,6 @@ from typing import Dict import requests -from qdrant_client._pydantic_compat import to_dict from qdrant_client.client_base import QdrantBase from qdrant_client.http import models from tqdm import tqdm @@ -104,11 +103,11 @@ def _recreate_collection( replication_factor=src_config.params.replication_factor, write_consistency_factor=src_config.params.write_consistency_factor, on_disk_payload=src_config.params.on_disk_payload, - hnsw_config=models.HnswConfigDiff(**to_dict(src_config.hnsw_config)), + hnsw_config=models.HnswConfigDiff(**src_config.hnsw_config.model_dump()), optimizers_config=models.OptimizersConfigDiff( - **to_dict(src_config.optimizer_config) + **src_config.optimizer_config.model_dump() ), - wal_config=models.WalConfigDiff(**to_dict(src_config.wal_config)), + wal_config=models.WalConfigDiff(**src_config.wal_config.model_dump()), quantization_config=src_config.quantization_config, timeout=300, ) diff --git a/backend/modules/metadata_store/base.py b/backend/modules/metadata_store/base.py index 222e5207..6c4e34ef 100644 --- a/backend/modules/metadata_store/base.py +++ b/backend/modules/metadata_store/base.py @@ -18,7 +18,7 @@ from backend.utils import run_in_executor -# TODO(chiragjn): Ideal would be we make `async def a*` abstract methods and drop sync ones +# TODO (chiragjn): Ideal would be we make `async def a*` abstract methods and drop sync ones # Implementations can then opt to call their sync versions using run_in_executor class BaseMetadataStore(ABC): def __init__(self, *args, **kwargs): diff --git a/backend/modules/metadata_store/prismastore.py b/backend/modules/metadata_store/prismastore.py index 97a64375..ff323449 100644 --- a/backend/modules/metadata_store/prismastore.py +++ b/backend/modules/metadata_store/prismastore.py @@ -28,7 +28,7 @@ # TODO (chiragjn): # Either we make everything async or add sync method to this -# Some methods are using json.dumps - not sure if this is the right way to do it +# Some methods are using json.dumps - not sure if this is the right way to send data via prisma client # primsa generates its own DB entity classes - ideally we should be using those instead of call .model_dump() on the pydantic objects @@ -499,7 +499,6 @@ async def alog_errors_for_data_ingestion_run( # RAG APPLICATION APIS ###### - # TODO (prathamesh): Implement these methods async def acreate_rag_app(self, app: RagApplication) -> RagApplicationDto: """Create a RAG application in the metadata store""" try: diff --git a/backend/modules/model_gateway/model_gateway.py b/backend/modules/model_gateway/model_gateway.py index d72e4176..33f75e53 100644 --- a/backend/modules/model_gateway/model_gateway.py +++ b/backend/modules/model_gateway/model_gateway.py @@ -1,4 +1,3 @@ -import json import os from typing import List diff --git a/backend/modules/model_gateway/reranker_svc.py b/backend/modules/model_gateway/reranker_svc.py index 370677c8..32a85043 100644 --- a/backend/modules/model_gateway/reranker_svc.py +++ b/backend/modules/model_gateway/reranker_svc.py @@ -6,14 +6,13 @@ from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from backend.logger import logger -from backend.settings import settings # Reranking Service using Infinity API class InfinityRerankerSvc(BaseDocumentCompressor): """ Reranker Service that uses Infinity API - Github: https://github.com/michaelfeil/infinity + GitHub: https://github.com/michaelfeil/infinity """ model: str diff --git a/backend/modules/parsers/parser.py b/backend/modules/parsers/parser.py index 1f6c46a6..a4d0adbe 100644 --- a/backend/modules/parsers/parser.py +++ b/backend/modules/parsers/parser.py @@ -57,7 +57,7 @@ async def get_chunks( def get_parser_for_extension( file_extension, parsers_map, *args, **kwargs -) -> BaseParser: +) -> Optional[BaseParser]: """ During the indexing phase, given the file_extension and parsers mapping, return the appropriate mapper. If no mapping is given, use the default registry. diff --git a/backend/modules/parsers/unstructured_io.py b/backend/modules/parsers/unstructured_io.py index 65e69eea..2307844b 100644 --- a/backend/modules/parsers/unstructured_io.py +++ b/backend/modules/parsers/unstructured_io.py @@ -51,6 +51,7 @@ def __init__(self, max_chunk_size: int = 2000, *args, **kwargs): self.adapter = HTTPAdapter(max_retries=self.retry_strategy) self.session.mount("https://", self.adapter) self.session.mount("http://", self.adapter) + super().__init__(*args, **kwargs) async def get_chunks(self, filepath: str, metadata: dict, *args, **kwargs): """ diff --git a/backend/modules/query_controllers/example/types.py b/backend/modules/query_controllers/example/types.py index 98ff0059..98791629 100644 --- a/backend/modules/query_controllers/example/types.py +++ b/backend/modules/query_controllers/example/types.py @@ -36,6 +36,7 @@ class VectorStoreRetrieverConfig(BaseModel): ) @model_validator(mode="before") + @classmethod def validate_search_type(cls, values: Dict) -> Dict: """Validate search type.""" search_type = values.get("search_type") @@ -127,6 +128,7 @@ class ExampleQueryInput(BaseModel): stream: Optional[bool] = Field(title="Stream the results", default=False) @model_validator(mode="before") + @classmethod def validate_retriever_type(cls, values: Dict) -> Dict: retriever_name = values.get("retriever_name") diff --git a/backend/modules/query_controllers/multimodal/types.py b/backend/modules/query_controllers/multimodal/types.py index 98ff0059..98791629 100644 --- a/backend/modules/query_controllers/multimodal/types.py +++ b/backend/modules/query_controllers/multimodal/types.py @@ -36,6 +36,7 @@ class VectorStoreRetrieverConfig(BaseModel): ) @model_validator(mode="before") + @classmethod def validate_search_type(cls, values: Dict) -> Dict: """Validate search type.""" search_type = values.get("search_type") @@ -127,6 +128,7 @@ class ExampleQueryInput(BaseModel): stream: Optional[bool] = Field(title="Stream the results", default=False) @model_validator(mode="before") + @classmethod def validate_retriever_type(cls, values: Dict) -> Dict: retriever_name = values.get("retriever_name") diff --git a/backend/requirements.txt b/backend/requirements.txt index 61c8cf85..1317f818 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -5,8 +5,8 @@ langchain-openai==0.1.7 langchain-core==0.1.46 openai==1.35.3 tiktoken==0.7.0 -uvicorn==0.23.2 -fastapi==0.109.1 +uvicorn[standard]==0.23.2 +fastapi==0.111.1 qdrant-client==1.9.0 python-dotenv==1.0.1 pydantic==2.7.4 diff --git a/backend/server/routers/rag_apps.py b/backend/server/routers/rag_apps.py index e558a038..48978b82 100644 --- a/backend/server/routers/rag_apps.py +++ b/backend/server/routers/rag_apps.py @@ -3,7 +3,7 @@ from backend.logger import logger from backend.modules.metadata_store.client import get_client -from backend.types import CreateRagApplication, RagApplicationDto +from backend.types import CreateRagApplication router = APIRouter(prefix="/v1/apps", tags=["apps"]) From 094ea4fc2b66fff3033a3f19815aae4aac651896 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Wed, 31 Jul 2024 02:59:52 +0530 Subject: [PATCH 07/16] More typing refactor for consistency --- backend/modules/metadata_store/base.py | 8 ++--- backend/modules/metadata_store/prismastore.py | 21 ++++++------ backend/modules/metadata_store/truefoundry.py | 22 ++++++------- .../query_controllers/example/types.py | 16 +++++----- .../query_controllers/multimodal/types.py | 16 +++++----- backend/types.py | 32 +++++++++---------- 6 files changed, 58 insertions(+), 57 deletions(-) diff --git a/backend/modules/metadata_store/base.py b/backend/modules/metadata_store/base.py index 6c4e34ef..5e568b64 100644 --- a/backend/modules/metadata_store/base.py +++ b/backend/modules/metadata_store/base.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from backend.constants import DATA_POINT_FQN_METADATA_KEY, FQN_SEPARATOR from backend.types import ( @@ -300,7 +300,7 @@ async def aupdate_data_ingestion_run_status( def log_metrics_for_data_ingestion_run( self, data_ingestion_run_name: str, - metric_dict: dict[str, int | float], + metric_dict: Dict[str, Union[int, float]], step: int = 0, ): """ @@ -311,7 +311,7 @@ def log_metrics_for_data_ingestion_run( async def alog_metrics_for_data_ingestion_run( self, data_ingestion_run_name: str, - metric_dict: dict[str, int | float], + metric_dict: Dict[str, Union[int, float]], step: int = 0, ): """ @@ -410,7 +410,7 @@ async def acreate_rag_app(self, app: RagApplication) -> RagApplicationDto: """ return await run_in_executor(None, self.create_rag_app, app=app) - def get_rag_app(self, app_name: str) -> RagApplicationDto | None: + def get_rag_app(self, app_name: str) -> Optional[RagApplicationDto]: """ Get a RAG application from the metadata store by name """ diff --git a/backend/modules/metadata_store/prismastore.py b/backend/modules/metadata_store/prismastore.py index ff323449..0682b324 100644 --- a/backend/modules/metadata_store/prismastore.py +++ b/backend/modules/metadata_store/prismastore.py @@ -4,7 +4,7 @@ import random import shutil import string -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Union from fastapi import HTTPException from prisma import Prisma @@ -27,9 +27,10 @@ ) # TODO (chiragjn): -# Either we make everything async or add sync method to this -# Some methods are using json.dumps - not sure if this is the right way to send data via prisma client -# primsa generates its own DB entity classes - ideally we should be using those instead of call .model_dump() on the pydantic objects +# 1. Either we make everything async or add sync method to this +# 2. Some methods are using json.dumps - not sure if this is the right way to send data via prisma client +# 3. primsa generates its own DB entity classes - ideally we should be using those instead of call +# .model_dump() on the pydantic objects # TODO (chiragjn): Either we make everything async or add sync method to this @@ -81,7 +82,7 @@ async def acreate_collection(self, collection: CreateCollection) -> Collection: async def aget_collection_by_name( self, collection_name: str, no_cache: bool = True - ) -> Collection | None: + ) -> Optional[Collection]: try: collection = await self.db.collection.find_first( where={"name": collection_name} @@ -97,7 +98,7 @@ async def aget_collection_by_name( async def aget_retrieve_collection_by_name( self, collection_name: str, no_cache: bool = True - ) -> Collection | None: + ) -> Optional[Collection]: return await self.aget_collection_by_name(collection_name, no_cache) async def aget_collections(self) -> List[Collection]: @@ -164,7 +165,7 @@ async def acreate_data_source(self, data_source: CreateDataSource) -> DataSource logger.exception(f"Failed to create data source: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") - async def aget_data_source_from_fqn(self, fqn: str) -> DataSource | None: + async def aget_data_source_from_fqn(self, fqn: str) -> Optional[DataSource]: try: data_source = await self.db.datasource.find_first(where={"fqn": fqn}) if data_source: @@ -433,7 +434,7 @@ async def acreate_data_ingestion_run( async def aget_data_ingestion_run( self, data_ingestion_run_name: str, no_cache: bool = False - ) -> DataIngestionRun | None: + ) -> Optional[DataIngestionRun]: try: data_ingestion_run = await self.db.ingestionruns.find_first( where={"name": data_ingestion_run_name} @@ -475,7 +476,7 @@ async def aupdate_data_ingestion_run_status( async def alog_metrics_for_data_ingestion_run( self, data_ingestion_run_name: str, - metric_dict: dict[str, int | float], + metric_dict: Dict[str, Union[int, float]], step: int = 0, ): pass @@ -524,7 +525,7 @@ async def acreate_rag_app(self, app: RagApplication) -> RagApplicationDto: logger.exception(f"Error: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") - async def aget_rag_app(self, app_name: str) -> RagApplicationDto | None: + async def aget_rag_app(self, app_name: str) -> Optional[RagApplicationDto]: """Get a RAG application from the metadata store""" try: rag_app = await self.db.ragapps.find_first(where={"name": app_name}) diff --git a/backend/modules/metadata_store/truefoundry.py b/backend/modules/metadata_store/truefoundry.py index 1629afed..b47b0fb1 100644 --- a/backend/modules/metadata_store/truefoundry.py +++ b/backend/modules/metadata_store/truefoundry.py @@ -3,7 +3,7 @@ import os import tempfile import warnings -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Union import mlflow from fastapi import HTTPException @@ -38,7 +38,7 @@ class MLRunTypes(str, enum.Enum): class TrueFoundry(BaseMetadataStore): - ml_runs: dict[str, ml.MlFoundryRun] = {} + ml_runs: Dict[str, ml.MlFoundryRun] = {} CONSTANT_DATA_SOURCE_RUN_NAME = "tfy-datasource" def __init__(self, *args, ml_repo_name: str, **kwargs): @@ -58,7 +58,7 @@ def __init__(self, *args, ml_repo_name: str, **kwargs): def _get_run_by_name( self, run_name: str, no_cache: bool = False - ) -> ml.MlFoundryRun | None: + ) -> Optional[ml.MlFoundryRun]: """ Cache the runs to avoid too many requests to the backend. """ @@ -134,7 +134,7 @@ def _get_entity_from_run( def _get_artifact_metadata_ml_run( self, run: ml.MlFoundryRun - ) -> ml.ArtifactVersion | None: + ) -> Optional[ml.ArtifactVersion]: params = run.get_params() metadata_artifact_fqn = params.get("metadata_artifact_fqn") if not metadata_artifact_fqn: @@ -177,7 +177,7 @@ def _update_entity_in_run( def get_collection_by_name( self, collection_name: str, no_cache: bool = True - ) -> Collection | None: + ) -> Optional[Collection]: """Get collection from given collection name.""" logger.debug(f"[Metadata Store] Getting collection with name {collection_name}") ml_run = self._get_run_by_name(run_name=collection_name, no_cache=no_cache) @@ -194,7 +194,7 @@ def get_collection_by_name( def get_retrieve_collection_by_name( self, collection_name: str, no_cache: bool = True - ) -> Collection | None: + ) -> Optional[Collection]: """Get collection from given collection name. Used during retrieval""" logger.debug(f"[Metadata Store] Getting collection with name {collection_name}") ml_run = self._get_run_by_name(run_name=collection_name, no_cache=no_cache) @@ -345,7 +345,7 @@ def create_data_source(self, data_source: CreateDataSource) -> DataSource: ) return created_data_source - def get_data_source_from_fqn(self, fqn: str) -> DataSource | None: + def get_data_source_from_fqn(self, fqn: str) -> Optional[DataSource]: logger.debug(f"[Metadata Store] Getting data_source by fqn {fqn}") runs = self.client.search_runs( ml_repo=self.ml_repo_name, @@ -412,7 +412,7 @@ def create_data_ingestion_run( def get_data_ingestion_run( self, data_ingestion_run_name: str, no_cache: bool = False - ) -> DataIngestionRun | None: + ) -> Optional[DataIngestionRun]: logger.debug( f"[Metadata Store] Getting ingestion run {data_ingestion_run_name}" ) @@ -518,7 +518,7 @@ def update_data_ingestion_run_status( def log_metrics_for_data_ingestion_run( self, data_ingestion_run_name: str, - metric_dict: dict[str, int | float], + metric_dict: Dict[str, Union[int, float]], step: int = 0, ): try: @@ -572,7 +572,7 @@ def list_collections(self) -> List[str]: ) return [run.run_name for run in ml_runs] - def list_data_sources(self) -> List[dict[str, str]]: + def list_data_sources(self) -> List[Dict[str, str]]: logger.info(f"[Metadata Store] Listing all data sources") ml_runs = self.client.search_runs( ml_repo=self.ml_repo_name, @@ -637,7 +637,7 @@ def create_rag_app(self, app: RagApplication) -> RagApplicationDto: logger.debug(f"[Metadata Store] RAG Application Saved") return created_app - def get_rag_app(self, app_name: str) -> RagApplicationDto | None: + def get_rag_app(self, app_name: str) -> Optional[RagApplicationDto]: """ Get a RAG application from the metadata store """ diff --git a/backend/modules/query_controllers/example/types.py b/backend/modules/query_controllers/example/types.py index 98791629..b232ccc7 100644 --- a/backend/modules/query_controllers/example/types.py +++ b/backend/modules/query_controllers/example/types.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Collection, Dict, Optional +from typing import Any, ClassVar, Dict, Optional, Sequence from pydantic import BaseModel, Field, model_validator from qdrant_client.models import Filter as QdrantFilter @@ -29,7 +29,7 @@ class VectorStoreRetrieverConfig(BaseModel): title="""Filter by document metadata""", ) - allowed_search_types: ClassVar[Collection[str]] = ( + allowed_search_types: ClassVar[Sequence[str]] = ( "similarity", "similarity_score_threshold", "mmr", @@ -82,7 +82,7 @@ class ContextualCompressionRetrieverConfig(VectorStoreRetrieverConfig): title="Top K docs to collect post compression", ) - allowed_compressor_model_providers: ClassVar[Collection[str]] + allowed_compressor_model_providers: ClassVar[Sequence[str]] class ContextualCompressionMultiQueryRetrieverConfig( @@ -94,12 +94,12 @@ class ContextualCompressionMultiQueryRetrieverConfig( class ExampleQueryInput(BaseModel): """ Model for Query input. - Requires a collection name, retriever configuration, query, LLM configuration and prompt template. + Requires a Sequence name, retriever configuration, query, LLM configuration and prompt template. """ - collection_name: str = Field( + Sequence_name: str = Field( default=None, - title="Collection name on which to search", + title="Sequence name on which to search", ) query: str = Field(title="Question to search for") @@ -118,14 +118,14 @@ class ExampleQueryInput(BaseModel): title="Retriever configuration", ) - allowed_retriever_types: ClassVar[Collection[str]] = ( + allowed_retriever_types: ClassVar[Sequence[str]] = ( "vectorstore", "multi-query", "contextual-compression", "contextual-compression-multi-query", ) - stream: Optional[bool] = Field(title="Stream the results", default=False) + stream: bool = Field(title="Stream the results", default=False) @model_validator(mode="before") @classmethod diff --git a/backend/modules/query_controllers/multimodal/types.py b/backend/modules/query_controllers/multimodal/types.py index 98791629..b232ccc7 100644 --- a/backend/modules/query_controllers/multimodal/types.py +++ b/backend/modules/query_controllers/multimodal/types.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Collection, Dict, Optional +from typing import Any, ClassVar, Dict, Optional, Sequence from pydantic import BaseModel, Field, model_validator from qdrant_client.models import Filter as QdrantFilter @@ -29,7 +29,7 @@ class VectorStoreRetrieverConfig(BaseModel): title="""Filter by document metadata""", ) - allowed_search_types: ClassVar[Collection[str]] = ( + allowed_search_types: ClassVar[Sequence[str]] = ( "similarity", "similarity_score_threshold", "mmr", @@ -82,7 +82,7 @@ class ContextualCompressionRetrieverConfig(VectorStoreRetrieverConfig): title="Top K docs to collect post compression", ) - allowed_compressor_model_providers: ClassVar[Collection[str]] + allowed_compressor_model_providers: ClassVar[Sequence[str]] class ContextualCompressionMultiQueryRetrieverConfig( @@ -94,12 +94,12 @@ class ContextualCompressionMultiQueryRetrieverConfig( class ExampleQueryInput(BaseModel): """ Model for Query input. - Requires a collection name, retriever configuration, query, LLM configuration and prompt template. + Requires a Sequence name, retriever configuration, query, LLM configuration and prompt template. """ - collection_name: str = Field( + Sequence_name: str = Field( default=None, - title="Collection name on which to search", + title="Sequence name on which to search", ) query: str = Field(title="Question to search for") @@ -118,14 +118,14 @@ class ExampleQueryInput(BaseModel): title="Retriever configuration", ) - allowed_retriever_types: ClassVar[Collection[str]] = ( + allowed_retriever_types: ClassVar[Sequence[str]] = ( "vectorstore", "multi-query", "contextual-compression", "contextual-compression-multi-query", ) - stream: Optional[bool] = Field(title="Stream the results", default=False) + stream: bool = Field(title="Stream the results", default=False) @model_validator(mode="before") @classmethod diff --git a/backend/types.py b/backend/types.py index adbecdb3..a341a66c 100644 --- a/backend/types.py +++ b/backend/types.py @@ -1,7 +1,7 @@ import enum import uuid from enum import Enum -from typing import Any, Literal, Optional +from typing import Any, Dict, List, Literal, Optional from pydantic import BaseModel, ConfigDict, Field, StringConstraints from typing_extensions import Annotated @@ -27,7 +27,7 @@ class DataPoint(BaseModel): data_point_fqn (str): Fully qualified name for the data point with respect to the data source data_point_uri (str): URI for the data point for given data source. It could be url, file path or any other identifier data_point_hash (str): Hash of the data point for the given data source that is guaranteed to be updated for any update in data point at source - metadata (Optional[dict[str, str]]): Additional metadata for the data point + metadata (Optional[Dict[str, str]]): Additional metadata for the data point """ data_source_fqn: str = Field( @@ -42,7 +42,7 @@ class DataPoint(BaseModel): title="Hash of the data point for the given data source that is guaranteed to be updated for any update in data point at source", ) - metadata: Optional[dict[str, Any]] = Field( + metadata: Optional[Dict[str, Any]] = Field( None, title="Additional metadata for the data point", ) @@ -107,7 +107,7 @@ class ModelType(str, Enum): class ModelConfig(BaseModel): name: str type: ModelType - parameters: dict[str, Any] = Field(default_factory=dict) + parameters: Dict[str, Any] = Field(default_factory=dict) class ModelProviderConfig(BaseModel): @@ -115,10 +115,10 @@ class ModelProviderConfig(BaseModel): api_format: str base_url: Optional[str] = None api_key_env_var: str - default_headers: dict[str, str] = Field(default_factory=dict) - llm_model_ids: list[str] = Field(default_factory=list) - embedding_model_ids: list[str] = Field(default_factory=list) - reranking_model_ids: list[str] = Field(default_factory=list) + default_headers: Dict[str, str] = Field(default_factory=dict) + llm_model_ids: List[str] = Field(default_factory=list) + embedding_model_ids: List[str] = Field(default_factory=list) + reranking_model_ids: List[str] = Field(default_factory=list) class EmbedderConfig(BaseModel): @@ -128,7 +128,7 @@ class EmbedderConfig(BaseModel): # TODO (chiragjn): Pydantic v2 does not like fields that begin with model_* model_config: ModelConfig - config: Optional[dict[str, Any]] = Field( + config: Optional[Dict[str, Any]] = Field( title="Configuration for the embedder", default_factory=dict ) @@ -140,11 +140,11 @@ class ParserConfig(BaseModel): chunk_size: int = Field(title="Chunk Size for data parsing", ge=1, default=1000) chunk_overlap: int = Field(title="Chunk Overlap for indexing", ge=0, default=20) - parser_map: dict[str, str] = Field( + parser_map: Dict[str, str] = Field( title="Mapping of file extensions to parsers", default_factory=dict, ) - additional_config: Optional[dict[str, Any]] = Field( + additional_config: Optional[Dict[str, Any]] = Field( title="Additional optional configuration for the parser", default_factory=dict, ) @@ -293,7 +293,7 @@ class BaseDataSource(BaseModel): uri: str = Field( title="A unique identifier for the data source", ) - metadata: Optional[dict[str, Any]] = Field( + metadata: Optional[Dict[str, Any]] = Field( None, title="Additional config for your data source" ) @@ -426,19 +426,19 @@ class CreateCollection(BaseCollection): class Collection(BaseCollection): - associated_data_sources: dict[str, AssociatedDataSources] = Field( + associated_data_sources: Dict[str, AssociatedDataSources] = Field( title="Data sources associated with the collection", default_factory=dict ) class CreateCollectionDto(CreateCollection): - associated_data_sources: Optional[list[AssociateDataSourceWithCollection]] = Field( + associated_data_sources: Optional[List[AssociateDataSourceWithCollection]] = Field( None, title="Data sources associated with the collection" ) class UploadToDataDirectoryDto(BaseModel): - filepaths: list[str] + filepaths: List[str] # allow only small case alphanumeric and hyphen, should contain at least one alphabet and begin with alphabet upload_name: str = Field( title="Name of the upload", @@ -461,7 +461,7 @@ class RagApplication(BaseModel): title="Name of the rag app", regex=r"^[a-z][a-z0-9-]*$", # allow only small case alphanumeric and hyphen, should contain at least one alphabet and begin with alphabet ) - config: dict[str, Any] = Field( + config: Dict[str, Any] = Field( title="Configuration for the rag app", ) From b870dd60e524cc521698b10e54af6ac31740e456 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Wed, 31 Jul 2024 03:02:38 +0530 Subject: [PATCH 08/16] Add todo for fields to fix --- backend/modules/dataloaders/loader.py | 2 +- backend/modules/parsers/multimodalparser.py | 4 ++-- backend/modules/parsers/parser.py | 9 ++++----- backend/modules/query_controllers/example/types.py | 3 ++- backend/modules/query_controllers/multimodal/types.py | 3 ++- backend/modules/query_controllers/query_controller.py | 2 +- backend/modules/vector_db/singlestore.py | 6 +++--- backend/modules/vector_db/weaviate.py | 4 ++-- backend/types.py | 10 ++++++---- 9 files changed, 23 insertions(+), 20 deletions(-) diff --git a/backend/modules/dataloaders/loader.py b/backend/modules/dataloaders/loader.py index 656c4ed0..336ebf8f 100644 --- a/backend/modules/dataloaders/loader.py +++ b/backend/modules/dataloaders/loader.py @@ -128,7 +128,7 @@ def list_dataloaders(): Returns a list of all the registered loaders. Returns: - List[dict]: A list of all the registered loaders. + List[Dict]: A list of all the registered loaders. """ global LOADER_REGISTRY return [ diff --git a/backend/modules/parsers/multimodalparser.py b/backend/modules/parsers/multimodalparser.py index bacc0a4e..f36834a5 100644 --- a/backend/modules/parsers/multimodalparser.py +++ b/backend/modules/parsers/multimodalparser.py @@ -3,7 +3,7 @@ import io import os from itertools import islice -from typing import Optional +from typing import Any, Dict, Optional import cv2 import fitz @@ -137,7 +137,7 @@ async def call_vlm_agent( return {"error": f"Error in page: {page_number}"} async def get_chunks( - self, filepath: str, metadata: Optional[dict] = None, *args, **kwargs + self, filepath: str, metadata: Optional[Dict[Any, Any]] = None, *args, **kwargs ): """ Asynchronously extracts text from a PDF file and returns it in chunks. diff --git a/backend/modules/parsers/parser.py b/backend/modules/parsers/parser.py index a4d0adbe..b81468e4 100644 --- a/backend/modules/parsers/parser.py +++ b/backend/modules/parsers/parser.py @@ -1,7 +1,6 @@ -import typing from abc import ABC, abstractmethod from collections import defaultdict -from typing import Optional +from typing import Any, Dict, List, Optional from langchain.docstore.document import Document @@ -39,10 +38,10 @@ def __init__(self, *args, **kwargs): async def get_chunks( self, filepath: str, - metadata: Optional[dict], + metadata: Optional[Dict[Any, Any]], *args, **kwargs, - ) -> typing.List[Document]: + ) -> List[Document]: """ Abstract method. This should asynchronously read a file and return its content in chunks. @@ -99,7 +98,7 @@ def list_parsers(): Returns a list of all the registered parsers. Returns: - List[dict]: A list of all the registered parsers. + List[Dict]: A list of all the registered parsers. """ global PARSER_REGISTRY return [ diff --git a/backend/modules/query_controllers/example/types.py b/backend/modules/query_controllers/example/types.py index b232ccc7..06c39159 100644 --- a/backend/modules/query_controllers/example/types.py +++ b/backend/modules/query_controllers/example/types.py @@ -24,7 +24,7 @@ class VectorStoreRetrieverConfig(BaseModel): search_kwargs: dict = Field(default_factory=dict) - filter: Optional[dict] = Field( + filter: Optional[Dict[Any, Any]] = Field( default_factory=dict, title="""Filter by document metadata""", ) @@ -104,6 +104,7 @@ class ExampleQueryInput(BaseModel): query: str = Field(title="Question to search for") + # TODO (chiragjn): pydantic v2 does not like fields that start with model_ model_configuration: ModelConfig prompt_template: str = Field( diff --git a/backend/modules/query_controllers/multimodal/types.py b/backend/modules/query_controllers/multimodal/types.py index b232ccc7..06c39159 100644 --- a/backend/modules/query_controllers/multimodal/types.py +++ b/backend/modules/query_controllers/multimodal/types.py @@ -24,7 +24,7 @@ class VectorStoreRetrieverConfig(BaseModel): search_kwargs: dict = Field(default_factory=dict) - filter: Optional[dict] = Field( + filter: Optional[Dict[Any, Any]] = Field( default_factory=dict, title="""Filter by document metadata""", ) @@ -104,6 +104,7 @@ class ExampleQueryInput(BaseModel): query: str = Field(title="Question to search for") + # TODO (chiragjn): pydantic v2 does not like fields that start with model_ model_configuration: ModelConfig prompt_template: str = Field( diff --git a/backend/modules/query_controllers/query_controller.py b/backend/modules/query_controllers/query_controller.py index fd21aa7f..e81d8a1e 100644 --- a/backend/modules/query_controllers/query_controller.py +++ b/backend/modules/query_controllers/query_controller.py @@ -18,7 +18,7 @@ def list_query_controllers(): Returns a list of all the registered query controllers. Returns: - List[dict]: A list of all the registered query controllers. + List[Dict]: A list of all the registered query controllers. """ global QUERY_CONTROLLER_REGISTRY return [ diff --git a/backend/modules/vector_db/singlestore.py b/backend/modules/vector_db/singlestore.py index fc1fba87..2b95c3b4 100644 --- a/backend/modules/vector_db/singlestore.py +++ b/backend/modules/vector_db/singlestore.py @@ -1,5 +1,5 @@ import json -from typing import Any, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional import singlestoredb as s2 from langchain.docstore.document import Document @@ -63,7 +63,7 @@ def _create_table(self: SingleStoreDB) -> None: def add_texts( self, texts: Iterable[str], - metadatas: Optional[List[dict]] = None, + metadatas: Optional[List[Dict[Any, Any]]] = None, embeddings: Optional[List[List[float]]] = None, **kwargs: Any, ) -> List[str]: @@ -71,7 +71,7 @@ def add_texts( Args: texts (Iterable[str]): Iterable of strings/text to add to the vectorstore. - metadatas (Optional[List[dict]], optional): Optional list of metadatas. + metadatas (Optional[List[Dict]], optional): Optional list of metadatas. Defaults to None. embeddings (Optional[List[List[float]]], optional): Optional pre-generated embeddings. Defaults to None. diff --git a/backend/modules/vector_db/weaviate.py b/backend/modules/vector_db/weaviate.py index 5b3986de..2aad43e5 100644 --- a/backend/modules/vector_db/weaviate.py +++ b/backend/modules/vector_db/weaviate.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, Dict, List import weaviate from langchain.embeddings.base import Embeddings @@ -99,7 +99,7 @@ def list_documents_in_collection( .with_fields("groupedBy { value }") .do() ) - groups: List[dict] = ( + groups: List[Dict[Any, Any]] = ( response.get("data", {}) .get("Aggregate", {}) .get(collection_name.capitalize(), []) diff --git a/backend/types.py b/backend/types.py index a341a66c..2846b010 100644 --- a/backend/types.py +++ b/backend/types.py @@ -8,6 +8,8 @@ from backend.constants import FQN_SEPARATOR +# TODO (chiragjn): Remove Optional from Dict and List type fields. Instead just use a default_factory + class DataIngestionMode(str, Enum): """ @@ -126,7 +128,7 @@ class EmbedderConfig(BaseModel): Embedder configuration """ - # TODO (chiragjn): Pydantic v2 does not like fields that begin with model_* + # TODO (chiragjn): pydantic v2 does not like fields that start with model_ model_config: ModelConfig config: Optional[Dict[str, Any]] = Field( title="Configuration for the embedder", default_factory=dict @@ -159,7 +161,7 @@ class VectorDBConfig(BaseModel): local: bool = False url: Optional[str] = None api_key: Optional[str] = None - config: Optional[dict] = Field(default_factory=dict) + config: Optional[Dict[str, Any]] = Field(default_factory=dict) class QdrantClientConfig(BaseModel): @@ -182,7 +184,7 @@ class MetadataStoreConfig(BaseModel): """ provider: str - config: Optional[dict] = Field(default_factory=dict) + config: Optional[Dict[str, Any]] = Field(default_factory=dict) class RetrieverConfig(BaseModel): @@ -203,7 +205,7 @@ class RetrieverConfig(BaseModel): default=20, title="""Amount of documents to pass to MMR algorithm (Default: 20)""", ) - filter: Optional[dict] = Field( + filter: Optional[Dict[Any, Any]] = Field( default=None, title="""Filter by document metadata""", ) From 1f5d0b73838006739ad710cff3a4503da5375d37 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Wed, 31 Jul 2024 03:31:30 +0530 Subject: [PATCH 09/16] Add exception if pydantic v2 sends us obj instead of dict when validating --- .../query_controllers/example/types.py | 16 ++++++- .../query_controllers/multimodal/types.py | 16 ++++++- backend/server/routers/internal.py | 2 +- backend/settings.py | 48 +++++++++---------- 4 files changed, 53 insertions(+), 29 deletions(-) diff --git a/backend/modules/query_controllers/example/types.py b/backend/modules/query_controllers/example/types.py index 06c39159..11850a22 100644 --- a/backend/modules/query_controllers/example/types.py +++ b/backend/modules/query_controllers/example/types.py @@ -7,6 +7,8 @@ GENERATION_TIMEOUT_SEC = 60.0 * 10 +# TODO (chiragjn): Remove all asserts and replace them with proper raises + class VectorStoreRetrieverConfig(BaseModel): """ @@ -37,8 +39,13 @@ class VectorStoreRetrieverConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_search_type(cls, values: Dict) -> Dict: + def validate_search_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate search type.""" + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) + search_type = values.get("search_type") assert ( @@ -130,7 +137,12 @@ class ExampleQueryInput(BaseModel): @model_validator(mode="before") @classmethod - def validate_retriever_type(cls, values: Dict) -> Dict: + def validate_retriever_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) + retriever_name = values.get("retriever_name") assert ( diff --git a/backend/modules/query_controllers/multimodal/types.py b/backend/modules/query_controllers/multimodal/types.py index 06c39159..11850a22 100644 --- a/backend/modules/query_controllers/multimodal/types.py +++ b/backend/modules/query_controllers/multimodal/types.py @@ -7,6 +7,8 @@ GENERATION_TIMEOUT_SEC = 60.0 * 10 +# TODO (chiragjn): Remove all asserts and replace them with proper raises + class VectorStoreRetrieverConfig(BaseModel): """ @@ -37,8 +39,13 @@ class VectorStoreRetrieverConfig(BaseModel): @model_validator(mode="before") @classmethod - def validate_search_type(cls, values: Dict) -> Dict: + def validate_search_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate search type.""" + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) + search_type = values.get("search_type") assert ( @@ -130,7 +137,12 @@ class ExampleQueryInput(BaseModel): @model_validator(mode="before") @classmethod - def validate_retriever_type(cls, values: Dict) -> Dict: + def validate_retriever_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) + retriever_name = values.get("retriever_name") assert ( diff --git a/backend/server/routers/internal.py b/backend/server/routers/internal.py index 6c555f97..6d152c42 100644 --- a/backend/server/routers/internal.py +++ b/backend/server/routers/internal.py @@ -93,7 +93,7 @@ async def upload_to_data_directory(req: UploadToDataDirectoryDto): content={"data": data, "data_directory_fqn": dataset.fqn}, ) except Exception as ex: - raise Exception(f"Error uploading files to data directory: {ex}") + raise Exception(f"Error uploading files to data directory: {ex}") from ex @router.get("/models") diff --git a/backend/settings.py b/backend/settings.py index c707bc4f..a9943bee 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -1,5 +1,5 @@ import os -from typing import Any +from typing import Any, Dict from pydantic import ConfigDict, model_validator from pydantic_settings import BaseSettings @@ -33,33 +33,33 @@ class Settings(BaseSettings): @model_validator(mode="before") @classmethod - def _validate_values(cls, values: Any) -> Any: - if isinstance(values, dict): - models_config_path = values.get("MODELS_CONFIG_PATH") - if not os.path.isabs(models_config_path): - this_dir = os.path.abspath(os.path.dirname(__file__)) - root_dir = os.path.dirname(this_dir) - models_config_path = os.path.join(root_dir, models_config_path) - - if not models_config_path: - raise Exception( - f"{models_config_path} does not exist. " - f"You can copy models_config.sample.yaml to {settings.MODELS_CONFIG_PATH} to bootstrap config" - ) + def _validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate search type.""" + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) - values["MODELS_CONFIG_PATH"] = models_config_path + models_config_path = values.get("MODELS_CONFIG_PATH") + if not os.path.isabs(models_config_path): + this_dir = os.path.abspath(os.path.dirname(__file__)) + root_dir = os.path.dirname(this_dir) + models_config_path = os.path.join(root_dir, models_config_path) - tfy_host = values.get("TFY_HOST") - tfy_llm_gateway_url = values.get("TFY_LLM_GATEWAY_URL") - if tfy_host and not tfy_llm_gateway_url: - tfy_llm_gateway_url = f"{tfy_host.rstrip('/')}/api/llm" - values["TFY_LLM_GATEWAY_URL"] = tfy_llm_gateway_url - else: - logger.warning( - f"[Validation Skipped] Pydantic v2 validator received " - f"non dict values of type {type(values)}" + if not models_config_path: + raise ValueError( + f"{models_config_path} does not exist. " + f"You can copy models_config.sample.yaml to {settings.MODELS_CONFIG_PATH} to bootstrap config" ) + values["MODELS_CONFIG_PATH"] = models_config_path + + tfy_host = values.get("TFY_HOST") + tfy_llm_gateway_url = values.get("TFY_LLM_GATEWAY_URL") + if tfy_host and not tfy_llm_gateway_url: + tfy_llm_gateway_url = f"{tfy_host.rstrip('/')}/api/llm" + values["TFY_LLM_GATEWAY_URL"] = tfy_llm_gateway_url + return values From 02269c9e035be80f1b0891c4cab033306b04f0ae Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Wed, 31 Jul 2024 04:06:11 +0530 Subject: [PATCH 10/16] Fix regex type constraint and rename model_config to embedding_model_config --- backend/indexer/indexer.py | 2 +- .../query_controllers/example/controller.py | 2 +- .../multimodal/controller.py | 2 +- backend/server/routers/collection.py | 2 +- backend/settings.py | 1 - backend/types.py | 28 +++++++++++-------- 6 files changed, 21 insertions(+), 16 deletions(-) diff --git a/backend/indexer/indexer.py b/backend/indexer/indexer.py index 104a2f6a..2effb2ef 100644 --- a/backend/indexer/indexer.py +++ b/backend/indexer/indexer.py @@ -219,7 +219,7 @@ async def ingest_data_points( """ embeddings = model_gateway.get_embedder_from_model_config( - model_name=inputs.embedder_config.model_config.name + model_name=inputs.embedder_config.embedding_model_config.name ) documents_to_be_upserted = [] logger.info( diff --git a/backend/modules/query_controllers/example/controller.py b/backend/modules/query_controllers/example/controller.py index 23594109..1fed4f3e 100644 --- a/backend/modules/query_controllers/example/controller.py +++ b/backend/modules/query_controllers/example/controller.py @@ -83,7 +83,7 @@ async def _get_vector_store(self, collection_name: str): return VECTOR_STORE_CLIENT.get_vector_store( collection_name=collection.name, embeddings=model_gateway.get_embedder_from_model_config( - model_name=collection.embedder_config.model_config.name + model_name=collection.embedder_config.embedding_model_config.name ), ) diff --git a/backend/modules/query_controllers/multimodal/controller.py b/backend/modules/query_controllers/multimodal/controller.py index 3f2239a2..01fae4dc 100644 --- a/backend/modules/query_controllers/multimodal/controller.py +++ b/backend/modules/query_controllers/multimodal/controller.py @@ -82,7 +82,7 @@ async def _get_vector_store(self, collection_name: str): return VECTOR_STORE_CLIENT.get_vector_store( collection_name=collection.name, embeddings=model_gateway.get_embedder_from_model_config( - model_name=collection.embedder_config.model_config.name + model_name=collection.embedder_config.embedding_model_config.name ), ) diff --git a/backend/server/routers/collection.py b/backend/server/routers/collection.py index 3cd8684a..dfd0a852 100644 --- a/backend/server/routers/collection.py +++ b/backend/server/routers/collection.py @@ -80,7 +80,7 @@ async def create_collection(collection: CreateCollectionDto): VECTOR_STORE_CLIENT.create_collection( collection_name=collection.name, embeddings=model_gateway.get_embedder_from_model_config( - model_name=collection.embedder_config.model_config.name + model_name=collection.embedder_config.embedding_model_config.name ), ) logger.info(f"Created collection... {created_collection}") diff --git a/backend/settings.py b/backend/settings.py index a9943bee..fbbd1fb9 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -4,7 +4,6 @@ from pydantic import ConfigDict, model_validator from pydantic_settings import BaseSettings -from backend.logger import logger from backend.types import MetadataStoreConfig, VectorDBConfig diff --git a/backend/types.py b/backend/types.py index 2846b010..5db777c9 100644 --- a/backend/types.py +++ b/backend/types.py @@ -128,8 +128,10 @@ class EmbedderConfig(BaseModel): Embedder configuration """ - # TODO (chiragjn): pydantic v2 does not like fields that start with model_ - model_config: ModelConfig + # This field will probably be removed soon or refactored + embedding_model_config: ModelConfig = Field( + validation_alias="model_config", serialization_alias="model_config" + ) config: Optional[Dict[str, Any]] = Field( title="Configuration for the embedder", default_factory=dict ) @@ -218,11 +220,12 @@ def get_search_type(self) -> str: @property def get_search_kwargs(self) -> dict: # Check at langchain.schema.vectorstore.VectorStore.as_retriever - match self.search_type: - case "similarity": - return {"k": self.k, "filter": self.filter} - case "mmr": - return {"k": self.k, "fetch_k": self.fetch_k, "filter": self.filter} + if self.search_type == "similarity": + return {"k": self.k, "filter": self.filter} + elif self.search_type == "mmr": + return {"k": self.k, "fetch_k": self.fetch_k, "filter": self.filter} + else: + raise ValueError(f"Search type {self.search_type} is not supported") class DataIngestionRunStatus(str, enum.Enum): @@ -442,9 +445,10 @@ class CreateCollectionDto(CreateCollection): class UploadToDataDirectoryDto(BaseModel): filepaths: List[str] # allow only small case alphanumeric and hyphen, should contain at least one alphabet and begin with alphabet - upload_name: str = Field( + upload_name: Annotated[ + str, StringConstraints(pattern=r"^[a-z][a-z0-9-]*$") + ] = Field( # type:ignore title="Name of the upload", - pattern=r"^[a-z][a-z0-9-]*$", default=str(uuid.uuid4()), ) @@ -459,9 +463,11 @@ class ListDataIngestionRunsDto(BaseModel): class RagApplication(BaseModel): - name: str = Field( + # allow only small case alphanumeric and hyphen, should contain at least one alphabet and begin with alphabet + name: Annotated[ + str, StringConstraints(pattern=r"^[a-z][a-z0-9-]*$") + ] = Field( # type:ignore title="Name of the rag app", - regex=r"^[a-z][a-z0-9-]*$", # allow only small case alphanumeric and hyphen, should contain at least one alphabet and begin with alphabet ) config: Dict[str, Any] = Field( title="Configuration for the rag app", From c7b532856fb8eed691d6542bf11ba9dfc1994190 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Wed, 31 Jul 2024 17:35:46 +0530 Subject: [PATCH 11/16] Refactor primsa store to work with pydantic v2 --- backend/database/schema.prisma | 1 + backend/modules/metadata_store/base.py | 11 +- backend/modules/metadata_store/prismastore.py | 256 +++++++++++------- .../query_controllers/example/types.py | 35 ++- .../query_controllers/multimodal/types.py | 24 +- backend/types.py | 35 ++- .../dashboard/docsqa/NewCollection.tsx | 2 + 7 files changed, 242 insertions(+), 122 deletions(-) diff --git a/backend/database/schema.prisma b/backend/database/schema.prisma index 3556cbe1..08ec4618 100644 --- a/backend/database/schema.prisma +++ b/backend/database/schema.prisma @@ -24,6 +24,7 @@ model Collection { description String? embedder_config Json // Collection can have multiple data sources + // TODO (chiragjn): Why does this have to be Nullable? Default to {} associated_data_sources Json? @@map("collections") diff --git a/backend/modules/metadata_store/base.py b/backend/modules/metadata_store/base.py index 5e568b64..534e36ab 100644 --- a/backend/modules/metadata_store/base.py +++ b/backend/modules/metadata_store/base.py @@ -346,6 +346,8 @@ async def alog_errors_for_data_ingestion_run( errors=errors, ) + # TODO (chiragjn): What is the difference between get_collections and this? + # TODO (chiragjn): Return complete entities, why return only str? def list_collections( self, ) -> List[str]: @@ -354,6 +356,7 @@ def list_collections( """ raise NotImplementedError() + # TODO (chiragjn): Return complete entities, why return only str? async def alist_collections( self, ) -> List[str]: @@ -365,17 +368,19 @@ async def alist_collections( self.list_collections, ) + # TODO (chiragjn): Return complete entities, why return dict? def list_data_sources( self, - ) -> List[str]: + ) -> List[Dict[str, str]]: """ List all data source names from metadata store """ raise NotImplementedError() + # TODO (chiragjn): Return complete entities, why return dict? async def alist_data_sources( self, - ) -> List[str]: + ) -> List[Dict[str, str]]: """ List all data source names from metadata store """ @@ -422,12 +427,14 @@ async def aget_rag_app(self, app_name: str) -> Optional[RagApplicationDto]: """ return await run_in_executor(None, self.get_rag_app, app_name=app_name) + # TODO (chiragjn): Return complete entities, why return only str? def list_rag_apps(self) -> List[str]: """ List all RAG application names from metadata store """ raise NotImplementedError() + # TODO (chiragjn): Return complete entities, why return only str? async def alist_rag_apps(self) -> List[str]: """ List all RAG application names from metadata store diff --git a/backend/modules/metadata_store/prismastore.py b/backend/modules/metadata_store/prismastore.py index 0682b324..5727e999 100644 --- a/backend/modules/metadata_store/prismastore.py +++ b/backend/modules/metadata_store/prismastore.py @@ -4,7 +4,7 @@ import random import shutil import string -from typing import Any, Dict, List, Optional, Union +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from fastapi import HTTPException from prisma import Prisma @@ -23,17 +23,26 @@ DataIngestionRunStatus, DataSource, RagApplication, - RagApplicationDto, ) +if TYPE_CHECKING: + # TODO (chiragjn): Can we import these safely even if the prisma client might not be generated yet? + from prisma.models import Collection as PrismaCollection + from prisma.models import DataSource as PrismaDataSource + from prisma.models import IngestionRuns as PrismaDataIngestionRun + from prisma.models import RagApps as PrismaRagApplication + # TODO (chiragjn): -# 1. Either we make everything async or add sync method to this -# 2. Some methods are using json.dumps - not sure if this is the right way to send data via prisma client -# 3. primsa generates its own DB entity classes - ideally we should be using those instead of call -# .model_dump() on the pydantic objects +# - Use transactions! +# - Some methods are using json.dumps - not sure if this is the right way to send data via prisma client +# - primsa generates its own DB entity classes - ideally we should be using those instead of call +# .model_dump() on the pydantic objects. See prisma.models and prisma.actions +# # TODO (chiragjn): Either we make everything async or add sync method to this + + class PrismaStore(BaseMetadataStore): def __init__(self, *args, db, **kwargs) -> None: self.db = db @@ -54,6 +63,22 @@ async def aconnect(cls, **kwargs): # COLLECTIONS APIS ###### + async def aget_collection_by_name( + self, collection_name: str, no_cache: bool = True + ) -> Optional[Collection]: + try: + collection: Optional[ + "PrismaCollection" + ] = await self.db.collection.find_first(where={"name": collection_name}) + if collection: + return Collection.model_validate(collection.model_dump()) + return None + except Exception as e: + logger.exception(f"Failed to get collection by name: {e}") + raise HTTPException( + status_code=500, detail="Failed to get collection by name" + ) + async def acreate_collection(self, collection: CreateCollection) -> Collection: try: existing_collection = await self.aget_collection_by_name(collection.name) @@ -74,37 +99,26 @@ async def acreate_collection(self, collection: CreateCollection) -> Collection: collection_data["embedder_config"] = json.dumps( collection_data["embedder_config"] ) - collection = await self.db.collection.create(data=collection_data) - return collection + collection: "PrismaCollection" = await self.db.collection.create( + data=collection_data + ) + return Collection.model_validate(collection.model_dump()) except Exception as e: logger.exception(f"Error: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") - async def aget_collection_by_name( - self, collection_name: str, no_cache: bool = True - ) -> Optional[Collection]: - try: - collection = await self.db.collection.find_first( - where={"name": collection_name} - ) - if collection: - return collection - return None - except Exception as e: - logger.exception(f"Failed to get collection by name: {e}") - raise HTTPException( - status_code=500, detail="Failed to get collection by name" - ) - async def aget_retrieve_collection_by_name( self, collection_name: str, no_cache: bool = True ) -> Optional[Collection]: - return await self.aget_collection_by_name(collection_name, no_cache) + collection: "PrismaCollection" = await self.aget_collection_by_name( + collection_name, no_cache + ) + return Collection.model_validate(collection.model_dump()) async def aget_collections(self) -> List[Collection]: try: - collections = await self.db.collection.find_many() - return collections + collections: List["PrismaCollection"] = await self.db.collection.find_many() + return [Collection.model_validate(c.model_dump()) for c in collections] except Exception as e: logger.exception(f"Failed to get collections: {e}") raise HTTPException(status_code=500, detail="Failed to get collections") @@ -119,17 +133,17 @@ async def alist_collections(self) -> List[str]: async def adelete_collection(self, collection_name: str, include_runs=False): try: - collection = await self.aget_collection_by_name(collection_name) - if not collection: - logger.debug(f"Collection with name {collection_name} does not exist") - except Exception as e: - logger.exception(e) - - try: - await self.db.collection.delete(where={"name": collection_name}) + deleted_collection: Optional[ + "PrismaCollection" + ] = await self.db.collection.delete(where={"name": collection_name}) + if not deleted_collection: + raise HTTPException( + status_code=404, + detail=f"Failed to delete collection {collection_name!r}. No such record found", + ) if include_runs: try: - await self.db.ingestionruns.delete_many( + _deleted_count = await self.db.ingestionruns.delete_many( where={"collection_name": collection_name} ) except Exception as e: @@ -141,6 +155,18 @@ async def adelete_collection(self, collection_name: str, include_runs=False): ###### # DATA SOURCE APIS ###### + async def aget_data_source_from_fqn(self, fqn: str) -> Optional[DataSource]: + try: + data_source: Optional[ + "PrismaDataSource" + ] = await self.db.datasource.find_first(where={"fqn": fqn}) + if data_source: + return DataSource.model_validate(data_source.model_dump()) + return None + except Exception as e: + logger.exception(f"Error: {e}") + raise HTTPException(status_code=500, detail=f"Error: {e}") + async def acreate_data_source(self, data_source: CreateDataSource) -> DataSource: try: existing_data_source = await self.aget_data_source_from_fqn(data_source.fqn) @@ -158,27 +184,19 @@ async def acreate_data_source(self, data_source: CreateDataSource) -> DataSource try: data = data_source.model_dump() data["metadata"] = json.dumps(data["metadata"]) - data_source = await self.db.datasource.create(data) + data_source: "PrismaDataSource" = await self.db.datasource.create(data) logger.info(f"Created data source: {data_source}") - return data_source + return DataSource.model_validate(data_source.model_dump()) except Exception as e: logger.exception(f"Failed to create data source: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") - async def aget_data_source_from_fqn(self, fqn: str) -> Optional[DataSource]: - try: - data_source = await self.db.datasource.find_first(where={"fqn": fqn}) - if data_source: - return data_source - return None - except Exception as e: - logger.exception(f"Error: {e}") - raise HTTPException(status_code=500, detail=f"Error: {e}") - async def aget_data_sources(self) -> List[DataSource]: try: - data_sources = await self.db.datasource.find_many() - return data_sources + data_sources: List[ + "PrismaDataSource" + ] = await self.db.datasource.find_many() + return [DataSource.model_validate(ds.model_dump()) for ds in data_sources] except Exception as e: logger.exception(f"Error: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") @@ -250,11 +268,19 @@ async def aassociate_data_source_with_collection( ) in existing_collection_associated_data_sources.items(): associated_data_sources[data_source_fqn] = data_source.model_dump() - updated_collection = await self.db.collection.update( + updated_collection: Optional[ + "PrismaCollection" + ] = await self.db.collection.update( where={"name": collection_name}, data={"associated_data_sources": json.dumps(associated_data_sources)}, ) - return updated_collection + if not updated_collection: + raise HTTPException( + status_code=404, + detail=f"Failed to associate data source with collection {collection_name!r}. " + f"No such record found", + ) + return Collection.model_validate(updated_collection.model_dump()) except Exception as e: logger.exception(f"Error: {e}") @@ -312,11 +338,19 @@ async def aunassociate_data_source_with_collection( associated_data_sources.pop(data_source_fqn, None) try: - updated_collection = await self.db.collection.update( + updated_collection: Optional[ + "PrismaCollection" + ] = await self.db.collection.update( where={"name": collection_name}, data={"associated_data_sources": json.dumps(associated_data_sources)}, ) - return updated_collection + if not updated_collection: + raise HTTPException( + status_code=404, + detail=f"Failed to unassociate data source from collection. " + f"No collection found with name {collection_name}", + ) + return Collection.model_validate(updated_collection.model_dump()) except Exception as e: logger.exception(f"Failed to unassociate data source with collection: {e}") raise HTTPException( @@ -326,7 +360,7 @@ async def aunassociate_data_source_with_collection( async def alist_data_sources( self, - ) -> List[dict[str, str]]: + ) -> List[Dict[str, str]]: try: data_sources = await self.aget_data_sources() return [data_source.model_dump() for data_source in data_sources] @@ -334,7 +368,7 @@ async def alist_data_sources( logger.exception(f"Failed to list data sources: {e}") raise HTTPException(status_code=500, detail="Failed to list data sources") - async def adelete_data_source(self, data_source_fqn: str): + async def adelete_data_source(self, data_source_fqn: str) -> None: if not settings.LOCAL: logger.error(f"Data source deletion is not allowed in local mode") raise HTTPException( @@ -379,7 +413,14 @@ async def adelete_data_source(self, data_source_fqn: str): # Delete the data source try: logger.info(f"Data source to delete: {data_source}") - await self.db.datasource.delete(where={"fqn": data_source.fqn}) + deleted_datasource: Optional[ + PrismaDataSource + ] = await self.db.datasource.delete(where={"fqn": data_source.fqn}) + if not deleted_datasource: + raise HTTPException( + status_code=404, + detail=f"Failed to delete data source {data_source.fqn!r}. No such record found", + ) # Delete the data from `/users_data` directory if data source is of type `localdir` if data_source.type == "localdir": data_source_uri = data_source.uri @@ -424,8 +465,10 @@ async def acreate_data_ingestion_run( try: run_data = created_data_ingestion_run.model_dump() run_data["parser_config"] = json.dumps(run_data["parser_config"]) - data_ingestion_run = await self.db.ingestionruns.create(data=run_data) - return DataIngestionRun(**data_ingestion_run.model_dump()) + data_ingestion_run: "PrismaDataIngestionRun" = ( + await self.db.ingestionruns.create(data=run_data) + ) + return DataIngestionRun.model_validate(data_ingestion_run.model_dump()) except Exception as e: logger.exception(f"Failed to create data ingestion run: {e}") raise HTTPException( @@ -436,12 +479,14 @@ async def aget_data_ingestion_run( self, data_ingestion_run_name: str, no_cache: bool = False ) -> Optional[DataIngestionRun]: try: - data_ingestion_run = await self.db.ingestionruns.find_first( + data_ingestion_run: Optional[ + "PrismaDataIngestionRun" + ] = await self.db.ingestionruns.find_first( where={"name": data_ingestion_run_name} ) logger.info(f"Data ingestion run: {data_ingestion_run}") if data_ingestion_run: - return DataIngestionRun(**data_ingestion_run.model_dump()) + return DataIngestionRun.model_validate(data_ingestion_run.model_dump()) return None except Exception as e: logger.exception(f"Failed to get data ingestion run: {e}") @@ -452,10 +497,15 @@ async def aget_data_ingestion_runs( ) -> List[DataIngestionRun]: """Get all data ingestion runs for a collection""" try: - data_ingestion_runs = await self.db.ingestionruns.find_many( + data_ingestion_runs: List[ + "PrismaDataIngestionRun" + ] = await self.db.ingestionruns.find_many( where={"collection_name": collection_name} ) - return data_ingestion_runs + return [ + DataIngestionRun.model_validate(data_ir.model_dump()) + for data_ir in data_ingestion_runs + ] except Exception as e: logger.exception(f"Failed to get data ingestion runs: {e}") raise HTTPException(status_code=500, detail=f"{e}") @@ -465,10 +515,20 @@ async def aupdate_data_ingestion_run_status( ) -> DataIngestionRun: """Update the status of a data ingestion run""" try: - updated_data_ingestion_run = await self.db.ingestionruns.update( + updated_data_ingestion_run: Optional[ + "PrismaDataIngestionRun" + ] = await self.db.ingestionruns.update( where={"name": data_ingestion_run_name}, data={"status": status} ) - return updated_data_ingestion_run + if not updated_data_ingestion_run: + raise HTTPException( + status_code=404, + detail=f"Failed to update ingestion run {data_ingestion_run_name!r}. No such record found", + ) + + return DataIngestionRun.model_validate( + updated_data_ingestion_run.model_dump() + ) except Exception as e: logger.exception(f"Failed to update data ingestion run status: {e}") raise HTTPException(status_code=500, detail=f"{e}") @@ -479,17 +539,24 @@ async def alog_metrics_for_data_ingestion_run( metric_dict: Dict[str, Union[int, float]], step: int = 0, ): - pass + raise NotImplementedError() async def alog_errors_for_data_ingestion_run( self, data_ingestion_run_name: str, errors: Dict[str, Any] - ): + ) -> None: """Log errors for the given data ingestion run""" try: - await self.db.ingestionruns.update( + updated_data_ingestion_run: Optional[ + "PrismaDataIngestionRun" + ] = await self.db.ingestionruns.update( where={"name": data_ingestion_run_name}, data={"errors": json.dumps(errors)}, ) + if not updated_data_ingestion_run: + raise HTTPException( + status_code=404, + detail=f"Failed to update ingestion run {data_ingestion_run_name!r}. No such record found", + ) except Exception as e: logger.exception( f"Failed to log errors data ingestion run {data_ingestion_run_name}: {e}" @@ -499,8 +566,22 @@ async def alog_errors_for_data_ingestion_run( ###### # RAG APPLICATION APIS ###### + async def aget_rag_app(self, app_name: str) -> Optional[RagApplication]: + """Get a RAG application from the metadata store""" + try: + rag_app: Optional[ + "PrismaRagApplication" + ] = await self.db.ragapps.find_first(where={"name": app_name}) + if rag_app: + return RagApplication.model_validate(rag_app.model_dump()) + return None + except Exception as e: + logger.exception(f"Failed to get RAG application by name: {e}") + raise HTTPException( + status_code=500, detail="Failed to get RAG application by name" + ) - async def acreate_rag_app(self, app: RagApplication) -> RagApplicationDto: + async def acreate_rag_app(self, app: RagApplication) -> RagApplication: """Create a RAG application in the metadata store""" try: existing_app = await self.aget_rag_app(app.name) @@ -519,29 +600,18 @@ async def acreate_rag_app(self, app: RagApplication) -> RagApplicationDto: logger.info(f"Creating RAG application: {app.model_dump()}") rag_app_data = app.model_dump() rag_app_data["config"] = json.dumps(rag_app_data["config"]) - rag_app = await self.db.ragapps.create(data=rag_app_data) - return rag_app + rag_app: "PrismaRagApplication" = await self.db.ragapps.create( + data=rag_app_data + ) + return RagApplication.model_validate(rag_app.model_dump()) except Exception as e: logger.exception(f"Error: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") - async def aget_rag_app(self, app_name: str) -> Optional[RagApplicationDto]: - """Get a RAG application from the metadata store""" - try: - rag_app = await self.db.ragapps.find_first(where={"name": app_name}) - if rag_app: - return rag_app - return None - except Exception as e: - logger.exception(f"Failed to get RAG application by name: {e}") - raise HTTPException( - status_code=500, detail="Failed to get RAG application by name" - ) - async def alist_rag_apps(self) -> List[str]: """List all RAG applications from the metadata store""" try: - rag_apps = await self.db.ragapps.find_many() + rag_apps: List["PrismaRagApplication"] = await self.db.ragapps.find_many() return [rag_app.name for rag_app in rag_apps] except Exception as e: logger.exception(f"Failed to list RAG applications: {e}") @@ -552,14 +622,14 @@ async def alist_rag_apps(self) -> List[str]: async def adelete_rag_app(self, app_name: str): """Delete a RAG application from the metadata store""" try: - rag_app = await self.aget_rag_app(app_name) - if not rag_app: - logger.debug(f"RAG application with name {app_name} does not exist") - except Exception as e: - logger.exception(e) - - try: - await self.db.ragapps.delete(where={"name": app_name}) + deleted_rag_app: Optional[ + "PrismaRagApplication" + ] = await self.db.ragapps.delete(where={"name": app_name}) + if not deleted_rag_app: + raise HTTPException( + status_code=404, + detail=f"Failed to delete RAG application {app_name!r}. No such record found", + ) except Exception as e: logger.exception(f"Failed to delete RAG application: {e}") raise HTTPException( diff --git a/backend/modules/query_controllers/example/types.py b/backend/modules/query_controllers/example/types.py index 11850a22..fb878308 100644 --- a/backend/modules/query_controllers/example/types.py +++ b/backend/modules/query_controllers/example/types.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Dict, Optional, Sequence +from typing import Any, ClassVar, Dict, Optional, Sequence, Union from pydantic import BaseModel, Field, model_validator from qdrant_client.models import Filter as QdrantFilter @@ -7,7 +7,7 @@ GENERATION_TIMEOUT_SEC = 60.0 * 10 -# TODO (chiragjn): Remove all asserts and replace them with proper raises +# TODO (chiragjn): Remove all asserts and replace them with proper pydantic validations or raises class VectorStoreRetrieverConfig(BaseModel): @@ -17,11 +17,12 @@ class VectorStoreRetrieverConfig(BaseModel): search_type: str = Field( default="similarity", - title="""Defines the type of search that the Retriever should perform. Can be 'similarity' (default), 'mmr', or 'similarity_score_threshold'. - - "similarity": Retrieve the top k most similar documents to the query., - - "mmr": Retrieve the top k most similar documents to the query and then rerank them using Maximal Marginal Relevance (MMR)., - - "similarity_score_threshold": Retrieve all documents with similarity score greater than a threshold. - """, + title="""Defines the type of search that the Retriever should perform. +Can be 'similarity' (default), 'mmr', or 'similarity_score_threshold'. + - "similarity": Retrieve the top k most similar documents to the query., + - "mmr": Retrieve the top k most similar documents to the query and then rerank them using Maximal Marginal Relevance (MMR)., + - "similarity_score_threshold": Retrieve all documents with similarity score greater than a threshold. +""", ) search_kwargs: dict = Field(default_factory=dict) @@ -104,9 +105,9 @@ class ExampleQueryInput(BaseModel): Requires a Sequence name, retriever configuration, query, LLM configuration and prompt template. """ - Sequence_name: str = Field( + collection_name: str = Field( default=None, - title="Sequence name on which to search", + title="Collection name on which to search", ) query: str = Field(title="Question to search for") @@ -122,7 +123,12 @@ class ExampleQueryInput(BaseModel): title="Retriever name", ) - retriever_config: Dict[str, Any] = Field( + retriever_config: Union[ + VectorStoreRetrieverConfig, + MultiQueryRetrieverConfig, + ContextualCompressionRetrieverConfig, + ContextualCompressionMultiQueryRetrieverConfig, + ] = Field( title="Retriever configuration", ) @@ -145,10 +151,6 @@ def validate_retriever_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: retriever_name = values.get("retriever_name") - assert ( - retriever_name in cls.allowed_retriever_types - ), f"retriever of {retriever_name} not allowed. Valid values are: {cls.allowed_retriever_types}" - if retriever_name == "vectorstore": values["retriever_config"] = VectorStoreRetrieverConfig( **values.get("retriever_config") @@ -168,5 +170,10 @@ def validate_retriever_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["retriever_config"] = ContextualCompressionMultiQueryRetrieverConfig( **values.get("retriever_config") ) + else: + raise ValueError( + f"Unexpected retriever name: {retriever_name}. " + f"Valid values are: {cls.allowed_retriever_types}" + ) return values diff --git a/backend/modules/query_controllers/multimodal/types.py b/backend/modules/query_controllers/multimodal/types.py index 11850a22..a86924d3 100644 --- a/backend/modules/query_controllers/multimodal/types.py +++ b/backend/modules/query_controllers/multimodal/types.py @@ -1,4 +1,4 @@ -from typing import Any, ClassVar, Dict, Optional, Sequence +from typing import Any, ClassVar, Dict, Optional, Sequence, Union from pydantic import BaseModel, Field, model_validator from qdrant_client.models import Filter as QdrantFilter @@ -7,7 +7,7 @@ GENERATION_TIMEOUT_SEC = 60.0 * 10 -# TODO (chiragjn): Remove all asserts and replace them with proper raises +# TODO (chiragjn): Remove all asserts and replace them with proper pydantic validations or raises class VectorStoreRetrieverConfig(BaseModel): @@ -104,7 +104,7 @@ class ExampleQueryInput(BaseModel): Requires a Sequence name, retriever configuration, query, LLM configuration and prompt template. """ - Sequence_name: str = Field( + collection_name: str = Field( default=None, title="Sequence name on which to search", ) @@ -118,11 +118,17 @@ class ExampleQueryInput(BaseModel): title="Prompt Template to use for generating answer to the question using the context", ) + # TODO (chiragjn): Move retriever name inside configuration to let pydantic disciminate between different retrievers retriever_name: str = Field( title="Retriever name", ) - retriever_config: Dict[str, Any] = Field( + retriever_config: Union[ + VectorStoreRetrieverConfig, + MultiQueryRetrieverConfig, + ContextualCompressionRetrieverConfig, + ContextualCompressionMultiQueryRetrieverConfig, + ] = Field( title="Retriever configuration", ) @@ -145,10 +151,6 @@ def validate_retriever_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: retriever_name = values.get("retriever_name") - assert ( - retriever_name in cls.allowed_retriever_types - ), f"retriever of {retriever_name} not allowed. Valid values are: {cls.allowed_retriever_types}" - if retriever_name == "vectorstore": values["retriever_config"] = VectorStoreRetrieverConfig( **values.get("retriever_config") @@ -163,10 +165,14 @@ def validate_retriever_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: values["retriever_config"] = ContextualCompressionRetrieverConfig( **values.get("retriever_config") ) - elif retriever_name == "contextual-compression-multi-query": values["retriever_config"] = ContextualCompressionMultiQueryRetrieverConfig( **values.get("retriever_config") ) + else: + raise ValueError( + f"Unexpected retriever name: {retriever_name}. " + f"Valid values are: {cls.allowed_retriever_types}" + ) return values diff --git a/backend/types.py b/backend/types.py index 5db777c9..f8547fdc 100644 --- a/backend/types.py +++ b/backend/types.py @@ -3,7 +3,15 @@ from enum import Enum from typing import Any, Dict, List, Literal, Optional -from pydantic import BaseModel, ConfigDict, Field, StringConstraints +from pydantic import ( + BaseModel, + ConfigDict, + Field, + StringConstraints, + computed_field, + model_serializer, + model_validator, +) from typing_extensions import Annotated from backend.constants import FQN_SEPARATOR @@ -108,7 +116,9 @@ class ModelType(str, Enum): class ModelConfig(BaseModel): name: str - type: ModelType + # TODO (chiragjn): This should not be Optional! Changing might break backward compatibility + # Problem is we have shared these entities between DTO layers and Service / DB layers + type: Optional[ModelType] = None parameters: Dict[str, Any] = Field(default_factory=dict) @@ -130,12 +140,19 @@ class EmbedderConfig(BaseModel): # This field will probably be removed soon or refactored embedding_model_config: ModelConfig = Field( - validation_alias="model_config", serialization_alias="model_config" + alias="model_config", ) config: Optional[Dict[str, Any]] = Field( title="Configuration for the embedder", default_factory=dict ) + @model_serializer + def serialize(self): + return { + "model_config": self.embedding_model_config, + "config": self.config, + } + class ParserConfig(BaseModel): """ @@ -302,8 +319,9 @@ class BaseDataSource(BaseModel): None, title="Additional config for your data source" ) + @computed_field @property - def fqn(self): + def fqn(self) -> str: return f"{FQN_SEPARATOR}".join([self.type, self.uri]) @@ -435,6 +453,15 @@ class Collection(BaseCollection): title="Data sources associated with the collection", default_factory=dict ) + @model_validator(mode="before") + @classmethod + def ensure_associated_data_sources_not_none( + cls, values: Dict[str, Any] + ) -> Dict[str, Any]: + if values.get("associated_data_sources") is None: + values["associated_data_sources"] = {} + return values + class CreateCollectionDto(CreateCollection): associated_data_sources: Optional[List[AssociateDataSourceWithCollection]] = Field( diff --git a/frontend/src/screens/dashboard/docsqa/NewCollection.tsx b/frontend/src/screens/dashboard/docsqa/NewCollection.tsx index 7e99f79a..96d0eabe 100644 --- a/frontend/src/screens/dashboard/docsqa/NewCollection.tsx +++ b/frontend/src/screens/dashboard/docsqa/NewCollection.tsx @@ -74,6 +74,8 @@ const NewCollection = ({ open, onClose, onSuccess }: NewCollectionProps) => { embedder_config: { model_config: { name: embeddingModel.name, + type: embeddingModel.type, + // TODO: pass parameters? }, }, chunk_size: chunkSize, From d7916c671949eb0fcfae286727b8900649a9305e Mon Sep 17 00:00:00 2001 From: Prathamesh Date: Fri, 9 Aug 2024 16:46:11 +0530 Subject: [PATCH 12/16] resolved conflicts --- backend/indexer/indexer.py | 2 +- backend/migration/qdrant_migration.py | 2 +- backend/migration/utils.py | 7 +- backend/modules/dataloaders/loader.py | 2 +- backend/modules/metadata_store/base.py | 21 +- backend/modules/metadata_store/prismastore.py | 279 +++++++++++------- backend/modules/metadata_store/truefoundry.py | 64 ++-- .../modules/model_gateway/model_gateway.py | 3 +- backend/modules/model_gateway/reranker_svc.py | 3 +- backend/modules/parsers/multimodalparser.py | 10 +- backend/modules/parsers/parser.py | 14 +- backend/modules/parsers/unstructured_io.py | 1 + .../query_controllers/example/controller.py | 4 +- .../query_controllers/example/types.py | 69 +++-- .../multimodal/controller.py | 2 +- .../query_controllers/multimodal/types.py | 62 ++-- .../query_controllers/query_controller.py | 2 +- backend/modules/vector_db/qdrant.py | 6 +- backend/modules/vector_db/singlestore.py | 6 +- backend/modules/vector_db/weaviate.py | 4 +- backend/requirements.txt | 7 +- backend/server/decorators.py | 14 +- backend/server/routers/collection.py | 14 +- backend/server/routers/data_source.py | 4 +- backend/server/routers/internal.py | 6 +- backend/server/routers/rag_apps.py | 6 +- backend/settings.py | 21 +- backend/types.py | 91 +++--- .../dashboard/docsqa/NewCollection.tsx | 2 +- 29 files changed, 450 insertions(+), 278 deletions(-) diff --git a/backend/indexer/indexer.py b/backend/indexer/indexer.py index b8161fcc..cfd9162e 100644 --- a/backend/indexer/indexer.py +++ b/backend/indexer/indexer.py @@ -307,7 +307,7 @@ async def ingest_data(request: IngestDataToCollectionDto): # convert to pydantic model if not already -> For prisma models if not isinstance(collection, Collection): - collection = Collection(**collection.dict()) + collection = Collection(**collection.model_dump()) if not collection: logger.error( diff --git a/backend/migration/qdrant_migration.py b/backend/migration/qdrant_migration.py index 1f535ea2..0b653a83 100644 --- a/backend/migration/qdrant_migration.py +++ b/backend/migration/qdrant_migration.py @@ -90,7 +90,7 @@ def migrate_collection( "associated_data_sources" ).items() ], - ).dict() + ).model_dump() logger.debug( f"Creating '{dest_collection.get('name')}' collection at destination" diff --git a/backend/migration/utils.py b/backend/migration/utils.py index d0198915..afc15732 100644 --- a/backend/migration/utils.py +++ b/backend/migration/utils.py @@ -2,7 +2,6 @@ from typing import Dict import requests -from qdrant_client._pydantic_compat import to_dict from qdrant_client.client_base import QdrantBase from qdrant_client.http import models from tqdm import tqdm @@ -104,11 +103,11 @@ def _recreate_collection( replication_factor=src_config.params.replication_factor, write_consistency_factor=src_config.params.write_consistency_factor, on_disk_payload=src_config.params.on_disk_payload, - hnsw_config=models.HnswConfigDiff(**to_dict(src_config.hnsw_config)), + hnsw_config=models.HnswConfigDiff(**src_config.hnsw_config.model_dump()), optimizers_config=models.OptimizersConfigDiff( - **to_dict(src_config.optimizer_config) + **src_config.optimizer_config.model_dump() ), - wal_config=models.WalConfigDiff(**to_dict(src_config.wal_config)), + wal_config=models.WalConfigDiff(**src_config.wal_config.model_dump()), quantization_config=src_config.quantization_config, timeout=300, ) diff --git a/backend/modules/dataloaders/loader.py b/backend/modules/dataloaders/loader.py index 656c4ed0..336ebf8f 100644 --- a/backend/modules/dataloaders/loader.py +++ b/backend/modules/dataloaders/loader.py @@ -128,7 +128,7 @@ def list_dataloaders(): Returns a list of all the registered loaders. Returns: - List[dict]: A list of all the registered loaders. + List[Dict]: A list of all the registered loaders. """ global LOADER_REGISTRY return [ diff --git a/backend/modules/metadata_store/base.py b/backend/modules/metadata_store/base.py index 222e5207..534e36ab 100644 --- a/backend/modules/metadata_store/base.py +++ b/backend/modules/metadata_store/base.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from backend.constants import DATA_POINT_FQN_METADATA_KEY, FQN_SEPARATOR from backend.types import ( @@ -18,7 +18,7 @@ from backend.utils import run_in_executor -# TODO(chiragjn): Ideal would be we make `async def a*` abstract methods and drop sync ones +# TODO (chiragjn): Ideal would be we make `async def a*` abstract methods and drop sync ones # Implementations can then opt to call their sync versions using run_in_executor class BaseMetadataStore(ABC): def __init__(self, *args, **kwargs): @@ -300,7 +300,7 @@ async def aupdate_data_ingestion_run_status( def log_metrics_for_data_ingestion_run( self, data_ingestion_run_name: str, - metric_dict: dict[str, int | float], + metric_dict: Dict[str, Union[int, float]], step: int = 0, ): """ @@ -311,7 +311,7 @@ def log_metrics_for_data_ingestion_run( async def alog_metrics_for_data_ingestion_run( self, data_ingestion_run_name: str, - metric_dict: dict[str, int | float], + metric_dict: Dict[str, Union[int, float]], step: int = 0, ): """ @@ -346,6 +346,8 @@ async def alog_errors_for_data_ingestion_run( errors=errors, ) + # TODO (chiragjn): What is the difference between get_collections and this? + # TODO (chiragjn): Return complete entities, why return only str? def list_collections( self, ) -> List[str]: @@ -354,6 +356,7 @@ def list_collections( """ raise NotImplementedError() + # TODO (chiragjn): Return complete entities, why return only str? async def alist_collections( self, ) -> List[str]: @@ -365,17 +368,19 @@ async def alist_collections( self.list_collections, ) + # TODO (chiragjn): Return complete entities, why return dict? def list_data_sources( self, - ) -> List[str]: + ) -> List[Dict[str, str]]: """ List all data source names from metadata store """ raise NotImplementedError() + # TODO (chiragjn): Return complete entities, why return dict? async def alist_data_sources( self, - ) -> List[str]: + ) -> List[Dict[str, str]]: """ List all data source names from metadata store """ @@ -410,7 +415,7 @@ async def acreate_rag_app(self, app: RagApplication) -> RagApplicationDto: """ return await run_in_executor(None, self.create_rag_app, app=app) - def get_rag_app(self, app_name: str) -> RagApplicationDto | None: + def get_rag_app(self, app_name: str) -> Optional[RagApplicationDto]: """ Get a RAG application from the metadata store by name """ @@ -422,12 +427,14 @@ async def aget_rag_app(self, app_name: str) -> Optional[RagApplicationDto]: """ return await run_in_executor(None, self.get_rag_app, app_name=app_name) + # TODO (chiragjn): Return complete entities, why return only str? def list_rag_apps(self) -> List[str]: """ List all RAG application names from metadata store """ raise NotImplementedError() + # TODO (chiragjn): Return complete entities, why return only str? async def alist_rag_apps(self) -> List[str]: """ List all RAG application names from metadata store diff --git a/backend/modules/metadata_store/prismastore.py b/backend/modules/metadata_store/prismastore.py index b15c514f..4ecd8dd2 100644 --- a/backend/modules/metadata_store/prismastore.py +++ b/backend/modules/metadata_store/prismastore.py @@ -4,7 +4,7 @@ import random import shutil import string -from typing import Any, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from fastapi import HTTPException from prisma import Prisma @@ -23,11 +23,26 @@ DataIngestionRunStatus, DataSource, RagApplication, - RagApplicationDto, ) +if TYPE_CHECKING: + # TODO (chiragjn): Can we import these safely even if the prisma client might not be generated yet? + from prisma.models import Collection as PrismaCollection + from prisma.models import DataSource as PrismaDataSource + from prisma.models import IngestionRuns as PrismaDataIngestionRun + from prisma.models import RagApps as PrismaRagApplication + +# TODO (chiragjn): +# - Use transactions! +# - Some methods are using json.dumps - not sure if this is the right way to send data via prisma client +# - primsa generates its own DB entity classes - ideally we should be using those instead of call +# .model_dump() on the pydantic objects. See prisma.models and prisma.actions +# + # TODO (chiragjn): Either we make everything async or add sync method to this + + class PrismaStore(BaseMetadataStore): def __init__(self, *args, db, **kwargs) -> None: self.db = db @@ -48,6 +63,22 @@ async def aconnect(cls, **kwargs): # COLLECTIONS APIS ###### + async def aget_collection_by_name( + self, collection_name: str, no_cache: bool = True + ) -> Optional[Collection]: + try: + collection: Optional[ + "PrismaCollection" + ] = await self.db.collection.find_first(where={"name": collection_name}) + if collection: + return Collection.model_validate(collection.model_dump()) + return None + except Exception as e: + logger.exception(f"Failed to get collection by name: {e}") + raise HTTPException( + status_code=500, detail="Failed to get collection by name" + ) + async def acreate_collection(self, collection: CreateCollection) -> Collection: try: existing_collection = await self.aget_collection_by_name(collection.name) @@ -63,42 +94,33 @@ async def acreate_collection(self, collection: CreateCollection) -> Collection: ) try: - logger.info(f"Creating collection: {collection.dict()}") - collection_data = collection.dict() + logger.info(f"Creating collection: {collection.model_dump()}") + collection_data = collection.model_dump() collection_data["embedder_config"] = json.dumps( collection_data["embedder_config"] ) - collection = await self.db.collection.create(data=collection_data) - return collection + collection: "PrismaCollection" = await self.db.collection.create( + data=collection_data + ) + return Collection.model_validate(collection.model_dump()) except Exception as e: logger.exception(f"Error: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") - async def aget_collection_by_name( - self, collection_name: str, no_cache: bool = True - ) -> Collection | None: - try: - collection = await self.db.collection.find_first( - where={"name": collection_name} - ) - if collection: - return collection - return None - except Exception as e: - logger.exception(f"Failed to get collection by name: {e}") - raise HTTPException( - status_code=500, detail="Failed to get collection by name" - ) - async def aget_retrieve_collection_by_name( self, collection_name: str, no_cache: bool = True - ) -> Collection | None: - return await self.aget_collection_by_name(collection_name, no_cache) + ) -> Optional[Collection]: + collection: "PrismaCollection" = await self.aget_collection_by_name( + collection_name, no_cache + ) + return Collection.model_validate(collection.model_dump()) async def aget_collections(self) -> List[Collection]: try: - collections = await self.db.collection.find_many(order={"id": "desc"}) - return collections + collections: List["PrismaCollection"] = await self.db.collection.find_many( + order={"id": "desc"} + ) + return [Collection.model_validate(c.model_dump()) for c in collections] except Exception as e: logger.exception(f"Failed to get collections: {e}") raise HTTPException(status_code=500, detail="Failed to get collections") @@ -113,17 +135,17 @@ async def alist_collections(self) -> List[str]: async def adelete_collection(self, collection_name: str, include_runs=False): try: - collection = await self.aget_collection_by_name(collection_name) - if not collection: - logger.debug(f"Collection with name {collection_name} does not exist") - except Exception as e: - logger.exception(e) - - try: - await self.db.collection.delete(where={"name": collection_name}) + deleted_collection: Optional[ + "PrismaCollection" + ] = await self.db.collection.delete(where={"name": collection_name}) + if not deleted_collection: + raise HTTPException( + status_code=404, + detail=f"Failed to delete collection {collection_name!r}. No such record found", + ) if include_runs: try: - await self.db.ingestionruns.delete_many( + _deleted_count = await self.db.ingestionruns.delete_many( where={"collection_name": collection_name} ) except Exception as e: @@ -135,6 +157,18 @@ async def adelete_collection(self, collection_name: str, include_runs=False): ###### # DATA SOURCE APIS ###### + async def aget_data_source_from_fqn(self, fqn: str) -> Optional[DataSource]: + try: + data_source: Optional[ + "PrismaDataSource" + ] = await self.db.datasource.find_first(where={"fqn": fqn}) + if data_source: + return DataSource.model_validate(data_source.model_dump()) + return None + except Exception as e: + logger.exception(f"Error: {e}") + raise HTTPException(status_code=500, detail=f"Error: {e}") + async def acreate_data_source(self, data_source: CreateDataSource) -> DataSource: try: existing_data_source = await self.aget_data_source_from_fqn(data_source.fqn) @@ -150,29 +184,21 @@ async def acreate_data_source(self, data_source: CreateDataSource) -> DataSource ) try: - data = data_source.dict() + data = data_source.model_dump() data["metadata"] = json.dumps(data["metadata"]) - data_source = await self.db.datasource.create(data) + data_source: "PrismaDataSource" = await self.db.datasource.create(data) logger.info(f"Created data source: {data_source}") - return data_source + return DataSource.model_validate(data_source.model_dump()) except Exception as e: logger.exception(f"Failed to create data source: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") - async def aget_data_source_from_fqn(self, fqn: str) -> DataSource | None: - try: - data_source = await self.db.datasource.find_first(where={"fqn": fqn}) - if data_source: - return data_source - return None - except Exception as e: - logger.exception(f"Error: {e}") - raise HTTPException(status_code=500, detail=f"Error: {e}") - async def aget_data_sources(self) -> List[DataSource]: try: - data_sources = await self.db.datasource.find_many(order={"id": "desc"}) - return data_sources + data_sources: List["PrismaDataSource"] = await self.db.datasource.find_many( + order={"id": "desc"} + ) + return [DataSource.model_validate(ds.model_dump()) for ds in data_sources] except Exception as e: logger.exception(f"Error: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") @@ -242,13 +268,21 @@ async def aassociate_data_source_with_collection( data_source_fqn, data_source, ) in existing_collection_associated_data_sources.items(): - associated_data_sources[data_source_fqn] = data_source.dict() + associated_data_sources[data_source_fqn] = data_source.model_dump() - updated_collection = await self.db.collection.update( + updated_collection: Optional[ + "PrismaCollection" + ] = await self.db.collection.update( where={"name": collection_name}, data={"associated_data_sources": json.dumps(associated_data_sources)}, ) - return updated_collection + if not updated_collection: + raise HTTPException( + status_code=404, + detail=f"Failed to associate data source with collection {collection_name!r}. " + f"No such record found", + ) + return Collection.model_validate(updated_collection.model_dump()) except Exception as e: logger.exception(f"Error: {e}") @@ -306,11 +340,19 @@ async def aunassociate_data_source_with_collection( associated_data_sources.pop(data_source_fqn, None) try: - updated_collection = await self.db.collection.update( + updated_collection: Optional[ + "PrismaCollection" + ] = await self.db.collection.update( where={"name": collection_name}, data={"associated_data_sources": json.dumps(associated_data_sources)}, ) - return updated_collection + if not updated_collection: + raise HTTPException( + status_code=404, + detail=f"Failed to unassociate data source from collection. " + f"No collection found with name {collection_name}", + ) + return Collection.model_validate(updated_collection.model_dump()) except Exception as e: logger.exception(f"Failed to unassociate data source with collection: {e}") raise HTTPException( @@ -320,15 +362,15 @@ async def aunassociate_data_source_with_collection( async def alist_data_sources( self, - ) -> List[dict[str, str]]: + ) -> List[Dict[str, str]]: try: data_sources = await self.aget_data_sources() - return [data_source.dict() for data_source in data_sources] + return [data_source.model_dump() for data_source in data_sources] except Exception as e: logger.exception(f"Failed to list data sources: {e}") raise HTTPException(status_code=500, detail="Failed to list data sources") - async def adelete_data_source(self, data_source_fqn: str): + async def adelete_data_source(self, data_source_fqn: str) -> None: if not settings.LOCAL: logger.error(f"Data source deletion is not allowed in local mode") raise HTTPException( @@ -373,7 +415,14 @@ async def adelete_data_source(self, data_source_fqn: str): # Delete the data source try: logger.info(f"Data source to delete: {data_source}") - await self.db.datasource.delete(where={"fqn": data_source.fqn}) + deleted_datasource: Optional[ + PrismaDataSource + ] = await self.db.datasource.delete(where={"fqn": data_source.fqn}) + if not deleted_datasource: + raise HTTPException( + status_code=404, + detail=f"Failed to delete data source {data_source.fqn!r}. No such record found", + ) # Delete the data from `/users_data` directory if data source is of type `localdir` if data_source.type == "localdir": data_source_uri = data_source.uri @@ -416,10 +465,12 @@ async def acreate_data_ingestion_run( ) try: - run_data = created_data_ingestion_run.dict() + run_data = created_data_ingestion_run.model_dump() run_data["parser_config"] = json.dumps(run_data["parser_config"]) - data_ingestion_run = await self.db.ingestionruns.create(data=run_data) - return DataIngestionRun(**data_ingestion_run.dict()) + data_ingestion_run: "PrismaDataIngestionRun" = ( + await self.db.ingestionruns.create(data=run_data) + ) + return DataIngestionRun.model_validate(data_ingestion_run.model_dump()) except Exception as e: logger.exception(f"Failed to create data ingestion run: {e}") raise HTTPException( @@ -428,14 +479,16 @@ async def acreate_data_ingestion_run( async def aget_data_ingestion_run( self, data_ingestion_run_name: str, no_cache: bool = False - ) -> DataIngestionRun | None: + ) -> Optional[DataIngestionRun]: try: - data_ingestion_run = await self.db.ingestionruns.find_first( + data_ingestion_run: Optional[ + "PrismaDataIngestionRun" + ] = await self.db.ingestionruns.find_first( where={"name": data_ingestion_run_name} ) logger.info(f"Data ingestion run: {data_ingestion_run}") if data_ingestion_run: - return DataIngestionRun(**data_ingestion_run.dict()) + return DataIngestionRun.model_validate(data_ingestion_run.model_dump()) return None except Exception as e: logger.exception(f"Failed to get data ingestion run: {e}") @@ -446,10 +499,15 @@ async def aget_data_ingestion_runs( ) -> List[DataIngestionRun]: """Get all data ingestion runs for a collection""" try: - data_ingestion_runs = await self.db.ingestionruns.find_many( + data_ingestion_runs: List[ + "PrismaDataIngestionRun" + ] = await self.db.ingestionruns.find_many( where={"collection_name": collection_name}, order={"id": "desc"} ) - return data_ingestion_runs + return [ + DataIngestionRun.model_validate(data_ir.model_dump()) + for data_ir in data_ingestion_runs + ] except Exception as e: logger.exception(f"Failed to get data ingestion runs: {e}") raise HTTPException(status_code=500, detail=f"{e}") @@ -459,10 +517,20 @@ async def aupdate_data_ingestion_run_status( ) -> DataIngestionRun: """Update the status of a data ingestion run""" try: - updated_data_ingestion_run = await self.db.ingestionruns.update( + updated_data_ingestion_run: Optional[ + "PrismaDataIngestionRun" + ] = await self.db.ingestionruns.update( where={"name": data_ingestion_run_name}, data={"status": status} ) - return updated_data_ingestion_run + if not updated_data_ingestion_run: + raise HTTPException( + status_code=404, + detail=f"Failed to update ingestion run {data_ingestion_run_name!r}. No such record found", + ) + + return DataIngestionRun.model_validate( + updated_data_ingestion_run.model_dump() + ) except Exception as e: logger.exception(f"Failed to update data ingestion run status: {e}") raise HTTPException(status_code=500, detail=f"{e}") @@ -470,20 +538,27 @@ async def aupdate_data_ingestion_run_status( async def alog_metrics_for_data_ingestion_run( self, data_ingestion_run_name: str, - metric_dict: dict[str, int | float], + metric_dict: Dict[str, Union[int, float]], step: int = 0, ): - pass + raise NotImplementedError() async def alog_errors_for_data_ingestion_run( self, data_ingestion_run_name: str, errors: Dict[str, Any] - ): + ) -> None: """Log errors for the given data ingestion run""" try: - await self.db.ingestionruns.update( + updated_data_ingestion_run: Optional[ + "PrismaDataIngestionRun" + ] = await self.db.ingestionruns.update( where={"name": data_ingestion_run_name}, data={"errors": json.dumps(errors)}, ) + if not updated_data_ingestion_run: + raise HTTPException( + status_code=404, + detail=f"Failed to update ingestion run {data_ingestion_run_name!r}. No such record found", + ) except Exception as e: logger.exception( f"Failed to log errors data ingestion run {data_ingestion_run_name}: {e}" @@ -493,9 +568,22 @@ async def alog_errors_for_data_ingestion_run( ###### # RAG APPLICATION APIS ###### + async def aget_rag_app(self, app_name: str) -> Optional[RagApplication]: + """Get a RAG application from the metadata store""" + try: + rag_app: Optional[ + "PrismaRagApplication" + ] = await self.db.ragapps.find_first(where={"name": app_name}) + if rag_app: + return RagApplication.model_validate(rag_app.model_dump()) + return None + except Exception as e: + logger.exception(f"Failed to get RAG application by name: {e}") + raise HTTPException( + status_code=500, detail="Failed to get RAG application by name" + ) - # TODO (prathamesh): Implement these methods - async def acreate_rag_app(self, app: RagApplication) -> RagApplicationDto: + async def acreate_rag_app(self, app: RagApplication) -> RagApplication: """Create a RAG application in the metadata store""" try: existing_app = await self.aget_rag_app(app.name) @@ -511,32 +599,21 @@ async def acreate_rag_app(self, app: RagApplication) -> RagApplicationDto: ) try: - logger.info(f"Creating RAG application: {app.dict()}") - rag_app_data = app.dict() + logger.info(f"Creating RAG application: {app.model_dump()}") + rag_app_data = app.model_dump() rag_app_data["config"] = json.dumps(rag_app_data["config"]) - rag_app = await self.db.ragapps.create(data=rag_app_data) - return rag_app + rag_app: "PrismaRagApplication" = await self.db.ragapps.create( + data=rag_app_data + ) + return RagApplication.model_validate(rag_app.model_dump()) except Exception as e: logger.exception(f"Error: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") - async def aget_rag_app(self, app_name: str) -> RagApplicationDto | None: - """Get a RAG application from the metadata store""" - try: - rag_app = await self.db.ragapps.find_first(where={"name": app_name}) - if rag_app: - return rag_app - return None - except Exception as e: - logger.exception(f"Failed to get RAG application by name: {e}") - raise HTTPException( - status_code=500, detail="Failed to get RAG application by name" - ) - async def alist_rag_apps(self) -> List[str]: """List all RAG applications from the metadata store""" try: - rag_apps = await self.db.ragapps.find_many() + rag_apps: List["PrismaRagApplication"] = await self.db.ragapps.find_many() return [rag_app.name for rag_app in rag_apps] except Exception as e: logger.exception(f"Failed to list RAG applications: {e}") @@ -547,14 +624,14 @@ async def alist_rag_apps(self) -> List[str]: async def adelete_rag_app(self, app_name: str): """Delete a RAG application from the metadata store""" try: - rag_app = await self.aget_rag_app(app_name) - if not rag_app: - logger.debug(f"RAG application with name {app_name} does not exist") - except Exception as e: - logger.exception(e) - - try: - await self.db.ragapps.delete(where={"name": app_name}) + deleted_rag_app: Optional[ + "PrismaRagApplication" + ] = await self.db.ragapps.delete(where={"name": app_name}) + if not deleted_rag_app: + raise HTTPException( + status_code=404, + detail=f"Failed to delete RAG application {app_name!r}. No such record found", + ) except Exception as e: logger.exception(f"Failed to delete RAG application: {e}") raise HTTPException( diff --git a/backend/modules/metadata_store/truefoundry.py b/backend/modules/metadata_store/truefoundry.py index 671d2d5a..b47b0fb1 100644 --- a/backend/modules/metadata_store/truefoundry.py +++ b/backend/modules/metadata_store/truefoundry.py @@ -3,7 +3,7 @@ import os import tempfile import warnings -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Union import mlflow from fastapi import HTTPException @@ -38,7 +38,7 @@ class MLRunTypes(str, enum.Enum): class TrueFoundry(BaseMetadataStore): - ml_runs: dict[str, ml.MlFoundryRun] = {} + ml_runs: Dict[str, ml.MlFoundryRun] = {} CONSTANT_DATA_SOURCE_RUN_NAME = "tfy-datasource" def __init__(self, *args, ml_repo_name: str, **kwargs): @@ -58,7 +58,7 @@ def __init__(self, *args, ml_repo_name: str, **kwargs): def _get_run_by_name( self, run_name: str, no_cache: bool = False - ) -> ml.MlFoundryRun | None: + ) -> Optional[ml.MlFoundryRun]: """ Cache the runs to avoid too many requests to the backend. """ @@ -113,7 +113,7 @@ def create_collection(self, collection: CreateCollection) -> Collection: embedder_config=collection.embedder_config, ) self._save_entity_to_run( - run=run, metadata=created_collection.dict(), params=params + run=run, metadata=created_collection.model_dump(), params=params ) run.end() logger.debug(f"[Metadata Store] Collection Saved") @@ -134,7 +134,7 @@ def _get_entity_from_run( def _get_artifact_metadata_ml_run( self, run: ml.MlFoundryRun - ) -> ml.ArtifactVersion | None: + ) -> Optional[ml.ArtifactVersion]: params = run.get_params() metadata_artifact_fqn = params.get("metadata_artifact_fqn") if not metadata_artifact_fqn: @@ -177,7 +177,7 @@ def _update_entity_in_run( def get_collection_by_name( self, collection_name: str, no_cache: bool = True - ) -> Collection | None: + ) -> Optional[Collection]: """Get collection from given collection name.""" logger.debug(f"[Metadata Store] Getting collection with name {collection_name}") ml_run = self._get_run_by_name(run_name=collection_name, no_cache=no_cache) @@ -187,14 +187,14 @@ def get_collection_by_name( ) return None collection = self._populate_collection( - Collection.parse_obj(self._get_entity_from_run(run=ml_run)) + Collection.model_validate(self._get_entity_from_run(run=ml_run)) ) logger.debug(f"[Metadata Store] Fetched collection with name {collection_name}") return collection def get_retrieve_collection_by_name( self, collection_name: str, no_cache: bool = True - ) -> Collection | None: + ) -> Optional[Collection]: """Get collection from given collection name. Used during retrieval""" logger.debug(f"[Metadata Store] Getting collection with name {collection_name}") ml_run = self._get_run_by_name(run_name=collection_name, no_cache=no_cache) @@ -203,7 +203,7 @@ def get_retrieve_collection_by_name( f"[Metadata Store] Collection with name {collection_name} not found" ) return None - collection = Collection.parse_obj(self._get_entity_from_run(run=ml_run)) + collection = Collection.model_validate(self._get_entity_from_run(run=ml_run)) logger.debug(f"[Metadata Store] Fetched collection with name {collection_name}") return collection @@ -216,7 +216,9 @@ def get_collections(self) -> List[Collection]: ) collections = [] for ml_run in ml_runs: - collection = Collection.parse_obj(self._get_entity_from_run(run=ml_run)) + collection = Collection.model_validate( + self._get_entity_from_run(run=ml_run) + ) collections.append(self._populate_collection(collection)) logger.debug(f"[Metadata Store] Listed {len(collections)} collections") return collections @@ -262,7 +264,9 @@ def associate_data_source_with_collection( f"data source with fqn {data_source_association.data_source_fqn} not found", ) # Always do this to avoid race conditions - collection = Collection.parse_obj(self._get_entity_from_run(run=collection_run)) + collection = Collection.model_validate( + self._get_entity_from_run(run=collection_run) + ) associated_data_source = AssociatedDataSources( data_source_fqn=data_source_association.data_source_fqn, parser_config=data_source_association.parser_config, @@ -271,7 +275,7 @@ def associate_data_source_with_collection( data_source_association.data_source_fqn ] = associated_data_source - self._update_entity_in_run(run=collection_run, metadata=collection.dict()) + self._update_entity_in_run(run=collection_run, metadata=collection.model_dump()) logger.debug( f"[Metadata Store] Associated data_source {data_source_association.data_source_fqn} " f"to collection {collection_name}" @@ -296,9 +300,11 @@ def unassociate_data_source_with_collection( f"Collection {collection_name} not found.", ) # Always do this to avoid run conditions - collection = Collection.parse_obj(self._get_entity_from_run(run=collection_run)) + collection = Collection.model_validate( + self._get_entity_from_run(run=collection_run) + ) collection.associated_data_sources.pop(data_source_fqn) - self._update_entity_in_run(run=collection_run, metadata=collection.dict()) + self._update_entity_in_run(run=collection_run, metadata=collection.model_dump()) logger.debug( f"[Metadata Store] Unassociated data_source {data_source_fqn} to collection {collection_name}" ) @@ -331,7 +337,7 @@ def create_data_source(self, data_source: CreateDataSource) -> DataSource: metadata=data_source.metadata, ) self._save_entity_to_run( - run=run, metadata=created_data_source.dict(), params=params + run=run, metadata=created_data_source.model_dump(), params=params ) run.end() logger.debug( @@ -339,14 +345,14 @@ def create_data_source(self, data_source: CreateDataSource) -> DataSource: ) return created_data_source - def get_data_source_from_fqn(self, fqn: str) -> DataSource | None: + def get_data_source_from_fqn(self, fqn: str) -> Optional[DataSource]: logger.debug(f"[Metadata Store] Getting data_source by fqn {fqn}") runs = self.client.search_runs( ml_repo=self.ml_repo_name, filter_string=f"params.entity_type = '{MLRunTypes.DATA_SOURCE.value}' and params.data_source_fqn = '{fqn}'", ) for run in runs: - data_source = DataSource.parse_obj(self._get_entity_from_run(run=run)) + data_source = DataSource.model_validate(self._get_entity_from_run(run=run)) logger.debug(f"[Metadata Store] Fetched Data Source with fqn {fqn}") return data_source logger.debug(f"[Metadata Store] Data Source with fqn {fqn} not found") @@ -360,7 +366,7 @@ def get_data_sources(self) -> List[DataSource]: ) data_sources: List[DataSource] = [] for run in runs: - data_source = DataSource.parse_obj(self._get_entity_from_run(run=run)) + data_source = DataSource.model_validate(self._get_entity_from_run(run=run)) data_sources.append(data_source) logger.debug(f"[Metadata Store] Listed {len(data_sources)} data sources") return data_sources @@ -395,7 +401,7 @@ def create_data_ingestion_run( status=DataIngestionRunStatus.INITIALIZED, ) self._save_entity_to_run( - run=run, metadata=created_data_ingestion_run.dict(), params=params + run=run, metadata=created_data_ingestion_run.model_dump(), params=params ) run.end() logger.debug( @@ -406,7 +412,7 @@ def create_data_ingestion_run( def get_data_ingestion_run( self, data_ingestion_run_name: str, no_cache: bool = False - ) -> DataIngestionRun | None: + ) -> Optional[DataIngestionRun]: logger.debug( f"[Metadata Store] Getting ingestion run {data_ingestion_run_name}" ) @@ -416,7 +422,7 @@ def get_data_ingestion_run( f"[Metadata Store] Ingestion run with name {data_ingestion_run_name} not found" ) return None - data_ingestion_run = DataIngestionRun.parse_obj( + data_ingestion_run = DataIngestionRun.model_validate( self._get_entity_from_run(run=run) ) run_tags = run.get_tags() @@ -447,7 +453,7 @@ def get_data_ingestion_runs( ) data_ingestion_runs: List[DataIngestionRun] = [] for run in runs: - data_ingestion_run = DataIngestionRun.parse_obj( + data_ingestion_run = DataIngestionRun.model_validate( self._get_entity_from_run(run=run) ) run_tags = run.get_tags() @@ -512,7 +518,7 @@ def update_data_ingestion_run_status( def log_metrics_for_data_ingestion_run( self, data_ingestion_run_name: str, - metric_dict: dict[str, int | float], + metric_dict: Dict[str, Union[int, float]], step: int = 0, ): try: @@ -566,7 +572,7 @@ def list_collections(self) -> List[str]: ) return [run.run_name for run in ml_runs] - def list_data_sources(self) -> List[dict[str, str]]: + def list_data_sources(self) -> List[Dict[str, str]]: logger.info(f"[Metadata Store] Listing all data sources") ml_runs = self.client.search_runs( ml_repo=self.ml_repo_name, @@ -624,12 +630,14 @@ def create_rag_app(self, app: RagApplication) -> RagApplicationDto: name=app.name, config=app.config, ) - self._save_entity_to_run(run=run, metadata=created_app.dict(), params=params) + self._save_entity_to_run( + run=run, metadata=created_app.model_dump(), params=params + ) run.end() logger.debug(f"[Metadata Store] RAG Application Saved") return created_app - def get_rag_app(self, app_name: str) -> RagApplicationDto | None: + def get_rag_app(self, app_name: str) -> Optional[RagApplicationDto]: """ Get a RAG application from the metadata store """ @@ -641,7 +649,9 @@ def get_rag_app(self, app_name: str) -> RagApplicationDto | None: ) return None try: - app = RagApplicationDto.parse_obj(self._get_entity_from_run(run=ml_run)) + app = RagApplicationDto.model_validate( + self._get_entity_from_run(run=ml_run) + ) logger.debug( f"[Metadata Store] Fetched RAG application with name {app_name}" ) diff --git a/backend/modules/model_gateway/model_gateway.py b/backend/modules/model_gateway/model_gateway.py index e94eb77c..ecea699e 100644 --- a/backend/modules/model_gateway/model_gateway.py +++ b/backend/modules/model_gateway/model_gateway.py @@ -1,4 +1,3 @@ -import json import os from typing import List @@ -28,7 +27,7 @@ def __init__(self): # parse the json data into a list of ModelProviderConfig objects self.provider_configs = [ - ModelProviderConfig.parse_obj(item) for item in _providers + ModelProviderConfig.model_validate(item) for item in _providers ] # load llm models diff --git a/backend/modules/model_gateway/reranker_svc.py b/backend/modules/model_gateway/reranker_svc.py index 370677c8..32a85043 100644 --- a/backend/modules/model_gateway/reranker_svc.py +++ b/backend/modules/model_gateway/reranker_svc.py @@ -6,14 +6,13 @@ from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from backend.logger import logger -from backend.settings import settings # Reranking Service using Infinity API class InfinityRerankerSvc(BaseDocumentCompressor): """ Reranker Service that uses Infinity API - Github: https://github.com/michaelfeil/infinity + GitHub: https://github.com/michaelfeil/infinity """ model: str diff --git a/backend/modules/parsers/multimodalparser.py b/backend/modules/parsers/multimodalparser.py index 7adf2fc4..7e1b0ae1 100644 --- a/backend/modules/parsers/multimodalparser.py +++ b/backend/modules/parsers/multimodalparser.py @@ -3,7 +3,7 @@ import io import os from itertools import islice -from typing import Optional +from typing import Any, Dict, Optional import cv2 import fitz @@ -67,8 +67,10 @@ def __init__( """ # Multi-modal parser needs to be configured with the openai compatible client url and vision model - if model_configuration: - self.model_configuration = ModelConfig.parse_obj(model_configuration) + if "model_configuration" in additional_config: + self.model_configuration = ModelConfig.model_validate( + additional_config["model_configuration"] + ) logger.info(f"Using custom vision model..., {self.model_configuration}") else: # Truefoundry specific model configuration @@ -131,7 +133,7 @@ async def call_vlm_agent( return {"error": f"Error in page: {page_number}"} async def get_chunks( - self, filepath: str, metadata: Optional[dict] = None, **kwargs + self, filepath: str, metadata: Optional[Dict[Any, Any]] = None, *args, **kwargs ): """ Asynchronously extracts text from a PDF file and returns it in chunks. diff --git a/backend/modules/parsers/parser.py b/backend/modules/parsers/parser.py index e230b004..37c28849 100644 --- a/backend/modules/parsers/parser.py +++ b/backend/modules/parsers/parser.py @@ -1,7 +1,6 @@ -import typing from abc import ABC, abstractmethod from collections import defaultdict -from typing import Optional +from typing import Any, Dict, List, Optional from langchain.docstore.document import Document @@ -39,9 +38,10 @@ def __init__(self, **kwargs): async def get_chunks( self, filepath: str, - metadata: Optional[dict], + metadata: Optional[Dict[Any, Any]], + *args, **kwargs, - ) -> typing.List[Document]: + ) -> List[Document]: """ Abstract method. This should asynchronously read a file and return its content in chunks. @@ -54,7 +54,9 @@ async def get_chunks( pass -def get_parser_for_extension(file_extension, parsers_map, **kwargs) -> BaseParser: +def get_parser_for_extension( + file_extension, parsers_map, *args, **kwargs +) -> Optional[BaseParser]: """ During the indexing phase, given the file_extension and parsers mapping, return the appropriate mapper. If no mapping is given, use the default registry. @@ -94,7 +96,7 @@ def list_parsers(): Returns a list of all the registered parsers. Returns: - List[dict]: A list of all the registered parsers. + List[Dict]: A list of all the registered parsers. """ global PARSER_REGISTRY return [ diff --git a/backend/modules/parsers/unstructured_io.py b/backend/modules/parsers/unstructured_io.py index 179df881..4d8ce2ba 100644 --- a/backend/modules/parsers/unstructured_io.py +++ b/backend/modules/parsers/unstructured_io.py @@ -51,6 +51,7 @@ def __init__(self, *, max_chunk_size: int = 2000, **kwargs): self.adapter = HTTPAdapter(max_retries=self.retry_strategy) self.session.mount("https://", self.adapter) self.session.mount("http://", self.adapter) + super().__init__(*args, **kwargs) super().__init__(**kwargs) diff --git a/backend/modules/query_controllers/example/controller.py b/backend/modules/query_controllers/example/controller.py index 7db59a42..f4599afb 100644 --- a/backend/modules/query_controllers/example/controller.py +++ b/backend/modules/query_controllers/example/controller.py @@ -86,7 +86,7 @@ async def _get_vector_store(self, collection_name: str): raise HTTPException(status_code=404, detail="Collection not found") if not isinstance(collection, Collection): - collection = Collection(**collection.dict()) + collection = Collection(**collection.model_dump()) return VECTOR_STORE_CLIENT.get_vector_store( collection_name=collection.name, @@ -329,7 +329,7 @@ async def answer( # "stream": True # } -# data = ExampleQueryInput(**payload).dict() +# data = ExampleQueryInput(**payload).model_dump() # ENDPOINT_URL = 'http://localhost:8000/retrievers/example-app/answer' diff --git a/backend/modules/query_controllers/example/types.py b/backend/modules/query_controllers/example/types.py index 6226b7aa..fb45fa77 100644 --- a/backend/modules/query_controllers/example/types.py +++ b/backend/modules/query_controllers/example/types.py @@ -1,13 +1,14 @@ -from typing import Any, ClassVar, Collection, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional, Sequence, Union -from langchain.docstore.document import Document -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field, model_validator from qdrant_client.models import Filter as QdrantFilter from backend.types import ModelConfig GENERATION_TIMEOUT_SEC = 60.0 * 10 +# TODO (chiragjn): Remove all asserts and replace them with proper pydantic validations or raises + class VectorStoreRetrieverConfig(BaseModel): """ @@ -16,29 +17,36 @@ class VectorStoreRetrieverConfig(BaseModel): search_type: str = Field( default="similarity", - title="""Defines the type of search that the Retriever should perform. Can be 'similarity' (default), 'mmr', or 'similarity_score_threshold'. - - "similarity": Retrieve the top k most similar documents to the query., - - "mmr": Retrieve the top k most similar documents to the query and then rerank them using Maximal Marginal Relevance (MMR)., - - "similarity_score_threshold": Retrieve all documents with similarity score greater than a threshold. - """, + title="""Defines the type of search that the Retriever should perform. +Can be 'similarity' (default), 'mmr', or 'similarity_score_threshold'. + - "similarity": Retrieve the top k most similar documents to the query., + - "mmr": Retrieve the top k most similar documents to the query and then rerank them using Maximal Marginal Relevance (MMR)., + - "similarity_score_threshold": Retrieve all documents with similarity score greater than a threshold. +""", ) search_kwargs: dict = Field(default_factory=dict) - filter: Optional[dict] = Field( + filter: Optional[Dict[Any, Any]] = Field( default_factory=dict, title="""Filter by document metadata""", ) - allowed_search_types: ClassVar[Collection[str]] = ( + allowed_search_types: ClassVar[Sequence[str]] = ( "similarity", "similarity_score_threshold", "mmr", ) - @root_validator - def validate_search_type(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_search_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate search type.""" + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) + search_type = values.get("search_type") assert ( @@ -63,7 +71,7 @@ def validate_search_type(cls, values: Dict) -> Dict: filters = values.get("filter") if filters: - search_kwargs["filter"] = QdrantFilter.parse_obj(filters) + search_kwargs["filter"] = QdrantFilter.model_validate(filters) return values @@ -82,7 +90,7 @@ class ContextualCompressionRetrieverConfig(VectorStoreRetrieverConfig): title="Top K docs to collect post compression", ) - allowed_compressor_model_providers: ClassVar[Collection[str]] + allowed_compressor_model_providers: ClassVar[Sequence[str]] class ContextualCompressionMultiQueryRetrieverConfig( @@ -94,7 +102,7 @@ class ContextualCompressionMultiQueryRetrieverConfig( class ExampleQueryInput(BaseModel): """ Model for Query input. - Requires a collection name, retriever configuration, query, LLM configuration and prompt template. + Requires a Sequence name, retriever configuration, query, LLM configuration and prompt template. """ collection_name: str = Field( @@ -104,6 +112,7 @@ class ExampleQueryInput(BaseModel): query: str = Field(title="Question to search for") + # TODO (chiragjn): pydantic v2 does not like fields that start with model_ model_configuration: ModelConfig prompt_template: str = Field( @@ -114,26 +123,33 @@ class ExampleQueryInput(BaseModel): title="Retriever name", ) - retriever_config: Dict[str, Any] = Field( + retriever_config: Union[ + VectorStoreRetrieverConfig, + MultiQueryRetrieverConfig, + ContextualCompressionRetrieverConfig, + ContextualCompressionMultiQueryRetrieverConfig, + ] = Field( title="Retriever configuration", ) - allowed_retriever_types: ClassVar[Collection[str]] = ( + allowed_retriever_types: ClassVar[Sequence[str]] = ( "vectorstore", "multi-query", "contextual-compression", "contextual-compression-multi-query", ) - stream: Optional[bool] = Field(title="Stream the results", default=False) + stream: bool = Field(title="Stream the results", default=False) - @root_validator() - def validate_retriever_type(cls, values: Dict) -> Dict: - retriever_name = values.get("retriever_name") + @model_validator(mode="before") + @classmethod + def validate_retriever_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) - assert ( - retriever_name in cls.allowed_retriever_types - ), f"retriever of {retriever_name} not allowed. Valid values are: {cls.allowed_retriever_types}" + retriever_name = values.get("retriever_name") if retriever_name == "vectorstore": values["retriever_config"] = VectorStoreRetrieverConfig( @@ -154,6 +170,11 @@ def validate_retriever_type(cls, values: Dict) -> Dict: values["retriever_config"] = ContextualCompressionMultiQueryRetrieverConfig( **values.get("retriever_config") ) + else: + raise ValueError( + f"Unexpected retriever name: {retriever_name}. " + f"Valid values are: {cls.allowed_retriever_types}" + ) return values diff --git a/backend/modules/query_controllers/multimodal/controller.py b/backend/modules/query_controllers/multimodal/controller.py index b54c3eb8..fcd1ec26 100644 --- a/backend/modules/query_controllers/multimodal/controller.py +++ b/backend/modules/query_controllers/multimodal/controller.py @@ -98,7 +98,7 @@ async def _get_vector_store(self, collection_name: str): raise HTTPException(status_code=404, detail="Collection not found") if not isinstance(collection, Collection): - collection = Collection(**collection.dict()) + collection = Collection(**collection.model_dump()) return VECTOR_STORE_CLIENT.get_vector_store( collection_name=collection.name, diff --git a/backend/modules/query_controllers/multimodal/types.py b/backend/modules/query_controllers/multimodal/types.py index 6226b7aa..53c037e2 100644 --- a/backend/modules/query_controllers/multimodal/types.py +++ b/backend/modules/query_controllers/multimodal/types.py @@ -1,13 +1,14 @@ -from typing import Any, ClassVar, Collection, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional, Sequence, Union -from langchain.docstore.document import Document -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field, model_validator from qdrant_client.models import Filter as QdrantFilter from backend.types import ModelConfig GENERATION_TIMEOUT_SEC = 60.0 * 10 +# TODO (chiragjn): Remove all asserts and replace them with proper pydantic validations or raises + class VectorStoreRetrieverConfig(BaseModel): """ @@ -25,20 +26,26 @@ class VectorStoreRetrieverConfig(BaseModel): search_kwargs: dict = Field(default_factory=dict) - filter: Optional[dict] = Field( + filter: Optional[Dict[Any, Any]] = Field( default_factory=dict, title="""Filter by document metadata""", ) - allowed_search_types: ClassVar[Collection[str]] = ( + allowed_search_types: ClassVar[Sequence[str]] = ( "similarity", "similarity_score_threshold", "mmr", ) - @root_validator - def validate_search_type(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_search_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate search type.""" + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) + search_type = values.get("search_type") assert ( @@ -63,7 +70,7 @@ def validate_search_type(cls, values: Dict) -> Dict: filters = values.get("filter") if filters: - search_kwargs["filter"] = QdrantFilter.parse_obj(filters) + search_kwargs["filter"] = QdrantFilter.model_validate(filters) return values @@ -82,7 +89,7 @@ class ContextualCompressionRetrieverConfig(VectorStoreRetrieverConfig): title="Top K docs to collect post compression", ) - allowed_compressor_model_providers: ClassVar[Collection[str]] + allowed_compressor_model_providers: ClassVar[Sequence[str]] class ContextualCompressionMultiQueryRetrieverConfig( @@ -94,46 +101,55 @@ class ContextualCompressionMultiQueryRetrieverConfig( class ExampleQueryInput(BaseModel): """ Model for Query input. - Requires a collection name, retriever configuration, query, LLM configuration and prompt template. + Requires a Sequence name, retriever configuration, query, LLM configuration and prompt template. """ collection_name: str = Field( default=None, - title="Collection name on which to search", + title="Sequence name on which to search", ) query: str = Field(title="Question to search for") + # TODO (chiragjn): pydantic v2 does not like fields that start with model_ model_configuration: ModelConfig prompt_template: str = Field( title="Prompt Template to use for generating answer to the question using the context", ) + # TODO (chiragjn): Move retriever name inside configuration to let pydantic disciminate between different retrievers retriever_name: str = Field( title="Retriever name", ) - retriever_config: Dict[str, Any] = Field( + retriever_config: Union[ + VectorStoreRetrieverConfig, + MultiQueryRetrieverConfig, + ContextualCompressionRetrieverConfig, + ContextualCompressionMultiQueryRetrieverConfig, + ] = Field( title="Retriever configuration", ) - allowed_retriever_types: ClassVar[Collection[str]] = ( + allowed_retriever_types: ClassVar[Sequence[str]] = ( "vectorstore", "multi-query", "contextual-compression", "contextual-compression-multi-query", ) - stream: Optional[bool] = Field(title="Stream the results", default=False) + stream: bool = Field(title="Stream the results", default=False) - @root_validator() - def validate_retriever_type(cls, values: Dict) -> Dict: - retriever_name = values.get("retriever_name") + @model_validator(mode="before") + @classmethod + def validate_retriever_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) - assert ( - retriever_name in cls.allowed_retriever_types - ), f"retriever of {retriever_name} not allowed. Valid values are: {cls.allowed_retriever_types}" + retriever_name = values.get("retriever_name") if retriever_name == "vectorstore": values["retriever_config"] = VectorStoreRetrieverConfig( @@ -149,11 +165,15 @@ def validate_retriever_type(cls, values: Dict) -> Dict: values["retriever_config"] = ContextualCompressionRetrieverConfig( **values.get("retriever_config") ) - elif retriever_name == "contextual-compression-multi-query": values["retriever_config"] = ContextualCompressionMultiQueryRetrieverConfig( **values.get("retriever_config") ) + else: + raise ValueError( + f"Unexpected retriever name: {retriever_name}. " + f"Valid values are: {cls.allowed_retriever_types}" + ) return values diff --git a/backend/modules/query_controllers/query_controller.py b/backend/modules/query_controllers/query_controller.py index fd21aa7f..e81d8a1e 100644 --- a/backend/modules/query_controllers/query_controller.py +++ b/backend/modules/query_controllers/query_controller.py @@ -18,7 +18,7 @@ def list_query_controllers(): Returns a list of all the registered query controllers. Returns: - List[dict]: A list of all the registered query controllers. + List[Dict]: A list of all the registered query controllers. """ global QUERY_CONTROLLER_REGISTRY return [ diff --git a/backend/modules/vector_db/qdrant.py b/backend/modules/vector_db/qdrant.py index 5ac47d04..341493e3 100644 --- a/backend/modules/vector_db/qdrant.py +++ b/backend/modules/vector_db/qdrant.py @@ -17,7 +17,7 @@ class QdrantVectorDB(BaseVectorDB): def __init__(self, config: VectorDBConfig): - logger.debug(f"Connecting to qdrant using config: {config.dict()}") + logger.debug(f"Connecting to qdrant using config: {config.model_dump()}") if config.local is True: # TODO: make this path customizable self.qdrant_client = QdrantClient( @@ -28,7 +28,7 @@ def __init__(self, config: VectorDBConfig): api_key = config.api_key if not api_key: api_key = None - qdrant_kwargs = QdrantClientConfig.parse_obj(config.config or {}) + qdrant_kwargs = QdrantClientConfig.model_validate(config.config or {}) if url.startswith("http://") or url.startswith("https://"): if qdrant_kwargs.port is None: parsed_port = urlparse(url).port @@ -37,7 +37,7 @@ def __init__(self, config: VectorDBConfig): else: qdrant_kwargs.port = 443 if url.startswith("https://") else 6333 self.qdrant_client = QdrantClient( - url=url, api_key=api_key, **qdrant_kwargs.dict() + url=url, api_key=api_key, **qdrant_kwargs.model_dump() ) def create_collection(self, collection_name: str, embeddings: Embeddings): diff --git a/backend/modules/vector_db/singlestore.py b/backend/modules/vector_db/singlestore.py index fc1fba87..2b95c3b4 100644 --- a/backend/modules/vector_db/singlestore.py +++ b/backend/modules/vector_db/singlestore.py @@ -1,5 +1,5 @@ import json -from typing import Any, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional import singlestoredb as s2 from langchain.docstore.document import Document @@ -63,7 +63,7 @@ def _create_table(self: SingleStoreDB) -> None: def add_texts( self, texts: Iterable[str], - metadatas: Optional[List[dict]] = None, + metadatas: Optional[List[Dict[Any, Any]]] = None, embeddings: Optional[List[List[float]]] = None, **kwargs: Any, ) -> List[str]: @@ -71,7 +71,7 @@ def add_texts( Args: texts (Iterable[str]): Iterable of strings/text to add to the vectorstore. - metadatas (Optional[List[dict]], optional): Optional list of metadatas. + metadatas (Optional[List[Dict]], optional): Optional list of metadatas. Defaults to None. embeddings (Optional[List[List[float]]], optional): Optional pre-generated embeddings. Defaults to None. diff --git a/backend/modules/vector_db/weaviate.py b/backend/modules/vector_db/weaviate.py index 5b3986de..2aad43e5 100644 --- a/backend/modules/vector_db/weaviate.py +++ b/backend/modules/vector_db/weaviate.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, Dict, List import weaviate from langchain.embeddings.base import Embeddings @@ -99,7 +99,7 @@ def list_documents_in_collection( .with_fields("groupedBy { value }") .do() ) - groups: List[dict] = ( + groups: List[Dict[Any, Any]] = ( response.get("data", {}) .get("Aggregate", {}) .get(collection_name.capitalize(), []) diff --git a/backend/requirements.txt b/backend/requirements.txt index af95cf07..1317f818 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -5,11 +5,12 @@ langchain-openai==0.1.7 langchain-core==0.1.46 openai==1.35.3 tiktoken==0.7.0 -uvicorn==0.23.2 -fastapi==0.109.1 +uvicorn[standard]==0.23.2 +fastapi==0.111.1 qdrant-client==1.9.0 python-dotenv==1.0.1 -pydantic==1.10.17 +pydantic==2.7.4 +pydantic-settings==2.3.3 orjson==3.9.15 PyMuPDF==1.23.6 redis==5.0.1 diff --git a/backend/server/decorators.py b/backend/server/decorators.py index 2b17ff98..8f12dd81 100644 --- a/backend/server/decorators.py +++ b/backend/server/decorators.py @@ -4,10 +4,9 @@ """ import inspect -from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints +from typing import Any, Callable, ClassVar, List, Type, TypeVar, Union, get_type_hints from fastapi import APIRouter, Depends -from pydantic.typing import is_classvar from starlette.routing import Route, WebSocketRoute T = TypeVar("T") @@ -59,7 +58,8 @@ def _init_cbv(cls: Type[Any]) -> None: ] dependency_names: List[str] = [] for name, hint in get_type_hints(cls).items(): - if is_classvar(hint): + # TODO (chiragjn): Verify this + if getattr(hint, "__origin__", None) is ClassVar: continue parameter_kwargs = {"default": getattr(cls, name, Ellipsis)} dependency_names.append(name) @@ -127,13 +127,13 @@ def wrapper(cls) -> ClassBasedView: for name, method in cls.__dict__.items(): if callable(method) and hasattr(method, "method"): # Check if method is decorated with an HTTP method decorator - assert ( - hasattr(method, "__path__") and method.__path__ - ), f"Missing path for method {name}" + if not hasattr(method, "__path__") or not method.__path__: + raise ValueError(f"Missing path for method {name}") http_method = method.method # Ensure that the method is a valid HTTP method - assert http_method in http_method_names, f"Invalid method {http_method}" + if http_method not in http_method_names: + raise ValueError(f"Invalid method {http_method}") if prefix: method.__path__ = prefix + method.__path__ if not method.__path__.startswith("/"): diff --git a/backend/server/routers/collection.py b/backend/server/routers/collection.py index f97ffe38..a0155e16 100644 --- a/backend/server/routers/collection.py +++ b/backend/server/routers/collection.py @@ -29,7 +29,7 @@ async def get_collections(): if collections is None: return JSONResponse(content={"collections": []}) return JSONResponse( - content={"collections": [obj.dict() for obj in collections]} + content={"collections": [obj.model_dump() for obj in collections]} ) except Exception as exp: logger.exception("Failed to get collection") @@ -55,7 +55,7 @@ async def get_collection_by_name(collection_name: str = Path(title="Collection n collection = await client.aget_collection_by_name(collection_name) if collection is None: return JSONResponse(content={"collection": []}) - return JSONResponse(content={"collection": collection.dict()}) + return JSONResponse(content={"collection": collection.model_dump()}) except HTTPException as exp: raise exp except Exception as exp: @@ -98,7 +98,7 @@ async def create_collection(collection: CreateCollectionDto): collection_name=created_collection.name ) return JSONResponse( - content={"collection": created_collection.dict()}, status_code=201 + content={"collection": created_collection.model_dump()}, status_code=201 ) except HTTPException as exp: raise exp @@ -121,7 +121,7 @@ async def associate_data_source_to_collection( parser_config=request.parser_config, ), ) - return JSONResponse(content={"collection": collection.dict()}) + return JSONResponse(content={"collection": collection.model_dump()}) except HTTPException as exp: raise exp except Exception as exp: @@ -140,7 +140,7 @@ async def unassociate_data_source_from_collection( collection_name=request.collection_name, data_source_fqn=request.data_source_fqn, ) - return JSONResponse(content={"collection": collection.dict()}) + return JSONResponse(content={"collection": collection.model_dump()}) except HTTPException as exp: raise exp except Exception as exp: @@ -182,7 +182,9 @@ async def list_data_ingestion_runs(request: ListDataIngestionRunsDto): request.collection_name, request.data_source_fqn ) return JSONResponse( - content={"data_ingestion_runs": [obj.dict() for obj in data_ingestion_runs]} + content={ + "data_ingestion_runs": [obj.model_dump() for obj in data_ingestion_runs] + } ) diff --git a/backend/server/routers/data_source.py b/backend/server/routers/data_source.py index ff40d686..ef821b5d 100644 --- a/backend/server/routers/data_source.py +++ b/backend/server/routers/data_source.py @@ -17,7 +17,7 @@ async def get_data_source(): client = await get_client() data_sources = await client.aget_data_sources() return JSONResponse( - content={"data_sources": [obj.dict() for obj in data_sources]} + content={"data_sources": [obj.model_dump() for obj in data_sources]} ) except Exception as exp: logger.exception("Failed to get data source") @@ -45,7 +45,7 @@ async def add_data_source( client = await get_client() created_data_source = await client.acreate_data_source(data_source=data_source) return JSONResponse( - content={"data_source": created_data_source.dict()}, status_code=201 + content={"data_source": created_data_source.model_dump()}, status_code=201 ) except HTTPException as exp: raise exp diff --git a/backend/server/routers/internal.py b/backend/server/routers/internal.py index 9d02052c..6d152c42 100644 --- a/backend/server/routers/internal.py +++ b/backend/server/routers/internal.py @@ -88,12 +88,12 @@ async def upload_to_data_directory(req: UploadToDataDirectoryDto): paths=req.filepaths, ) - data = [url.dict() for url in urls] + data = [url.model_dump() for url in urls] return JSONResponse( content={"data": data, "data_directory_fqn": dataset.fqn}, ) except Exception as ex: - raise Exception(f"Error uploading files to data directory: {ex}") + raise Exception(f"Error uploading files to data directory: {ex}") from ex @router.get("/models") @@ -113,7 +113,7 @@ def get_enabled_models( ) # Serialized models - serialized_models = [model.dict() for model in enabled_models] + serialized_models = [model.model_dump() for model in enabled_models] return JSONResponse( content={"models": serialized_models}, ) diff --git a/backend/server/routers/rag_apps.py b/backend/server/routers/rag_apps.py index aadf3178..48978b82 100644 --- a/backend/server/routers/rag_apps.py +++ b/backend/server/routers/rag_apps.py @@ -3,7 +3,7 @@ from backend.logger import logger from backend.modules.metadata_store.client import get_client -from backend.types import CreateRagApplication, RagApplicationDto +from backend.types import CreateRagApplication router = APIRouter(prefix="/v1/apps", tags=["apps"]) @@ -18,7 +18,7 @@ async def register_rag_app( client = await get_client() created_rag_app = await client.acreate_rag_app(rag_app) return JSONResponse( - content={"rag_app": created_rag_app.dict()}, status_code=201 + content={"rag_app": created_rag_app.model_dump()}, status_code=201 ) except HTTPException as exp: raise exp @@ -47,7 +47,7 @@ async def get_rag_app_by_name(app_name: str = Path(title="App name")): rag_app = await client.aget_rag_app(app_name) if rag_app is None: return JSONResponse(content={"rag_app": []}) - return JSONResponse(content={"rag_app": rag_app.dict()}) + return JSONResponse(content={"rag_app": rag_app.model_dump()}) except HTTPException as exp: raise exp diff --git a/backend/settings.py b/backend/settings.py index 7945b2c6..fbbd1fb9 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -1,7 +1,8 @@ import os -from typing import Optional +from typing import Any, Dict -from pydantic import BaseSettings, root_validator +from pydantic import ConfigDict, model_validator +from pydantic_settings import BaseSettings from backend.types import MetadataStoreConfig, VectorDBConfig @@ -11,8 +12,7 @@ class Settings(BaseSettings): Settings class to hold all the environment variables """ - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") MODELS_CONFIG_PATH: str METADATA_STORE_CONFIG: MetadataStoreConfig @@ -30,8 +30,15 @@ class Config: os.path.join(os.path.dirname(os.path.dirname(__file__)), "user_data") ) - @root_validator(pre=True) - def _validate_values(cls, values): + @model_validator(mode="before") + @classmethod + def _validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate search type.""" + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) + models_config_path = values.get("MODELS_CONFIG_PATH") if not os.path.isabs(models_config_path): this_dir = os.path.abspath(os.path.dirname(__file__)) @@ -39,7 +46,7 @@ def _validate_values(cls, values): models_config_path = os.path.join(root_dir, models_config_path) if not models_config_path: - raise Exception( + raise ValueError( f"{models_config_path} does not exist. " f"You can copy models_config.sample.yaml to {settings.MODELS_CONFIG_PATH} to bootstrap config" ) diff --git a/backend/types.py b/backend/types.py index 67a167c8..5b8f628a 100644 --- a/backend/types.py +++ b/backend/types.py @@ -1,13 +1,23 @@ import enum -import json import uuid from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union -from pydantic import BaseModel, Field, constr, root_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + StringConstraints, + computed_field, + model_serializer, + model_validator, +) +from typing_extensions import Annotated from backend.constants import FQN_SEPARATOR +# TODO (chiragjn): Remove Optional from Dict and List type fields. Instead just use a default_factory + class DataIngestionMode(str, Enum): """ @@ -42,7 +52,8 @@ class DataPoint(BaseModel): title="Hash of the data point for the given data source that is guaranteed to be updated for any update in data point at source", ) - metadata: Optional[Dict[str, str]] = Field( + metadata: Optional[Dict[str, Any]] = Field( + None, title="Additional metadata for the data point", ) @@ -84,9 +95,11 @@ class LoadedDataPoint(DataPoint): title="Local file path of the loaded data point", ) file_extension: Optional[str] = Field( + None, title="File extension of the loaded data point", ) local_metadata_file_path: Optional[str] = Field( + None, title="Local file path of the metadata file", ) @@ -118,12 +131,12 @@ def to_dict(self) -> Dict[str, Any]: class ModelProviderConfig(BaseModel): provider_name: str api_format: str + base_url: Optional[str] = None + api_key_env_var: str + default_headers: Dict[str, str] = Field(default_factory=dict) llm_model_ids: List[str] = Field(default_factory=list) embedding_model_ids: List[str] = Field(default_factory=list) reranking_model_ids: List[str] = Field(default_factory=list) - api_key_env_var: str - base_url: Optional[str] = None - default_headers: Dict[str, str] = Field(default_factory=dict) class EmbedderConfig(ModelConfig): @@ -152,7 +165,7 @@ class VectorDBConfig(BaseModel): local: bool = False url: Optional[str] = None api_key: Optional[str] = None - config: Optional[dict] = None + config: Optional[Dict[str, Any]] = Field(default_factory=dict) class QdrantClientConfig(BaseModel): @@ -160,8 +173,7 @@ class QdrantClientConfig(BaseModel): Qdrant extra configuration """ - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") port: Optional[int] = None grpc_port: int = 6334 @@ -176,7 +188,7 @@ class MetadataStoreConfig(BaseModel): """ provider: str - config: Optional[dict] = Field(default_factory=dict) + config: Optional[Dict[str, Any]] = Field(default_factory=dict) class RetrieverConfig(BaseModel): @@ -186,7 +198,8 @@ class RetrieverConfig(BaseModel): search_type: Literal["mmr", "similarity"] = Field( default="similarity", - title="""Defines the type of search that the Retriever should perform. Can be "similarity" (default), "mmr", or "similarity_score_threshold".""", + title="""Defines the type of search that the Retriever should perform. \ + Can be "similarity" (default), "mmr", or "similarity_score_threshold".""", ) k: int = Field( default=4, @@ -196,7 +209,7 @@ class RetrieverConfig(BaseModel): default=20, title="""Amount of documents to pass to MMR algorithm (Default: 20)""", ) - filter: Optional[dict] = Field( + filter: Optional[Dict[Any, Any]] = Field( default=None, title="""Filter by document metadata""", ) @@ -209,11 +222,12 @@ def get_search_type(self) -> str: @property def get_search_kwargs(self) -> dict: # Check at langchain.schema.vectorstore.VectorStore.as_retriever - match self.search_type: - case "similarity": - return {"k": self.k, "filter": self.filter} - case "mmr": - return {"k": self.k, "fetch_k": self.fetch_k, "filter": self.filter} + if self.search_type == "similarity": + return {"k": self.k, "filter": self.filter} + elif self.search_type == "mmr": + return {"k": self.k, "fetch_k": self.fetch_k, "filter": self.filter} + else: + raise ValueError(f"Search type {self.search_type} is not supported") class DataIngestionRunStatus(str, enum.Enum): @@ -255,7 +269,7 @@ class BaseDataIngestionRun(BaseModel): title="Data ingestion mode for the data ingestion", ) - raise_error_on_failure: Optional[bool] = Field( + raise_error_on_failure: bool = Field( title="Flag to configure weather to raise error on failure or not. Default is True", default=True, ) @@ -270,6 +284,7 @@ class DataIngestionRun(BaseDataIngestionRun): title="Name of the data ingestion run", ) status: Optional[DataIngestionRunStatus] = Field( + None, title="Status of the data ingestion run", ) @@ -286,18 +301,14 @@ class BaseDataSource(BaseModel): title="A unique identifier for the data source", ) metadata: Optional[Dict[str, Any]] = Field( - title="Additional config for your data source" + None, title="Additional config for your data source" ) + @computed_field @property - def fqn(self): + def fqn(self) -> str: return f"{FQN_SEPARATOR}".join([self.type, self.uri]) - @root_validator - def validate_fqn(cls, values: Dict) -> Dict: - values["fqn"] = f"{FQN_SEPARATOR}".join([values["type"], values["uri"]]) - return values - class CreateDataSource(BaseDataSource): pass @@ -319,7 +330,7 @@ class AssociatedDataSources(BaseModel): title="Parser configuration for the data transformation", default_factory=dict ) data_source: Optional[DataSource] = Field( - title="Data source associated with the collection" + None, title="Data source associated with the collection" ) @@ -333,6 +344,7 @@ class IngestDataToCollectionDto(BaseModel): ) data_source_fqn: Optional[str] = Field( + None, title="Fully qualified name of the data source", ) @@ -341,7 +353,7 @@ class IngestDataToCollectionDto(BaseModel): title="Data ingestion mode for the data ingestion", ) - raise_error_on_failure: Optional[bool] = Field( + raise_error_on_failure: bool = Field( title="Flag to configure weather to raise error on failure or not. Default is True", default=True, ) @@ -414,12 +426,13 @@ class BaseCollection(BaseModel): Base collection configuration """ - name: constr(regex=r"^[a-z][a-z0-9-]*$") = Field( # type: ignore + name: Annotated[str, StringConstraints(pattern=r"^[a-z][a-z0-9-]*$")] = Field( # type: ignore title="a unique name to your collection", description="Should only contain lowercase alphanumeric character and hypen, should start with alphabet", example="test-collection", ) description: Optional[str] = Field( + None, title="a description for your collection", example="This is a test collection", ) @@ -442,19 +455,29 @@ class Collection(BaseCollection): title="Data sources associated with the collection", default_factory=dict ) + @model_validator(mode="before") + @classmethod + def ensure_associated_data_sources_not_none( + cls, values: Dict[str, Any] + ) -> Dict[str, Any]: + if values.get("associated_data_sources") is None: + values["associated_data_sources"] = {} + return values + class CreateCollectionDto(CreateCollection): associated_data_sources: Optional[List[AssociateDataSourceWithCollection]] = Field( - title="Data sources associated with the collection" + None, title="Data sources associated with the collection" ) class UploadToDataDirectoryDto(BaseModel): filepaths: List[str] # allow only small case alphanumeric and hyphen, should contain at least one alphabet and begin with alphabet - upload_name: str = Field( + upload_name: Annotated[ + str, StringConstraints(pattern=r"^[a-z][a-z0-9-]*$") + ] = Field( # type:ignore title="Name of the upload", - regex=r"^[a-z][a-z0-9-]*$", default=str(uuid.uuid4()), ) @@ -469,9 +492,11 @@ class ListDataIngestionRunsDto(BaseModel): class RagApplication(BaseModel): - name: str = Field( + # allow only small case alphanumeric and hyphen, should contain at least one alphabet and begin with alphabet + name: Annotated[ + str, StringConstraints(pattern=r"^[a-z][a-z0-9-]*$") + ] = Field( # type:ignore title="Name of the rag app", - regex=r"^[a-z][a-z0-9-]*$", # allow only small case alphanumeric and hyphen, should contain at least one alphabet and begin with alphabet ) config: Dict[str, Any] = Field( title="Configuration for the rag app", diff --git a/frontend/src/screens/dashboard/docsqa/NewCollection.tsx b/frontend/src/screens/dashboard/docsqa/NewCollection.tsx index db86873d..40b8f79d 100644 --- a/frontend/src/screens/dashboard/docsqa/NewCollection.tsx +++ b/frontend/src/screens/dashboard/docsqa/NewCollection.tsx @@ -73,7 +73,7 @@ const NewCollection = ({ open, onClose, onSuccess }: NewCollectionProps) => { name: collectionName, embedder_config: { name: embeddingModel.name, - type: "embedding", + type: 'embedding', }, associated_data_sources: [ { From c3f6f62209410a29a79d571805ca51860da0249a Mon Sep 17 00:00:00 2001 From: Prathamesh Date: Fri, 9 Aug 2024 17:05:43 +0530 Subject: [PATCH 13/16] Fixed pydantic --- backend/modules/parsers/multimodalparser.py | 6 ++---- backend/modules/parsers/unstructured_io.py | 1 - .../query_controllers/example/controller.py | 2 +- .../modules/query_controllers/example/types.py | 5 +++++ .../query_controllers/multimodal/controller.py | 2 +- .../modules/query_controllers/multimodal/types.py | 15 +++++++++++++++ backend/types.py | 2 +- 7 files changed, 25 insertions(+), 8 deletions(-) diff --git a/backend/modules/parsers/multimodalparser.py b/backend/modules/parsers/multimodalparser.py index 7e1b0ae1..3d36d439 100644 --- a/backend/modules/parsers/multimodalparser.py +++ b/backend/modules/parsers/multimodalparser.py @@ -67,10 +67,8 @@ def __init__( """ # Multi-modal parser needs to be configured with the openai compatible client url and vision model - if "model_configuration" in additional_config: - self.model_configuration = ModelConfig.model_validate( - additional_config["model_configuration"] - ) + if model_configuration: + self.model_configuration = ModelConfig.model_validate(model_configuration) logger.info(f"Using custom vision model..., {self.model_configuration}") else: # Truefoundry specific model configuration diff --git a/backend/modules/parsers/unstructured_io.py b/backend/modules/parsers/unstructured_io.py index 4d8ce2ba..179df881 100644 --- a/backend/modules/parsers/unstructured_io.py +++ b/backend/modules/parsers/unstructured_io.py @@ -51,7 +51,6 @@ def __init__(self, *, max_chunk_size: int = 2000, **kwargs): self.adapter = HTTPAdapter(max_retries=self.retry_strategy) self.session.mount("https://", self.adapter) self.session.mount("http://", self.adapter) - super().__init__(*args, **kwargs) super().__init__(**kwargs) diff --git a/backend/modules/query_controllers/example/controller.py b/backend/modules/query_controllers/example/controller.py index f4599afb..e941a077 100644 --- a/backend/modules/query_controllers/example/controller.py +++ b/backend/modules/query_controllers/example/controller.py @@ -5,7 +5,6 @@ import async_timeout from fastapi import Body, HTTPException from fastapi.responses import StreamingResponse -from langchain.docstore.document import Document from langchain.prompts import PromptTemplate from langchain.retrievers import ContextualCompressionRetriever, MultiQueryRetriever from langchain.schema.vectorstore import VectorStoreRetriever @@ -26,6 +25,7 @@ GENERATION_TIMEOUT_SEC, Answer, Docs, + Document, ExampleQueryInput, ) from backend.modules.vector_db.client import VECTOR_STORE_CLIENT diff --git a/backend/modules/query_controllers/example/types.py b/backend/modules/query_controllers/example/types.py index fb45fa77..8ac0d99a 100644 --- a/backend/modules/query_controllers/example/types.py +++ b/backend/modules/query_controllers/example/types.py @@ -179,6 +179,11 @@ def validate_retriever_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values +class Document(BaseModel): + page_content: str + metadata: dict = Field(default_factory=dict) + + class Answer(BaseModel): type: str = "answer" content: str diff --git a/backend/modules/query_controllers/multimodal/controller.py b/backend/modules/query_controllers/multimodal/controller.py index fcd1ec26..2d102f38 100644 --- a/backend/modules/query_controllers/multimodal/controller.py +++ b/backend/modules/query_controllers/multimodal/controller.py @@ -4,7 +4,6 @@ import async_timeout from fastapi import Body, HTTPException from fastapi.responses import StreamingResponse -from langchain.docstore.document import Document from langchain.prompts import PromptTemplate from langchain.retrievers import ContextualCompressionRetriever, MultiQueryRetriever from langchain.schema.vectorstore import VectorStoreRetriever @@ -25,6 +24,7 @@ GENERATION_TIMEOUT_SEC, Answer, Docs, + Document, ExampleQueryInput, ) from backend.modules.vector_db.client import VECTOR_STORE_CLIENT diff --git a/backend/modules/query_controllers/multimodal/types.py b/backend/modules/query_controllers/multimodal/types.py index 53c037e2..f4d9b8ab 100644 --- a/backend/modules/query_controllers/multimodal/types.py +++ b/backend/modules/query_controllers/multimodal/types.py @@ -178,6 +178,21 @@ def validate_retriever_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values +class Document(BaseModel): + page_content: str + metadata: dict = Field(default_factory=dict) + + +class Answer(BaseModel): + type: str = "answer" + content: str + + +class Docs(BaseModel): + type: str = "docs" + content: List[Document] = Field(default_factory=list) + + class Answer(BaseModel): type: str = "answer" content: str diff --git a/backend/types.py b/backend/types.py index 5b8f628a..1e29c478 100644 --- a/backend/types.py +++ b/backend/types.py @@ -117,7 +117,7 @@ class ModelType(str, Enum): class ModelConfig(BaseModel): name: str - type: Optional[ModelType] + type: Optional[ModelType] = None parameters: Optional[Dict[str, Any]] = Field(default_factory=dict) def to_dict(self) -> Dict[str, Any]: From cb84e1691b7b59fa5b419be9a0d58536c59414dc Mon Sep 17 00:00:00 2001 From: Prathamesh Date: Fri, 9 Aug 2024 21:51:21 +0530 Subject: [PATCH 14/16] Removed model_serializer, replaced with model_dump, added enum_values --- backend/modules/metadata_store/prismastore.py | 18 +++++----- backend/types.py | 35 ++++++------------- 2 files changed, 20 insertions(+), 33 deletions(-) diff --git a/backend/modules/metadata_store/prismastore.py b/backend/modules/metadata_store/prismastore.py index cc0bdc49..0c3130ce 100644 --- a/backend/modules/metadata_store/prismastore.py +++ b/backend/modules/metadata_store/prismastore.py @@ -269,7 +269,7 @@ async def aassociate_data_source_with_collection( data_source_fqn, data_source, ) in existing_collection_associated_data_sources.items(): - associated_data_sources[data_source_fqn] = data_source.serialize() + associated_data_sources[data_source_fqn] = data_source.model_dump() updated_collection: Optional[ "PrismaCollection" @@ -296,7 +296,7 @@ async def aunassociate_data_source_with_collection( self, collection_name: str, data_source_fqn: str ) -> Collection: try: - collection = await self.aget_collection_by_name(collection_name) + collection: Collection = await self.aget_collection_by_name(collection_name) except Exception as e: logger.exception(f"Error: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") @@ -321,7 +321,9 @@ async def aunassociate_data_source_with_collection( detail=f"Data source with fqn {data_source_fqn} does not exist", ) - associated_data_sources = collection.associated_data_sources + associated_data_sources: AssociatedDataSources = ( + collection.associated_data_sources + ) if not associated_data_sources: logger.error( f"No associated data sources found for collection {collection_name}" @@ -341,16 +343,16 @@ async def aunassociate_data_source_with_collection( associated_data_sources.pop(data_source_fqn, None) - associated_data_sources = { - k: v.serialize() for k, v in associated_data_sources.items() - } - try: updated_collection: Optional[ "PrismaCollection" ] = await self.db.collection.update( where={"name": collection_name}, - data={"associated_data_sources": json.dumps(associated_data_sources)}, + data={ + "associated_data_sources": json.dumps( + associated_data_sources.model_dump() + ) + }, ) if not updated_collection: raise HTTPException( diff --git a/backend/types.py b/backend/types.py index 65b51a50..363a6ca0 100644 --- a/backend/types.py +++ b/backend/types.py @@ -9,7 +9,6 @@ Field, StringConstraints, computed_field, - model_serializer, model_validator, ) from typing_extensions import Annotated @@ -28,6 +27,9 @@ class DataIngestionMode(str, Enum): INCREMENTAL = "INCREMENTAL" FULL = "FULL" + class Config: + use_enum_values = True + class DataPoint(BaseModel): """ @@ -114,6 +116,9 @@ class ModelType(str, Enum): reranking = "reranking" parser = "parser" + class Config: + use_enum_values = True + class ModelConfig(BaseModel): name: str @@ -156,14 +161,6 @@ class ParserConfig(ModelConfig): type: ModelType = ModelType.parser - @model_serializer - def serialize(self): - return { - "name": self.name, - "parameters": self.parameters, - "type": self.type.value, - } - class VectorDBConfig(BaseModel): """ @@ -255,6 +252,9 @@ class DataIngestionRunStatus(str, enum.Enum): COMPLETED = "COMPLETED" ERROR = "ERROR" + class Config: + use_enum_values = True + class BaseDataIngestionRun(BaseModel): """ @@ -324,14 +324,7 @@ class CreateDataSource(BaseDataSource): class DataSource(BaseDataSource): - @model_serializer - def serialize(self): - return { - "type": self.type, - "uri": self.uri, - "metadata": self.metadata, - "fqn": self.fqn, - } + pass class AssociatedDataSources(BaseModel): @@ -349,14 +342,6 @@ class AssociatedDataSources(BaseModel): None, title="Data source associated with the collection" ) - @model_serializer - def serialize(self): - return { - "data_source_fqn": self.data_source_fqn, - "parser_config": {k: v.serialize() for k, v in self.parser_config.items()}, - "data_source": self.data_source.serialize() if self.data_source else None, - } - class IngestDataToCollectionDto(BaseModel): """ From e062e09b09a7f7c711e99cdae8a6c3b618e93883 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Sun, 11 Aug 2024 11:19:17 +0530 Subject: [PATCH 15/16] Fix enum values config and separate model type and module type --- backend/indexer/types.py | 14 +++- .../query_controllers/example/types.py | 12 +-- .../query_controllers/multimodal/types.py | 16 ++-- backend/types.py | 83 ++++++++++--------- 4 files changed, 68 insertions(+), 57 deletions(-) diff --git a/backend/indexer/types.py b/backend/indexer/types.py index 5bdc0500..e7fd2142 100644 --- a/backend/indexer/types.py +++ b/backend/indexer/types.py @@ -1,11 +1,17 @@ -from typing import Dict, Union +from typing import Dict -from pydantic import BaseModel, Field +from pydantic import Field -from backend.types import DataIngestionMode, DataSource, EmbedderConfig, ParserConfig +from backend.types import ( + ConfiguredBaseModel, + DataIngestionMode, + DataSource, + EmbedderConfig, + ParserConfig, +) -class DataIngestionConfig(BaseModel): +class DataIngestionConfig(ConfiguredBaseModel): """ Configuration to store Data Ingestion Configuration """ diff --git a/backend/modules/query_controllers/example/types.py b/backend/modules/query_controllers/example/types.py index 8ac0d99a..926f6b51 100644 --- a/backend/modules/query_controllers/example/types.py +++ b/backend/modules/query_controllers/example/types.py @@ -3,14 +3,14 @@ from pydantic import BaseModel, Field, model_validator from qdrant_client.models import Filter as QdrantFilter -from backend.types import ModelConfig +from backend.types import ConfiguredBaseModel, ModelConfig GENERATION_TIMEOUT_SEC = 60.0 * 10 # TODO (chiragjn): Remove all asserts and replace them with proper pydantic validations or raises -class VectorStoreRetrieverConfig(BaseModel): +class VectorStoreRetrieverConfig(ConfiguredBaseModel): """ Configuration for VectorStore Retriever """ @@ -99,7 +99,7 @@ class ContextualCompressionMultiQueryRetrieverConfig( pass -class ExampleQueryInput(BaseModel): +class ExampleQueryInput(ConfiguredBaseModel): """ Model for Query input. Requires a Sequence name, retriever configuration, query, LLM configuration and prompt template. @@ -179,16 +179,16 @@ def validate_retriever_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values -class Document(BaseModel): +class Document(ConfiguredBaseModel): page_content: str metadata: dict = Field(default_factory=dict) -class Answer(BaseModel): +class Answer(ConfiguredBaseModel): type: str = "answer" content: str -class Docs(BaseModel): +class Docs(ConfiguredBaseModel): type: str = "docs" content: List[Document] = Field(default_factory=list) diff --git a/backend/modules/query_controllers/multimodal/types.py b/backend/modules/query_controllers/multimodal/types.py index f4d9b8ab..349b6697 100644 --- a/backend/modules/query_controllers/multimodal/types.py +++ b/backend/modules/query_controllers/multimodal/types.py @@ -3,14 +3,14 @@ from pydantic import BaseModel, Field, model_validator from qdrant_client.models import Filter as QdrantFilter -from backend.types import ModelConfig +from backend.types import ConfiguredBaseModel, ModelConfig GENERATION_TIMEOUT_SEC = 60.0 * 10 # TODO (chiragjn): Remove all asserts and replace them with proper pydantic validations or raises -class VectorStoreRetrieverConfig(BaseModel): +class VectorStoreRetrieverConfig(ConfiguredBaseModel): """ Configuration for VectorStore Retriever """ @@ -98,7 +98,7 @@ class ContextualCompressionMultiQueryRetrieverConfig( pass -class ExampleQueryInput(BaseModel): +class ExampleQueryInput(ConfiguredBaseModel): """ Model for Query input. Requires a Sequence name, retriever configuration, query, LLM configuration and prompt template. @@ -178,26 +178,26 @@ def validate_retriever_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: return values -class Document(BaseModel): +class Document(ConfiguredBaseModel): page_content: str metadata: dict = Field(default_factory=dict) -class Answer(BaseModel): +class Answer(ConfiguredBaseModel): type: str = "answer" content: str -class Docs(BaseModel): +class Docs(ConfiguredBaseModel): type: str = "docs" content: List[Document] = Field(default_factory=list) -class Answer(BaseModel): +class Answer(ConfiguredBaseModel): type: str = "answer" content: str -class Docs(BaseModel): +class Docs(ConfiguredBaseModel): type: str = "docs" content: List[Document] = Field(default_factory=list) diff --git a/backend/types.py b/backend/types.py index 363a6ca0..c805fe5b 100644 --- a/backend/types.py +++ b/backend/types.py @@ -18,6 +18,10 @@ # TODO (chiragjn): Remove Optional from Dict and List type fields. Instead just use a default_factory +class ConfiguredBaseModel(BaseModel): + model_config = ConfigDict(use_enum_values=True) + + class DataIngestionMode(str, Enum): """ Data Ingestion Modes @@ -27,11 +31,8 @@ class DataIngestionMode(str, Enum): INCREMENTAL = "INCREMENTAL" FULL = "FULL" - class Config: - use_enum_values = True - -class DataPoint(BaseModel): +class DataPoint(ConfiguredBaseModel): """ Data point describes a single data point in the data source Properties: @@ -64,7 +65,7 @@ def data_point_fqn(self) -> str: return f"{FQN_SEPARATOR}".join([self.data_source_fqn, self.data_point_uri]) -class DataPointVector(BaseModel): +class DataPointVector(ConfiguredBaseModel): """ Data point vector describes a single data point in the vector store Additional Properties: @@ -114,28 +115,17 @@ class ModelType(str, Enum): chat = "chat" embedding = "embedding" reranking = "reranking" - parser = "parser" - - class Config: - use_enum_values = True -class ModelConfig(BaseModel): +class ModelConfig(ConfiguredBaseModel): name: str # TODO (chiragjn): This should not be Optional! Changing might break backward compatibility # Problem is we have shared these entities between DTO layers and Service / DB layers type: Optional[ModelType] = None parameters: Optional[Dict[str, Any]] = Field(default_factory=dict) - def to_dict(self) -> Dict[str, Any]: - return { - "name": self.name, - "type": self.type, - "parameters": self.parameters, - } - -class ModelProviderConfig(BaseModel): +class ModelProviderConfig(ConfiguredBaseModel): provider_name: str api_format: str base_url: Optional[str] = None @@ -146,23 +136,41 @@ class ModelProviderConfig(BaseModel): reranking_model_ids: List[str] = Field(default_factory=list) -class EmbedderConfig(ModelConfig): +class ModuleType(str, Enum): + embedding = "embedding" + parser = "parser" + + +class ModuleConfig(ConfiguredBaseModel): + name: str + type: ModuleType + parameters: Optional[Dict[str, Any]] = Field(default_factory=dict) + + @model_validator(mode="before") + @classmethod + def ensure_type_not_none(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if values.get("type") is None: + values.pop("type", None) + return values + + +class EmbedderConfig(ModuleConfig): """ Embedder configuration """ - pass + type: Literal[ModuleType.embedding] = ModuleType.embedding -class ParserConfig(ModelConfig): +class ParserConfig(ModuleConfig): """ Parser configuration """ - type: ModelType = ModelType.parser + type: Literal[ModuleType.parser] = ModuleType.parser -class VectorDBConfig(BaseModel): +class VectorDBConfig(ConfiguredBaseModel): """ Vector db configuration """ @@ -174,7 +182,7 @@ class VectorDBConfig(BaseModel): config: Optional[Dict[str, Any]] = Field(default_factory=dict) -class QdrantClientConfig(BaseModel): +class QdrantClientConfig(ConfiguredBaseModel): """ Qdrant extra configuration """ @@ -188,7 +196,7 @@ class QdrantClientConfig(BaseModel): timeout: int = 300 -class MetadataStoreConfig(BaseModel): +class MetadataStoreConfig(ConfiguredBaseModel): """ Metadata store configuration """ @@ -197,7 +205,7 @@ class MetadataStoreConfig(BaseModel): config: Optional[Dict[str, Any]] = Field(default_factory=dict) -class RetrieverConfig(BaseModel): +class RetrieverConfig(ConfiguredBaseModel): """ Retriever configuration """ @@ -252,11 +260,8 @@ class DataIngestionRunStatus(str, enum.Enum): COMPLETED = "COMPLETED" ERROR = "ERROR" - class Config: - use_enum_values = True - -class BaseDataIngestionRun(BaseModel): +class BaseDataIngestionRun(ConfiguredBaseModel): """ Base data ingestion run configuration """ @@ -298,7 +303,7 @@ class DataIngestionRun(BaseDataIngestionRun): ) -class BaseDataSource(BaseModel): +class BaseDataSource(ConfiguredBaseModel): """ Data source configuration """ @@ -327,7 +332,7 @@ class DataSource(BaseDataSource): pass -class AssociatedDataSources(BaseModel): +class AssociatedDataSources(ConfiguredBaseModel): """ Associated data source configuration """ @@ -343,7 +348,7 @@ class AssociatedDataSources(BaseModel): ) -class IngestDataToCollectionDto(BaseModel): +class IngestDataToCollectionDto(ConfiguredBaseModel): """ Configuration to ingest data to collection """ @@ -378,7 +383,7 @@ class IngestDataToCollectionDto(BaseModel): ) -class AssociateDataSourceWithCollection(BaseModel): +class AssociateDataSourceWithCollection(ConfiguredBaseModel): """ Configuration to associate data source to collection """ @@ -417,7 +422,7 @@ class AssociateDataSourceWithCollectionDto(AssociateDataSourceWithCollection): ) -class UnassociateDataSourceWithCollectionDto(BaseModel): +class UnassociateDataSourceWithCollectionDto(ConfiguredBaseModel): """ Configuration to unassociate data source to collection """ @@ -430,7 +435,7 @@ class UnassociateDataSourceWithCollectionDto(BaseModel): ) -class BaseCollection(BaseModel): +class BaseCollection(ConfiguredBaseModel): """ Base collection configuration """ @@ -480,7 +485,7 @@ class CreateCollectionDto(CreateCollection): ) -class UploadToDataDirectoryDto(BaseModel): +class UploadToDataDirectoryDto(ConfiguredBaseModel): filepaths: List[str] # allow only small case alphanumeric and hyphen, should contain at least one alphabet and begin with alphabet upload_name: Annotated[ @@ -491,7 +496,7 @@ class UploadToDataDirectoryDto(BaseModel): ) -class ListDataIngestionRunsDto(BaseModel): +class ListDataIngestionRunsDto(ConfiguredBaseModel): collection_name: str = Field( title="Name of the collection", ) @@ -500,7 +505,7 @@ class ListDataIngestionRunsDto(BaseModel): ) -class RagApplication(BaseModel): +class RagApplication(ConfiguredBaseModel): # allow only small case alphanumeric and hyphen, should contain at least one alphabet and begin with alphabet name: Annotated[ str, StringConstraints(pattern=r"^[a-z][a-z0-9-]*$") From 182bafeb38c6d7ef2bab21d99f02eece72266954 Mon Sep 17 00:00:00 2001 From: Chirag Jain Date: Sun, 11 Aug 2024 17:21:48 +0530 Subject: [PATCH 16/16] Fix frontend types to match up with backend types --- backend/modules/parsers/multimodalparser.py | 11 +++--- backend/server/routers/internal.py | 5 ++- backend/types.py | 38 +++++++++---------- docker-compose.yaml | 2 +- .../dashboard/docsqa/NewCollection.tsx | 1 - .../dashboard/docsqa/settings/index.tsx | 2 +- frontend/src/stores/qafoundry/index.ts | 28 +++++++++----- 7 files changed, 46 insertions(+), 41 deletions(-) diff --git a/backend/modules/parsers/multimodalparser.py b/backend/modules/parsers/multimodalparser.py index 3d36d439..3c26dbdb 100644 --- a/backend/modules/parsers/multimodalparser.py +++ b/backend/modules/parsers/multimodalparser.py @@ -17,7 +17,7 @@ from backend.modules.model_gateway.model_gateway import model_gateway from backend.modules.parsers.parser import BaseParser from backend.modules.parsers.utils import contains_text -from backend.types import ModelConfig +from backend.types import ModelConfig, ModelType def stringToRGB(base64_string: str): @@ -45,9 +45,8 @@ class MultiModalParser(BaseParser): Parser Configuration will look like the following while creating the collection: { ".pdf": { - "parser": "MultiModalParser", - "kwargs": { - "chunk_size": 1000, + "name": "MultiModalParser", + "parameters": { "model_configuration": { "name" : "truefoundry/openai-main/gpt-4o-mini" }, @@ -65,7 +64,6 @@ def __init__( """ Initializes the MultiModalParser object. """ - # Multi-modal parser needs to be configured with the openai compatible client url and vision model if model_configuration: self.model_configuration = ModelConfig.model_validate(model_configuration) @@ -73,7 +71,8 @@ def __init__( else: # Truefoundry specific model configuration self.model_configuration = ModelConfig( - name="truefoundry/openai-main/gpt-4o-mini" + name="truefoundry/openai-main/gpt-4o-mini", + type=ModelType.chat, ) if prompt: diff --git a/backend/server/routers/internal.py b/backend/server/routers/internal.py index 6d152c42..248f3ec1 100644 --- a/backend/server/routers/internal.py +++ b/backend/server/routers/internal.py @@ -31,7 +31,7 @@ async def upload_to_docker_directory( status_code=500, ) try: - logger.info(f"Uploading files to docker directory: {upload_name}") + logger.info(f"Uploading files to directory: {upload_name}") # create a folder within `/volumes/user_data/` that maps to `/app/user_data/` in the docker volume # this folder will be used to store the uploaded files folder_path = os.path.join(settings.LOCAL_DATA_DIRECTORY, upload_name) @@ -60,8 +60,9 @@ async def upload_to_docker_directory( # Add the data source to the metadata store. return await add_data_source(data_source) except Exception as ex: + logger.exception(f"Error uploading files to directory: {ex}") return JSONResponse( - content={"error": f"Error uploading files to docker directory: {ex}"}, + content={"error": f"Error uploading files to directory: {ex}"}, status_code=500, ) diff --git a/backend/types.py b/backend/types.py index c805fe5b..02bcab3a 100644 --- a/backend/types.py +++ b/backend/types.py @@ -136,38 +136,36 @@ class ModelProviderConfig(ConfiguredBaseModel): reranking_model_ids: List[str] = Field(default_factory=list) -class ModuleType(str, Enum): - embedding = "embedding" - parser = "parser" - +class EmbedderConfig(ConfiguredBaseModel): + """ + Embedder configuration + """ -class ModuleConfig(ConfiguredBaseModel): name: str - type: ModuleType - parameters: Optional[Dict[str, Any]] = Field(default_factory=dict) + parameters: Dict[str, Any] = Field(default_factory=dict) @model_validator(mode="before") @classmethod - def ensure_type_not_none(cls, values: Dict[str, Any]) -> Dict[str, Any]: - if values.get("type") is None: - values.pop("type", None) + def ensure_parameters_not_none(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if values.get("parameters") is None: + values.pop("parameters", None) return values -class EmbedderConfig(ModuleConfig): - """ - Embedder configuration - """ - - type: Literal[ModuleType.embedding] = ModuleType.embedding - - -class ParserConfig(ModuleConfig): +class ParserConfig(ConfiguredBaseModel): """ Parser configuration """ - type: Literal[ModuleType.parser] = ModuleType.parser + name: str + parameters: Dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="before") + @classmethod + def ensure_parameters_not_none(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if values.get("parameters") is None: + values.pop("parameters", None) + return values class VectorDBConfig(ConfiguredBaseModel): diff --git a/docker-compose.yaml b/docker-compose.yaml index d6951ffc..00c691ac 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -71,7 +71,7 @@ services: environment: - INFINITY_MODEL_ID=${INFINITY_EMBEDDING_MODEL};${INFINITY_RERANKING_MODEL} - INFINITY_BATCH_SIZE=8 - - API_KEY=${INFINITY_API_KEY} + - INFINITY_API_KEY=${INFINITY_API_KEY} command: v2 networks: - cognita-docker diff --git a/frontend/src/screens/dashboard/docsqa/NewCollection.tsx b/frontend/src/screens/dashboard/docsqa/NewCollection.tsx index 40b8f79d..8a0c7103 100644 --- a/frontend/src/screens/dashboard/docsqa/NewCollection.tsx +++ b/frontend/src/screens/dashboard/docsqa/NewCollection.tsx @@ -73,7 +73,6 @@ const NewCollection = ({ open, onClose, onSuccess }: NewCollectionProps) => { name: collectionName, embedder_config: { name: embeddingModel.name, - type: 'embedding', }, associated_data_sources: [ { diff --git a/frontend/src/screens/dashboard/docsqa/settings/index.tsx b/frontend/src/screens/dashboard/docsqa/settings/index.tsx index 6711527c..7d4e2e12 100644 --- a/frontend/src/screens/dashboard/docsqa/settings/index.tsx +++ b/frontend/src/screens/dashboard/docsqa/settings/index.tsx @@ -150,7 +150,7 @@ const DocsQASettings = () => { {collectionDetails && (
Embedder Used :{' '} - {collectionDetails?.embedder_config?.config?.model} + {collectionDetails?.embedder_config?.name}
)} diff --git a/frontend/src/stores/qafoundry/index.ts b/frontend/src/stores/qafoundry/index.ts index 10ce3d16..d7127b84 100644 --- a/frontend/src/stores/qafoundry/index.ts +++ b/frontend/src/stores/qafoundry/index.ts @@ -3,8 +3,15 @@ import { createApi } from '@reduxjs/toolkit/query/react' // import * as T from './types' import { createBaseQuery } from '../utils' +export enum ModelType { + chat = 'chat', + embedding = 'embedding', + reranking = 'reranking', +} + export interface ModelConfig { name: string + type?: ModelType parameters: { temperature?: number maximum_length?: number @@ -40,23 +47,25 @@ interface DataSource { fqn: string } +interface ParserConfig { + name: string + parameters?: { + [key: string]: any + } +} + export interface AssociatedDataSource { data_source_fqn: string parser_config: { - chunk_size: number - chunk_overlap: number - parser_map: { - [key: string]: string - } + [key: string]: ParserConfig } data_source: DataSource } interface EmbedderConfig { - description?: string - provider?: string - config?: { - model: string + name: string + parameters?: { + [key: string]: any } } @@ -67,7 +76,6 @@ export interface Collection { associated_data_sources: { [key: string]: AssociatedDataSource } - chunk_size?: number } interface AddDataSourcePayload {