diff --git a/pyodmongo/models/db_field_info.py b/pyodmongo/models/db_field_info.py index 95440bd..8754b02 100644 --- a/pyodmongo/models/db_field_info.py +++ b/pyodmongo/models/db_field_info.py @@ -1,8 +1,9 @@ from typing import Any -from pydantic import BaseModel, ConfigDict +from dataclasses import dataclass -class DbField(BaseModel): +@dataclass +class DbField: field_name: str = None field_alias: str = None path_str: str = None @@ -10,4 +11,3 @@ class DbField(BaseModel): by_reference: bool = None is_list: bool = None has_model_fields: bool = None - model_config = ConfigDict(extra="allow") diff --git a/pyodmongo/models/db_model.py b/pyodmongo/models/db_model.py index edddde3..6561cf7 100644 --- a/pyodmongo/models/db_model.py +++ b/pyodmongo/models/db_model.py @@ -6,20 +6,28 @@ resolve_indexes, resolve_class_fields_db_info, resolve_reference_pipeline, + resolve_db_fields, ) from pydantic import BaseModel from pydantic._internal._model_construction import ModelMetaclass from typing_extensions import dataclass_transform from typing import ClassVar +import copy -@dataclass_transform() +@dataclass_transform(kw_only_default=True) class DbMeta(ModelMetaclass): - def __new__(cls, name: str, bases: tuple, namespace: dict, **kwargs: Any) -> type: + def __new__( + cls, name: str, bases: tuple[Any], namespace: dict, **kwargs: Any + ) -> type: setattr(cls, "__pyodmongo_complete__", False) for base in bases: setattr(base, "__pyodmongo_complete__", False) + # TODO finish db_fields after ModelMetaclass + db_fields = copy.deepcopy(namespace.get("__annotations__")) + db_fields = resolve_db_fields(bases=bases, db_fields=db_fields) + cls: BaseModel = ModelMetaclass.__new__(cls, name, bases, namespace, **kwargs) setattr(cls, "__pyodmongo_complete__", True) diff --git a/pyodmongo/services/model_init.py b/pyodmongo/services/model_init.py index ce45531..abccba2 100644 --- a/pyodmongo/services/model_init.py +++ b/pyodmongo/services/model_init.py @@ -188,3 +188,21 @@ def resolve_class_fields_db_info(cls: BaseModel): db_field_info=db_field_info, path=[path] ) setattr(cls, field + "__pyodmongo", field_to_set) + + +# TODO finish resolve_single_db_field +def resolve_single_db_field(value: Any) -> DbField: + by_reference = False + return value + + +def resolve_db_fields(bases: tuple[Any], db_fields: dict): + for base in bases: + if base is object: + continue + base_annotations = base.__dict__.get("__annotations__") + for base_field, value in base_annotations.items(): + if base_field not in db_fields.keys() and not base_field.startswith("_"): + db_fields[base_field] = resolve_single_db_field(value=value) + resolve_db_fields(bases=base.__bases__, db_fields=db_fields) + return db_fields