Skip to content

Commit

Permalink
♻️ Start refator with paths to reference ids
Browse files Browse the repository at this point in the history
  • Loading branch information
mauro-andre committed Apr 9, 2024
1 parent 4f27240 commit d61011b
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 11 deletions.
2 changes: 2 additions & 0 deletions pyodmongo/models/db_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
resolve_indexes,
resolve_ref_pipeline,
resolve_class_fields_db_info,
resolve_reference_pipeline,
)


Expand Down Expand Up @@ -45,6 +46,7 @@ def __init__(self, **attrs):
def __pydantic_init_subclass__(cls):
resolve_class_fields_db_info(cls=cls)
ref_pipeline = resolve_ref_pipeline(cls=cls, pipeline=[], path=[])
pipeline = resolve_reference_pipeline(cls=cls)
setattr(cls, "_reference_pipeline", ref_pipeline)
indexes = resolve_indexes(cls=cls)
setattr(cls, "_init_indexes", indexes)
44 changes: 44 additions & 0 deletions pyodmongo/services/aggregate_stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,47 @@ def lookup_and_set(
if is_reference_list:
return lookup_stage
return lookup_stage + set_stage


def unwind(path: str, array_index: str, preserve_empty: bool):
return [
{
"$unwind": {
"path": f"${path}",
"includeArrayIndex": array_index,
"preserveNullAndEmptyArrays": preserve_empty,
}
}
]


def lookup(_from: str, local_field: str, foreign_field: str, _as: str, pipeline: list):
return [
{
"$lookup": {
"from": _from,
"localField": local_field,
"foreignField": foreign_field,
"as": _as,
"pipeline": pipeline,
}
}
]


def group_set_replace_root(id: str, field: str, path_str: str):
return [
{
"$group": {
"_id": f"${id}",
"_document": {"$first": "$$ROOT"},
field: {"$push": f"${path_str}"},
}
},
{"$set": {f"_document.{path_str}": f"${field}"}},
{"$replaceRoot": {"newRoot": "$_document"}},
]


def unset(fields: list):
return [{"$unset": fields}]
44 changes: 43 additions & 1 deletion pyodmongo/services/model_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from types import UnionType
from ..models.id_model import Id
from ..models.db_field_info import DbField
from .aggregate_stages import lookup_and_set
from .aggregate_stages import lookup_and_set, unwind
import copy


def resolve_indexes(cls: BaseModel):
Expand Down Expand Up @@ -147,6 +148,47 @@ def resolve_ref_pipeline(cls: BaseModel, pipeline: list, path: list):
return pipeline


def _paths_to_ref_ids(cls: BaseModel, paths: list, single_path: list):
for field, field_info in cls.model_fields.items():
db_field = field_annotation_infos(field=field, field_info=field_info)
single_path.append(db_field)
if db_field.by_reference:
paths.append(copy.deepcopy(single_path))
elif db_field.has_model_fields:
_paths_to_ref_ids(
cls=db_field.field_type, paths=paths, single_path=single_path
)
single_path.pop(-1)
return paths


def _mount_pipeline(cls: BaseModel, pipeline: list):
paths = _paths_to_ref_ids(cls=cls, paths=[], single_path=[])
for path in paths:
path_str = ""
for index, db_filed in enumerate(path):
db_filed: DbField
path_str += (
"." + db_filed.field_alias if path_str != "" else db_filed.field_alias
)
if db_filed.is_list and not db_filed.by_reference:
pipeline += unwind(
path=path_str, array_index=f"__unwind_{index}", preserve_empty=True
)

return pipeline


def resolve_reference_pipeline(cls: BaseModel):
from pprint import pprint

print("*-" * 40)
print(f"-------------- {cls.__name__} --------------")
pipeline = _mount_pipeline(cls=cls, pipeline=[])
pprint(pipeline)
# print()


def _recursice_db_fields_info(db_field_info: DbField, path: list) -> DbField:
if db_field_info.has_model_fields:
for field, field_info in db_field_info.field_type.model_fields.items():
Expand Down
50 changes: 40 additions & 10 deletions tests/test_async_crud_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,17 +488,19 @@ class ClassOne(DbModel):

class ClassTwoA(BaseModel):
attr_2_a: str = "attr 2 A"
class_one: list[ClassOne | Id] | None = []
class_one: list[ClassOne | Id] | None


class ClassTwoB(BaseModel):
attr_2_b: str
attr_2_b: str = "attr 2 B"
class_two_a: ClassTwoA | None
class_two_a_list: list[ClassTwoA] | None


class ClassThree(DbModel):
attr_3: str = "attr 3"
class_two: ClassTwoB | None = ClassTwoB(attr_2_b="attr 2 b", class_two_a=None)
class_two_b: ClassTwoB | None
class_two_b_list: list[ClassTwoB] | None
_collection: ClassVar = "class_three"


Expand All @@ -507,14 +509,42 @@ async def drop_collections_one_three():
await db._db[ClassOne._collection].drop()
await db._db[ClassThree._collection].drop()
yield
await db._db[ClassOne._collection].drop()
await db._db[ClassThree._collection].drop()
# await db._db[ClassOne._collection].drop()
# await db._db[ClassThree._collection].drop()


@pytest.mark.asyncio
async def test_nested_none_object(drop_collections_one_three):
obj = ClassThree()
await db.save(obj=ClassOne())
await db.save(obj=obj)
obj_found = await db.find_one(Model=ClassThree, populate=True)
assert obj_found == obj
# obj_1 = ClassOne(attr_1="obj_1")
# obj_2 = ClassOne(attr_1="obj_2")
# obj_3 = ClassOne(attr_1="obj_3")
# obj_4 = ClassOne(attr_1="obj_4")
# obj_5 = ClassOne(attr_1="obj_5")
# obj_6 = ClassOne(attr_1="obj_6")
# obj_7 = ClassOne(attr_1="obj_7")
# obj_8 = ClassOne(attr_1="obj_8")
# await db.save_all([obj_1, obj_2, obj_3, obj_4, obj_5, obj_6, obj_7, obj_8])
# obj_9 = ClassTwoA(attr_2_a="obj_9", class_one=[obj_1, obj_2])
# obj_10 = ClassTwoA(attr_2_a="obj_10", class_one=[obj_3, obj_4])
# obj_11 = ClassTwoA(attr_2_a="obj_11", class_one=[obj_5, obj_6])
# obj_12 = ClassTwoA(attr_2_a="obj_12", class_one=[obj_7, obj_8])
# obj_13 = ClassTwoB(
# attr_2_b="obj_13", class_two_a=obj_9, class_two_a_list=[obj_9, obj_10]
# )
# obj_14 = ClassTwoB(
# attr_2_b="obj_14", class_two_a=obj_11, class_two_a_list=[obj_11, obj_12]
# )
# obj_15 = ClassThree(
# attr_3="obj_16", class_two_b=obj_13, class_two_b_list=[obj_13, obj_14]
# )
# await db.save(obj_15)

# obj_list = await db.find_many(Model=ClassThree)
# print(obj_list)
# obj = ClassThree()
# await db.save(obj=ClassOne())
# await db.save(obj=obj)
# print(ClassThree._reference_pipeline)
# obj_found = await db.find_one(Model=ClassThree, populate=True)
# assert obj_found == obj
pass

0 comments on commit d61011b

Please sign in to comment.