Skip to content

Commit

Permalink
add support for partial indexes
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Oct 24, 2024
1 parent 17d8a5d commit 99623ee
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 9 deletions.
1 change: 1 addition & 0 deletions .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ jobs:
from_db_value
generic_relations
generic_relations_regress
indexes
introspection
known_related_objects
lookup
Expand Down
2 changes: 2 additions & 0 deletions django_mongodb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,14 @@
from .expressions import register_expressions # noqa: E402
from .fields import register_fields # noqa: E402
from .functions import register_functions # noqa: E402
from .indexes import register_indexes # noqa: E402
from .lookups import register_lookups # noqa: E402
from .query import register_nodes # noqa: E402

register_aggregates()
register_expressions()
register_fields()
register_functions()
register_indexes()
register_lookups()
register_nodes()
3 changes: 2 additions & 1 deletion django_mongodb/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def col(self, compiler, connection): # noqa: ARG001
compiler.column_indices[self] = index
return f"$${compiler.PARENT_FIELD_TEMPLATE.format(index)}"
# Add the column's collection's alias for columns in joined collections.
prefix = f"{self.alias}." if self.alias != compiler.collection_name else ""
has_alias = self.alias and self.alias != compiler.collection_name
prefix = f"{self.alias}." if has_alias else ""
return f"${prefix}{self.target.column}"


Expand Down
17 changes: 15 additions & 2 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ class DatabaseFeatures(BaseDatabaseFeatures):
# BSON Date type doesn't support microsecond precision.
supports_microsecond_precision = False
supports_paramstyle_pyformat = False
# Not implemented.
supports_partial_indexes = False
supports_select_difference = False
supports_select_intersection = False
supports_sequence_reset = False
Expand Down Expand Up @@ -58,6 +56,11 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"db_functions.datetime.test_extract_trunc.DateFunctionWithTimeZoneTests.test_trunc_timezone_applied_before_truncation",
# Length of null considered zero rather than null.
"db_functions.text.test_length.LengthTests.test_basic",
# Partial indexes don't support multiple conditions. It requires this
# backend to convert $aggregate MQL syntax to $find (or else just
# generate $find syntax in the first place).
"indexes.tests.PartialIndexTests.test_is_null_condition",
"indexes.tests.PartialIndexTests.test_multiple_conditions",
# Unexpected alias_refcount in alias_map.
"queries.tests.Queries1Tests.test_order_by_tables",
# The $sum aggregation returns 0 instead of None for null.
Expand Down Expand Up @@ -92,11 +95,17 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"expressions.tests.ExpressionOperatorTests.test_lefthand_bitwise_xor_right_null",
"expressions.tests.ExpressionOperatorTests.test_lefthand_transformed_field_bitwise_or",
}
# Before MongoDB 6.0, $in cannot be used in partialFilterExpression.
_django_test_expected_failures_partial_expression_in = {
"schema.tests.SchemaTests.test_remove_ignored_unique_constraint_not_create_fk_index",
}

@cached_property
def django_test_expected_failures(self):
expected_failures = super().django_test_expected_failures
expected_failures.update(self._django_test_expected_failures)
if not self.is_mongodb_6_0:
expected_failures.update(self._django_test_expected_failures_partial_expression_in)
if not self.is_mongodb_6_3:
expected_failures.update(self._django_test_expected_failures_bitwise)
return expected_failures
Expand Down Expand Up @@ -446,6 +455,10 @@ def django_test_expected_failures(self):
},
}

@cached_property
def is_mongodb_6_0(self):
return self.connection.get_database_version() >= (6, 0)

@cached_property
def is_mongodb_6_3(self):
return self.connection.get_database_version() >= (6, 3)
23 changes: 23 additions & 0 deletions django_mongodb/indexes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from django.db.models import Index
from django.db.models.sql.query import Query


def _get_condition_mql(self, model, schema_editor):
"""Analogous to Index._get_condition_sql()."""
query = Query(model=model, alias_cols=False)
where = query.build_where(self.condition)
compiler = query.get_compiler(connection=schema_editor.connection)
mql_ = where.as_mql(compiler, schema_editor.connection)
# Transform aggregate() query syntax into find() syntax.
mql = {}
for key in mql_:
col, value = mql_[key]
# multiple conditions don't work yet
if isinstance(col, dict):
return {}
mql[col.lstrip("$")] = {key: value}
return mql


def register_indexes():
Index._get_condition_mql = _get_condition_mql
28 changes: 22 additions & 6 deletions django_mongodb/schema.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from collections import defaultdict

from django.db.backends.base.schema import BaseDatabaseSchemaEditor
from django.db.models import Index, UniqueConstraint
from pymongo import ASCENDING, DESCENDING
Expand Down Expand Up @@ -166,17 +168,23 @@ def add_index(self, model, index, field=None, unique=False):
if index.contains_expressions:
return
kwargs = {}
filter_expression = defaultdict(dict)
if index.condition:
filter_expression.update(index._get_condition_mql(model, self))
if unique:
filter_expression = {}
kwargs["unique"] = True
# Indexing on $type matches the value of most SQL databases by
# allowing multiple null values for the unique constraint.
if field:
filter_expression[field.column] = {"$type": field.db_type(self.connection)}
filter_expression[field.column].update({"$type": field.db_type(self.connection)})
else:
for field_name, _ in index.fields_orders:
field_ = model._meta.get_field(field_name)
filter_expression[field_.column] = {"$type": field_.db_type(self.connection)}
kwargs = {"partialFilterExpression": filter_expression, "unique": True}
filter_expression[field_.column].update(
{"$type": field_.db_type(self.connection)}
)
if filter_expression:
kwargs["partialFilterExpression"] = filter_expression
index_orders = (
[(field.column, ASCENDING)]
if field
Expand Down Expand Up @@ -260,7 +268,11 @@ def add_constraint(self, model, constraint, field=None):
expressions=constraint.expressions,
nulls_distinct=constraint.nulls_distinct,
):
idx = Index(fields=constraint.fields, name=constraint.name)
idx = Index(
fields=constraint.fields,
name=constraint.name,
condition=constraint.condition,
)
self.add_index(model, idx, field=field, unique=True)

def _add_field_unique(self, model, field):
Expand All @@ -276,7 +288,11 @@ def remove_constraint(self, model, constraint):
expressions=constraint.expressions,
nulls_distinct=constraint.nulls_distinct,
):
idx = Index(fields=constraint.fields, name=constraint.name)
idx = Index(
fields=constraint.fields,
name=constraint.name,
condition=constraint.condition,
)
self.remove_index(model, idx)

def _remove_field_unique(self, model, field):
Expand Down

0 comments on commit 99623ee

Please sign in to comment.