diff --git a/pyodmongo/services/aggregate_stages.py b/pyodmongo/services/aggregate_stages.py index 3bdcd81..3c71a29 100644 --- a/pyodmongo/services/aggregate_stages.py +++ b/pyodmongo/services/aggregate_stages.py @@ -100,7 +100,9 @@ def set_(local_field: str, as_: str): return pipeline -def group_set_replace_root(id_: list[str], array_index: str, field: str, path_str: str): +def group_set_replace_root( + to_sort: dict, id_: list[str], array_index: str, field: str, path_str: str +): """ Constructs a combination of group, set, and replaceRoot stages for a MongoDB aggregation pipeline. @@ -118,6 +120,7 @@ def group_set_replace_root(id_: list[str], array_index: str, field: str, path_st to the top-level. """ return [ + {"$sort": to_sort}, { "$group": { "_id": id_, diff --git a/pyodmongo/services/reference_pipeline.py b/pyodmongo/services/reference_pipeline.py index 4158a1a..2fc798d 100644 --- a/pyodmongo/services/reference_pipeline.py +++ b/pyodmongo/services/reference_pipeline.py @@ -98,17 +98,21 @@ def resolve_reference_pipeline( if not db_field.is_list: pipeline += set_(local_field=path_str, as_=path_str) + index_to_unset = [] for index, path_str in enumerate(reversed(paths_str_to_group)): id_ = [ f"${e}" for e in unwind_index_list[: len(unwind_index_list) - index - 1] ] - to_unset = unwind_index_list[-(index + 1)] + index_to_unset.append(unwind_index_list[-(index + 1)]) + to_sort = {key: 1 for key in unwind_index_list[1:]} pipeline += group_set_replace_root( - id_=[id_], + to_sort=to_sort, + id_=id_, array_index=unwind_index_list[-1], field=path_str.split(".")[-1], path_str=path_str, ) - pipeline += unset(fields=[to_unset]) + if len(index_to_unset) > 0: + pipeline += unset(fields=index_to_unset) return pipeline