Skip to content

Commit 76ca293

Browse files
timgrahamWaVEV
andcommitted
add support for partial indexes
Co-authored-by: Emanuel Lupi <[email protected]>
1 parent c49e345 commit 76ca293

File tree

10 files changed

+271
-17
lines changed

10 files changed

+271
-17
lines changed

.github/workflows/runtests.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
"get_or_create",
6767
"i18n",
6868
"indexes",
69+
"indexes_",
6970
"inline_formsets",
7071
"introspection",
7172
"invalid_models_tests",

django_mongodb/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,14 @@
1010
from .expressions import register_expressions # noqa: E402
1111
from .fields import register_fields # noqa: E402
1212
from .functions import register_functions # noqa: E402
13+
from .indexes import register_indexes # noqa: E402
1314
from .lookups import register_lookups # noqa: E402
1415
from .query import register_nodes # noqa: E402
1516

1617
register_aggregates()
1718
register_expressions()
1819
register_fields()
1920
register_functions()
21+
register_indexes()
2022
register_lookups()
2123
register_nodes()

django_mongodb/compiler.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -436,14 +436,29 @@ def project_field(column):
436436
@cached_property
437437
def base_table(self):
438438
return next(
439-
v
440-
for k, v in self.query.alias_map.items()
441-
if isinstance(v, BaseTable) and self.query.alias_refcount[k]
439+
(
440+
v
441+
for k, v in self.query.alias_map.items()
442+
if isinstance(v, BaseTable) and self.query.alias_refcount[k]
443+
),
444+
None,
442445
)
443446

444447
@cached_property
445448
def collection_name(self):
446-
return self.base_table.table_alias or self.base_table.table_name
449+
if self.base_table:
450+
return self.base_table.table_alias or self.base_table.table_name
451+
# Use a dummy collection if the query doesn't specify a table.
452+
# For Constraint.validate() with a condition,
453+
# SELECT 1 WHERE EXISTS(subquery checking if a constraint is violated)
454+
# is translated as:
455+
# [{"$facet": {"__null": []}},
456+
# {"$lookup": {"the subquery"}},
457+
# {"$match": {"condition to check from the subquery"}}]
458+
query = self.query_class(self)
459+
query.aggregation_pipeline = [{"$facet": {"__null": []}}]
460+
self.subqueries.insert(0, query)
461+
return "__null"
447462

448463
@cached_property
449464
def collection(self):

django_mongodb/expressions.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,9 @@ def case(self, compiler, connection):
5353
def col(self, compiler, connection): # noqa: ARG001
5454
# If the column is part of a subquery and belongs to one of the parent
5555
# queries, it will be stored for reference using $let in a $lookup stage.
56-
if (
56+
# If the query is built with `alias_cols=False`, treat the column as
57+
# belonging to the current collection.
58+
if self.alias is not None and (
5759
self.alias not in compiler.query.alias_refcount
5860
or compiler.query.alias_refcount[self.alias] == 0
5961
):
@@ -64,7 +66,8 @@ def col(self, compiler, connection): # noqa: ARG001
6466
compiler.column_indices[self] = index
6567
return f"$${compiler.PARENT_FIELD_TEMPLATE.format(index)}"
6668
# Add the column's collection's alias for columns in joined collections.
67-
prefix = f"{self.alias}." if self.alias != compiler.collection_name else ""
69+
has_alias = self.alias and self.alias != compiler.collection_name
70+
prefix = f"{self.alias}." if has_alias else ""
6871
return f"${prefix}{self.target.column}"
6972

7073

django_mongodb/features.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import operator
2+
13
from django.db.backends.base.features import BaseDatabaseFeatures
24
from django.utils.functional import cached_property
35

@@ -21,12 +23,14 @@ class DatabaseFeatures(BaseDatabaseFeatures):
2123
supports_expression_indexes = False
2224
supports_foreign_keys = False
2325
supports_ignore_conflicts = False
26+
# Before MongoDB 6.0, $in cannot be used in partialFilterExpression.
27+
supports_in_index_operator = property(operator.attrgetter("is_mongodb_6_0"))
28+
# Before MongoDB 6.0, $or cannot be used in partialFilterExpression.
29+
supports_or_index_operator = property(operator.attrgetter("is_mongodb_6_0"))
2430
supports_json_field_contains = False
2531
# BSON Date type doesn't support microsecond precision.
2632
supports_microsecond_precision = False
2733
supports_paramstyle_pyformat = False
28-
# Not implemented.
29-
supports_partial_indexes = False
3034
supports_select_difference = False
3135
supports_select_intersection = False
3236
supports_sequence_reset = False
@@ -91,11 +95,16 @@ class DatabaseFeatures(BaseDatabaseFeatures):
9195
"expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null",
9296
"expressions.tests.ExpressionOperatorTests.test_lefthand_transformed_field_bitwise_or",
9397
}
98+
_django_test_expected_failures_partial_expression_in = {
99+
"schema.tests.SchemaTests.test_remove_ignored_unique_constraint_not_create_fk_index",
100+
}
94101

95102
@cached_property
96103
def django_test_expected_failures(self):
97104
expected_failures = super().django_test_expected_failures
98105
expected_failures.update(self._django_test_expected_failures)
106+
if not self.supports_in_index_operator:
107+
expected_failures.update(self._django_test_expected_failures_partial_expression_in)
99108
if not self.is_mongodb_6_3:
100109
expected_failures.update(self._django_test_expected_failures_bitwise)
101110
return expected_failures
@@ -560,9 +569,6 @@ def django_test_expected_failures(self):
560569
# Probably something to do with lack of transaction support.
561570
"migration_test_data_persistence.tests.MigrationDataNormalPersistenceTestCase.test_persistence",
562571
},
563-
"Partial indexes to be supported.": {
564-
"indexes.tests.PartialIndexConditionIgnoredTests.test_condition_ignored",
565-
},
566572
"Database caching not implemented.": {
567573
"cache.tests.CreateCacheTableForDBCacheTests",
568574
"cache.tests.DBCacheTests",
@@ -597,6 +603,10 @@ def django_test_expected_failures(self):
597603
},
598604
}
599605

606+
@cached_property
607+
def is_mongodb_6_0(self):
608+
return self.connection.get_database_version() >= (6, 0)
609+
600610
@cached_property
601611
def is_mongodb_6_3(self):
602612
return self.connection.get_database_version() >= (6, 3)

django_mongodb/indexes.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
from django.db import NotSupportedError
2+
from django.db.models import Index
3+
from django.db.models.fields.related_lookups import In
4+
from django.db.models.lookups import BuiltinLookup
5+
from django.db.models.sql.query import Query
6+
from django.db.models.sql.where import AND, XOR, WhereNode
7+
8+
from .query_utils import process_rhs
9+
10+
MONGO_INDEX_OPERATORS = {
11+
"exact": "$eq",
12+
"gt": "$gt",
13+
"gte": "$gte",
14+
"lt": "$lt",
15+
"lte": "$lte",
16+
"in": "$in",
17+
}
18+
19+
20+
def _get_condition_mql(self, model, schema_editor):
21+
"""Analogous to Index._get_condition_sql()."""
22+
query = Query(model=model, alias_cols=False)
23+
where = query.build_where(self.condition)
24+
compiler = query.get_compiler(connection=schema_editor.connection)
25+
return where.as_mql_idx(compiler, schema_editor.connection)
26+
27+
28+
def builtin_lookup_idx(self, compiler, connection):
29+
lhs_mql = self.lhs.target.column
30+
value = process_rhs(self, compiler, connection)
31+
try:
32+
operator = MONGO_INDEX_OPERATORS[self.lookup_name]
33+
except KeyError:
34+
raise NotSupportedError(
35+
f"MongoDB does not support the '{self.lookup_name}' lookup in indexes."
36+
) from None
37+
return {lhs_mql: {operator: value}}
38+
39+
40+
def in_idx(self, compiler, connection):
41+
if not connection.features.supports_in_index_operator:
42+
raise NotSupportedError("MongoDB < 6.0 does not support the 'in' lookup in indexes.")
43+
return builtin_lookup_idx(self, compiler, connection)
44+
45+
46+
def where_node_idx(self, compiler, connection):
47+
if self.connector == AND:
48+
operator = "$and"
49+
elif self.connector == XOR:
50+
raise NotSupportedError("MongoDB does not support the '^' operator lookup in indexes.")
51+
else:
52+
if not connection.features.supports_in_index_operator:
53+
raise NotSupportedError("MongoDB < 6.0 does not support the '|' operator in indexes.")
54+
operator = "$or"
55+
if self.negated:
56+
raise NotSupportedError("MongoDB does not support the '~' operator in indexes.")
57+
children_mql = []
58+
for child in self.children:
59+
mql = child.as_mql_idx(compiler, connection)
60+
children_mql.append(mql)
61+
if len(children_mql) == 1:
62+
mql = children_mql[0]
63+
elif len(children_mql) > 1:
64+
mql = {operator: children_mql}
65+
else:
66+
mql = {}
67+
return mql
68+
69+
70+
def register_indexes():
71+
BuiltinLookup.as_mql_idx = builtin_lookup_idx
72+
In.as_mql_idx = in_idx
73+
Index._get_condition_mql = _get_condition_mql
74+
WhereNode.as_mql_idx = where_node_idx

django_mongodb/schema.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import defaultdict
2+
13
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
24
from django.db.models import Index, UniqueConstraint
35
from pymongo import ASCENDING, DESCENDING
@@ -166,17 +168,23 @@ def add_index(self, model, index, field=None, unique=False):
166168
if index.contains_expressions:
167169
return
168170
kwargs = {}
171+
filter_expression = defaultdict(dict)
172+
if index.condition:
173+
filter_expression.update(index._get_condition_mql(model, self))
169174
if unique:
170-
filter_expression = {}
175+
kwargs["unique"] = True
171176
# Indexing on $type matches the value of most SQL databases by
172177
# allowing multiple null values for the unique constraint.
173178
if field:
174-
filter_expression[field.column] = {"$type": field.db_type(self.connection)}
179+
filter_expression[field.column].update({"$type": field.db_type(self.connection)})
175180
else:
176181
for field_name, _ in index.fields_orders:
177182
field_ = model._meta.get_field(field_name)
178-
filter_expression[field_.column] = {"$type": field_.db_type(self.connection)}
179-
kwargs = {"partialFilterExpression": filter_expression, "unique": True}
183+
filter_expression[field_.column].update(
184+
{"$type": field_.db_type(self.connection)}
185+
)
186+
if filter_expression:
187+
kwargs["partialFilterExpression"] = filter_expression
180188
index_orders = (
181189
[(field.column, ASCENDING)]
182190
if field
@@ -260,7 +268,11 @@ def add_constraint(self, model, constraint, field=None):
260268
expressions=constraint.expressions,
261269
nulls_distinct=constraint.nulls_distinct,
262270
):
263-
idx = Index(fields=constraint.fields, name=constraint.name)
271+
idx = Index(
272+
fields=constraint.fields,
273+
name=constraint.name,
274+
condition=constraint.condition,
275+
)
264276
self.add_index(model, idx, field=field, unique=True)
265277

266278
def _add_field_unique(self, model, field):
@@ -276,7 +288,11 @@ def remove_constraint(self, model, constraint):
276288
expressions=constraint.expressions,
277289
nulls_distinct=constraint.nulls_distinct,
278290
):
279-
idx = Index(fields=constraint.fields, name=constraint.name)
291+
idx = Index(
292+
fields=constraint.fields,
293+
name=constraint.name,
294+
condition=constraint.condition,
295+
)
280296
self.remove_index(model, idx)
281297

282298
def _remove_field_unique(self, model, field):

tests/indexes_/__init__.py

Whitespace-only changes.

tests/indexes_/models.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from django.db import models
2+
3+
4+
class Article(models.Model):
5+
headline = models.CharField(max_length=100)
6+
number = models.IntegerField()
7+
body = models.TextField()

0 commit comments

Comments
 (0)