Skip to content

Commit

Permalink
add support for partial indexes
Browse files Browse the repository at this point in the history
Co-authored-by: Emanuel Lupi <[email protected]>
  • Loading branch information
timgraham and WaVEV committed Nov 21, 2024
1 parent c49e345 commit 76ca293
Show file tree
Hide file tree
Showing 10 changed files with 271 additions and 17 deletions.
1 change: 1 addition & 0 deletions .github/workflows/runtests.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
"get_or_create",
"i18n",
"indexes",
"indexes_",
"inline_formsets",
"introspection",
"invalid_models_tests",
Expand Down
2 changes: 2 additions & 0 deletions django_mongodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from .expressions import register_expressions # noqa: E402
from .fields import register_fields # noqa: E402
from .functions import register_functions # noqa: E402
from .indexes import register_indexes # noqa: E402
from .lookups import register_lookups # noqa: E402
from .query import register_nodes # noqa: E402

register_aggregates()
register_expressions()
register_fields()
register_functions()
register_indexes()
register_lookups()
register_nodes()
23 changes: 19 additions & 4 deletions django_mongodb/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,14 +436,29 @@ def project_field(column):
@cached_property
def base_table(self):
return next(
v
for k, v in self.query.alias_map.items()
if isinstance(v, BaseTable) and self.query.alias_refcount[k]
(
v
for k, v in self.query.alias_map.items()
if isinstance(v, BaseTable) and self.query.alias_refcount[k]
),
None,
)

@cached_property
def collection_name(self):
return self.base_table.table_alias or self.base_table.table_name
if self.base_table:
return self.base_table.table_alias or self.base_table.table_name
# Use a dummy collection if the query doesn't specify a table.
# For Constraint.validate() with a condition,
# SELECT 1 WHERE EXISTS(subquery checking if a constraint is violated)
# is translated as:
# [{"$facet": {"__null": []}},
# {"$lookup": {"the subquery"}},
# {"$match": {"condition to check from the subquery"}}]
query = self.query_class(self)
query.aggregation_pipeline = [{"$facet": {"__null": []}}]
self.subqueries.insert(0, query)
return "__null"

@cached_property
def collection(self):
Expand Down
7 changes: 5 additions & 2 deletions django_mongodb/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,9 @@ def case(self, compiler, connection):
def col(self, compiler, connection): # noqa: ARG001
# If the column is part of a subquery and belongs to one of the parent
# queries, it will be stored for reference using $let in a $lookup stage.
if (
# If the query is built with `alias_cols=False`, treat the column as
# belonging to the current collection.
if self.alias is not None and (
self.alias not in compiler.query.alias_refcount
or compiler.query.alias_refcount[self.alias] == 0
):
Expand All @@ -64,7 +66,8 @@ def col(self, compiler, connection): # noqa: ARG001
compiler.column_indices[self] = index
return f"$${compiler.PARENT_FIELD_TEMPLATE.format(index)}"
# Add the column's collection's alias for columns in joined collections.
prefix = f"{self.alias}." if self.alias != compiler.collection_name else ""
has_alias = self.alias and self.alias != compiler.collection_name
prefix = f"{self.alias}." if has_alias else ""
return f"${prefix}{self.target.column}"


Expand Down
20 changes: 15 additions & 5 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import operator

from django.db.backends.base.features import BaseDatabaseFeatures
from django.utils.functional import cached_property

Expand All @@ -21,12 +23,14 @@ class DatabaseFeatures(BaseDatabaseFeatures):
supports_expression_indexes = False
supports_foreign_keys = False
supports_ignore_conflicts = False
# Before MongoDB 6.0, $in cannot be used in partialFilterExpression.
supports_in_index_operator = property(operator.attrgetter("is_mongodb_6_0"))
# Before MongoDB 6.0, $or cannot be used in partialFilterExpression.
supports_or_index_operator = property(operator.attrgetter("is_mongodb_6_0"))
supports_json_field_contains = False
# BSON Date type doesn't support microsecond precision.
supports_microsecond_precision = False
supports_paramstyle_pyformat = False
# Not implemented.
supports_partial_indexes = False
supports_select_difference = False
supports_select_intersection = False
supports_sequence_reset = False
Expand Down Expand Up @@ -91,11 +95,16 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null",
"expressions.tests.ExpressionOperatorTests.test_lefthand_transformed_field_bitwise_or",
}
_django_test_expected_failures_partial_expression_in = {
"schema.tests.SchemaTests.test_remove_ignored_unique_constraint_not_create_fk_index",
}

@cached_property
def django_test_expected_failures(self):
expected_failures = super().django_test_expected_failures
expected_failures.update(self._django_test_expected_failures)
if not self.supports_in_index_operator:
expected_failures.update(self._django_test_expected_failures_partial_expression_in)
if not self.is_mongodb_6_3:
expected_failures.update(self._django_test_expected_failures_bitwise)
return expected_failures
Expand Down Expand Up @@ -560,9 +569,6 @@ def django_test_expected_failures(self):
# Probably something to do with lack of transaction support.
"migration_test_data_persistence.tests.MigrationDataNormalPersistenceTestCase.test_persistence",
},
"Partial indexes to be supported.": {
"indexes.tests.PartialIndexConditionIgnoredTests.test_condition_ignored",
},
"Database caching not implemented.": {
"cache.tests.CreateCacheTableForDBCacheTests",
"cache.tests.DBCacheTests",
Expand Down Expand Up @@ -597,6 +603,10 @@ def django_test_expected_failures(self):
},
}

@cached_property
def is_mongodb_6_0(self):
return self.connection.get_database_version() >= (6, 0)

@cached_property
def is_mongodb_6_3(self):
return self.connection.get_database_version() >= (6, 3)
74 changes: 74 additions & 0 deletions django_mongodb/indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from django.db import NotSupportedError
from django.db.models import Index
from django.db.models.fields.related_lookups import In
from django.db.models.lookups import BuiltinLookup
from django.db.models.sql.query import Query
from django.db.models.sql.where import AND, XOR, WhereNode

from .query_utils import process_rhs

MONGO_INDEX_OPERATORS = {
"exact": "$eq",
"gt": "$gt",
"gte": "$gte",
"lt": "$lt",
"lte": "$lte",
"in": "$in",
}


def _get_condition_mql(self, model, schema_editor):
"""Analogous to Index._get_condition_sql()."""
query = Query(model=model, alias_cols=False)
where = query.build_where(self.condition)
compiler = query.get_compiler(connection=schema_editor.connection)
return where.as_mql_idx(compiler, schema_editor.connection)


def builtin_lookup_idx(self, compiler, connection):
lhs_mql = self.lhs.target.column
value = process_rhs(self, compiler, connection)
try:
operator = MONGO_INDEX_OPERATORS[self.lookup_name]
except KeyError:
raise NotSupportedError(
f"MongoDB does not support the '{self.lookup_name}' lookup in indexes."
) from None
return {lhs_mql: {operator: value}}


def in_idx(self, compiler, connection):
if not connection.features.supports_in_index_operator:
raise NotSupportedError("MongoDB < 6.0 does not support the 'in' lookup in indexes.")
return builtin_lookup_idx(self, compiler, connection)


def where_node_idx(self, compiler, connection):
if self.connector == AND:
operator = "$and"
elif self.connector == XOR:
raise NotSupportedError("MongoDB does not support the '^' operator lookup in indexes.")
else:
if not connection.features.supports_in_index_operator:
raise NotSupportedError("MongoDB < 6.0 does not support the '|' operator in indexes.")
operator = "$or"
if self.negated:
raise NotSupportedError("MongoDB does not support the '~' operator in indexes.")
children_mql = []
for child in self.children:
mql = child.as_mql_idx(compiler, connection)
children_mql.append(mql)
if len(children_mql) == 1:
mql = children_mql[0]
elif len(children_mql) > 1:
mql = {operator: children_mql}
else:
mql = {}
return mql


def register_indexes():
BuiltinLookup.as_mql_idx = builtin_lookup_idx
In.as_mql_idx = in_idx
Index._get_condition_mql = _get_condition_mql
WhereNode.as_mql_idx = where_node_idx
28 changes: 22 additions & 6 deletions django_mongodb/schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import defaultdict

from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.models import Index, UniqueConstraint
from pymongo import ASCENDING, DESCENDING
Expand Down Expand Up @@ -166,17 +168,23 @@ def add_index(self, model, index, field=None, unique=False):
if index.contains_expressions:
return
kwargs = {}
filter_expression = defaultdict(dict)
if index.condition:
filter_expression.update(index._get_condition_mql(model, self))
if unique:
filter_expression = {}
kwargs["unique"] = True
# Indexing on $type matches the value of most SQL databases by
# allowing multiple null values for the unique constraint.
if field:
filter_expression[field.column] = {"$type": field.db_type(self.connection)}
filter_expression[field.column].update({"$type": field.db_type(self.connection)})
else:
for field_name, _ in index.fields_orders:
field_ = model._meta.get_field(field_name)
filter_expression[field_.column] = {"$type": field_.db_type(self.connection)}
kwargs = {"partialFilterExpression": filter_expression, "unique": True}
filter_expression[field_.column].update(
{"$type": field_.db_type(self.connection)}
)
if filter_expression:
kwargs["partialFilterExpression"] = filter_expression
index_orders = (
[(field.column, ASCENDING)]
if field
Expand Down Expand Up @@ -260,7 +268,11 @@ def add_constraint(self, model, constraint, field=None):
expressions=constraint.expressions,
nulls_distinct=constraint.nulls_distinct,
):
idx = Index(fields=constraint.fields, name=constraint.name)
idx = Index(
fields=constraint.fields,
name=constraint.name,
condition=constraint.condition,
)
self.add_index(model, idx, field=field, unique=True)

def _add_field_unique(self, model, field):
Expand All @@ -276,7 +288,11 @@ def remove_constraint(self, model, constraint):
expressions=constraint.expressions,
nulls_distinct=constraint.nulls_distinct,
):
idx = Index(fields=constraint.fields, name=constraint.name)
idx = Index(
fields=constraint.fields,
name=constraint.name,
condition=constraint.condition,
)
self.remove_index(model, idx)

def _remove_field_unique(self, model, field):
Expand Down
Empty file added tests/indexes_/__init__.py
Empty file.
7 changes: 7 additions & 0 deletions tests/indexes_/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from django.db import models


class Article(models.Model):
headline = models.CharField(max_length=100)
number = models.IntegerField()
body = models.TextField()
Loading

0 comments on commit 76ca293

Please sign in to comment.