Skip to content

Commit 7e96d4c

Browse files
committed
fix expressions in array lookups
1 parent 706301f commit 7e96d4c

File tree

3 files changed

+19
-23
lines changed

3 files changed

+19
-23
lines changed

django_mongodb/expressions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ def when(self, compiler, connection):
204204

205205
def value(self, compiler, connection): # noqa: ARG001
206206
value = self.value
207+
# TODO: check this, not sure $literal is needed for all cases.
208+
if isinstance(value, datetime.datetime):
209+
return value
207210
if isinstance(value, Decimal):
208211
value = Decimal128(value)
209212
elif isinstance(value, datetime.date):

django_mongodb/features.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,13 +82,10 @@ class DatabaseFeatures(BaseDatabaseFeatures):
8282
"auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key",
8383
# GenericRelation.value_to_string() assumes integer pk.
8484
"contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string",
85-
# contains with expressions/subqueries doesn't work.
86-
"model_fields_.test_arrayfield.TestQuerying.test_contains_including_expression",
85+
# contains with subqueries doesn't work.
8786
"model_fields_.test_arrayfield.TestQuerying.test_contains_subquery",
8887
# Unsupported conversion from array to string in $convert
8988
"model_fields_.test_arrayfield.TestQuerying.test_icontains",
90-
# Field 'field' expected a number but got Value(1).
91-
"model_fields_.test_arrayfield.TestQuerying.test_exact_with_expression",
9289
# $lt treats null values as zero.
9390
"model_fields_.test_arrayfield.TestQuerying.test_lt",
9491
"model_fields_.test_arrayfield.TestQuerying.test_len",

django_mongodb/fields/array.py

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from django.core import checks, exceptions
55
from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value
66
from django.db.models.fields.mixins import CheckFieldDefaultMixin
7-
from django.db.models.lookups import FieldGetDbPrepValueMixin, In, Lookup
7+
from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin, In, Lookup
88
from django.utils.translation import gettext_lazy as _
99

1010
from django_mongodb.forms import SimpleArrayField
@@ -235,6 +235,11 @@ def formfield(self, **kwargs):
235235
)
236236

237237

238+
class Array(Func):
239+
def as_mql(self, compiler, connection):
240+
return [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()]
241+
242+
238243
class ArrayRHSMixin:
239244
def __init__(self, lhs, rhs):
240245
# Don't wrap arrays that contains only None values, psycopg doesn't
@@ -246,18 +251,9 @@ def __init__(self, lhs, rhs):
246251
field = lhs.output_field
247252
value = Value(field.base_field.get_prep_value(value))
248253
expressions.append(value)
249-
rhs = Func(
250-
*expressions,
251-
function="ARRAY",
252-
template="%(function)s[%(expressions)s]",
253-
)
254+
rhs = Array(*expressions)
254255
super().__init__(lhs, rhs)
255256

256-
def process_rhs(self, compiler, connection):
257-
rhs, rhs_params = super().process_rhs(compiler, connection)
258-
cast_type = self.lhs.output_field.cast_db_type(connection)
259-
return f"{rhs}::{cast_type}", rhs_params
260-
261257
def _rhs_not_none_values(self, rhs):
262258
for x in rhs:
263259
if isinstance(x, list | tuple):
@@ -267,29 +263,29 @@ def _rhs_not_none_values(self, rhs):
267263

268264

269265
@ArrayField.register_lookup
270-
class ArrayContains(FieldGetDbPrepValueMixin, Lookup):
266+
class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup):
271267
lookup_name = "contains"
272268

273269
def as_mql(self, compiler, connection):
274270
lhs_mql = process_lhs(self, compiler, connection)
275271
value = process_rhs(self, compiler, connection)
276272
return {
277-
"$gt": [
273+
"$eq": [
278274
{
279275
"$cond": {
280276
"if": {"$eq": [lhs_mql, None]},
281-
"then": None,
282-
"else": {"$size": {"$setIntersection": [lhs_mql, value]}},
277+
"then": False,
278+
"else": {"$setIsSubset": [value, lhs_mql]},
283279
}
284280
},
285-
0,
281+
True,
286282
]
287283
}
288284

289285

290-
# @ArrayField.register_lookup
291-
# class ArrayExact(ArrayRHSMixin, Exact):
292-
# pass
286+
@ArrayField.register_lookup
287+
class ArrayExact(ArrayRHSMixin, Exact):
288+
pass
293289

294290

295291
@ArrayField.register_lookup

0 commit comments

Comments
 (0)