Skip to content

Commit

Permalink
🎨 Recursive nested object pipeline
Browse files Browse the repository at this point in the history
  • Loading branch information
mauro-andre committed Mar 15, 2024
1 parent 424444e commit 73c11e7
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 24 deletions.
43 changes: 30 additions & 13 deletions pyodmongo/services/model_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,37 +96,54 @@ def field_annotation_infos(field, field_info) -> DbField:
)


def resolve_project_pipeline(cls: BaseModel):
project_dict = {"_id": True}
def resolve_project_pipeline(cls: BaseModel, path: list):
project = {}
for field, field_info in cls.model_fields.items():
field_alias = field_info.alias or field
project_dict[field_alias] = True
try:
project_dict.pop("id")
except KeyError:
pass
return [{"$project": project_dict}]
db_field_info = field_annotation_infos(field=field, field_info=field_info)
path.append(db_field_info.field_alias)
path_str = ".".join(path)
project[path_str] = True
if db_field_info.has_model_fields:
if not db_field_info.by_reference:
project.pop(path_str)
project.update(
resolve_project_pipeline(cls=db_field_info.field_type, path=path)
)
path.pop(-1)
return project


def resolve_ref_pipeline(cls: BaseModel, pipeline: list, path: list):
for field, field_info in cls.model_fields.items():
db_field_info = field_annotation_infos(field=field, field_info=field_info)
path.append(db_field_info.field_alias)
path_str = ".".join(path)
if db_field_info.has_model_fields:
if db_field_info.by_reference:
collection = db_field_info.field_type._collection
pipeline += lookup_and_set(
from_=collection,
local_field=db_field_info.field_alias,
local_field=path_str,
foreign_field="_id",
as_=db_field_info.field_alias,
as_=path_str,
pipeline=resolve_ref_pipeline(
cls=db_field_info.field_type, pipeline=[], path=[]
),
is_reference_list=db_field_info.is_list,
)
else:
resolve_ref_pipeline(cls=db_field_info.field_type, pipeline=pipeline)
pipeline += resolve_project_pipeline(cls=cls)
resolve_ref_pipeline(
cls=db_field_info.field_type,
pipeline=pipeline,
path=path,
)
path.pop(-1)
project = resolve_project_pipeline(cls=cls, path=[])
try:
project_index = [list(dct.keys())[0] for dct in pipeline].index("$project")
pipeline[project_index] = {"$project": project}
except ValueError:
pipeline += [{"$project": project}]
return pipeline


Expand Down
122 changes: 111 additions & 11 deletions tests/test_reference_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,51 +131,151 @@ class MyModel2(DbModel):
def test_recursive_reference_pipeline():
from pprint import pprint

class Zero(DbModel):
attr_0: str = "Zero"
_collection: ClassVar = "col_0"

class A(DbModel):
attr_1: str = "One"
_collection: ClassVar = "a"
zero_1: Zero | Id = Zero()
zero_2: Zero = Zero()
_collection: ClassVar = "col_a"

class B(BaseModel):
attr_2: str = "Two"
a: A | Id = A()
a1: A | Id = A()
a2: A = A()

class C(DbModel):
attr_3: str = "Three"
b: B = B()
a: A | Id = A()
_collection: ClassVar = "c"
b: B = B()
_collection: ClassVar = "col_c"

expected = [
assert C._reference_pipeline == [
{
"$lookup": {
"as": "a",
"foreignField": "_id",
"from": "a",
"from": "col_a",
"localField": "a",
"pipeline": [
{
"$lookup": {
"as": "zero_1",
"foreignField": "_id",
"from": "col_0",
"localField": "zero_1",
"pipeline": [
{
"$project": {
"_id": True,
"attr_0": True,
"created_at": True,
"updated_at": True,
}
}
],
}
},
{"$set": {"zero_1": {"$arrayElemAt": ["$zero_1", 0]}}},
{
"$project": {
"_id": True,
"attr_1": True,
"created_at": True,
"updated_at": True,
"zero_1": True,
"zero_2._id": True,
"zero_2.attr_0": True,
"zero_2.created_at": True,
"zero_2.updated_at": True,
}
}
},
],
}
},
{"$set": {"a": {"$arrayElemAt": ["$a", 0]}}},
{
"$lookup": {
"as": "b.a1",
"foreignField": "_id",
"from": "col_a",
"localField": "b.a1",
"pipeline": [
{
"$lookup": {
"as": "zero_1",
"foreignField": "_id",
"from": "col_0",
"localField": "zero_1",
"pipeline": [
{
"$project": {
"_id": True,
"attr_0": True,
"created_at": True,
"updated_at": True,
}
}
],
}
},
{"$set": {"zero_1": {"$arrayElemAt": ["$zero_1", 0]}}},
{
"$project": {
"_id": True,
"attr_1": True,
"created_at": True,
"updated_at": True,
"zero_1": True,
"zero_2._id": True,
"zero_2.attr_0": True,
"zero_2.created_at": True,
"zero_2.updated_at": True,
}
},
],
}
},
{"$set": {"b.a1": {"$arrayElemAt": ["$b.a1", 0]}}},
{
"$lookup": {
"as": "b.a2.zero_1",
"foreignField": "_id",
"from": "col_0",
"localField": "b.a2.zero_1",
"pipeline": [
{
"$project": {
"_id": True,
"attr_0": True,
"created_at": True,
"updated_at": True,
}
}
],
}
},
{"$set": {"b.a2.zero_1": {"$arrayElemAt": ["$b.a2.zero_1", 0]}}},
{
"$project": {
"_id": True,
"a": True,
"attr_3": True,
"b": True,
"b.a1": True,
"b.a2._id": True,
"b.a2.attr_1": True,
"b.a2.created_at": True,
"b.a2.updated_at": True,
"b.a2.zero_1": True,
"b.a2.zero_2._id": True,
"b.a2.zero_2.attr_0": True,
"b.a2.zero_2.created_at": True,
"b.a2.zero_2.updated_at": True,
"b.attr_2": True,
"created_at": True,
"updated_at": True,
}
},
]

print()
pprint(C._reference_pipeline)

0 comments on commit 73c11e7

Please sign in to comment.