diff --git a/django_mongodb/expressions.py b/django_mongodb/expressions.py index bf4bf23b..b6268a1c 100644 --- a/django_mongodb/expressions.py +++ b/django_mongodb/expressions.py @@ -204,6 +204,9 @@ def when(self, compiler, connection): def value(self, compiler, connection): # noqa: ARG001 value = self.value + # TODO: check this, not sure $literal is needed for all cases. + if isinstance(value, datetime.datetime): + return value if isinstance(value, Decimal): value = Decimal128(value) elif isinstance(value, datetime.date): diff --git a/django_mongodb/features.py b/django_mongodb/features.py index 0e29b200..c077fe69 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -82,13 +82,10 @@ class DatabaseFeatures(BaseDatabaseFeatures): "auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key", # GenericRelation.value_to_string() assumes integer pk. "contenttypes_tests.test_fields.GenericRelationTests.test_value_to_string", - # contains with expressions/subqueries doesn't work. - "model_fields_.test_arrayfield.TestQuerying.test_contains_including_expression", + # contains with subqueries doesn't work. "model_fields_.test_arrayfield.TestQuerying.test_contains_subquery", # Unsupported conversion from array to string in $convert "model_fields_.test_arrayfield.TestQuerying.test_icontains", - # Field 'field' expected a number but got Value(1). - "model_fields_.test_arrayfield.TestQuerying.test_exact_with_expression", # $lt treats null values as zero. "model_fields_.test_arrayfield.TestQuerying.test_lt", "model_fields_.test_arrayfield.TestQuerying.test_len", diff --git a/django_mongodb/fields/array.py b/django_mongodb/fields/array.py index 938dadfe..a2c29afc 100644 --- a/django_mongodb/fields/array.py +++ b/django_mongodb/fields/array.py @@ -4,7 +4,7 @@ from django.core import checks, exceptions from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value from django.db.models.fields.mixins import CheckFieldDefaultMixin -from django.db.models.lookups import FieldGetDbPrepValueMixin, In, Lookup +from django.db.models.lookups import Exact, FieldGetDbPrepValueMixin, In, Lookup from django.utils.translation import gettext_lazy as _ from django_mongodb.forms import SimpleArrayField @@ -235,6 +235,11 @@ def formfield(self, **kwargs): ) +class Array(Func): + def as_mql(self, compiler, connection): + return [expr.as_mql(compiler, connection) for expr in self.get_source_expressions()] + + class ArrayRHSMixin: def __init__(self, lhs, rhs): # Don't wrap arrays that contains only None values, psycopg doesn't @@ -246,18 +251,9 @@ def __init__(self, lhs, rhs): field = lhs.output_field value = Value(field.base_field.get_prep_value(value)) expressions.append(value) - rhs = Func( - *expressions, - function="ARRAY", - template="%(function)s[%(expressions)s]", - ) + rhs = Array(*expressions) super().__init__(lhs, rhs) - def process_rhs(self, compiler, connection): - rhs, rhs_params = super().process_rhs(compiler, connection) - cast_type = self.lhs.output_field.cast_db_type(connection) - return f"{rhs}::{cast_type}", rhs_params - def _rhs_not_none_values(self, rhs): for x in rhs: if isinstance(x, list | tuple): @@ -267,29 +263,29 @@ def _rhs_not_none_values(self, rhs): @ArrayField.register_lookup -class ArrayContains(FieldGetDbPrepValueMixin, Lookup): +class ArrayContains(ArrayRHSMixin, FieldGetDbPrepValueMixin, Lookup): lookup_name = "contains" def as_mql(self, compiler, connection): lhs_mql = process_lhs(self, compiler, connection) value = process_rhs(self, compiler, connection) return { - "$gt": [ + "$eq": [ { "$cond": { "if": {"$eq": [lhs_mql, None]}, - "then": None, - "else": {"$size": {"$setIntersection": [lhs_mql, value]}}, + "then": False, + "else": {"$setIsSubset": [value, lhs_mql]}, } }, - 0, + True, ] } -# @ArrayField.register_lookup -# class ArrayExact(ArrayRHSMixin, Exact): -# pass +@ArrayField.register_lookup +class ArrayExact(ArrayRHSMixin, Exact): + pass @ArrayField.register_lookup