Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix queries when subquery has a union #188

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions django_mongodb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,13 +434,16 @@ def project_field(column):
)

@cached_property
def collection_name(self):
base_table = next(
def base_table(self):
return next(
v
for k, v in self.query.alias_map.items()
if isinstance(v, BaseTable) and self.query.alias_refcount[k]
)
return base_table.table_alias or base_table.table_name

@cached_property
def collection_name(self):
return self.base_table.table_alias or self.base_table.table_name

@cached_property
def collection(self):
Expand Down Expand Up @@ -469,6 +472,7 @@ def get_combinator_queries(self):
)
)
compiler_.pre_sql_setup()
compiler_.column_indices = self.column_indices
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If self.column_indices is a list, should we copy it instead?

Copy link
Collaborator Author

@WaVEV WaVEV Nov 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, passing by reference was on purpose, the compiler in a union query is made by multiples compilers, so those sub-compilers could use some field from the parent's parent query, then column_indices should be shared.

columns = compiler_.get_columns()
parts.append((compiler_.build_query(columns), compiler_, columns))
except EmptyResultSet:
Expand Down Expand Up @@ -496,7 +500,12 @@ def get_combinator_queries(self):
# Combine query with the current combinator pipeline.
if combinator_pipeline:
combinator_pipeline.append(
{"$unionWith": {"coll": compiler_.collection_name, "pipeline": inner_pipeline}}
{
"$unionWith": {
"coll": compiler_.base_table.table_name,
"pipeline": inner_pipeline,
}
}
)
else:
combinator_pipeline = inner_pipeline
Expand Down
63 changes: 34 additions & 29 deletions django_mongodb/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,39 +122,44 @@ def query(self, compiler, connection, lookup_name=None):
if lookup_name in ("in", "range"):
if subquery.aggregation_pipeline is None:
subquery.aggregation_pipeline = []
subquery.aggregation_pipeline.extend(
[
{
"$facet": {
"group": [
wrapping_result_pipeline = [
{
"$facet": {
"group": [
{
"$group": {
"_id": None,
"tmp_name": {
"$addToSet": expr.as_mql(subquery_compiler, connection)
},
}
}
]
}
},
{
"$project": {
field_name: {
"$ifNull": [
{
"$group": {
"_id": None,
"tmp_name": {
"$addToSet": expr.as_mql(subquery_compiler, connection)
},
"$getField": {
"input": {"$arrayElemAt": ["$group", 0]},
"field": "tmp_name",
}
}
},
[],
]
}
},
{
"$project": {
field_name: {
"$ifNull": [
{
"$getField": {
"input": {"$arrayElemAt": ["$group", 0]},
"field": "tmp_name",
}
},
[],
]
}
}
},
]
)
}
},
]
# If the subquery is a combinator, wrap the result at the end of the
# combinator pipeline...
if subquery.query.combinator:
subquery.combinator_pipeline.extend(wrapping_result_pipeline)
# ... otherwise put at the end of subquery's pipeline.
else:
subquery.aggregation_pipeline.extend(wrapping_result_pipeline)
# Erase project_fields since the required value is projected above.
subquery.project_fields = None
compiler.subqueries.append(subquery)
Expand Down
5 changes: 0 additions & 5 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# Connection creation doesn't follow the usual Django API.
"backends.tests.ThreadTests.test_pass_connection_between_threads",
"backends.tests.ThreadTests.test_default_connection_thread_local",
# Union as subquery is not mapping the parent parameter and collections:
# https://github.com/mongodb-labs/django-mongodb/issues/156
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_subquery_related_outerref",
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_subquery",
"queries.test_qs_combinators.QuerySetSetOperationTests.test_union_in_with_ordering",
# ObjectId type mismatch in a subquery:
# https://github.com/mongodb-labs/django-mongodb/issues/161
"queries.tests.RelatedLookupTypeTests.test_values_queryset_lookup",
Expand Down