Skip to content

Commit

Permalink
fix expressions in array lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
timgraham committed Dec 28, 2024
1 parent 706301f commit 7e96d4c
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 23 deletions.
3 changes: 3 additions & 0 deletions django_mongodb/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ def when(self, compiler, connection):

def value(self, compiler, connection): # noqa: ARG001
value = self.value
# TODO: check this, not sure $literal is needed for all cases.
if isinstance(value, datetime.datetime):
return value
if isinstance(value, Decimal):
value = Decimal128(value)
elif isinstance(value, datetime.date):
Expand Down
5 changes: 1 addition & 4 deletions django_mongodb/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
"auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key",
# GenericRelation.value_to_string() assumes integer pk.
"contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string",
# contains with expressions/subqueries doesn't work.
"model_fields_.test_arrayfield.TestQuerying.test_contains_including_expression",
# contains with subqueries doesn't work.
"model_fields_.test_arrayfield.TestQuerying.test_contains_subquery",
# Unsupported conversion from array to string in $convert
"model_fields_.test_arrayfield.TestQuerying.test_icontains",
# Field 'field' expected a number but got Value(1).
"model_fields_.test_arrayfield.TestQuerying.test_exact_with_expression",
# $lt treats null values as zero.
"model_fields_.test_arrayfield.TestQuerying.test_lt",
"model_fields_.test_arrayfield.TestQuerying.test_len",
Expand Down
34 changes: 15 additions & 19 deletions django_mongodb/fields/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from django.core import checks, exceptions
from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value
from django.db.models.fields.mixins import CheckFieldDefaultMixin
from django.db.models.lookups import FieldGetDbPrepValueMixin, In, Lookup
from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin, In, Lookup
from django.utils.translation import gettext_lazy as _

from django_mongodb.forms import SimpleArrayField
Expand Down Expand Up @@ -235,6 +235,11 @@ def formfield(self, **kwargs):
)


class Array(Func):
def as_mql(self, compiler, connection):
return [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]


class ArrayRHSMixin:
def __init__(self, lhs, rhs):
# Don't wrap arrays that contains only None values, psycopg doesn't
Expand All @@ -246,18 +251,9 @@ def __init__(self, lhs, rhs):
field = lhs.output_field
value = Value(field.base_field.get_prep_value(value))
expressions.append(value)
rhs = Func(
*expressions,
function="ARRAY",
template="%(function)s[%(expressions)s]",
)
rhs = Array(*expressions)
super().__init__(lhs, rhs)

def process_rhs(self, compiler, connection):
rhs, rhs_params = super().process_rhs(compiler, connection)
cast_type = self.lhs.output_field.cast_db_type(connection)
return f"{rhs}::{cast_type}", rhs_params

def _rhs_not_none_values(self, rhs):
for x in rhs:
if isinstance(x, list | tuple):
Expand All @@ -267,29 +263,29 @@ def _rhs_not_none_values(self, rhs):


@ArrayField.register_lookup
class ArrayContains(FieldGetDbPrepValueMixin, Lookup):
class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
lookup_name = "contains"

def as_mql(self, compiler, connection):
lhs_mql = process_lhs(self, compiler, connection)
value = process_rhs(self, compiler, connection)
return {
"$gt": [
"$eq": [
{
"$cond": {
"if": {"$eq": [lhs_mql, None]},
"then": None,
"else": {"$size": {"$setIntersection": [lhs_mql, value]}},
"then": False,
"else": {"$setIsSubset": [value, lhs_mql]},
}
},
0,
True,
]
}


# @ArrayField.register_lookup
# class ArrayExact(ArrayRHSMixin, Exact):
# pass
@ArrayField.register_lookup
class ArrayExact(ArrayRHSMixin, Exact):
pass


@ArrayField.register_lookup
Expand Down

0 comments on commit 7e96d4c

Please sign in to comment.