From 84d0f04ea84f1e5ca03011508dc41523af470cd1 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Tue, 19 Nov 2024 14:11:47 -0500 Subject: [PATCH] add support for partial indexes Co-authored-by: Emanuel Lupi --- .github/workflows/runtests.py | 1 + django_mongodb/__init__.py | 2 + django_mongodb/compiler.py | 25 +++++- django_mongodb/expressions.py | 7 +- django_mongodb/features.py | 20 +++-- django_mongodb/indexes.py | 74 ++++++++++++++++++ django_mongodb/schema.py | 28 +++++-- tests/indexes_/__init__.py | 0 tests/indexes_/models.py | 7 ++ tests/indexes_/test_condition.py | 126 +++++++++++++++++++++++++++++++ 10 files changed, 273 insertions(+), 17 deletions(-) create mode 100644 django_mongodb/indexes.py create mode 100644 tests/indexes_/__init__.py create mode 100644 tests/indexes_/models.py create mode 100644 tests/indexes_/test_condition.py diff --git a/.github/workflows/runtests.py b/.github/workflows/runtests.py index e5c26c34..3d01657d 100755 --- a/.github/workflows/runtests.py +++ b/.github/workflows/runtests.py @@ -66,6 +66,7 @@ "get_or_create", "i18n", "indexes", + "indexes_", "inline_formsets", "introspection", "invalid_models_tests", diff --git a/django_mongodb/__init__.py b/django_mongodb/__init__.py index 7994999d..31d8f2d3 100644 --- a/django_mongodb/__init__.py +++ b/django_mongodb/__init__.py @@ -10,6 +10,7 @@ from .expressions import register_expressions # noqa: E402 from .fields import register_fields # noqa: E402 from .functions import register_functions # noqa: E402 +from .indexes import register_indexes # noqa: E402 from .lookups import register_lookups # noqa: E402 from .query import register_nodes # noqa: E402 @@ -17,5 +18,6 @@ register_expressions() register_fields() register_functions() +register_indexes() register_lookups() register_nodes() diff --git a/django_mongodb/compiler.py b/django_mongodb/compiler.py index 73d01beb..ea130567 100644 --- a/django_mongodb/compiler.py +++ b/django_mongodb/compiler.py @@ -436,14 +436,31 @@ def project_field(column): @cached_property def base_table(self): return next( - v - for k, v in self.query.alias_map.items() - if isinstance(v, BaseTable) and self.query.alias_refcount[k] + ( + v + for k, v in self.query.alias_map.items() + if isinstance(v, BaseTable) and self.query.alias_refcount[k] + ), + None, ) @cached_property def collection_name(self): - return self.base_table.table_alias or self.base_table.table_name + if self.base_table: + return self.base_table.table_alias or self.base_table.table_name + # Use a dummy collection if the query doesn't specify a table. + # For Constraint.validate(): + # SELECT 1 WHERE EXISTS(subquery checking if a constraint is violated) + # is translated as: + # [{"$facet": {"__null": []}}, + # {"$lookup": {"the subquery"}}, + # {"$match": {"condition to check from the subquery"}}] + query = self.query_class(self) + # The "__null" document is a placeholder so that the subquery has + # somewhere to return its results. + query.aggregation_pipeline = [{"$facet": {"__null": []}}] + self.subqueries.insert(0, query) + return "__null" @cached_property def collection(self): diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py index 626f7711..bf4bf23b 100644 --- a/django_mongodb/expressions.py +++ b/django_mongodb/expressions.py @@ -53,7 +53,9 @@ def case(self, compiler, connection): def col(self, compiler, connection): # noqa: ARG001 # If the column is part of a subquery and belongs to one of the parent # queries, it will be stored for reference using $let in a $lookup stage. - if ( + # If the query is built with `alias_cols=False`, treat the column as + # belonging to the current collection. + if self.alias is not None and ( self.alias not in compiler.query.alias_refcount or compiler.query.alias_refcount[self.alias] == 0 ): @@ -64,7 +66,8 @@ def col(self, compiler, connection): # noqa: ARG001 compiler.column_indices[self] = index return f"$${compiler.PARENT_FIELD_TEMPLATE.format(index)}" # Add the column's collection's alias for columns in joined collections. - prefix = f"{self.alias}." if self.alias != compiler.collection_name else "" + has_alias = self.alias and self.alias != compiler.collection_name + prefix = f"{self.alias}." if has_alias else "" return f"${prefix}{self.target.column}" diff --git a/django_mongodb/features.py b/django_mongodb/features.py index d9cb8b89..1dbdbf79 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -1,3 +1,5 @@ +import operator + from django.db.backends.base.features import BaseDatabaseFeatures from django.utils.functional import cached_property @@ -21,12 +23,14 @@ class DatabaseFeatures(BaseDatabaseFeatures): supports_expression_indexes = False supports_foreign_keys = False supports_ignore_conflicts = False + # Before MongoDB 6.0, $in cannot be used in partialFilterExpression. + supports_in_index_operator = property(operator.attrgetter("is_mongodb_6_0")) + # Before MongoDB 6.0, $or cannot be used in partialFilterExpression. + supports_or_index_operator = property(operator.attrgetter("is_mongodb_6_0")) supports_json_field_contains = False # BSON Date type doesn't support microsecond precision. supports_microsecond_precision = False supports_paramstyle_pyformat = False - # Not implemented. - supports_partial_indexes = False supports_select_difference = False supports_select_intersection = False supports_sequence_reset = False @@ -91,11 +95,16 @@ class DatabaseFeatures(BaseDatabaseFeatures): "expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null", "expressions.tests.ExpressionOperatorTests.test_lefthand_transformed_field_bitwise_or", } + _django_test_expected_failures_partial_expression_in = { + "schema.tests.SchemaTests.test_remove_ignored_unique_constraint_not_create_fk_index", + } @cached_property def django_test_expected_failures(self): expected_failures = super().django_test_expected_failures expected_failures.update(self._django_test_expected_failures) + if not self.supports_in_index_operator: + expected_failures.update(self._django_test_expected_failures_partial_expression_in) if not self.is_mongodb_6_3: expected_failures.update(self._django_test_expected_failures_bitwise) return expected_failures @@ -560,9 +569,6 @@ def django_test_expected_failures(self): # Probably something to do with lack of transaction support. "migration_test_data_persistence.tests.MigrationDataNormalPersistenceTestCase.test_persistence", }, - "Partial indexes to be supported.": { - "indexes.tests.PartialIndexConditionIgnoredTests.test_condition_ignored", - }, "Database caching not implemented.": { "cache.tests.CreateCacheTableForDBCacheTests", "cache.tests.DBCacheTests", @@ -597,6 +603,10 @@ def django_test_expected_failures(self): }, } + @cached_property + def is_mongodb_6_0(self): + return self.connection.get_database_version() >= (6, 0) + @cached_property def is_mongodb_6_3(self): return self.connection.get_database_version() >= (6, 3) diff --git a/django_mongodb/indexes.py b/django_mongodb/indexes.py new file mode 100644 index 00000000..0fc6646b --- /dev/null +++ b/django_mongodb/indexes.py @@ -0,0 +1,74 @@ +from django.db import NotSupportedError +from django.db.models import Index +from django.db.models.fields.related_lookups import In +from django.db.models.lookups import BuiltinLookup +from django.db.models.sql.query import Query +from django.db.models.sql.where import AND, XOR, WhereNode + +from .query_utils import process_rhs + +MONGO_INDEX_OPERATORS = { + "exact": "$eq", + "gt": "$gt", + "gte": "$gte", + "lt": "$lt", + "lte": "$lte", + "in": "$in", +} + + +def _get_condition_mql(self, model, schema_editor): + """Analogous to Index._get_condition_sql().""" + query = Query(model=model, alias_cols=False) + where = query.build_where(self.condition) + compiler = query.get_compiler(connection=schema_editor.connection) + return where.as_mql_idx(compiler, schema_editor.connection) + + +def builtin_lookup_idx(self, compiler, connection): + lhs_mql = self.lhs.target.column + value = process_rhs(self, compiler, connection) + try: + operator = MONGO_INDEX_OPERATORS[self.lookup_name] + except KeyError: + raise NotSupportedError( + f"MongoDB does not support the '{self.lookup_name}' lookup in indexes." + ) from None + return {lhs_mql: {operator: value}} + + +def in_idx(self, compiler, connection): + if not connection.features.supports_in_index_operator: + raise NotSupportedError("MongoDB < 6.0 does not support the 'in' lookup in indexes.") + return builtin_lookup_idx(self, compiler, connection) + + +def where_node_idx(self, compiler, connection): + if self.connector == AND: + operator = "$and" + elif self.connector == XOR: + raise NotSupportedError("MongoDB does not support the '^' operator lookup in indexes.") + else: + if not connection.features.supports_in_index_operator: + raise NotSupportedError("MongoDB < 6.0 does not support the '|' operator in indexes.") + operator = "$or" + if self.negated: + raise NotSupportedError("MongoDB does not support the '~' operator in indexes.") + children_mql = [] + for child in self.children: + mql = child.as_mql_idx(compiler, connection) + children_mql.append(mql) + if len(children_mql) == 1: + mql = children_mql[0] + elif len(children_mql) > 1: + mql = {operator: children_mql} + else: + mql = {} + return mql + + +def register_indexes(): + BuiltinLookup.as_mql_idx = builtin_lookup_idx + In.as_mql_idx = in_idx + Index._get_condition_mql = _get_condition_mql + WhereNode.as_mql_idx = where_node_idx diff --git a/django_mongodb/schema.py b/django_mongodb/schema.py index 8fd902ac..8fc18fea 100644 --- a/django_mongodb/schema.py +++ b/django_mongodb/schema.py @@ -1,3 +1,5 @@ +from collections import defaultdict + from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.models import Index, UniqueConstraint from pymongo import ASCENDING, DESCENDING @@ -166,17 +168,23 @@ def add_index(self, model, index, field=None, unique=False): if index.contains_expressions: return kwargs = {} + filter_expression = defaultdict(dict) + if index.condition: + filter_expression.update(index._get_condition_mql(model, self)) if unique: - filter_expression = {} + kwargs["unique"] = True # Indexing on $type matches the value of most SQL databases by # allowing multiple null values for the unique constraint. if field: - filter_expression[field.column] = {"$type": field.db_type(self.connection)} + filter_expression[field.column].update({"$type": field.db_type(self.connection)}) else: for field_name, _ in index.fields_orders: field_ = model._meta.get_field(field_name) - filter_expression[field_.column] = {"$type": field_.db_type(self.connection)} - kwargs = {"partialFilterExpression": filter_expression, "unique": True} + filter_expression[field_.column].update( + {"$type": field_.db_type(self.connection)} + ) + if filter_expression: + kwargs["partialFilterExpression"] = filter_expression index_orders = ( [(field.column, ASCENDING)] if field @@ -260,7 +268,11 @@ def add_constraint(self, model, constraint, field=None): expressions=constraint.expressions, nulls_distinct=constraint.nulls_distinct, ): - idx = Index(fields=constraint.fields, name=constraint.name) + idx = Index( + fields=constraint.fields, + name=constraint.name, + condition=constraint.condition, + ) self.add_index(model, idx, field=field, unique=True) def _add_field_unique(self, model, field): @@ -276,7 +288,11 @@ def remove_constraint(self, model, constraint): expressions=constraint.expressions, nulls_distinct=constraint.nulls_distinct, ): - idx = Index(fields=constraint.fields, name=constraint.name) + idx = Index( + fields=constraint.fields, + name=constraint.name, + condition=constraint.condition, + ) self.remove_index(model, idx) def _remove_field_unique(self, model, field): diff --git a/tests/indexes_/__init__.py b/tests/indexes_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/indexes_/models.py b/tests/indexes_/models.py new file mode 100644 index 00000000..acfeb581 --- /dev/null +++ b/tests/indexes_/models.py @@ -0,0 +1,7 @@ +from django.db import models + + +class Article(models.Model): + headline = models.CharField(max_length=100) + number = models.IntegerField() + body = models.TextField() diff --git a/tests/indexes_/test_condition.py b/tests/indexes_/test_condition.py new file mode 100644 index 00000000..f0d67b36 --- /dev/null +++ b/tests/indexes_/test_condition.py @@ -0,0 +1,126 @@ +import operator + +from django.db import NotSupportedError, connection +from django.db.models import Index, Q +from django.test import TestCase, skipIfDBFeature, skipUnlessDBFeature + +from .models import Article + + +class PartialIndexTests(TestCase): + def assertAddRemoveIndex(self, editor, model, index): + editor.add_index(index=index, model=model) + self.assertIn( + index.name, + connection.introspection.get_constraints( + cursor=None, + table_name=model._meta.db_table, + ), + ) + editor.remove_index(index=index, model=model) + + def test_not_supported(self): + msg = "MongoDB does not support the 'isnull' lookup in indexes." + with connection.schema_editor() as editor, self.assertRaisesMessage(NotSupportedError, msg): + Index( + name="test", + fields=["headline"], + condition=Q(pk__isnull=True), + )._get_condition_mql(Article, schema_editor=editor) + + def test_negated_not_supported(self): + msg = "MongoDB does not support the '~' operator in indexes." + with self.assertRaisesMessage(NotSupportedError, msg), connection.schema_editor() as editor: + Index( + name="test", + fields=["headline"], + condition=~Q(pk=True), + )._get_condition_mql(Article, schema_editor=editor) + + def test_xor_not_supported(self): + msg = "MongoDB does not support the '^' operator lookup in indexes." + with self.assertRaisesMessage(NotSupportedError, msg), connection.schema_editor() as editor: + Index( + name="test", + fields=["headline"], + condition=Q(pk=True) ^ Q(pk=False), + )._get_condition_mql(Article, schema_editor=editor) + + @skipIfDBFeature("supports_or_index_operator") + def test_or_not_supported(self): + msg = "MongoDB < 6.0 does not support the '|' operator in indexes." + with self.assertRaisesMessage(NotSupportedError, msg), connection.schema_editor() as editor: + Index( + name="test", + fields=["headline"], + condition=Q(pk=True) | Q(pk=False), + )._get_condition_mql(Article, schema_editor=editor) + + @skipIfDBFeature("supports_in_index_operator") + def test_in_not_supported(self): + msg = "MongoDB < 6.0 does not support the 'in' lookup in indexes." + with self.assertRaisesMessage(NotSupportedError, msg), connection.schema_editor() as editor: + Index( + name="test", + fields=["headline"], + condition=Q(pk__in=[True]), + )._get_condition_mql(Article, schema_editor=editor) + + def test_operations(self): + operators = ( + ("gt", "$gt"), + ("gte", "$gte"), + ("lt", "$lt"), + ("lte", "$lte"), + ) + for op, mongo_operator in operators: + with self.subTest(operator=op), connection.schema_editor() as editor: + index = Index( + name="test", + fields=["headline"], + condition=Q(**{f"number__{op}": 3}), + ) + self.assertEqual( + {"number": {mongo_operator: 3}}, + index._get_condition_mql(Article, schema_editor=editor), + ) + self.assertAddRemoveIndex(editor, Article, index) + + @skipUnlessDBFeature("supports_in_index_operator") + def test_composite_index(self): + with connection.schema_editor() as editor: + index = Index( + name="test", + fields=["headline"], + condition=Q(number__gte=3) & (Q(body__gt="test1") | Q(body__in=["A", "B"])), + ) + self.assertEqual( + index._get_condition_mql(Article, schema_editor=editor), + { + "$and": [ + {"number": {"$gte": 3}}, + {"$or": [{"body": {"$gt": "test1"}}, {"body": {"$in": ["A", "B"]}}]}, + ] + }, + ) + self.assertAddRemoveIndex(editor, Article, index) + + def test_composite_op_index(self): + operators = ( + (operator.or_, "$or"), + (operator.and_, "$and"), + ) + if not connection.features.supports_or_index_operator: + operators = operators[1:] + for op, mongo_operator in operators: + with self.subTest(operator=op), connection.schema_editor() as editor: + index = Index( + name="test", + fields=["headline"], + condition=op(Q(number__gte=3), Q(body__gt="test1")), + ) + self.assertEqual( + {mongo_operator: [{"number": {"$gte": 3}}, {"body": {"$gt": "test1"}}]}, + index._get_condition_mql(Article, schema_editor=editor), + ) + self.assertAddRemoveIndex(editor, Article, index)