diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index 0b724114..73d01beb 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -434,13 +434,16 @@ def project_field(column): ) @cached_property - def collection_name(self): - base_table = next( + def base_table(self): + return next( v for k, v in self.query.alias_map.items() if isinstance(v, BaseTable) and self.query.alias_refcount[k] ) - return base_table.table_alias or base_table.table_name + + @cached_property + def collection_name(self): + return self.base_table.table_alias or self.base_table.table_name @cached_property def collection(self): @@ -469,6 +472,7 @@ def get_combinator_queries(self): ) ) compiler_.pre_sql_setup() + compiler_.column_indices = self.column_indices columns = compiler_.get_columns() parts.append((compiler_.build_query(columns), compiler_, columns)) except EmptyResultSet: @@ -496,7 +500,12 @@ def get_combinator_queries(self): # Combine query with the current combinator pipeline. if combinator_pipeline: combinator_pipeline.append( - {"$unionWith": {"coll": compiler_.collection_name, "pipeline": inner_pipeline}} + { + "$unionWith": { + "coll": compiler_.base_table.table_name, + "pipeline": inner_pipeline, + } + } ) else: combinator_pipeline = inner_pipeline diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py index a95e8b1c..626f7711 100644 --- a/django_mongodb/expressions.py +++ b/django_mongodb/expressions.py @@ -122,39 +122,44 @@ def query(self, compiler, connection, lookup_name=None): if lookup_name in ("in", "range"): if subquery.aggregation_pipeline is None: subquery.aggregation_pipeline = [] - subquery.aggregation_pipeline.extend( - [ - { - "$facet": { - "group": [ + wrapping_result_pipeline = [ + { + "$facet": { + "group": [ + { + "$group": { + "_id": None, + "tmp_name": { + "$addToSet": expr.as_mql(subquery_compiler, connection) + }, + } + } + ] + } + }, + { + "$project": { + field_name: { + "$ifNull": [ { - "$group": { - "_id": None, - "tmp_name": { - "$addToSet": expr.as_mql(subquery_compiler, connection) - }, + "$getField": { + "input": {"$arrayElemAt": ["$group", 0]}, + "field": "tmp_name", } - } + }, + [], ] } - }, - { - "$project": { - field_name: { - "$ifNull": [ - { - "$getField": { - "input": {"$arrayElemAt": ["$group", 0]}, - "field": "tmp_name", - } - }, - [], - ] - } - } - }, - ] - ) + } + }, + ] + # If the subquery is a combinator, wrap the result at the end of the + # combinator pipeline... + if subquery.query.combinator: + subquery.combinator_pipeline.extend(wrapping_result_pipeline) + # ... otherwise put at the end of subquery's pipeline. + else: + subquery.aggregation_pipeline.extend(wrapping_result_pipeline) # Erase project_fields since the required value is projected above. subquery.project_fields = None compiler.subqueries.append(subquery) diff --git a/django_mongodb/features.py b/django_mongodb/features.py index 3be98779..d9cb8b89 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -73,11 +73,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): # Connection creation doesn't follow the usual Django API. "backends.tests.ThreadTests.test_pass_connection_between_threads", "backends.tests.ThreadTests.test_default_connection_thread_local", - # Union as subquery is not mapping the parent parameter and collections: - # https://github.com/mongodb-labs/django-mongodb/issues/156 - "queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_subquery_related_outerref", - "queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_subquery", - "queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_with_ordering", # ObjectId type mismatch in a subquery: # https://github.com/mongodb-labs/django-mongodb/issues/161 "queries.tests.RelatedLookupTypeTests.test_values_queryset_lookup",