From e4d83b75dfb6882ad68047222a127c84cf2ea30e Mon Sep 17 00:00:00 2001 From: Mauro Andre Date: Wed, 15 Jan 2025 12:01:47 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=91=20Added=20sort=20stage=20before=20?= =?UTF-8?q?group=20stage=20in=20pipeline.=20Unset=20after=20all=20nested?= =?UTF-8?q?=20group=20stages?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyodmongo/services/aggregate_stages.py | 5 ++++- pyodmongo/services/reference_pipeline.py | 10 +++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pyodmongo/services/aggregate_stages.py b/pyodmongo/services/aggregate_stages.py index 3bdcd81..3c71a29 100644 --- a/pyodmongo/services/aggregate_stages.py +++ b/pyodmongo/services/aggregate_stages.py @@ -100,7 +100,9 @@ def set_(local_field: str, as_: str): return pipeline -def group_set_replace_root(id_: list[str], array_index: str, field: str, path_str: str): +def group_set_replace_root( + to_sort: dict, id_: list[str], array_index: str, field: str, path_str: str +): """ Constructs a combination of group, set, and replaceRoot stages for a MongoDB aggregation pipeline. @@ -118,6 +120,7 @@ def group_set_replace_root(id_: list[str], array_index: str, field: str, path_st to the top-level. """ return [ + {"$sort": to_sort}, { "$group": { "_id": id_, diff --git a/pyodmongo/services/reference_pipeline.py b/pyodmongo/services/reference_pipeline.py index 4158a1a..2fc798d 100644 --- a/pyodmongo/services/reference_pipeline.py +++ b/pyodmongo/services/reference_pipeline.py @@ -98,17 +98,21 @@ def resolve_reference_pipeline( if not db_field.is_list: pipeline += set_(local_field=path_str, as_=path_str) + index_to_unset = [] for index, path_str in enumerate(reversed(paths_str_to_group)): id_ = [ f"${e}" for e in unwind_index_list[: len(unwind_index_list) - index - 1] ] - to_unset = unwind_index_list[-(index + 1)] + index_to_unset.append(unwind_index_list[-(index + 1)]) + to_sort = {key: 1 for key in unwind_index_list[1:]} pipeline += group_set_replace_root( - id_=[id_], + to_sort=to_sort, + id_=id_, array_index=unwind_index_list[-1], field=path_str.split(".")[-1], path_str=path_str, ) - pipeline += unset(fields=[to_unset]) + if len(index_to_unset) > 0: + pipeline += unset(fields=index_to_unset) return pipeline