Skip to content

Commit

Permalink
Port pydantic v1 models to pydantic v2 (#224)
Browse files Browse the repository at this point in the history
* Port pydantic v1 models to pydantic v2

* fad

* Port pydantic v1 models to pydantic v2

* Port pydantic v1 models to pydantic v2

* update default values, type annotations, validators

* Fixes on top of merging main

* More typing refactor for consistency

* Add todo for fields to fix

* Add exception if pydantic v2 sends us obj instead of dict when validating

* Fix regex type constraint and rename model_config to embedding_model_config

* Refactor primsa store to work with pydantic v2

* resolved conflicts

* Fixed pydantic

* Removed model_serializer, replaced with model_dump, added enum_values

* Fix enum values config and separate model type and module type

* Fix frontend types to match up with backend types

---------

Co-authored-by: Chirag Jain <[email protected]>
Co-authored-by: Prathamesh <[email protected]>
  • Loading branch information
3 people authored Aug 11, 2024
1 parent 9243efc commit 79423c5
Show file tree
Hide file tree
Showing 33 changed files with 573 additions and 351 deletions.
2 changes: 1 addition & 1 deletion backend/indexer/indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ async def ingest_data(request: IngestDataToCollectionDto):

# convert to pydantic model if not already -> For prisma models
if not isinstance(collection, Collection):
collection = Collection(**collection.dict())
collection = Collection(**collection.model_dump())

if not collection:
logger.error(
Expand Down
14 changes: 10 additions & 4 deletions backend/indexer/types.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
from typing import Dict, Union
from typing import Dict

from pydantic import BaseModel, Field
from pydantic import Field

from backend.types import DataIngestionMode, DataSource, EmbedderConfig, ParserConfig
from backend.types import (
ConfiguredBaseModel,
DataIngestionMode,
DataSource,
EmbedderConfig,
ParserConfig,
)


class DataIngestionConfig(BaseModel):
class DataIngestionConfig(ConfiguredBaseModel):
"""
Configuration to store Data Ingestion Configuration
"""
Expand Down
2 changes: 1 addition & 1 deletion backend/migration/qdrant_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def migrate_collection(
"associated_data_sources"
).items()
],
).dict()
).model_dump()

logger.debug(
f"Creating '{dest_collection.get('name')}' collection at destination"
Expand Down
7 changes: 3 additions & 4 deletions backend/migration/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Dict

import requests
from qdrant_client._pydantic_compat import to_dict
from qdrant_client.client_base import QdrantBase
from qdrant_client.http import models
from tqdm import tqdm
Expand Down Expand Up @@ -104,11 +103,11 @@ def _recreate_collection(
replication_factor=src_config.params.replication_factor,
write_consistency_factor=src_config.params.write_consistency_factor,
on_disk_payload=src_config.params.on_disk_payload,
hnsw_config=models.HnswConfigDiff(**to_dict(src_config.hnsw_config)),
hnsw_config=models.HnswConfigDiff(**src_config.hnsw_config.model_dump()),
optimizers_config=models.OptimizersConfigDiff(
**to_dict(src_config.optimizer_config)
**src_config.optimizer_config.model_dump()
),
wal_config=models.WalConfigDiff(**to_dict(src_config.wal_config)),
wal_config=models.WalConfigDiff(**src_config.wal_config.model_dump()),
quantization_config=src_config.quantization_config,
timeout=300,
)
Expand Down
2 changes: 1 addition & 1 deletion backend/modules/dataloaders/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def list_dataloaders():
Returns a list of all the registered loaders.
Returns:
List[dict]: A list of all the registered loaders.
List[Dict]: A list of all the registered loaders.
"""
global LOADER_REGISTRY
return [
Expand Down
21 changes: 14 additions & 7 deletions backend/modules/metadata_store/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from backend.constants import DATA_POINT_FQN_METADATA_KEY, FQN_SEPARATOR
from backend.types import (
Expand All @@ -18,7 +18,7 @@
from backend.utils import run_in_executor


# TODO(chiragjn): Ideal would be we make `async def a*` abstract methods and drop sync ones
# TODO (chiragjn): Ideal would be we make `async def a*` abstract methods and drop sync ones
# Implementations can then opt to call their sync versions using run_in_executor
class BaseMetadataStore(ABC):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -300,7 +300,7 @@ async def aupdate_data_ingestion_run_status(
def log_metrics_for_data_ingestion_run(
self,
data_ingestion_run_name: str,
metric_dict: dict[str, int | float],
metric_dict: Dict[str, Union[int, float]],
step: int = 0,
):
"""
Expand All @@ -311,7 +311,7 @@ def log_metrics_for_data_ingestion_run(
async def alog_metrics_for_data_ingestion_run(
self,
data_ingestion_run_name: str,
metric_dict: dict[str, int | float],
metric_dict: Dict[str, Union[int, float]],
step: int = 0,
):
"""
Expand Down Expand Up @@ -346,6 +346,8 @@ async def alog_errors_for_data_ingestion_run(
errors=errors,
)

# TODO (chiragjn): What is the difference between get_collections and this?
# TODO (chiragjn): Return complete entities, why return only str?
def list_collections(
self,
) -> List[str]:
Expand All @@ -354,6 +356,7 @@ def list_collections(
"""
raise NotImplementedError()

# TODO (chiragjn): Return complete entities, why return only str?
async def alist_collections(
self,
) -> List[str]:
Expand All @@ -365,17 +368,19 @@ async def alist_collections(
self.list_collections,
)

# TODO (chiragjn): Return complete entities, why return dict?
def list_data_sources(
self,
) -> List[str]:
) -> List[Dict[str, str]]:
"""
List all data source names from metadata store
"""
raise NotImplementedError()

# TODO (chiragjn): Return complete entities, why return dict?
async def alist_data_sources(
self,
) -> List[str]:
) -> List[Dict[str, str]]:
"""
List all data source names from metadata store
"""
Expand Down Expand Up @@ -410,7 +415,7 @@ async def acreate_rag_app(self, app: RagApplication) -> RagApplicationDto:
"""
return await run_in_executor(None, self.create_rag_app, app=app)

def get_rag_app(self, app_name: str) -> RagApplicationDto | None:
def get_rag_app(self, app_name: str) -> Optional[RagApplicationDto]:
"""
Get a RAG application from the metadata store by name
"""
Expand All @@ -422,12 +427,14 @@ async def aget_rag_app(self, app_name: str) -> Optional[RagApplicationDto]:
"""
return await run_in_executor(None, self.get_rag_app, app_name=app_name)

# TODO (chiragjn): Return complete entities, why return only str?
def list_rag_apps(self) -> List[str]:
"""
List all RAG application names from metadata store
"""
raise NotImplementedError()

# TODO (chiragjn): Return complete entities, why return only str?
async def alist_rag_apps(self) -> List[str]:
"""
List all RAG application names from metadata store
Expand Down
Loading

0 comments on commit 79423c5

Please sign in to comment.