diff --git a/pyodmongo/models/db_model.py b/pyodmongo/models/db_model.py index 7fef7f3..f018629 100644 --- a/pyodmongo/models/db_model.py +++ b/pyodmongo/models/db_model.py @@ -7,6 +7,7 @@ resolve_indexes, resolve_ref_pipeline, resolve_class_fields_db_info, + resolve_reference_pipeline, ) @@ -45,6 +46,7 @@ def __init__(self, **attrs): def __pydantic_init_subclass__(cls): resolve_class_fields_db_info(cls=cls) ref_pipeline = resolve_ref_pipeline(cls=cls, pipeline=[], path=[]) + pipeline = resolve_reference_pipeline(cls=cls) setattr(cls, "_reference_pipeline", ref_pipeline) indexes = resolve_indexes(cls=cls) setattr(cls, "_init_indexes", indexes) diff --git a/pyodmongo/services/aggregate_stages.py b/pyodmongo/services/aggregate_stages.py index e56d39f..c11412c 100644 --- a/pyodmongo/services/aggregate_stages.py +++ b/pyodmongo/services/aggregate_stages.py @@ -21,3 +21,47 @@ def lookup_and_set( if is_reference_list: return lookup_stage return lookup_stage + set_stage + + +def unwind(path: str, array_index: str, preserve_empty: bool): + return [ + { + "$unwind": { + "path": f"${path}", + "includeArrayIndex": array_index, + "preserveNullAndEmptyArrays": preserve_empty, + } + } + ] + + +def lookup(_from: str, local_field: str, foreign_field: str, _as: str, pipeline: list): + return [ + { + "$lookup": { + "from": _from, + "localField": local_field, + "foreignField": foreign_field, + "as": _as, + "pipeline": pipeline, + } + } + ] + + +def group_set_replace_root(id: str, field: str, path_str: str): + return [ + { + "$group": { + "_id": f"${id}", + "_document": {"$first": "$$ROOT"}, + field: {"$push": f"${path_str}"}, + } + }, + {"$set": {f"_document.{path_str}": f"${field}"}}, + {"$replaceRoot": {"newRoot": "$_document"}}, + ] + + +def unset(fields: list): + return [{"$unset": fields}] diff --git a/pyodmongo/services/model_init.py b/pyodmongo/services/model_init.py index d8f284a..3efc43e 100644 --- a/pyodmongo/services/model_init.py +++ b/pyodmongo/services/model_init.py @@ -4,7 +4,8 @@ from types import UnionType from ..models.id_model import Id from ..models.db_field_info import DbField -from .aggregate_stages import lookup_and_set +from .aggregate_stages import lookup_and_set, unwind +import copy def resolve_indexes(cls: BaseModel): @@ -147,6 +148,47 @@ def resolve_ref_pipeline(cls: BaseModel, pipeline: list, path: list): return pipeline +def _paths_to_ref_ids(cls: BaseModel, paths: list, single_path: list): + for field, field_info in cls.model_fields.items(): + db_field = field_annotation_infos(field=field, field_info=field_info) + single_path.append(db_field) + if db_field.by_reference: + paths.append(copy.deepcopy(single_path)) + elif db_field.has_model_fields: + _paths_to_ref_ids( + cls=db_field.field_type, paths=paths, single_path=single_path + ) + single_path.pop(-1) + return paths + + +def _mount_pipeline(cls: BaseModel, pipeline: list): + paths = _paths_to_ref_ids(cls=cls, paths=[], single_path=[]) + for path in paths: + path_str = "" + for index, db_filed in enumerate(path): + db_filed: DbField + path_str += ( + "." + db_filed.field_alias if path_str != "" else db_filed.field_alias + ) + if db_filed.is_list and not db_filed.by_reference: + pipeline += unwind( + path=path_str, array_index=f"__unwind_{index}", preserve_empty=True + ) + + return pipeline + + +def resolve_reference_pipeline(cls: BaseModel): + from pprint import pprint + + print("*-" * 40) + print(f"-------------- {cls.__name__} --------------") + pipeline = _mount_pipeline(cls=cls, pipeline=[]) + pprint(pipeline) + # print() + + def _recursice_db_fields_info(db_field_info: DbField, path: list) -> DbField: if db_field_info.has_model_fields: for field, field_info in db_field_info.field_type.model_fields.items(): diff --git a/tests/test_async_crud_db.py b/tests/test_async_crud_db.py index 08a758d..b6ec0a5 100644 --- a/tests/test_async_crud_db.py +++ b/tests/test_async_crud_db.py @@ -488,17 +488,19 @@ class ClassOne(DbModel): class ClassTwoA(BaseModel): attr_2_a: str = "attr 2 A" - class_one: list[ClassOne | Id] | None = [] + class_one: list[ClassOne | Id] | None class ClassTwoB(BaseModel): - attr_2_b: str + attr_2_b: str = "attr 2 B" class_two_a: ClassTwoA | None + class_two_a_list: list[ClassTwoA] | None class ClassThree(DbModel): attr_3: str = "attr 3" - class_two: ClassTwoB | None = ClassTwoB(attr_2_b="attr 2 b", class_two_a=None) + class_two_b: ClassTwoB | None + class_two_b_list: list[ClassTwoB] | None _collection: ClassVar = "class_three" @@ -507,14 +509,42 @@ async def drop_collections_one_three(): await db._db[ClassOne._collection].drop() await db._db[ClassThree._collection].drop() yield - await db._db[ClassOne._collection].drop() - await db._db[ClassThree._collection].drop() + # await db._db[ClassOne._collection].drop() + # await db._db[ClassThree._collection].drop() @pytest.mark.asyncio async def test_nested_none_object(drop_collections_one_three): - obj = ClassThree() - await db.save(obj=ClassOne()) - await db.save(obj=obj) - obj_found = await db.find_one(Model=ClassThree, populate=True) - assert obj_found == obj + # obj_1 = ClassOne(attr_1="obj_1") + # obj_2 = ClassOne(attr_1="obj_2") + # obj_3 = ClassOne(attr_1="obj_3") + # obj_4 = ClassOne(attr_1="obj_4") + # obj_5 = ClassOne(attr_1="obj_5") + # obj_6 = ClassOne(attr_1="obj_6") + # obj_7 = ClassOne(attr_1="obj_7") + # obj_8 = ClassOne(attr_1="obj_8") + # await db.save_all([obj_1, obj_2, obj_3, obj_4, obj_5, obj_6, obj_7, obj_8]) + # obj_9 = ClassTwoA(attr_2_a="obj_9", class_one=[obj_1, obj_2]) + # obj_10 = ClassTwoA(attr_2_a="obj_10", class_one=[obj_3, obj_4]) + # obj_11 = ClassTwoA(attr_2_a="obj_11", class_one=[obj_5, obj_6]) + # obj_12 = ClassTwoA(attr_2_a="obj_12", class_one=[obj_7, obj_8]) + # obj_13 = ClassTwoB( + # attr_2_b="obj_13", class_two_a=obj_9, class_two_a_list=[obj_9, obj_10] + # ) + # obj_14 = ClassTwoB( + # attr_2_b="obj_14", class_two_a=obj_11, class_two_a_list=[obj_11, obj_12] + # ) + # obj_15 = ClassThree( + # attr_3="obj_16", class_two_b=obj_13, class_two_b_list=[obj_13, obj_14] + # ) + # await db.save(obj_15) + + # obj_list = await db.find_many(Model=ClassThree) + # print(obj_list) + # obj = ClassThree() + # await db.save(obj=ClassOne()) + # await db.save(obj=obj) + # print(ClassThree._reference_pipeline) + # obj_found = await db.find_one(Model=ClassThree, populate=True) + # assert obj_found == obj + pass