From 78a8619239bf1683575e42b53482926ef67523b4 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Wed, 25 Sep 2024 10:33:48 -0400 Subject: [PATCH] fix incorrect GenericRelation joining By adding support for Field.get_extra_restriction(). --- django_mongodb/features.py | 4 +--- django_mongodb/query.py | 39 ++++++++++++++++++++++++++++++++++++-- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/django_mongodb/features.py b/django_mongodb/features.py index c001de11..994a83df 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -68,9 +68,7 @@ class DatabaseFeatures(BaseDatabaseFeatures): "aggregation.tests.AggregateTestCase.test_reverse_fkey_annotate", "aggregation_regress.tests.AggregationTests.test_annotation_disjunction", "aggregation_regress.tests.AggregationTests.test_decimal_aggregate_annotation_filter", - # Incorrect JOIN with GenericRelation gives incorrect results. - "aggregation_regress.tests.AggregationTests.test_aggregation_with_generic_reverse_relation", - "generic_relations.tests.GenericRelationsTests.test_queries_content_type_restriction", + # Wrong result for GenericRelation annotation. "generic_relations_regress.tests.GenericRelationTests.test_annotate", # subclasses of BaseDatabaseWrapper may require an is_usable() method "backends.tests.BackendTestCase.test_is_usable_after_database_disconnects", diff --git a/django_mongodb/query.py b/django_mongodb/query.py index 4c8b4639..fb732c01 100644 --- a/django_mongodb/query.py +++ b/django_mongodb/query.py @@ -3,7 +3,7 @@ from django.core.exceptions import EmptyResultSet, FullResultSet from django.db import DatabaseError, IntegrityError, NotSupportedError -from django.db.models.expressions import Case, When +from django.db.models.expressions import Case, Col, When from django.db.models.functions import Mod from django.db.models.lookups import Exact from django.db.models.sql.constants import INNER @@ -105,6 +105,7 @@ def join(self, compiler, connection): lhs_fields = [] rhs_fields = [] # Add a join condition for each pair of joining fields. + parent_template = "parent__field__" for lhs, rhs in self.join_fields: lhs, rhs = connection.ops.prepare_join_on_clause( self.parent_alias, lhs, compiler.collection_name, rhs @@ -113,8 +114,41 @@ def join(self, compiler, connection): # In the lookup stage, the reference to this column doesn't include # the collection name. rhs_fields.append(rhs.as_mql(compiler, connection)) + # Handle any join conditions besides matching field pairs. + extra = self.join_field.get_extra_restriction(self.table_alias, self.parent_alias) + if extra: + columns = [] + for expr in extra.leaves(): + # Determine whether the column needs to be transformed or rerouted + # as part of the subquery. + for hand_side in ["lhs", "rhs"]: + hand_side_value = getattr(expr, hand_side, None) + if isinstance(hand_side_value, Col): + # If the column is not part of the joined table, add it to + # lhs_fields. + if hand_side_value.alias != self.table_name: + pos = len(lhs_fields) + lhs_fields.append(expr.lhs.as_mql(compiler, connection)) + else: + pos = None + columns.append((hand_side_value, pos)) + # Replace columns in the extra conditions with new column references + # based on their rerouted positions in the join pipeline. + replacements = {} + for col, parent_pos in columns: + column_target = Col(compiler.collection_name, expr.output_field.__class__()) + if parent_pos is not None: + target_col = f"${parent_template}{parent_pos}" + column_target.target.db_column = target_col + column_target.target.set_attributes_from_name(target_col) + else: + column_target.target = col.target + replacements[col] = column_target + # Apply the transformed expressions in the extra condition. + extra_condition = [extra.replace_expressions(replacements).as_mql(compiler, connection)] + else: + extra_condition = [] - parent_template = "parent__field__" lookup_pipeline = [ { "$lookup": { @@ -140,6 +174,7 @@ def join(self, compiler, connection): {"$eq": [f"$${parent_template}{i}", field]} for i, field in enumerate(rhs_fields) ] + + extra_condition } } }