From 99623eeb8f60b5d697164118452ab1d511d82215 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Mon, 14 Oct 2024 19:00:47 -0400 Subject: [PATCH] add support for partial indexes --- .github/workflows/test-python.yml | 1 + django_mongodb/__init__.py | 2 ++ django_mongodb/expressions.py | 3 ++- django_mongodb/features.py | 17 +++++++++++++++-- django_mongodb/indexes.py | 23 +++++++++++++++++++++++ django_mongodb/schema.py | 28 ++++++++++++++++++++++------ 6 files changed, 65 insertions(+), 9 deletions(-) create mode 100644 django_mongodb/indexes.py diff --git a/.github/workflows/test-python.yml b/.github/workflows/test-python.yml index befc7a0f..1f26a3ea 100644 --- a/.github/workflows/test-python.yml +++ b/.github/workflows/test-python.yml @@ -75,6 +75,7 @@ jobs: from_db_value generic_relations generic_relations_regress + indexes introspection known_related_objects lookup diff --git a/django_mongodb/__init__.py b/django_mongodb/__init__.py index 7994999d..31d8f2d3 100644 --- a/django_mongodb/__init__.py +++ b/django_mongodb/__init__.py @@ -10,6 +10,7 @@ from .expressions import register_expressions # noqa: E402 from .fields import register_fields # noqa: E402 from .functions import register_functions # noqa: E402 +from .indexes import register_indexes # noqa: E402 from .lookups import register_lookups # noqa: E402 from .query import register_nodes # noqa: E402 @@ -17,5 +18,6 @@ register_expressions() register_fields() register_functions() +register_indexes() register_lookups() register_nodes() diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py index a95e8b1c..854694ce 100644 --- a/django_mongodb/expressions.py +++ b/django_mongodb/expressions.py @@ -64,7 +64,8 @@ def col(self, compiler, connection): # noqa: ARG001 compiler.column_indices[self] = index return f"$${compiler.PARENT_FIELD_TEMPLATE.format(index)}" # Add the column's collection's alias for columns in joined collections. - prefix = f"{self.alias}." if self.alias != compiler.collection_name else "" + has_alias = self.alias and self.alias != compiler.collection_name + prefix = f"{self.alias}." if has_alias else "" return f"${prefix}{self.target.column}" diff --git a/django_mongodb/features.py b/django_mongodb/features.py index b17f3abe..9a06f62d 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -24,8 +24,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): # BSON Date type doesn't support microsecond precision. supports_microsecond_precision = False supports_paramstyle_pyformat = False - # Not implemented. - supports_partial_indexes = False supports_select_difference = False supports_select_intersection = False supports_sequence_reset = False @@ -58,6 +56,11 @@ class DatabaseFeatures(BaseDatabaseFeatures): "db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_timezone_applied_before_truncation", # Length of null considered zero rather than null. "db_functions.text.test_length.LengthTests.test_basic", + # Partial indexes don't support multiple conditions. It requires this + # backend to convert $aggregate MQL syntax to $find (or else just + # generate $find syntax in the first place). + "indexes.tests.PartialIndexTests.test_is_null_condition", + "indexes.tests.PartialIndexTests.test_multiple_conditions", # Unexpected alias_refcount in alias_map. "queries.tests.Queries1Tests.test_order_by_tables", # The $sum aggregation returns 0 instead of None for null. @@ -92,11 +95,17 @@ class DatabaseFeatures(BaseDatabaseFeatures): "expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null", "expressions.tests.ExpressionOperatorTests.test_lefthand_transformed_field_bitwise_or", } + # Before MongoDB 6.0, $in cannot be used in partialFilterExpression. + _django_test_expected_failures_partial_expression_in = { + "schema.tests.SchemaTests.test_remove_ignored_unique_constraint_not_create_fk_index", + } @cached_property def django_test_expected_failures(self): expected_failures = super().django_test_expected_failures expected_failures.update(self._django_test_expected_failures) + if not self.is_mongodb_6_0: + expected_failures.update(self._django_test_expected_failures_partial_expression_in) if not self.is_mongodb_6_3: expected_failures.update(self._django_test_expected_failures_bitwise) return expected_failures @@ -446,6 +455,10 @@ def django_test_expected_failures(self): }, } + @cached_property + def is_mongodb_6_0(self): + return self.connection.get_database_version() >= (6, 0) + @cached_property def is_mongodb_6_3(self): return self.connection.get_database_version() >= (6, 3) diff --git a/django_mongodb/indexes.py b/django_mongodb/indexes.py new file mode 100644 index 00000000..0e5f2e9b --- /dev/null +++ b/django_mongodb/indexes.py @@ -0,0 +1,23 @@ +from django.db.models import Index +from django.db.models.sql.query import Query + + +def _get_condition_mql(self, model, schema_editor): + """Analogous to Index._get_condition_sql().""" + query = Query(model=model, alias_cols=False) + where = query.build_where(self.condition) + compiler = query.get_compiler(connection=schema_editor.connection) + mql_ = where.as_mql(compiler, schema_editor.connection) + # Transform aggregate() query syntax into find() syntax. + mql = {} + for key in mql_: + col, value = mql_[key] + # multiple conditions don't work yet + if isinstance(col, dict): + return {} + mql[col.lstrip("$")] = {key: value} + return mql + + +def register_indexes(): + Index._get_condition_mql = _get_condition_mql diff --git a/django_mongodb/schema.py b/django_mongodb/schema.py index 8fd902ac..8fc18fea 100644 --- a/django_mongodb/schema.py +++ b/django_mongodb/schema.py @@ -1,3 +1,5 @@ +from collections import defaultdict + from django.db.backends.base.schema import BaseDatabaseSchemaEditor from django.db.models import Index, UniqueConstraint from pymongo import ASCENDING, DESCENDING @@ -166,17 +168,23 @@ def add_index(self, model, index, field=None, unique=False): if index.contains_expressions: return kwargs = {} + filter_expression = defaultdict(dict) + if index.condition: + filter_expression.update(index._get_condition_mql(model, self)) if unique: - filter_expression = {} + kwargs["unique"] = True # Indexing on $type matches the value of most SQL databases by # allowing multiple null values for the unique constraint. if field: - filter_expression[field.column] = {"$type": field.db_type(self.connection)} + filter_expression[field.column].update({"$type": field.db_type(self.connection)}) else: for field_name, _ in index.fields_orders: field_ = model._meta.get_field(field_name) - filter_expression[field_.column] = {"$type": field_.db_type(self.connection)} - kwargs = {"partialFilterExpression": filter_expression, "unique": True} + filter_expression[field_.column].update( + {"$type": field_.db_type(self.connection)} + ) + if filter_expression: + kwargs["partialFilterExpression"] = filter_expression index_orders = ( [(field.column, ASCENDING)] if field @@ -260,7 +268,11 @@ def add_constraint(self, model, constraint, field=None): expressions=constraint.expressions, nulls_distinct=constraint.nulls_distinct, ): - idx = Index(fields=constraint.fields, name=constraint.name) + idx = Index( + fields=constraint.fields, + name=constraint.name, + condition=constraint.condition, + ) self.add_index(model, idx, field=field, unique=True) def _add_field_unique(self, model, field): @@ -276,7 +288,11 @@ def remove_constraint(self, model, constraint): expressions=constraint.expressions, nulls_distinct=constraint.nulls_distinct, ): - idx = Index(fields=constraint.fields, name=constraint.name) + idx = Index( + fields=constraint.fields, + name=constraint.name, + condition=constraint.condition, + ) self.remove_index(model, idx) def _remove_field_unique(self, model, field):