diff --git a/backend/indexer/indexer.py b/backend/indexer/indexer.py index b8161fcc..cfd9162e 100644 --- a/backend/indexer/indexer.py +++ b/backend/indexer/indexer.py @@ -307,7 +307,7 @@ async def ingest_data(request: IngestDataToCollectionDto): # convert to pydantic model if not already -> For prisma models if not isinstance(collection, Collection): - collection = Collection(**collection.dict()) + collection = Collection(**collection.model_dump()) if not collection: logger.error( diff --git a/backend/indexer/types.py b/backend/indexer/types.py index 5bdc0500..e7fd2142 100644 --- a/backend/indexer/types.py +++ b/backend/indexer/types.py @@ -1,11 +1,17 @@ -from typing import Dict, Union +from typing import Dict -from pydantic import BaseModel, Field +from pydantic import Field -from backend.types import DataIngestionMode, DataSource, EmbedderConfig, ParserConfig +from backend.types import ( + ConfiguredBaseModel, + DataIngestionMode, + DataSource, + EmbedderConfig, + ParserConfig, +) -class DataIngestionConfig(BaseModel): +class DataIngestionConfig(ConfiguredBaseModel): """ Configuration to store Data Ingestion Configuration """ diff --git a/backend/migration/qdrant_migration.py b/backend/migration/qdrant_migration.py index 1f535ea2..0b653a83 100644 --- a/backend/migration/qdrant_migration.py +++ b/backend/migration/qdrant_migration.py @@ -90,7 +90,7 @@ def migrate_collection( "associated_data_sources" ).items() ], - ).dict() + ).model_dump() logger.debug( f"Creating '{dest_collection.get('name')}' collection at destination" diff --git a/backend/migration/utils.py b/backend/migration/utils.py index d0198915..afc15732 100644 --- a/backend/migration/utils.py +++ b/backend/migration/utils.py @@ -2,7 +2,6 @@ from typing import Dict import requests -from qdrant_client._pydantic_compat import to_dict from qdrant_client.client_base import QdrantBase from qdrant_client.http import models from tqdm import tqdm @@ -104,11 +103,11 @@ def _recreate_collection( replication_factor=src_config.params.replication_factor, write_consistency_factor=src_config.params.write_consistency_factor, on_disk_payload=src_config.params.on_disk_payload, - hnsw_config=models.HnswConfigDiff(**to_dict(src_config.hnsw_config)), + hnsw_config=models.HnswConfigDiff(**src_config.hnsw_config.model_dump()), optimizers_config=models.OptimizersConfigDiff( - **to_dict(src_config.optimizer_config) + **src_config.optimizer_config.model_dump() ), - wal_config=models.WalConfigDiff(**to_dict(src_config.wal_config)), + wal_config=models.WalConfigDiff(**src_config.wal_config.model_dump()), quantization_config=src_config.quantization_config, timeout=300, ) diff --git a/backend/modules/dataloaders/loader.py b/backend/modules/dataloaders/loader.py index 656c4ed0..336ebf8f 100644 --- a/backend/modules/dataloaders/loader.py +++ b/backend/modules/dataloaders/loader.py @@ -128,7 +128,7 @@ def list_dataloaders(): Returns a list of all the registered loaders. Returns: - List[dict]: A list of all the registered loaders. + List[Dict]: A list of all the registered loaders. """ global LOADER_REGISTRY return [ diff --git a/backend/modules/metadata_store/base.py b/backend/modules/metadata_store/base.py index 222e5207..534e36ab 100644 --- a/backend/modules/metadata_store/base.py +++ b/backend/modules/metadata_store/base.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from backend.constants import DATA_POINT_FQN_METADATA_KEY, FQN_SEPARATOR from backend.types import ( @@ -18,7 +18,7 @@ from backend.utils import run_in_executor -# TODO(chiragjn): Ideal would be we make `async def a*` abstract methods and drop sync ones +# TODO (chiragjn): Ideal would be we make `async def a*` abstract methods and drop sync ones # Implementations can then opt to call their sync versions using run_in_executor class BaseMetadataStore(ABC): def __init__(self, *args, **kwargs): @@ -300,7 +300,7 @@ async def aupdate_data_ingestion_run_status( def log_metrics_for_data_ingestion_run( self, data_ingestion_run_name: str, - metric_dict: dict[str, int | float], + metric_dict: Dict[str, Union[int, float]], step: int = 0, ): """ @@ -311,7 +311,7 @@ def log_metrics_for_data_ingestion_run( async def alog_metrics_for_data_ingestion_run( self, data_ingestion_run_name: str, - metric_dict: dict[str, int | float], + metric_dict: Dict[str, Union[int, float]], step: int = 0, ): """ @@ -346,6 +346,8 @@ async def alog_errors_for_data_ingestion_run( errors=errors, ) + # TODO (chiragjn): What is the difference between get_collections and this? + # TODO (chiragjn): Return complete entities, why return only str? def list_collections( self, ) -> List[str]: @@ -354,6 +356,7 @@ def list_collections( """ raise NotImplementedError() + # TODO (chiragjn): Return complete entities, why return only str? async def alist_collections( self, ) -> List[str]: @@ -365,17 +368,19 @@ async def alist_collections( self.list_collections, ) + # TODO (chiragjn): Return complete entities, why return dict? def list_data_sources( self, - ) -> List[str]: + ) -> List[Dict[str, str]]: """ List all data source names from metadata store """ raise NotImplementedError() + # TODO (chiragjn): Return complete entities, why return dict? async def alist_data_sources( self, - ) -> List[str]: + ) -> List[Dict[str, str]]: """ List all data source names from metadata store """ @@ -410,7 +415,7 @@ async def acreate_rag_app(self, app: RagApplication) -> RagApplicationDto: """ return await run_in_executor(None, self.create_rag_app, app=app) - def get_rag_app(self, app_name: str) -> RagApplicationDto | None: + def get_rag_app(self, app_name: str) -> Optional[RagApplicationDto]: """ Get a RAG application from the metadata store by name """ @@ -422,12 +427,14 @@ async def aget_rag_app(self, app_name: str) -> Optional[RagApplicationDto]: """ return await run_in_executor(None, self.get_rag_app, app_name=app_name) + # TODO (chiragjn): Return complete entities, why return only str? def list_rag_apps(self) -> List[str]: """ List all RAG application names from metadata store """ raise NotImplementedError() + # TODO (chiragjn): Return complete entities, why return only str? async def alist_rag_apps(self) -> List[str]: """ List all RAG application names from metadata store diff --git a/backend/modules/metadata_store/prismastore.py b/backend/modules/metadata_store/prismastore.py index b3cf0272..0c3130ce 100644 --- a/backend/modules/metadata_store/prismastore.py +++ b/backend/modules/metadata_store/prismastore.py @@ -4,7 +4,7 @@ import random import shutil import string -from typing import Any, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from fastapi import HTTPException from prisma import Prisma @@ -23,11 +23,26 @@ DataIngestionRunStatus, DataSource, RagApplication, - RagApplicationDto, ) +if TYPE_CHECKING: + # TODO (chiragjn): Can we import these safely even if the prisma client might not be generated yet? + from prisma.models import Collection as PrismaCollection + from prisma.models import DataSource as PrismaDataSource + from prisma.models import IngestionRuns as PrismaDataIngestionRun + from prisma.models import RagApps as PrismaRagApplication + +# TODO (chiragjn): +# - Use transactions! +# - Some methods are using json.dumps - not sure if this is the right way to send data via prisma client +# - primsa generates its own DB entity classes - ideally we should be using those instead of call +# .model_dump() on the pydantic objects. See prisma.models and prisma.actions +# + # TODO (chiragjn): Either we make everything async or add sync method to this + + class PrismaStore(BaseMetadataStore): def __init__(self, *args, db, **kwargs) -> None: self.db = db @@ -48,6 +63,22 @@ async def aconnect(cls, **kwargs): # COLLECTIONS APIS ###### + async def aget_collection_by_name( + self, collection_name: str, no_cache: bool = True + ) -> Optional[Collection]: + try: + collection: Optional[ + "PrismaCollection" + ] = await self.db.collection.find_first(where={"name": collection_name}) + if collection: + return Collection.model_validate(collection.model_dump()) + return None + except Exception as e: + logger.exception(f"Failed to get collection by name: {e}") + raise HTTPException( + status_code=500, detail="Failed to get collection by name" + ) + async def acreate_collection(self, collection: CreateCollection) -> Collection: try: existing_collection = await self.aget_collection_by_name(collection.name) @@ -63,42 +94,33 @@ async def acreate_collection(self, collection: CreateCollection) -> Collection: ) try: - logger.info(f"Creating collection: {collection.dict()}") - collection_data = collection.dict() + logger.info(f"Creating collection: {collection.model_dump()}") + collection_data = collection.model_dump() collection_data["embedder_config"] = json.dumps( collection_data["embedder_config"] ) - collection = await self.db.collection.create(data=collection_data) - return collection + collection: "PrismaCollection" = await self.db.collection.create( + data=collection_data + ) + return Collection.model_validate(collection.model_dump()) except Exception as e: logger.exception(f"Error: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") - async def aget_collection_by_name( - self, collection_name: str, no_cache: bool = True - ) -> Collection | None: - try: - collection = await self.db.collection.find_first( - where={"name": collection_name} - ) - if collection: - return collection - return None - except Exception as e: - logger.exception(f"Failed to get collection by name: {e}") - raise HTTPException( - status_code=500, detail="Failed to get collection by name" - ) - async def aget_retrieve_collection_by_name( self, collection_name: str, no_cache: bool = True - ) -> Collection | None: - return await self.aget_collection_by_name(collection_name, no_cache) + ) -> Optional[Collection]: + collection: "PrismaCollection" = await self.aget_collection_by_name( + collection_name, no_cache + ) + return Collection.model_validate(collection.model_dump()) async def aget_collections(self) -> List[Collection]: try: - collections = await self.db.collection.find_many(order={"id": "desc"}) - return collections + collections: List["PrismaCollection"] = await self.db.collection.find_many( + order={"id": "desc"} + ) + return [Collection.model_validate(c.model_dump()) for c in collections] except Exception as e: logger.exception(f"Failed to get collections: {e}") raise HTTPException(status_code=500, detail="Failed to get collections") @@ -113,17 +135,17 @@ async def alist_collections(self) -> List[str]: async def adelete_collection(self, collection_name: str, include_runs=False): try: - collection = await self.aget_collection_by_name(collection_name) - if not collection: - logger.debug(f"Collection with name {collection_name} does not exist") - except Exception as e: - logger.exception(e) - - try: - await self.db.collection.delete(where={"name": collection_name}) + deleted_collection: Optional[ + "PrismaCollection" + ] = await self.db.collection.delete(where={"name": collection_name}) + if not deleted_collection: + raise HTTPException( + status_code=404, + detail=f"Failed to delete collection {collection_name!r}. No such record found", + ) if include_runs: try: - await self.db.ingestionruns.delete_many( + _deleted_count = await self.db.ingestionruns.delete_many( where={"collection_name": collection_name} ) except Exception as e: @@ -135,6 +157,18 @@ async def adelete_collection(self, collection_name: str, include_runs=False): ###### # DATA SOURCE APIS ###### + async def aget_data_source_from_fqn(self, fqn: str) -> Optional[DataSource]: + try: + data_source: Optional[ + "PrismaDataSource" + ] = await self.db.datasource.find_first(where={"fqn": fqn}) + if data_source: + return DataSource.model_validate(data_source.model_dump()) + return None + except Exception as e: + logger.exception(f"Error: {e}") + raise HTTPException(status_code=500, detail=f"Error: {e}") + async def acreate_data_source(self, data_source: CreateDataSource) -> DataSource: try: existing_data_source = await self.aget_data_source_from_fqn(data_source.fqn) @@ -150,29 +184,21 @@ async def acreate_data_source(self, data_source: CreateDataSource) -> DataSource ) try: - data = data_source.dict() + data = data_source.model_dump() data["metadata"] = json.dumps(data["metadata"]) - data_source = await self.db.datasource.create(data) + data_source: "PrismaDataSource" = await self.db.datasource.create(data) logger.info(f"Created data source: {data_source}") - return data_source + return DataSource.model_validate(data_source.model_dump()) except Exception as e: logger.exception(f"Failed to create data source: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") - async def aget_data_source_from_fqn(self, fqn: str) -> DataSource | None: - try: - data_source = await self.db.datasource.find_first(where={"fqn": fqn}) - if data_source: - return data_source - return None - except Exception as e: - logger.exception(f"Error: {e}") - raise HTTPException(status_code=500, detail=f"Error: {e}") - async def aget_data_sources(self) -> List[DataSource]: try: - data_sources = await self.db.datasource.find_many(order={"id": "desc"}) - return data_sources + data_sources: List["PrismaDataSource"] = await self.db.datasource.find_many( + order={"id": "desc"} + ) + return [DataSource.model_validate(ds.model_dump()) for ds in data_sources] except Exception as e: logger.exception(f"Error: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") @@ -231,24 +257,33 @@ async def aassociate_data_source_with_collection( if existing_collection_associated_data_sources: existing_collection_associated_data_sources[ data_src_to_associate.data_source_fqn - ] = data_src_to_associate.dict() + ] = data_src_to_associate else: existing_collection_associated_data_sources = { - data_src_to_associate.data_source_fqn: data_src_to_associate.dict() + data_src_to_associate.data_source_fqn: data_src_to_associate } + logger.info(existing_collection_associated_data_sources) associated_data_sources: Dict[str, Dict[str, Any]] = {} for ( data_source_fqn, data_source, ) in existing_collection_associated_data_sources.items(): - associated_data_sources[data_source_fqn] = data_source + associated_data_sources[data_source_fqn] = data_source.model_dump() - updated_collection = await self.db.collection.update( + updated_collection: Optional[ + "PrismaCollection" + ] = await self.db.collection.update( where={"name": collection_name}, data={"associated_data_sources": json.dumps(associated_data_sources)}, ) - return updated_collection + if not updated_collection: + raise HTTPException( + status_code=404, + detail=f"Failed to associate data source with collection {collection_name!r}. " + f"No such record found", + ) + return Collection.model_validate(updated_collection.model_dump()) except Exception as e: logger.exception(f"Error: {e}") @@ -261,7 +296,7 @@ async def aunassociate_data_source_with_collection( self, collection_name: str, data_source_fqn: str ) -> Collection: try: - collection = await self.aget_collection_by_name(collection_name) + collection: Collection = await self.aget_collection_by_name(collection_name) except Exception as e: logger.exception(f"Error: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") @@ -286,7 +321,9 @@ async def aunassociate_data_source_with_collection( detail=f"Data source with fqn {data_source_fqn} does not exist", ) - associated_data_sources = collection.associated_data_sources + associated_data_sources: AssociatedDataSources = ( + collection.associated_data_sources + ) if not associated_data_sources: logger.error( f"No associated data sources found for collection {collection_name}" @@ -305,12 +342,26 @@ async def aunassociate_data_source_with_collection( ) associated_data_sources.pop(data_source_fqn, None) + try: - updated_collection = await self.db.collection.update( + updated_collection: Optional[ + "PrismaCollection" + ] = await self.db.collection.update( where={"name": collection_name}, - data={"associated_data_sources": json.dumps(associated_data_sources)}, + data={ + "associated_data_sources": json.dumps( + associated_data_sources.model_dump() + ) + }, ) - return updated_collection + if not updated_collection: + raise HTTPException( + status_code=404, + detail=f"Failed to unassociate data source from collection. " + f"No collection found with name {collection_name}", + ) + logger.info(f"Updated collection: {updated_collection}") + return Collection.model_validate(updated_collection.model_dump()) except Exception as e: logger.exception(f"Failed to unassociate data source with collection: {e}") raise HTTPException( @@ -320,15 +371,15 @@ async def aunassociate_data_source_with_collection( async def alist_data_sources( self, - ) -> List[dict[str, str]]: + ) -> List[Dict[str, str]]: try: data_sources = await self.aget_data_sources() - return [data_source.dict() for data_source in data_sources] + return [data_source.model_dump() for data_source in data_sources] except Exception as e: logger.exception(f"Failed to list data sources: {e}") raise HTTPException(status_code=500, detail="Failed to list data sources") - async def adelete_data_source(self, data_source_fqn: str): + async def adelete_data_source(self, data_source_fqn: str) -> None: if not settings.LOCAL: logger.error(f"Data source deletion is not allowed in local mode") raise HTTPException( @@ -373,7 +424,14 @@ async def adelete_data_source(self, data_source_fqn: str): # Delete the data source try: logger.info(f"Data source to delete: {data_source}") - await self.db.datasource.delete(where={"fqn": data_source.fqn}) + deleted_datasource: Optional[ + PrismaDataSource + ] = await self.db.datasource.delete(where={"fqn": data_source.fqn}) + if not deleted_datasource: + raise HTTPException( + status_code=404, + detail=f"Failed to delete data source {data_source.fqn!r}. No such record found", + ) # Delete the data from `/users_data` directory if data source is of type `localdir` if data_source.type == "localdir": data_source_uri = data_source.uri @@ -416,10 +474,12 @@ async def acreate_data_ingestion_run( ) try: - run_data = created_data_ingestion_run.dict() + run_data = created_data_ingestion_run.model_dump() run_data["parser_config"] = json.dumps(run_data["parser_config"]) - data_ingestion_run = await self.db.ingestionruns.create(data=run_data) - return DataIngestionRun(**data_ingestion_run.dict()) + data_ingestion_run: "PrismaDataIngestionRun" = ( + await self.db.ingestionruns.create(data=run_data) + ) + return DataIngestionRun.model_validate(data_ingestion_run.model_dump()) except Exception as e: logger.exception(f"Failed to create data ingestion run: {e}") raise HTTPException( @@ -428,14 +488,16 @@ async def acreate_data_ingestion_run( async def aget_data_ingestion_run( self, data_ingestion_run_name: str, no_cache: bool = False - ) -> DataIngestionRun | None: + ) -> Optional[DataIngestionRun]: try: - data_ingestion_run = await self.db.ingestionruns.find_first( + data_ingestion_run: Optional[ + "PrismaDataIngestionRun" + ] = await self.db.ingestionruns.find_first( where={"name": data_ingestion_run_name} ) logger.info(f"Data ingestion run: {data_ingestion_run}") if data_ingestion_run: - return DataIngestionRun(**data_ingestion_run.dict()) + return DataIngestionRun.model_validate(data_ingestion_run.model_dump()) return None except Exception as e: logger.exception(f"Failed to get data ingestion run: {e}") @@ -446,10 +508,15 @@ async def aget_data_ingestion_runs( ) -> List[DataIngestionRun]: """Get all data ingestion runs for a collection""" try: - data_ingestion_runs = await self.db.ingestionruns.find_many( + data_ingestion_runs: List[ + "PrismaDataIngestionRun" + ] = await self.db.ingestionruns.find_many( where={"collection_name": collection_name}, order={"id": "desc"} ) - return data_ingestion_runs + return [ + DataIngestionRun.model_validate(data_ir.model_dump()) + for data_ir in data_ingestion_runs + ] except Exception as e: logger.exception(f"Failed to get data ingestion runs: {e}") raise HTTPException(status_code=500, detail=f"{e}") @@ -459,10 +526,20 @@ async def aupdate_data_ingestion_run_status( ) -> DataIngestionRun: """Update the status of a data ingestion run""" try: - updated_data_ingestion_run = await self.db.ingestionruns.update( + updated_data_ingestion_run: Optional[ + "PrismaDataIngestionRun" + ] = await self.db.ingestionruns.update( where={"name": data_ingestion_run_name}, data={"status": status} ) - return updated_data_ingestion_run + if not updated_data_ingestion_run: + raise HTTPException( + status_code=404, + detail=f"Failed to update ingestion run {data_ingestion_run_name!r}. No such record found", + ) + + return DataIngestionRun.model_validate( + updated_data_ingestion_run.model_dump() + ) except Exception as e: logger.exception(f"Failed to update data ingestion run status: {e}") raise HTTPException(status_code=500, detail=f"{e}") @@ -470,20 +547,27 @@ async def aupdate_data_ingestion_run_status( async def alog_metrics_for_data_ingestion_run( self, data_ingestion_run_name: str, - metric_dict: dict[str, int | float], + metric_dict: Dict[str, Union[int, float]], step: int = 0, ): - pass + raise NotImplementedError() async def alog_errors_for_data_ingestion_run( self, data_ingestion_run_name: str, errors: Dict[str, Any] - ): + ) -> None: """Log errors for the given data ingestion run""" try: - await self.db.ingestionruns.update( + updated_data_ingestion_run: Optional[ + "PrismaDataIngestionRun" + ] = await self.db.ingestionruns.update( where={"name": data_ingestion_run_name}, data={"errors": json.dumps(errors)}, ) + if not updated_data_ingestion_run: + raise HTTPException( + status_code=404, + detail=f"Failed to update ingestion run {data_ingestion_run_name!r}. No such record found", + ) except Exception as e: logger.exception( f"Failed to log errors data ingestion run {data_ingestion_run_name}: {e}" @@ -493,9 +577,22 @@ async def alog_errors_for_data_ingestion_run( ###### # RAG APPLICATION APIS ###### + async def aget_rag_app(self, app_name: str) -> Optional[RagApplication]: + """Get a RAG application from the metadata store""" + try: + rag_app: Optional[ + "PrismaRagApplication" + ] = await self.db.ragapps.find_first(where={"name": app_name}) + if rag_app: + return RagApplication.model_validate(rag_app.model_dump()) + return None + except Exception as e: + logger.exception(f"Failed to get RAG application by name: {e}") + raise HTTPException( + status_code=500, detail="Failed to get RAG application by name" + ) - # TODO (prathamesh): Implement these methods - async def acreate_rag_app(self, app: RagApplication) -> RagApplicationDto: + async def acreate_rag_app(self, app: RagApplication) -> RagApplication: """Create a RAG application in the metadata store""" try: existing_app = await self.aget_rag_app(app.name) @@ -511,32 +608,21 @@ async def acreate_rag_app(self, app: RagApplication) -> RagApplicationDto: ) try: - logger.info(f"Creating RAG application: {app.dict()}") - rag_app_data = app.dict() + logger.info(f"Creating RAG application: {app.model_dump()}") + rag_app_data = app.model_dump() rag_app_data["config"] = json.dumps(rag_app_data["config"]) - rag_app = await self.db.ragapps.create(data=rag_app_data) - return rag_app + rag_app: "PrismaRagApplication" = await self.db.ragapps.create( + data=rag_app_data + ) + return RagApplication.model_validate(rag_app.model_dump()) except Exception as e: logger.exception(f"Error: {e}") raise HTTPException(status_code=500, detail=f"Error: {e}") - async def aget_rag_app(self, app_name: str) -> RagApplicationDto | None: - """Get a RAG application from the metadata store""" - try: - rag_app = await self.db.ragapps.find_first(where={"name": app_name}) - if rag_app: - return rag_app - return None - except Exception as e: - logger.exception(f"Failed to get RAG application by name: {e}") - raise HTTPException( - status_code=500, detail="Failed to get RAG application by name" - ) - async def alist_rag_apps(self) -> List[str]: """List all RAG applications from the metadata store""" try: - rag_apps = await self.db.ragapps.find_many() + rag_apps: List["PrismaRagApplication"] = await self.db.ragapps.find_many() return [rag_app.name for rag_app in rag_apps] except Exception as e: logger.exception(f"Failed to list RAG applications: {e}") @@ -547,14 +633,14 @@ async def alist_rag_apps(self) -> List[str]: async def adelete_rag_app(self, app_name: str): """Delete a RAG application from the metadata store""" try: - rag_app = await self.aget_rag_app(app_name) - if not rag_app: - logger.debug(f"RAG application with name {app_name} does not exist") - except Exception as e: - logger.exception(e) - - try: - await self.db.ragapps.delete(where={"name": app_name}) + deleted_rag_app: Optional[ + "PrismaRagApplication" + ] = await self.db.ragapps.delete(where={"name": app_name}) + if not deleted_rag_app: + raise HTTPException( + status_code=404, + detail=f"Failed to delete RAG application {app_name!r}. No such record found", + ) except Exception as e: logger.exception(f"Failed to delete RAG application: {e}") raise HTTPException( diff --git a/backend/modules/metadata_store/truefoundry.py b/backend/modules/metadata_store/truefoundry.py index 671d2d5a..b47b0fb1 100644 --- a/backend/modules/metadata_store/truefoundry.py +++ b/backend/modules/metadata_store/truefoundry.py @@ -3,7 +3,7 @@ import os import tempfile import warnings -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional, Union import mlflow from fastapi import HTTPException @@ -38,7 +38,7 @@ class MLRunTypes(str, enum.Enum): class TrueFoundry(BaseMetadataStore): - ml_runs: dict[str, ml.MlFoundryRun] = {} + ml_runs: Dict[str, ml.MlFoundryRun] = {} CONSTANT_DATA_SOURCE_RUN_NAME = "tfy-datasource" def __init__(self, *args, ml_repo_name: str, **kwargs): @@ -58,7 +58,7 @@ def __init__(self, *args, ml_repo_name: str, **kwargs): def _get_run_by_name( self, run_name: str, no_cache: bool = False - ) -> ml.MlFoundryRun | None: + ) -> Optional[ml.MlFoundryRun]: """ Cache the runs to avoid too many requests to the backend. """ @@ -113,7 +113,7 @@ def create_collection(self, collection: CreateCollection) -> Collection: embedder_config=collection.embedder_config, ) self._save_entity_to_run( - run=run, metadata=created_collection.dict(), params=params + run=run, metadata=created_collection.model_dump(), params=params ) run.end() logger.debug(f"[Metadata Store] Collection Saved") @@ -134,7 +134,7 @@ def _get_entity_from_run( def _get_artifact_metadata_ml_run( self, run: ml.MlFoundryRun - ) -> ml.ArtifactVersion | None: + ) -> Optional[ml.ArtifactVersion]: params = run.get_params() metadata_artifact_fqn = params.get("metadata_artifact_fqn") if not metadata_artifact_fqn: @@ -177,7 +177,7 @@ def _update_entity_in_run( def get_collection_by_name( self, collection_name: str, no_cache: bool = True - ) -> Collection | None: + ) -> Optional[Collection]: """Get collection from given collection name.""" logger.debug(f"[Metadata Store] Getting collection with name {collection_name}") ml_run = self._get_run_by_name(run_name=collection_name, no_cache=no_cache) @@ -187,14 +187,14 @@ def get_collection_by_name( ) return None collection = self._populate_collection( - Collection.parse_obj(self._get_entity_from_run(run=ml_run)) + Collection.model_validate(self._get_entity_from_run(run=ml_run)) ) logger.debug(f"[Metadata Store] Fetched collection with name {collection_name}") return collection def get_retrieve_collection_by_name( self, collection_name: str, no_cache: bool = True - ) -> Collection | None: + ) -> Optional[Collection]: """Get collection from given collection name. Used during retrieval""" logger.debug(f"[Metadata Store] Getting collection with name {collection_name}") ml_run = self._get_run_by_name(run_name=collection_name, no_cache=no_cache) @@ -203,7 +203,7 @@ def get_retrieve_collection_by_name( f"[Metadata Store] Collection with name {collection_name} not found" ) return None - collection = Collection.parse_obj(self._get_entity_from_run(run=ml_run)) + collection = Collection.model_validate(self._get_entity_from_run(run=ml_run)) logger.debug(f"[Metadata Store] Fetched collection with name {collection_name}") return collection @@ -216,7 +216,9 @@ def get_collections(self) -> List[Collection]: ) collections = [] for ml_run in ml_runs: - collection = Collection.parse_obj(self._get_entity_from_run(run=ml_run)) + collection = Collection.model_validate( + self._get_entity_from_run(run=ml_run) + ) collections.append(self._populate_collection(collection)) logger.debug(f"[Metadata Store] Listed {len(collections)} collections") return collections @@ -262,7 +264,9 @@ def associate_data_source_with_collection( f"data source with fqn {data_source_association.data_source_fqn} not found", ) # Always do this to avoid race conditions - collection = Collection.parse_obj(self._get_entity_from_run(run=collection_run)) + collection = Collection.model_validate( + self._get_entity_from_run(run=collection_run) + ) associated_data_source = AssociatedDataSources( data_source_fqn=data_source_association.data_source_fqn, parser_config=data_source_association.parser_config, @@ -271,7 +275,7 @@ def associate_data_source_with_collection( data_source_association.data_source_fqn ] = associated_data_source - self._update_entity_in_run(run=collection_run, metadata=collection.dict()) + self._update_entity_in_run(run=collection_run, metadata=collection.model_dump()) logger.debug( f"[Metadata Store] Associated data_source {data_source_association.data_source_fqn} " f"to collection {collection_name}" @@ -296,9 +300,11 @@ def unassociate_data_source_with_collection( f"Collection {collection_name} not found.", ) # Always do this to avoid run conditions - collection = Collection.parse_obj(self._get_entity_from_run(run=collection_run)) + collection = Collection.model_validate( + self._get_entity_from_run(run=collection_run) + ) collection.associated_data_sources.pop(data_source_fqn) - self._update_entity_in_run(run=collection_run, metadata=collection.dict()) + self._update_entity_in_run(run=collection_run, metadata=collection.model_dump()) logger.debug( f"[Metadata Store] Unassociated data_source {data_source_fqn} to collection {collection_name}" ) @@ -331,7 +337,7 @@ def create_data_source(self, data_source: CreateDataSource) -> DataSource: metadata=data_source.metadata, ) self._save_entity_to_run( - run=run, metadata=created_data_source.dict(), params=params + run=run, metadata=created_data_source.model_dump(), params=params ) run.end() logger.debug( @@ -339,14 +345,14 @@ def create_data_source(self, data_source: CreateDataSource) -> DataSource: ) return created_data_source - def get_data_source_from_fqn(self, fqn: str) -> DataSource | None: + def get_data_source_from_fqn(self, fqn: str) -> Optional[DataSource]: logger.debug(f"[Metadata Store] Getting data_source by fqn {fqn}") runs = self.client.search_runs( ml_repo=self.ml_repo_name, filter_string=f"params.entity_type = '{MLRunTypes.DATA_SOURCE.value}' and params.data_source_fqn = '{fqn}'", ) for run in runs: - data_source = DataSource.parse_obj(self._get_entity_from_run(run=run)) + data_source = DataSource.model_validate(self._get_entity_from_run(run=run)) logger.debug(f"[Metadata Store] Fetched Data Source with fqn {fqn}") return data_source logger.debug(f"[Metadata Store] Data Source with fqn {fqn} not found") @@ -360,7 +366,7 @@ def get_data_sources(self) -> List[DataSource]: ) data_sources: List[DataSource] = [] for run in runs: - data_source = DataSource.parse_obj(self._get_entity_from_run(run=run)) + data_source = DataSource.model_validate(self._get_entity_from_run(run=run)) data_sources.append(data_source) logger.debug(f"[Metadata Store] Listed {len(data_sources)} data sources") return data_sources @@ -395,7 +401,7 @@ def create_data_ingestion_run( status=DataIngestionRunStatus.INITIALIZED, ) self._save_entity_to_run( - run=run, metadata=created_data_ingestion_run.dict(), params=params + run=run, metadata=created_data_ingestion_run.model_dump(), params=params ) run.end() logger.debug( @@ -406,7 +412,7 @@ def create_data_ingestion_run( def get_data_ingestion_run( self, data_ingestion_run_name: str, no_cache: bool = False - ) -> DataIngestionRun | None: + ) -> Optional[DataIngestionRun]: logger.debug( f"[Metadata Store] Getting ingestion run {data_ingestion_run_name}" ) @@ -416,7 +422,7 @@ def get_data_ingestion_run( f"[Metadata Store] Ingestion run with name {data_ingestion_run_name} not found" ) return None - data_ingestion_run = DataIngestionRun.parse_obj( + data_ingestion_run = DataIngestionRun.model_validate( self._get_entity_from_run(run=run) ) run_tags = run.get_tags() @@ -447,7 +453,7 @@ def get_data_ingestion_runs( ) data_ingestion_runs: List[DataIngestionRun] = [] for run in runs: - data_ingestion_run = DataIngestionRun.parse_obj( + data_ingestion_run = DataIngestionRun.model_validate( self._get_entity_from_run(run=run) ) run_tags = run.get_tags() @@ -512,7 +518,7 @@ def update_data_ingestion_run_status( def log_metrics_for_data_ingestion_run( self, data_ingestion_run_name: str, - metric_dict: dict[str, int | float], + metric_dict: Dict[str, Union[int, float]], step: int = 0, ): try: @@ -566,7 +572,7 @@ def list_collections(self) -> List[str]: ) return [run.run_name for run in ml_runs] - def list_data_sources(self) -> List[dict[str, str]]: + def list_data_sources(self) -> List[Dict[str, str]]: logger.info(f"[Metadata Store] Listing all data sources") ml_runs = self.client.search_runs( ml_repo=self.ml_repo_name, @@ -624,12 +630,14 @@ def create_rag_app(self, app: RagApplication) -> RagApplicationDto: name=app.name, config=app.config, ) - self._save_entity_to_run(run=run, metadata=created_app.dict(), params=params) + self._save_entity_to_run( + run=run, metadata=created_app.model_dump(), params=params + ) run.end() logger.debug(f"[Metadata Store] RAG Application Saved") return created_app - def get_rag_app(self, app_name: str) -> RagApplicationDto | None: + def get_rag_app(self, app_name: str) -> Optional[RagApplicationDto]: """ Get a RAG application from the metadata store """ @@ -641,7 +649,9 @@ def get_rag_app(self, app_name: str) -> RagApplicationDto | None: ) return None try: - app = RagApplicationDto.parse_obj(self._get_entity_from_run(run=ml_run)) + app = RagApplicationDto.model_validate( + self._get_entity_from_run(run=ml_run) + ) logger.debug( f"[Metadata Store] Fetched RAG application with name {app_name}" ) diff --git a/backend/modules/model_gateway/model_gateway.py b/backend/modules/model_gateway/model_gateway.py index e94eb77c..ecea699e 100644 --- a/backend/modules/model_gateway/model_gateway.py +++ b/backend/modules/model_gateway/model_gateway.py @@ -1,4 +1,3 @@ -import json import os from typing import List @@ -28,7 +27,7 @@ def __init__(self): # parse the json data into a list of ModelProviderConfig objects self.provider_configs = [ - ModelProviderConfig.parse_obj(item) for item in _providers + ModelProviderConfig.model_validate(item) for item in _providers ] # load llm models diff --git a/backend/modules/model_gateway/reranker_svc.py b/backend/modules/model_gateway/reranker_svc.py index 370677c8..32a85043 100644 --- a/backend/modules/model_gateway/reranker_svc.py +++ b/backend/modules/model_gateway/reranker_svc.py @@ -6,14 +6,13 @@ from langchain.retrievers.document_compressors.base import BaseDocumentCompressor from backend.logger import logger -from backend.settings import settings # Reranking Service using Infinity API class InfinityRerankerSvc(BaseDocumentCompressor): """ Reranker Service that uses Infinity API - Github: https://github.com/michaelfeil/infinity + GitHub: https://github.com/michaelfeil/infinity """ model: str diff --git a/backend/modules/parsers/multimodalparser.py b/backend/modules/parsers/multimodalparser.py index 7adf2fc4..3c26dbdb 100644 --- a/backend/modules/parsers/multimodalparser.py +++ b/backend/modules/parsers/multimodalparser.py @@ -3,7 +3,7 @@ import io import os from itertools import islice -from typing import Optional +from typing import Any, Dict, Optional import cv2 import fitz @@ -17,7 +17,7 @@ from backend.modules.model_gateway.model_gateway import model_gateway from backend.modules.parsers.parser import BaseParser from backend.modules.parsers.utils import contains_text -from backend.types import ModelConfig +from backend.types import ModelConfig, ModelType def stringToRGB(base64_string: str): @@ -45,9 +45,8 @@ class MultiModalParser(BaseParser): Parser Configuration will look like the following while creating the collection: { ".pdf": { - "parser": "MultiModalParser", - "kwargs": { - "chunk_size": 1000, + "name": "MultiModalParser", + "parameters": { "model_configuration": { "name" : "truefoundry/openai-main/gpt-4o-mini" }, @@ -65,15 +64,15 @@ def __init__( """ Initializes the MultiModalParser object. """ - # Multi-modal parser needs to be configured with the openai compatible client url and vision model if model_configuration: - self.model_configuration = ModelConfig.parse_obj(model_configuration) + self.model_configuration = ModelConfig.model_validate(model_configuration) logger.info(f"Using custom vision model..., {self.model_configuration}") else: # Truefoundry specific model configuration self.model_configuration = ModelConfig( - name="truefoundry/openai-main/gpt-4o-mini" + name="truefoundry/openai-main/gpt-4o-mini", + type=ModelType.chat, ) if prompt: @@ -131,7 +130,7 @@ async def call_vlm_agent( return {"error": f"Error in page: {page_number}"} async def get_chunks( - self, filepath: str, metadata: Optional[dict] = None, **kwargs + self, filepath: str, metadata: Optional[Dict[Any, Any]] = None, *args, **kwargs ): """ Asynchronously extracts text from a PDF file and returns it in chunks. diff --git a/backend/modules/parsers/parser.py b/backend/modules/parsers/parser.py index e230b004..37c28849 100644 --- a/backend/modules/parsers/parser.py +++ b/backend/modules/parsers/parser.py @@ -1,7 +1,6 @@ -import typing from abc import ABC, abstractmethod from collections import defaultdict -from typing import Optional +from typing import Any, Dict, List, Optional from langchain.docstore.document import Document @@ -39,9 +38,10 @@ def __init__(self, **kwargs): async def get_chunks( self, filepath: str, - metadata: Optional[dict], + metadata: Optional[Dict[Any, Any]], + *args, **kwargs, - ) -> typing.List[Document]: + ) -> List[Document]: """ Abstract method. This should asynchronously read a file and return its content in chunks. @@ -54,7 +54,9 @@ async def get_chunks( pass -def get_parser_for_extension(file_extension, parsers_map, **kwargs) -> BaseParser: +def get_parser_for_extension( + file_extension, parsers_map, *args, **kwargs +) -> Optional[BaseParser]: """ During the indexing phase, given the file_extension and parsers mapping, return the appropriate mapper. If no mapping is given, use the default registry. @@ -94,7 +96,7 @@ def list_parsers(): Returns a list of all the registered parsers. Returns: - List[dict]: A list of all the registered parsers. + List[Dict]: A list of all the registered parsers. """ global PARSER_REGISTRY return [ diff --git a/backend/modules/parsers/unstructured_io.py b/backend/modules/parsers/unstructured_io.py index 179df881..0951a102 100644 --- a/backend/modules/parsers/unstructured_io.py +++ b/backend/modules/parsers/unstructured_io.py @@ -51,7 +51,6 @@ def __init__(self, *, max_chunk_size: int = 2000, **kwargs): self.adapter = HTTPAdapter(max_retries=self.retry_strategy) self.session.mount("https://", self.adapter) self.session.mount("http://", self.adapter) - super().__init__(**kwargs) async def get_chunks(self, filepath: str, metadata: dict, **kwargs): diff --git a/backend/modules/query_controllers/example/controller.py b/backend/modules/query_controllers/example/controller.py index 7db59a42..e941a077 100644 --- a/backend/modules/query_controllers/example/controller.py +++ b/backend/modules/query_controllers/example/controller.py @@ -5,7 +5,6 @@ import async_timeout from fastapi import Body, HTTPException from fastapi.responses import StreamingResponse -from langchain.docstore.document import Document from langchain.prompts import PromptTemplate from langchain.retrievers import ContextualCompressionRetriever, MultiQueryRetriever from langchain.schema.vectorstore import VectorStoreRetriever @@ -26,6 +25,7 @@ GENERATION_TIMEOUT_SEC, Answer, Docs, + Document, ExampleQueryInput, ) from backend.modules.vector_db.client import VECTOR_STORE_CLIENT @@ -86,7 +86,7 @@ async def _get_vector_store(self, collection_name: str): raise HTTPException(status_code=404, detail="Collection not found") if not isinstance(collection, Collection): - collection = Collection(**collection.dict()) + collection = Collection(**collection.model_dump()) return VECTOR_STORE_CLIENT.get_vector_store( collection_name=collection.name, @@ -329,7 +329,7 @@ async def answer( # "stream": True # } -# data = ExampleQueryInput(**payload).dict() +# data = ExampleQueryInput(**payload).model_dump() # ENDPOINT_URL = 'http://localhost:8000/retrievers/example-app/answer' diff --git a/backend/modules/query_controllers/example/types.py b/backend/modules/query_controllers/example/types.py index 6226b7aa..926f6b51 100644 --- a/backend/modules/query_controllers/example/types.py +++ b/backend/modules/query_controllers/example/types.py @@ -1,44 +1,52 @@ -from typing import Any, ClassVar, Collection, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional, Sequence, Union -from langchain.docstore.document import Document -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field, model_validator from qdrant_client.models import Filter as QdrantFilter -from backend.types import ModelConfig +from backend.types import ConfiguredBaseModel, ModelConfig GENERATION_TIMEOUT_SEC = 60.0 * 10 +# TODO (chiragjn): Remove all asserts and replace them with proper pydantic validations or raises -class VectorStoreRetrieverConfig(BaseModel): + +class VectorStoreRetrieverConfig(ConfiguredBaseModel): """ Configuration for VectorStore Retriever """ search_type: str = Field( default="similarity", - title="""Defines the type of search that the Retriever should perform. Can be 'similarity' (default), 'mmr', or 'similarity_score_threshold'. - - "similarity": Retrieve the top k most similar documents to the query., - - "mmr": Retrieve the top k most similar documents to the query and then rerank them using Maximal Marginal Relevance (MMR)., - - "similarity_score_threshold": Retrieve all documents with similarity score greater than a threshold. - """, + title="""Defines the type of search that the Retriever should perform. +Can be 'similarity' (default), 'mmr', or 'similarity_score_threshold'. + - "similarity": Retrieve the top k most similar documents to the query., + - "mmr": Retrieve the top k most similar documents to the query and then rerank them using Maximal Marginal Relevance (MMR)., + - "similarity_score_threshold": Retrieve all documents with similarity score greater than a threshold. +""", ) search_kwargs: dict = Field(default_factory=dict) - filter: Optional[dict] = Field( + filter: Optional[Dict[Any, Any]] = Field( default_factory=dict, title="""Filter by document metadata""", ) - allowed_search_types: ClassVar[Collection[str]] = ( + allowed_search_types: ClassVar[Sequence[str]] = ( "similarity", "similarity_score_threshold", "mmr", ) - @root_validator - def validate_search_type(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_search_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate search type.""" + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) + search_type = values.get("search_type") assert ( @@ -63,7 +71,7 @@ def validate_search_type(cls, values: Dict) -> Dict: filters = values.get("filter") if filters: - search_kwargs["filter"] = QdrantFilter.parse_obj(filters) + search_kwargs["filter"] = QdrantFilter.model_validate(filters) return values @@ -82,7 +90,7 @@ class ContextualCompressionRetrieverConfig(VectorStoreRetrieverConfig): title="Top K docs to collect post compression", ) - allowed_compressor_model_providers: ClassVar[Collection[str]] + allowed_compressor_model_providers: ClassVar[Sequence[str]] class ContextualCompressionMultiQueryRetrieverConfig( @@ -91,10 +99,10 @@ class ContextualCompressionMultiQueryRetrieverConfig( pass -class ExampleQueryInput(BaseModel): +class ExampleQueryInput(ConfiguredBaseModel): """ Model for Query input. - Requires a collection name, retriever configuration, query, LLM configuration and prompt template. + Requires a Sequence name, retriever configuration, query, LLM configuration and prompt template. """ collection_name: str = Field( @@ -104,6 +112,7 @@ class ExampleQueryInput(BaseModel): query: str = Field(title="Question to search for") + # TODO (chiragjn): pydantic v2 does not like fields that start with model_ model_configuration: ModelConfig prompt_template: str = Field( @@ -114,26 +123,33 @@ class ExampleQueryInput(BaseModel): title="Retriever name", ) - retriever_config: Dict[str, Any] = Field( + retriever_config: Union[ + VectorStoreRetrieverConfig, + MultiQueryRetrieverConfig, + ContextualCompressionRetrieverConfig, + ContextualCompressionMultiQueryRetrieverConfig, + ] = Field( title="Retriever configuration", ) - allowed_retriever_types: ClassVar[Collection[str]] = ( + allowed_retriever_types: ClassVar[Sequence[str]] = ( "vectorstore", "multi-query", "contextual-compression", "contextual-compression-multi-query", ) - stream: Optional[bool] = Field(title="Stream the results", default=False) + stream: bool = Field(title="Stream the results", default=False) - @root_validator() - def validate_retriever_type(cls, values: Dict) -> Dict: - retriever_name = values.get("retriever_name") + @model_validator(mode="before") + @classmethod + def validate_retriever_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) - assert ( - retriever_name in cls.allowed_retriever_types - ), f"retriever of {retriever_name} not allowed. Valid values are: {cls.allowed_retriever_types}" + retriever_name = values.get("retriever_name") if retriever_name == "vectorstore": values["retriever_config"] = VectorStoreRetrieverConfig( @@ -154,15 +170,25 @@ def validate_retriever_type(cls, values: Dict) -> Dict: values["retriever_config"] = ContextualCompressionMultiQueryRetrieverConfig( **values.get("retriever_config") ) + else: + raise ValueError( + f"Unexpected retriever name: {retriever_name}. " + f"Valid values are: {cls.allowed_retriever_types}" + ) return values -class Answer(BaseModel): +class Document(ConfiguredBaseModel): + page_content: str + metadata: dict = Field(default_factory=dict) + + +class Answer(ConfiguredBaseModel): type: str = "answer" content: str -class Docs(BaseModel): +class Docs(ConfiguredBaseModel): type: str = "docs" content: List[Document] = Field(default_factory=list) diff --git a/backend/modules/query_controllers/multimodal/controller.py b/backend/modules/query_controllers/multimodal/controller.py index b54c3eb8..2d102f38 100644 --- a/backend/modules/query_controllers/multimodal/controller.py +++ b/backend/modules/query_controllers/multimodal/controller.py @@ -4,7 +4,6 @@ import async_timeout from fastapi import Body, HTTPException from fastapi.responses import StreamingResponse -from langchain.docstore.document import Document from langchain.prompts import PromptTemplate from langchain.retrievers import ContextualCompressionRetriever, MultiQueryRetriever from langchain.schema.vectorstore import VectorStoreRetriever @@ -25,6 +24,7 @@ GENERATION_TIMEOUT_SEC, Answer, Docs, + Document, ExampleQueryInput, ) from backend.modules.vector_db.client import VECTOR_STORE_CLIENT @@ -98,7 +98,7 @@ async def _get_vector_store(self, collection_name: str): raise HTTPException(status_code=404, detail="Collection not found") if not isinstance(collection, Collection): - collection = Collection(**collection.dict()) + collection = Collection(**collection.model_dump()) return VECTOR_STORE_CLIENT.get_vector_store( collection_name=collection.name, diff --git a/backend/modules/query_controllers/multimodal/types.py b/backend/modules/query_controllers/multimodal/types.py index 6226b7aa..349b6697 100644 --- a/backend/modules/query_controllers/multimodal/types.py +++ b/backend/modules/query_controllers/multimodal/types.py @@ -1,15 +1,16 @@ -from typing import Any, ClassVar, Collection, Dict, List, Optional +from typing import Any, ClassVar, Dict, List, Optional, Sequence, Union -from langchain.docstore.document import Document -from pydantic import BaseModel, Field, root_validator +from pydantic import BaseModel, Field, model_validator from qdrant_client.models import Filter as QdrantFilter -from backend.types import ModelConfig +from backend.types import ConfiguredBaseModel, ModelConfig GENERATION_TIMEOUT_SEC = 60.0 * 10 +# TODO (chiragjn): Remove all asserts and replace them with proper pydantic validations or raises -class VectorStoreRetrieverConfig(BaseModel): + +class VectorStoreRetrieverConfig(ConfiguredBaseModel): """ Configuration for VectorStore Retriever """ @@ -25,20 +26,26 @@ class VectorStoreRetrieverConfig(BaseModel): search_kwargs: dict = Field(default_factory=dict) - filter: Optional[dict] = Field( + filter: Optional[Dict[Any, Any]] = Field( default_factory=dict, title="""Filter by document metadata""", ) - allowed_search_types: ClassVar[Collection[str]] = ( + allowed_search_types: ClassVar[Sequence[str]] = ( "similarity", "similarity_score_threshold", "mmr", ) - @root_validator - def validate_search_type(cls, values: Dict) -> Dict: + @model_validator(mode="before") + @classmethod + def validate_search_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: """Validate search type.""" + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) + search_type = values.get("search_type") assert ( @@ -63,7 +70,7 @@ def validate_search_type(cls, values: Dict) -> Dict: filters = values.get("filter") if filters: - search_kwargs["filter"] = QdrantFilter.parse_obj(filters) + search_kwargs["filter"] = QdrantFilter.model_validate(filters) return values @@ -82,7 +89,7 @@ class ContextualCompressionRetrieverConfig(VectorStoreRetrieverConfig): title="Top K docs to collect post compression", ) - allowed_compressor_model_providers: ClassVar[Collection[str]] + allowed_compressor_model_providers: ClassVar[Sequence[str]] class ContextualCompressionMultiQueryRetrieverConfig( @@ -91,49 +98,58 @@ class ContextualCompressionMultiQueryRetrieverConfig( pass -class ExampleQueryInput(BaseModel): +class ExampleQueryInput(ConfiguredBaseModel): """ Model for Query input. - Requires a collection name, retriever configuration, query, LLM configuration and prompt template. + Requires a Sequence name, retriever configuration, query, LLM configuration and prompt template. """ collection_name: str = Field( default=None, - title="Collection name on which to search", + title="Sequence name on which to search", ) query: str = Field(title="Question to search for") + # TODO (chiragjn): pydantic v2 does not like fields that start with model_ model_configuration: ModelConfig prompt_template: str = Field( title="Prompt Template to use for generating answer to the question using the context", ) + # TODO (chiragjn): Move retriever name inside configuration to let pydantic disciminate between different retrievers retriever_name: str = Field( title="Retriever name", ) - retriever_config: Dict[str, Any] = Field( + retriever_config: Union[ + VectorStoreRetrieverConfig, + MultiQueryRetrieverConfig, + ContextualCompressionRetrieverConfig, + ContextualCompressionMultiQueryRetrieverConfig, + ] = Field( title="Retriever configuration", ) - allowed_retriever_types: ClassVar[Collection[str]] = ( + allowed_retriever_types: ClassVar[Sequence[str]] = ( "vectorstore", "multi-query", "contextual-compression", "contextual-compression-multi-query", ) - stream: Optional[bool] = Field(title="Stream the results", default=False) + stream: bool = Field(title="Stream the results", default=False) - @root_validator() - def validate_retriever_type(cls, values: Dict) -> Dict: - retriever_name = values.get("retriever_name") + @model_validator(mode="before") + @classmethod + def validate_retriever_type(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) - assert ( - retriever_name in cls.allowed_retriever_types - ), f"retriever of {retriever_name} not allowed. Valid values are: {cls.allowed_retriever_types}" + retriever_name = values.get("retriever_name") if retriever_name == "vectorstore": values["retriever_config"] = VectorStoreRetrieverConfig( @@ -149,20 +165,39 @@ def validate_retriever_type(cls, values: Dict) -> Dict: values["retriever_config"] = ContextualCompressionRetrieverConfig( **values.get("retriever_config") ) - elif retriever_name == "contextual-compression-multi-query": values["retriever_config"] = ContextualCompressionMultiQueryRetrieverConfig( **values.get("retriever_config") ) + else: + raise ValueError( + f"Unexpected retriever name: {retriever_name}. " + f"Valid values are: {cls.allowed_retriever_types}" + ) return values -class Answer(BaseModel): +class Document(ConfiguredBaseModel): + page_content: str + metadata: dict = Field(default_factory=dict) + + +class Answer(ConfiguredBaseModel): + type: str = "answer" + content: str + + +class Docs(ConfiguredBaseModel): + type: str = "docs" + content: List[Document] = Field(default_factory=list) + + +class Answer(ConfiguredBaseModel): type: str = "answer" content: str -class Docs(BaseModel): +class Docs(ConfiguredBaseModel): type: str = "docs" content: List[Document] = Field(default_factory=list) diff --git a/backend/modules/query_controllers/query_controller.py b/backend/modules/query_controllers/query_controller.py index fd21aa7f..e81d8a1e 100644 --- a/backend/modules/query_controllers/query_controller.py +++ b/backend/modules/query_controllers/query_controller.py @@ -18,7 +18,7 @@ def list_query_controllers(): Returns a list of all the registered query controllers. Returns: - List[dict]: A list of all the registered query controllers. + List[Dict]: A list of all the registered query controllers. """ global QUERY_CONTROLLER_REGISTRY return [ diff --git a/backend/modules/vector_db/qdrant.py b/backend/modules/vector_db/qdrant.py index 5ac47d04..341493e3 100644 --- a/backend/modules/vector_db/qdrant.py +++ b/backend/modules/vector_db/qdrant.py @@ -17,7 +17,7 @@ class QdrantVectorDB(BaseVectorDB): def __init__(self, config: VectorDBConfig): - logger.debug(f"Connecting to qdrant using config: {config.dict()}") + logger.debug(f"Connecting to qdrant using config: {config.model_dump()}") if config.local is True: # TODO: make this path customizable self.qdrant_client = QdrantClient( @@ -28,7 +28,7 @@ def __init__(self, config: VectorDBConfig): api_key = config.api_key if not api_key: api_key = None - qdrant_kwargs = QdrantClientConfig.parse_obj(config.config or {}) + qdrant_kwargs = QdrantClientConfig.model_validate(config.config or {}) if url.startswith("http://") or url.startswith("https://"): if qdrant_kwargs.port is None: parsed_port = urlparse(url).port @@ -37,7 +37,7 @@ def __init__(self, config: VectorDBConfig): else: qdrant_kwargs.port = 443 if url.startswith("https://") else 6333 self.qdrant_client = QdrantClient( - url=url, api_key=api_key, **qdrant_kwargs.dict() + url=url, api_key=api_key, **qdrant_kwargs.model_dump() ) def create_collection(self, collection_name: str, embeddings: Embeddings): diff --git a/backend/modules/vector_db/singlestore.py b/backend/modules/vector_db/singlestore.py index fc1fba87..2b95c3b4 100644 --- a/backend/modules/vector_db/singlestore.py +++ b/backend/modules/vector_db/singlestore.py @@ -1,5 +1,5 @@ import json -from typing import Any, Iterable, List, Optional +from typing import Any, Dict, Iterable, List, Optional import singlestoredb as s2 from langchain.docstore.document import Document @@ -63,7 +63,7 @@ def _create_table(self: SingleStoreDB) -> None: def add_texts( self, texts: Iterable[str], - metadatas: Optional[List[dict]] = None, + metadatas: Optional[List[Dict[Any, Any]]] = None, embeddings: Optional[List[List[float]]] = None, **kwargs: Any, ) -> List[str]: @@ -71,7 +71,7 @@ def add_texts( Args: texts (Iterable[str]): Iterable of strings/text to add to the vectorstore. - metadatas (Optional[List[dict]], optional): Optional list of metadatas. + metadatas (Optional[List[Dict]], optional): Optional list of metadatas. Defaults to None. embeddings (Optional[List[List[float]]], optional): Optional pre-generated embeddings. Defaults to None. diff --git a/backend/modules/vector_db/weaviate.py b/backend/modules/vector_db/weaviate.py index 5b3986de..2aad43e5 100644 --- a/backend/modules/vector_db/weaviate.py +++ b/backend/modules/vector_db/weaviate.py @@ -1,4 +1,4 @@ -from typing import List +from typing import Any, Dict, List import weaviate from langchain.embeddings.base import Embeddings @@ -99,7 +99,7 @@ def list_documents_in_collection( .with_fields("groupedBy { value }") .do() ) - groups: List[dict] = ( + groups: List[Dict[Any, Any]] = ( response.get("data", {}) .get("Aggregate", {}) .get(collection_name.capitalize(), []) diff --git a/backend/requirements.txt b/backend/requirements.txt index af95cf07..1317f818 100644 --- a/backend/requirements.txt +++ b/backend/requirements.txt @@ -5,11 +5,12 @@ langchain-openai==0.1.7 langchain-core==0.1.46 openai==1.35.3 tiktoken==0.7.0 -uvicorn==0.23.2 -fastapi==0.109.1 +uvicorn[standard]==0.23.2 +fastapi==0.111.1 qdrant-client==1.9.0 python-dotenv==1.0.1 -pydantic==1.10.17 +pydantic==2.7.4 +pydantic-settings==2.3.3 orjson==3.9.15 PyMuPDF==1.23.6 redis==5.0.1 diff --git a/backend/server/decorators.py b/backend/server/decorators.py index 2b17ff98..8f12dd81 100644 --- a/backend/server/decorators.py +++ b/backend/server/decorators.py @@ -4,10 +4,9 @@ """ import inspect -from typing import Any, Callable, List, Type, TypeVar, Union, get_type_hints +from typing import Any, Callable, ClassVar, List, Type, TypeVar, Union, get_type_hints from fastapi import APIRouter, Depends -from pydantic.typing import is_classvar from starlette.routing import Route, WebSocketRoute T = TypeVar("T") @@ -59,7 +58,8 @@ def _init_cbv(cls: Type[Any]) -> None: ] dependency_names: List[str] = [] for name, hint in get_type_hints(cls).items(): - if is_classvar(hint): + # TODO (chiragjn): Verify this + if getattr(hint, "__origin__", None) is ClassVar: continue parameter_kwargs = {"default": getattr(cls, name, Ellipsis)} dependency_names.append(name) @@ -127,13 +127,13 @@ def wrapper(cls) -> ClassBasedView: for name, method in cls.__dict__.items(): if callable(method) and hasattr(method, "method"): # Check if method is decorated with an HTTP method decorator - assert ( - hasattr(method, "__path__") and method.__path__ - ), f"Missing path for method {name}" + if not hasattr(method, "__path__") or not method.__path__: + raise ValueError(f"Missing path for method {name}") http_method = method.method # Ensure that the method is a valid HTTP method - assert http_method in http_method_names, f"Invalid method {http_method}" + if http_method not in http_method_names: + raise ValueError(f"Invalid method {http_method}") if prefix: method.__path__ = prefix + method.__path__ if not method.__path__.startswith("/"): diff --git a/backend/server/routers/collection.py b/backend/server/routers/collection.py index f97ffe38..a0155e16 100644 --- a/backend/server/routers/collection.py +++ b/backend/server/routers/collection.py @@ -29,7 +29,7 @@ async def get_collections(): if collections is None: return JSONResponse(content={"collections": []}) return JSONResponse( - content={"collections": [obj.dict() for obj in collections]} + content={"collections": [obj.model_dump() for obj in collections]} ) except Exception as exp: logger.exception("Failed to get collection") @@ -55,7 +55,7 @@ async def get_collection_by_name(collection_name: str = Path(title="Collection n collection = await client.aget_collection_by_name(collection_name) if collection is None: return JSONResponse(content={"collection": []}) - return JSONResponse(content={"collection": collection.dict()}) + return JSONResponse(content={"collection": collection.model_dump()}) except HTTPException as exp: raise exp except Exception as exp: @@ -98,7 +98,7 @@ async def create_collection(collection: CreateCollectionDto): collection_name=created_collection.name ) return JSONResponse( - content={"collection": created_collection.dict()}, status_code=201 + content={"collection": created_collection.model_dump()}, status_code=201 ) except HTTPException as exp: raise exp @@ -121,7 +121,7 @@ async def associate_data_source_to_collection( parser_config=request.parser_config, ), ) - return JSONResponse(content={"collection": collection.dict()}) + return JSONResponse(content={"collection": collection.model_dump()}) except HTTPException as exp: raise exp except Exception as exp: @@ -140,7 +140,7 @@ async def unassociate_data_source_from_collection( collection_name=request.collection_name, data_source_fqn=request.data_source_fqn, ) - return JSONResponse(content={"collection": collection.dict()}) + return JSONResponse(content={"collection": collection.model_dump()}) except HTTPException as exp: raise exp except Exception as exp: @@ -182,7 +182,9 @@ async def list_data_ingestion_runs(request: ListDataIngestionRunsDto): request.collection_name, request.data_source_fqn ) return JSONResponse( - content={"data_ingestion_runs": [obj.dict() for obj in data_ingestion_runs]} + content={ + "data_ingestion_runs": [obj.model_dump() for obj in data_ingestion_runs] + } ) diff --git a/backend/server/routers/data_source.py b/backend/server/routers/data_source.py index ff40d686..ef821b5d 100644 --- a/backend/server/routers/data_source.py +++ b/backend/server/routers/data_source.py @@ -17,7 +17,7 @@ async def get_data_source(): client = await get_client() data_sources = await client.aget_data_sources() return JSONResponse( - content={"data_sources": [obj.dict() for obj in data_sources]} + content={"data_sources": [obj.model_dump() for obj in data_sources]} ) except Exception as exp: logger.exception("Failed to get data source") @@ -45,7 +45,7 @@ async def add_data_source( client = await get_client() created_data_source = await client.acreate_data_source(data_source=data_source) return JSONResponse( - content={"data_source": created_data_source.dict()}, status_code=201 + content={"data_source": created_data_source.model_dump()}, status_code=201 ) except HTTPException as exp: raise exp diff --git a/backend/server/routers/internal.py b/backend/server/routers/internal.py index 9d02052c..248f3ec1 100644 --- a/backend/server/routers/internal.py +++ b/backend/server/routers/internal.py @@ -31,7 +31,7 @@ async def upload_to_docker_directory( status_code=500, ) try: - logger.info(f"Uploading files to docker directory: {upload_name}") + logger.info(f"Uploading files to directory: {upload_name}") # create a folder within `/volumes/user_data/` that maps to `/app/user_data/` in the docker volume # this folder will be used to store the uploaded files folder_path = os.path.join(settings.LOCAL_DATA_DIRECTORY, upload_name) @@ -60,8 +60,9 @@ async def upload_to_docker_directory( # Add the data source to the metadata store. return await add_data_source(data_source) except Exception as ex: + logger.exception(f"Error uploading files to directory: {ex}") return JSONResponse( - content={"error": f"Error uploading files to docker directory: {ex}"}, + content={"error": f"Error uploading files to directory: {ex}"}, status_code=500, ) @@ -88,12 +89,12 @@ async def upload_to_data_directory(req: UploadToDataDirectoryDto): paths=req.filepaths, ) - data = [url.dict() for url in urls] + data = [url.model_dump() for url in urls] return JSONResponse( content={"data": data, "data_directory_fqn": dataset.fqn}, ) except Exception as ex: - raise Exception(f"Error uploading files to data directory: {ex}") + raise Exception(f"Error uploading files to data directory: {ex}") from ex @router.get("/models") @@ -113,7 +114,7 @@ def get_enabled_models( ) # Serialized models - serialized_models = [model.dict() for model in enabled_models] + serialized_models = [model.model_dump() for model in enabled_models] return JSONResponse( content={"models": serialized_models}, ) diff --git a/backend/server/routers/rag_apps.py b/backend/server/routers/rag_apps.py index aadf3178..48978b82 100644 --- a/backend/server/routers/rag_apps.py +++ b/backend/server/routers/rag_apps.py @@ -3,7 +3,7 @@ from backend.logger import logger from backend.modules.metadata_store.client import get_client -from backend.types import CreateRagApplication, RagApplicationDto +from backend.types import CreateRagApplication router = APIRouter(prefix="/v1/apps", tags=["apps"]) @@ -18,7 +18,7 @@ async def register_rag_app( client = await get_client() created_rag_app = await client.acreate_rag_app(rag_app) return JSONResponse( - content={"rag_app": created_rag_app.dict()}, status_code=201 + content={"rag_app": created_rag_app.model_dump()}, status_code=201 ) except HTTPException as exp: raise exp @@ -47,7 +47,7 @@ async def get_rag_app_by_name(app_name: str = Path(title="App name")): rag_app = await client.aget_rag_app(app_name) if rag_app is None: return JSONResponse(content={"rag_app": []}) - return JSONResponse(content={"rag_app": rag_app.dict()}) + return JSONResponse(content={"rag_app": rag_app.model_dump()}) except HTTPException as exp: raise exp diff --git a/backend/settings.py b/backend/settings.py index 7945b2c6..fbbd1fb9 100644 --- a/backend/settings.py +++ b/backend/settings.py @@ -1,7 +1,8 @@ import os -from typing import Optional +from typing import Any, Dict -from pydantic import BaseSettings, root_validator +from pydantic import ConfigDict, model_validator +from pydantic_settings import BaseSettings from backend.types import MetadataStoreConfig, VectorDBConfig @@ -11,8 +12,7 @@ class Settings(BaseSettings): Settings class to hold all the environment variables """ - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") MODELS_CONFIG_PATH: str METADATA_STORE_CONFIG: MetadataStoreConfig @@ -30,8 +30,15 @@ class Config: os.path.join(os.path.dirname(os.path.dirname(__file__)), "user_data") ) - @root_validator(pre=True) - def _validate_values(cls, values): + @model_validator(mode="before") + @classmethod + def _validate_values(cls, values: Dict[str, Any]) -> Dict[str, Any]: + """Validate search type.""" + if not isinstance(values, dict): + raise ValueError( + f"Unexpected Pydantic v2 Validation: values are of type {type(values)}" + ) + models_config_path = values.get("MODELS_CONFIG_PATH") if not os.path.isabs(models_config_path): this_dir = os.path.abspath(os.path.dirname(__file__)) @@ -39,7 +46,7 @@ def _validate_values(cls, values): models_config_path = os.path.join(root_dir, models_config_path) if not models_config_path: - raise Exception( + raise ValueError( f"{models_config_path} does not exist. " f"You can copy models_config.sample.yaml to {settings.MODELS_CONFIG_PATH} to bootstrap config" ) diff --git a/backend/types.py b/backend/types.py index 67a167c8..02bcab3a 100644 --- a/backend/types.py +++ b/backend/types.py @@ -1,13 +1,26 @@ import enum -import json import uuid from enum import Enum from typing import Any, Dict, List, Literal, Optional, Union -from pydantic import BaseModel, Field, constr, root_validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + StringConstraints, + computed_field, + model_validator, +) +from typing_extensions import Annotated from backend.constants import FQN_SEPARATOR +# TODO (chiragjn): Remove Optional from Dict and List type fields. Instead just use a default_factory + + +class ConfiguredBaseModel(BaseModel): + model_config = ConfigDict(use_enum_values=True) + class DataIngestionMode(str, Enum): """ @@ -19,7 +32,7 @@ class DataIngestionMode(str, Enum): FULL = "FULL" -class DataPoint(BaseModel): +class DataPoint(ConfiguredBaseModel): """ Data point describes a single data point in the data source Properties: @@ -42,7 +55,8 @@ class DataPoint(BaseModel): title="Hash of the data point for the given data source that is guaranteed to be updated for any update in data point at source", ) - metadata: Optional[Dict[str, str]] = Field( + metadata: Optional[Dict[str, Any]] = Field( + None, title="Additional metadata for the data point", ) @@ -51,7 +65,7 @@ def data_point_fqn(self) -> str: return f"{FQN_SEPARATOR}".join([self.data_source_fqn, self.data_point_uri]) -class DataPointVector(BaseModel): +class DataPointVector(ConfiguredBaseModel): """ Data point vector describes a single data point in the vector store Additional Properties: @@ -84,9 +98,11 @@ class LoadedDataPoint(DataPoint): title="Local file path of the loaded data point", ) file_extension: Optional[str] = Field( + None, title="File extension of the loaded data point", ) local_metadata_file_path: Optional[str] = Field( + None, title="Local file path of the metadata file", ) @@ -99,51 +115,60 @@ class ModelType(str, Enum): chat = "chat" embedding = "embedding" reranking = "reranking" - parser = "parser" -class ModelConfig(BaseModel): +class ModelConfig(ConfiguredBaseModel): name: str - type: Optional[ModelType] + # TODO (chiragjn): This should not be Optional! Changing might break backward compatibility + # Problem is we have shared these entities between DTO layers and Service / DB layers + type: Optional[ModelType] = None parameters: Optional[Dict[str, Any]] = Field(default_factory=dict) - def to_dict(self) -> Dict[str, Any]: - return { - "name": self.name, - "type": self.type, - "parameters": self.parameters, - } - -class ModelProviderConfig(BaseModel): +class ModelProviderConfig(ConfiguredBaseModel): provider_name: str api_format: str + base_url: Optional[str] = None + api_key_env_var: str + default_headers: Dict[str, str] = Field(default_factory=dict) llm_model_ids: List[str] = Field(default_factory=list) embedding_model_ids: List[str] = Field(default_factory=list) reranking_model_ids: List[str] = Field(default_factory=list) - api_key_env_var: str - base_url: Optional[str] = None - default_headers: Dict[str, str] = Field(default_factory=dict) -class EmbedderConfig(ModelConfig): +class EmbedderConfig(ConfiguredBaseModel): """ Embedder configuration """ - pass + name: str + parameters: Dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="before") + @classmethod + def ensure_parameters_not_none(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if values.get("parameters") is None: + values.pop("parameters", None) + return values -class ParserConfig(ModelConfig): +class ParserConfig(ConfiguredBaseModel): """ Parser configuration """ - type: ModelType = ModelType.parser - pass + name: str + parameters: Dict[str, Any] = Field(default_factory=dict) + + @model_validator(mode="before") + @classmethod + def ensure_parameters_not_none(cls, values: Dict[str, Any]) -> Dict[str, Any]: + if values.get("parameters") is None: + values.pop("parameters", None) + return values -class VectorDBConfig(BaseModel): +class VectorDBConfig(ConfiguredBaseModel): """ Vector db configuration """ @@ -152,16 +177,15 @@ class VectorDBConfig(BaseModel): local: bool = False url: Optional[str] = None api_key: Optional[str] = None - config: Optional[dict] = None + config: Optional[Dict[str, Any]] = Field(default_factory=dict) -class QdrantClientConfig(BaseModel): +class QdrantClientConfig(ConfiguredBaseModel): """ Qdrant extra configuration """ - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") port: Optional[int] = None grpc_port: int = 6334 @@ -170,23 +194,24 @@ class Config: timeout: int = 300 -class MetadataStoreConfig(BaseModel): +class MetadataStoreConfig(ConfiguredBaseModel): """ Metadata store configuration """ provider: str - config: Optional[dict] = Field(default_factory=dict) + config: Optional[Dict[str, Any]] = Field(default_factory=dict) -class RetrieverConfig(BaseModel): +class RetrieverConfig(ConfiguredBaseModel): """ Retriever configuration """ search_type: Literal["mmr", "similarity"] = Field( default="similarity", - title="""Defines the type of search that the Retriever should perform. Can be "similarity" (default), "mmr", or "similarity_score_threshold".""", + title="""Defines the type of search that the Retriever should perform. \ + Can be "similarity" (default), "mmr", or "similarity_score_threshold".""", ) k: int = Field( default=4, @@ -196,7 +221,7 @@ class RetrieverConfig(BaseModel): default=20, title="""Amount of documents to pass to MMR algorithm (Default: 20)""", ) - filter: Optional[dict] = Field( + filter: Optional[Dict[Any, Any]] = Field( default=None, title="""Filter by document metadata""", ) @@ -209,11 +234,12 @@ def get_search_type(self) -> str: @property def get_search_kwargs(self) -> dict: # Check at langchain.schema.vectorstore.VectorStore.as_retriever - match self.search_type: - case "similarity": - return {"k": self.k, "filter": self.filter} - case "mmr": - return {"k": self.k, "fetch_k": self.fetch_k, "filter": self.filter} + if self.search_type == "similarity": + return {"k": self.k, "filter": self.filter} + elif self.search_type == "mmr": + return {"k": self.k, "fetch_k": self.fetch_k, "filter": self.filter} + else: + raise ValueError(f"Search type {self.search_type} is not supported") class DataIngestionRunStatus(str, enum.Enum): @@ -233,7 +259,7 @@ class DataIngestionRunStatus(str, enum.Enum): ERROR = "ERROR" -class BaseDataIngestionRun(BaseModel): +class BaseDataIngestionRun(ConfiguredBaseModel): """ Base data ingestion run configuration """ @@ -255,7 +281,7 @@ class BaseDataIngestionRun(BaseModel): title="Data ingestion mode for the data ingestion", ) - raise_error_on_failure: Optional[bool] = Field( + raise_error_on_failure: bool = Field( title="Flag to configure weather to raise error on failure or not. Default is True", default=True, ) @@ -270,11 +296,12 @@ class DataIngestionRun(BaseDataIngestionRun): title="Name of the data ingestion run", ) status: Optional[DataIngestionRunStatus] = Field( + None, title="Status of the data ingestion run", ) -class BaseDataSource(BaseModel): +class BaseDataSource(ConfiguredBaseModel): """ Data source configuration """ @@ -286,18 +313,14 @@ class BaseDataSource(BaseModel): title="A unique identifier for the data source", ) metadata: Optional[Dict[str, Any]] = Field( - title="Additional config for your data source" + None, title="Additional config for your data source" ) + @computed_field @property - def fqn(self): + def fqn(self) -> str: return f"{FQN_SEPARATOR}".join([self.type, self.uri]) - @root_validator - def validate_fqn(cls, values: Dict) -> Dict: - values["fqn"] = f"{FQN_SEPARATOR}".join([values["type"], values["uri"]]) - return values - class CreateDataSource(BaseDataSource): pass @@ -307,7 +330,7 @@ class DataSource(BaseDataSource): pass -class AssociatedDataSources(BaseModel): +class AssociatedDataSources(ConfiguredBaseModel): """ Associated data source configuration """ @@ -319,11 +342,11 @@ class AssociatedDataSources(BaseModel): title="Parser configuration for the data transformation", default_factory=dict ) data_source: Optional[DataSource] = Field( - title="Data source associated with the collection" + None, title="Data source associated with the collection" ) -class IngestDataToCollectionDto(BaseModel): +class IngestDataToCollectionDto(ConfiguredBaseModel): """ Configuration to ingest data to collection """ @@ -333,6 +356,7 @@ class IngestDataToCollectionDto(BaseModel): ) data_source_fqn: Optional[str] = Field( + None, title="Fully qualified name of the data source", ) @@ -341,7 +365,7 @@ class IngestDataToCollectionDto(BaseModel): title="Data ingestion mode for the data ingestion", ) - raise_error_on_failure: Optional[bool] = Field( + raise_error_on_failure: bool = Field( title="Flag to configure weather to raise error on failure or not. Default is True", default=True, ) @@ -357,7 +381,7 @@ class IngestDataToCollectionDto(BaseModel): ) -class AssociateDataSourceWithCollection(BaseModel): +class AssociateDataSourceWithCollection(ConfiguredBaseModel): """ Configuration to associate data source to collection """ @@ -396,7 +420,7 @@ class AssociateDataSourceWithCollectionDto(AssociateDataSourceWithCollection): ) -class UnassociateDataSourceWithCollectionDto(BaseModel): +class UnassociateDataSourceWithCollectionDto(ConfiguredBaseModel): """ Configuration to unassociate data source to collection """ @@ -409,17 +433,18 @@ class UnassociateDataSourceWithCollectionDto(BaseModel): ) -class BaseCollection(BaseModel): +class BaseCollection(ConfiguredBaseModel): """ Base collection configuration """ - name: constr(regex=r"^[a-z][a-z0-9-]*$") = Field( # type: ignore + name: Annotated[str, StringConstraints(pattern=r"^[a-z][a-z0-9-]*$")] = Field( # type: ignore title="a unique name to your collection", description="Should only contain lowercase alphanumeric character and hypen, should start with alphabet", example="test-collection", ) description: Optional[str] = Field( + None, title="a description for your collection", example="This is a test collection", ) @@ -442,24 +467,34 @@ class Collection(BaseCollection): title="Data sources associated with the collection", default_factory=dict ) + @model_validator(mode="before") + @classmethod + def ensure_associated_data_sources_not_none( + cls, values: Dict[str, Any] + ) -> Dict[str, Any]: + if values.get("associated_data_sources") is None: + values["associated_data_sources"] = {} + return values + class CreateCollectionDto(CreateCollection): associated_data_sources: Optional[List[AssociateDataSourceWithCollection]] = Field( - title="Data sources associated with the collection" + None, title="Data sources associated with the collection" ) -class UploadToDataDirectoryDto(BaseModel): +class UploadToDataDirectoryDto(ConfiguredBaseModel): filepaths: List[str] # allow only small case alphanumeric and hyphen, should contain at least one alphabet and begin with alphabet - upload_name: str = Field( + upload_name: Annotated[ + str, StringConstraints(pattern=r"^[a-z][a-z0-9-]*$") + ] = Field( # type:ignore title="Name of the upload", - regex=r"^[a-z][a-z0-9-]*$", default=str(uuid.uuid4()), ) -class ListDataIngestionRunsDto(BaseModel): +class ListDataIngestionRunsDto(ConfiguredBaseModel): collection_name: str = Field( title="Name of the collection", ) @@ -468,10 +503,12 @@ class ListDataIngestionRunsDto(BaseModel): ) -class RagApplication(BaseModel): - name: str = Field( +class RagApplication(ConfiguredBaseModel): + # allow only small case alphanumeric and hyphen, should contain at least one alphabet and begin with alphabet + name: Annotated[ + str, StringConstraints(pattern=r"^[a-z][a-z0-9-]*$") + ] = Field( # type:ignore title="Name of the rag app", - regex=r"^[a-z][a-z0-9-]*$", # allow only small case alphanumeric and hyphen, should contain at least one alphabet and begin with alphabet ) config: Dict[str, Any] = Field( title="Configuration for the rag app", diff --git a/docker-compose.yaml b/docker-compose.yaml index d6951ffc..00c691ac 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -71,7 +71,7 @@ services: environment: - INFINITY_MODEL_ID=${INFINITY_EMBEDDING_MODEL};${INFINITY_RERANKING_MODEL} - INFINITY_BATCH_SIZE=8 - - API_KEY=${INFINITY_API_KEY} + - INFINITY_API_KEY=${INFINITY_API_KEY} command: v2 networks: - cognita-docker diff --git a/frontend/src/screens/dashboard/docsqa/NewCollection.tsx b/frontend/src/screens/dashboard/docsqa/NewCollection.tsx index db86873d..8a0c7103 100644 --- a/frontend/src/screens/dashboard/docsqa/NewCollection.tsx +++ b/frontend/src/screens/dashboard/docsqa/NewCollection.tsx @@ -73,7 +73,6 @@ const NewCollection = ({ open, onClose, onSuccess }: NewCollectionProps) => { name: collectionName, embedder_config: { name: embeddingModel.name, - type: "embedding", }, associated_data_sources: [ { diff --git a/frontend/src/screens/dashboard/docsqa/settings/index.tsx b/frontend/src/screens/dashboard/docsqa/settings/index.tsx index 6711527c..7d4e2e12 100644 --- a/frontend/src/screens/dashboard/docsqa/settings/index.tsx +++ b/frontend/src/screens/dashboard/docsqa/settings/index.tsx @@ -150,7 +150,7 @@ const DocsQASettings = () => { {collectionDetails && (
Embedder Used :{' '} - {collectionDetails?.embedder_config?.config?.model} + {collectionDetails?.embedder_config?.name}
)} diff --git a/frontend/src/stores/qafoundry/index.ts b/frontend/src/stores/qafoundry/index.ts index 10ce3d16..d7127b84 100644 --- a/frontend/src/stores/qafoundry/index.ts +++ b/frontend/src/stores/qafoundry/index.ts @@ -3,8 +3,15 @@ import { createApi } from '@reduxjs/toolkit/query/react' // import * as T from './types' import { createBaseQuery } from '../utils' +export enum ModelType { + chat = 'chat', + embedding = 'embedding', + reranking = 'reranking', +} + export interface ModelConfig { name: string + type?: ModelType parameters: { temperature?: number maximum_length?: number @@ -40,23 +47,25 @@ interface DataSource { fqn: string } +interface ParserConfig { + name: string + parameters?: { + [key: string]: any + } +} + export interface AssociatedDataSource { data_source_fqn: string parser_config: { - chunk_size: number - chunk_overlap: number - parser_map: { - [key: string]: string - } + [key: string]: ParserConfig } data_source: DataSource } interface EmbedderConfig { - description?: string - provider?: string - config?: { - model: string + name: string + parameters?: { + [key: string]: any } } @@ -67,7 +76,6 @@ export interface Collection { associated_data_sources: { [key: string]: AssociatedDataSource } - chunk_size?: number } interface AddDataSourcePayload {