Skip to content

Commit

Permalink
refactor "group by" to use embedded documents
Browse files Browse the repository at this point in the history
Foreign fields are computed as {'T3': {'age': '$T3.age'} instead of 
{'T3___age': '$T3.age'}.
  • Loading branch information
WaVEV authored Oct 21, 2024
1 parent b3fd2c4 commit 56cd604
Showing 1 changed file with 11 additions and 40 deletions.
51 changes: 11 additions & 40 deletions django_mongodb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class SQLCompiler(compiler.SQLCompiler):
"""Base class for all Mongo compilers."""

query_class = MongoQuery
GROUP_SEPARATOR = "___"
PARENT_FIELD_TEMPLATE = "parent__field__{}"

def __init__(self, *args, **kwargs):
Expand All @@ -37,34 +36,6 @@ def __init__(self, *args, **kwargs):
self.order_by_objs = None
self.subqueries = []

def _unfold_column(self, col):
"""
Flatten a field by returning its target or by replacing dots with
GROUP_SEPARATOR for foreign fields.
"""
if self.collection_name == col.alias:
return col.target.column
# If this is a foreign field, replace the normal dot (.) with
# GROUP_SEPARATOR since FieldPath field names may not contain '.'.
return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}"

def _fold_columns(self, unfold_columns):
"""
Convert flat columns into a nested dictionary, grouping fields by
table name.
"""
result = defaultdict(dict)
for key in unfold_columns:
value = f"$_id.{key}"
if self.GROUP_SEPARATOR in key:
table, field = key.split(self.GROUP_SEPARATOR)
result[table][field] = value
else:
result[key] = value
# Convert defaultdict to dict so it doesn't appear as
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
return dict(result)

def _get_group_alias_column(self, expr, annotation_group_idx):
"""Generate a dummy field for use in the ids fields in $group."""
replacement = None
Expand All @@ -75,7 +46,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx):
alias = f"__annotation_group{next(annotation_group_idx)}"
col = self._get_column_from_expression(expr, alias)
replacement = col
return self._unfold_column(col), replacement
return col.target.column, replacement

def _get_column_from_expression(self, expr, alias):
"""
Expand Down Expand Up @@ -198,18 +169,15 @@ def _get_group_id_expressions(self, order_by):
else:
annotation_group_idx = itertools.count(start=1)
ids = {}
columns = []
for col in group_expressions:
alias, replacement = self._get_group_alias_column(col, annotation_group_idx)
try:
ids[alias] = col.as_mql(self, self.connection)
except EmptyResultSet:
ids[alias] = Value(False).as_mql(self, self.connection)
except FullResultSet:
ids[alias] = Value(True).as_mql(self, self.connection)
columns.append((alias, col))
if replacement is not None:
replacements[col] = replacement
if isinstance(col, Ref):
replacements[col.source] = replacement
ids = self.get_project_fields(tuple(columns), force_expression=True)
return ids, replacements

def _build_aggregation_pipeline(self, ids, group):
Expand All @@ -234,7 +202,7 @@ def _build_aggregation_pipeline(self, ids, group):
else:
group["_id"] = ids
pipeline.append({"$group": group})
projected_fields = self._fold_columns(ids)
projected_fields = {key: f"$_id.{key}" for key in ids}
pipeline.append({"$addFields": projected_fields})
if "_id" not in projected_fields:
pipeline.append({"$unset": "_id"})
Expand Down Expand Up @@ -522,15 +490,18 @@ def get_combinator_queries(self):
else:
combinator_pipeline = inner_pipeline
if not self.query.combinator_all:
ids = {}
ids = defaultdict(dict)
for alias, expr in main_query_columns:
# Unfold foreign fields.
if isinstance(expr, Col) and expr.alias != self.collection_name:
ids[self._unfold_column(expr)] = expr.as_mql(self, self.connection)
ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection)
else:
ids[alias] = f"${alias}"
# Convert defaultdict to dict so it doesn't appear as
# "defaultdict(<CLASS 'dict'>, ..." in query logging.
ids = dict(ids)
combinator_pipeline.append({"$group": {"_id": ids}})
projected_fields = self._fold_columns(ids)
projected_fields = {key: f"$_id.{key}" for key in ids}
combinator_pipeline.append({"$addFields": projected_fields})
if "_id" not in projected_fields:
combinator_pipeline.append({"$unset": "_id"})
Expand Down

0 comments on commit 56cd604

Please sign in to comment.