Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🎨 Pipeline reference at first #105

Merged
merged 1 commit into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyodmongo/engine/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def mount_base_pipeline(Model, query, populate: bool = False):
model_stage = Model._pipeline
reference_stage = Model._reference_pipeline
if populate:
return match_stage + model_stage + reference_stage
return reference_stage + match_stage + model_stage
# return match_stage + model_stage + reference_stage
else:
return match_stage + model_stage
3 changes: 2 additions & 1 deletion pyodmongo/queries/query_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,8 @@ def mount_query_filter(
for index, item in enumerate(value):
value[index] = js_regex_to_python(item)
try:
db_field_info: DbField = getattr(Model, field_name)
# db_field_info: DbField = getattr(Model, field_name)
db_field_info: DbField = eval("Model." + field_name)
except AttributeError:
raise AttributeError(f"There's no field '{field_name}' in {Model.__name__}")
initial_comparison_operators.append(
Expand Down
46 changes: 46 additions & 0 deletions tests/test_async_crud_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,3 +433,49 @@ async def test_recursive_reference_pipeline(create_find_dict_collection):
d: D = await db.find_one(Model=D, populate=True)

assert d.d1[0].b2[0].b1.a1 == "A"


class ClassA(DbModel):
attr_1: str = "A String 1"
attr_2: str = "A String 2"
_collection: ClassVar = "col_a"


class ClassB(DbModel):
attr_3: str = "A String 3"
a: ClassA | Id
_collection: ClassVar = "col_b"


@pytest_asyncio.fixture()
async def drop_collections():
await db._db[ClassA._collection].drop()
await db._db[ClassB._collection].drop()
yield
await db._db[ClassA._collection].drop()
await db._db[ClassB._collection].drop()


@pytest.mark.asyncio
async def test_find_nested_field_query(drop_collections):
obj_a = ClassA()
await db.save(obj=obj_a)
obj_b = ClassB(a=obj_a)
await db.save(obj=obj_b)
query = eq(ClassB.a.attr_2, "A String 2")
result = await db.find_many(Model=ClassB, query=query, populate=True)
assert result == [obj_b]


@pytest.mark.asyncio
async def test_find_nested_field_mount_query(drop_collections):
obj_a = ClassA()
await db.save(obj=obj_a)
obj_b = ClassB(a=obj_a)
await db.save(obj=obj_b)
input_dict = {"a.attr_2_eq": "A String 2"}
query = mount_query_filter(
Model=ClassB, items=input_dict, initial_comparison_operators=[]
)
result = await db.find_many(Model=ClassB, query=query, populate=True)
assert result == [obj_b]
Loading