Skip to content

Commit 99623ee

Browse files
committed
add support for partial indexes
1 parent 17d8a5d commit 99623ee

File tree

6 files changed

+65
-9
lines changed

6 files changed

+65
-9
lines changed

.github/workflows/test-python.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ jobs:
7575
from_db_value
7676
generic_relations
7777
generic_relations_regress
78+
indexes
7879
introspection
7980
known_related_objects
8081
lookup

django_mongodb/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
from .expressions import register_expressions # noqa: E402
1111
from .fields import register_fields # noqa: E402
1212
from .functions import register_functions # noqa: E402
13+
from .indexes import register_indexes # noqa: E402
1314
from .lookups import register_lookups # noqa: E402
1415
from .query import register_nodes # noqa: E402
1516

1617
register_aggregates()
1718
register_expressions()
1819
register_fields()
1920
register_functions()
21+
register_indexes()
2022
register_lookups()
2123
register_nodes()

django_mongodb/expressions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,8 @@ def col(self, compiler, connection): # noqa: ARG001
6464
compiler.column_indices[self] = index
6565
return f"$${compiler.PARENT_FIELD_TEMPLATE.format(index)}"
6666
# Add the column's collection's alias for columns in joined collections.
67-
prefix = f"{self.alias}." if self.alias != compiler.collection_name else ""
67+
has_alias = self.alias and self.alias != compiler.collection_name
68+
prefix = f"{self.alias}." if has_alias else ""
6869
return f"${prefix}{self.target.column}"
6970

7071

django_mongodb/features.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
2424
# BSON Date type doesn't support microsecond precision.
2525
supports_microsecond_precision = False
2626
supports_paramstyle_pyformat = False
27-
# Not implemented.
28-
supports_partial_indexes = False
2927
supports_select_difference = False
3028
supports_select_intersection = False
3129
supports_sequence_reset = False
@@ -58,6 +56,11 @@ class DatabaseFeatures(BaseDatabaseFeatures):
5856
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_timezone_applied_before_truncation",
5957
# Length of null considered zero rather than null.
6058
"db_functions.text.test_length.LengthTests.test_basic",
59+
# Partial indexes don't support multiple conditions. It requires this
60+
# backend to convert $aggregate MQL syntax to $find (or else just
61+
# generate $find syntax in the first place).
62+
"indexes.tests.PartialIndexTests.test_is_null_condition",
63+
"indexes.tests.PartialIndexTests.test_multiple_conditions",
6164
# Unexpected alias_refcount in alias_map.
6265
"queries.tests.Queries1Tests.test_order_by_tables",
6366
# The $sum aggregation returns 0 instead of None for null.
@@ -92,11 +95,17 @@ class DatabaseFeatures(BaseDatabaseFeatures):
9295
"expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null",
9396
"expressions.tests.ExpressionOperatorTests.test_lefthand_transformed_field_bitwise_or",
9497
}
98+
# Before MongoDB 6.0, $in cannot be used in partialFilterExpression.
99+
_django_test_expected_failures_partial_expression_in = {
100+
"schema.tests.SchemaTests.test_remove_ignored_unique_constraint_not_create_fk_index",
101+
}
95102

96103
@cached_property
97104
def django_test_expected_failures(self):
98105
expected_failures = super().django_test_expected_failures
99106
expected_failures.update(self._django_test_expected_failures)
107+
if not self.is_mongodb_6_0:
108+
expected_failures.update(self._django_test_expected_failures_partial_expression_in)
100109
if not self.is_mongodb_6_3:
101110
expected_failures.update(self._django_test_expected_failures_bitwise)
102111
return expected_failures
@@ -446,6 +455,10 @@ def django_test_expected_failures(self):
446455
},
447456
}
448457

458+
@cached_property
459+
def is_mongodb_6_0(self):
460+
return self.connection.get_database_version() >= (6, 0)
461+
449462
@cached_property
450463
def is_mongodb_6_3(self):
451464
return self.connection.get_database_version() >= (6, 3)

django_mongodb/indexes.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from django.db.models import Index
2+
from django.db.models.sql.query import Query
3+
4+
5+
def _get_condition_mql(self, model, schema_editor):
6+
"""Analogous to Index._get_condition_sql()."""
7+
query = Query(model=model, alias_cols=False)
8+
where = query.build_where(self.condition)
9+
compiler = query.get_compiler(connection=schema_editor.connection)
10+
mql_ = where.as_mql(compiler, schema_editor.connection)
11+
# Transform aggregate() query syntax into find() syntax.
12+
mql = {}
13+
for key in mql_:
14+
col, value = mql_[key]
15+
# multiple conditions don't work yet
16+
if isinstance(col, dict):
17+
return {}
18+
mql[col.lstrip("$")] = {key: value}
19+
return mql
20+
21+
22+
def register_indexes():
23+
Index._get_condition_mql = _get_condition_mql

django_mongodb/schema.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import defaultdict
2+
13
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
24
from django.db.models import Index, UniqueConstraint
35
from pymongo import ASCENDING, DESCENDING
@@ -166,17 +168,23 @@ def add_index(self, model, index, field=None, unique=False):
166168
if index.contains_expressions:
167169
return
168170
kwargs = {}
171+
filter_expression = defaultdict(dict)
172+
if index.condition:
173+
filter_expression.update(index._get_condition_mql(model, self))
169174
if unique:
170-
filter_expression = {}
175+
kwargs["unique"] = True
171176
# Indexing on $type matches the value of most SQL databases by
172177
# allowing multiple null values for the unique constraint.
173178
if field:
174-
filter_expression[field.column] = {"$type": field.db_type(self.connection)}
179+
filter_expression[field.column].update({"$type": field.db_type(self.connection)})
175180
else:
176181
for field_name, _ in index.fields_orders:
177182
field_ = model._meta.get_field(field_name)
178-
filter_expression[field_.column] = {"$type": field_.db_type(self.connection)}
179-
kwargs = {"partialFilterExpression": filter_expression, "unique": True}
183+
filter_expression[field_.column].update(
184+
{"$type": field_.db_type(self.connection)}
185+
)
186+
if filter_expression:
187+
kwargs["partialFilterExpression"] = filter_expression
180188
index_orders = (
181189
[(field.column, ASCENDING)]
182190
if field
@@ -260,7 +268,11 @@ def add_constraint(self, model, constraint, field=None):
260268
expressions=constraint.expressions,
261269
nulls_distinct=constraint.nulls_distinct,
262270
):
263-
idx = Index(fields=constraint.fields, name=constraint.name)
271+
idx = Index(
272+
fields=constraint.fields,
273+
name=constraint.name,
274+
condition=constraint.condition,
275+
)
264276
self.add_index(model, idx, field=field, unique=True)
265277

266278
def _add_field_unique(self, model, field):
@@ -276,7 +288,11 @@ def remove_constraint(self, model, constraint):
276288
expressions=constraint.expressions,
277289
nulls_distinct=constraint.nulls_distinct,
278290
):
279-
idx = Index(fields=constraint.fields, name=constraint.name)
291+
idx = Index(
292+
fields=constraint.fields,
293+
name=constraint.name,
294+
condition=constraint.condition,
295+
)
280296
self.remove_index(model, idx)
281297

282298
def _remove_field_unique(self, model, field):

0 commit comments

Comments
 (0)