From 56e3dd43c8c1f24534c651feb5fe054ac30c5676 Mon Sep 17 00:00:00 2001 From: Mauro Andre Date: Wed, 13 Mar 2024 10:14:55 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=8E=A8=20Pipeline=20reference=20at=20firs?= =?UTF-8?q?t?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyodmongo/engine/utils.py | 3 +- pyodmongo/queries/query_string.py | 3 +- tests/test_async_crud_db.py | 46 +++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 2 deletions(-) diff --git a/pyodmongo/engine/utils.py b/pyodmongo/engine/utils.py index 24b37af..790d8f5 100644 --- a/pyodmongo/engine/utils.py +++ b/pyodmongo/engine/utils.py @@ -53,6 +53,7 @@ def mount_base_pipeline(Model, query, populate: bool = False): model_stage = Model._pipeline reference_stage = Model._reference_pipeline if populate: - return match_stage + model_stage + reference_stage + return reference_stage + match_stage + model_stage + # return match_stage + model_stage + reference_stage else: return match_stage + model_stage diff --git a/pyodmongo/queries/query_string.py b/pyodmongo/queries/query_string.py index ed8ecf8..d220318 100644 --- a/pyodmongo/queries/query_string.py +++ b/pyodmongo/queries/query_string.py @@ -66,7 +66,8 @@ def mount_query_filter( for index, item in enumerate(value): value[index] = js_regex_to_python(item) try: - db_field_info: DbField = getattr(Model, field_name) + # db_field_info: DbField = getattr(Model, field_name) + db_field_info: DbField = eval("Model." + field_name) except AttributeError: raise AttributeError(f"There's no field '{field_name}' in {Model.__name__}") initial_comparison_operators.append( diff --git a/tests/test_async_crud_db.py b/tests/test_async_crud_db.py index 4ab9e47..15413e2 100644 --- a/tests/test_async_crud_db.py +++ b/tests/test_async_crud_db.py @@ -433,3 +433,49 @@ async def test_recursive_reference_pipeline(create_find_dict_collection): d: D = await db.find_one(Model=D, populate=True) assert d.d1[0].b2[0].b1.a1 == "A" + + +class ClassA(DbModel): + attr_1: str = "A String 1" + attr_2: str = "A String 2" + _collection: ClassVar = "col_a" + + +class ClassB(DbModel): + attr_3: str = "A String 3" + a: ClassA | Id + _collection: ClassVar = "col_b" + + +@pytest_asyncio.fixture() +async def drop_collections(): + await db._db[ClassA._collection].drop() + await db._db[ClassB._collection].drop() + yield + await db._db[ClassA._collection].drop() + await db._db[ClassB._collection].drop() + + +@pytest.mark.asyncio +async def test_find_nested_field_query(drop_collections): + obj_a = ClassA() + await db.save(obj=obj_a) + obj_b = ClassB(a=obj_a) + await db.save(obj=obj_b) + query = eq(ClassB.a.attr_2, "A String 2") + result = await db.find_many(Model=ClassB, query=query, populate=True) + assert result == [obj_b] + + +@pytest.mark.asyncio +async def test_find_nested_field_mount_query(drop_collections): + obj_a = ClassA() + await db.save(obj=obj_a) + obj_b = ClassB(a=obj_a) + await db.save(obj=obj_b) + input_dict = {"a.attr_2_eq": "A String 2"} + query = mount_query_filter( + Model=ClassB, items=input_dict, initial_comparison_operators=[] + ) + result = await db.find_many(Model=ClassB, query=query, populate=True) + assert result == [obj_b]