diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index c187542..a21ce9b 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -25,7 +25,6 @@ class SQLCompiler(compiler.SQLCompiler): """Base class for all Mongo compilers.""" query_class = MongoQuery - GROUP_SEPARATOR = "___" PARENT_FIELD_TEMPLATE = "parent__field__{}" def __init__(self, *args, **kwargs): @@ -37,34 +36,6 @@ def __init__(self, *args, **kwargs): self.order_by_objs = None self.subqueries = [] - def _unfold_column(self, col): - """ - Flatten a field by returning its target or by replacing dots with - GROUP_SEPARATOR for foreign fields. - """ - if self.collection_name == col.alias: - return col.target.column - # If this is a foreign field, replace the normal dot (.) with - # GROUP_SEPARATOR since FieldPath field names may not contain '.'. - return f"{col.alias}{self.GROUP_SEPARATOR}{col.target.column}" - - def _fold_columns(self, unfold_columns): - """ - Convert flat columns into a nested dictionary, grouping fields by - table name. - """ - result = defaultdict(dict) - for key in unfold_columns: - value = f"$_id.{key}" - if self.GROUP_SEPARATOR in key: - table, field = key.split(self.GROUP_SEPARATOR) - result[table][field] = value - else: - result[key] = value - # Convert defaultdict to dict so it doesn't appear as - # "defaultdict(, ..." in query logging. - return dict(result) - def _get_group_alias_column(self, expr, annotation_group_idx): """Generate a dummy field for use in the ids fields in $group.""" replacement = None @@ -75,7 +46,7 @@ def _get_group_alias_column(self, expr, annotation_group_idx): alias = f"__annotation_group{next(annotation_group_idx)}" col = self._get_column_from_expression(expr, alias) replacement = col - return self._unfold_column(col), replacement + return col.target.column, replacement def _get_column_from_expression(self, expr, alias): """ @@ -198,18 +169,15 @@ def _get_group_id_expressions(self, order_by): else: annotation_group_idx = itertools.count(start=1) ids = {} + columns = [] for col in group_expressions: alias, replacement = self._get_group_alias_column(col, annotation_group_idx) - try: - ids[alias] = col.as_mql(self, self.connection) - except EmptyResultSet: - ids[alias] = Value(False).as_mql(self, self.connection) - except FullResultSet: - ids[alias] = Value(True).as_mql(self, self.connection) + columns.append((alias, col)) if replacement is not None: replacements[col] = replacement if isinstance(col, Ref): replacements[col.source] = replacement + ids = self.get_project_fields(tuple(columns), force_expression=True) return ids, replacements def _build_aggregation_pipeline(self, ids, group): @@ -234,7 +202,7 @@ def _build_aggregation_pipeline(self, ids, group): else: group["_id"] = ids pipeline.append({"$group": group}) - projected_fields = self._fold_columns(ids) + projected_fields = {key: f"$_id.{key}" for key in ids} pipeline.append({"$addFields": projected_fields}) if "_id" not in projected_fields: pipeline.append({"$unset": "_id"}) @@ -522,15 +490,18 @@ def get_combinator_queries(self): else: combinator_pipeline = inner_pipeline if not self.query.combinator_all: - ids = {} + ids = defaultdict(dict) for alias, expr in main_query_columns: # Unfold foreign fields. if isinstance(expr, Col) and expr.alias != self.collection_name: - ids[self._unfold_column(expr)] = expr.as_mql(self, self.connection) + ids[expr.alias][expr.target.column] = expr.as_mql(self, self.connection) else: ids[alias] = f"${alias}" + # Convert defaultdict to dict so it doesn't appear as + # "defaultdict(, ..." in query logging. + ids = dict(ids) combinator_pipeline.append({"$group": {"_id": ids}}) - projected_fields = self._fold_columns(ids) + projected_fields = {key: f"$_id.{key}" for key in ids} combinator_pipeline.append({"$addFields": projected_fields}) if "_id" not in projected_fields: combinator_pipeline.append({"$unset": "_id"})