diff --git a/django_mongodb/fields/array.py b/django_mongodb/fields/array.py index 12cc0bae..883ba051 100644 --- a/django_mongodb/fields/array.py +++ b/django_mongodb/fields/array.py @@ -1,18 +1,17 @@ import json -from django.contrib.postgres.forms import SimpleArrayField from django.contrib.postgres.validators import ArrayMaxLengthValidator from django.core import checks, exceptions -from django.db.models import DecimalField, Field, Func, Transform, Value +from django.db.models import DecimalField, Field, Func, IntegerField, Transform, Value from django.db.models.fields.mixins import CheckFieldDefaultMixin +from django.db.models.lookups import In from django.utils.translation import gettext_lazy as _ -__all__ = ["ArrayField"] +from django_mongodb.forms import SimpleArrayField +from ..utils import prefix_validation_error -from django.core.exceptions import ValidationError -from django.utils.functional import SimpleLazyObject -from django.utils.text import format_lazy +__all__ = ["ArrayField"] class AttributeSetter: @@ -20,32 +19,6 @@ def __init__(self, name, value): setattr(self, name, value) -def prefix_validation_error(error, prefix, code, params): - """ - Prefix a validation error message while maintaining the existing - validation data structure. - """ - if error.error_list == [error]: - error_params = error.params or {} - return ValidationError( - # We can't simply concatenate messages since they might require - # their associated parameters to be expressed correctly which - # is not something `format_lazy` does. For example, proxied - # ngettext calls require a count parameter and are converted - # to an empty string if they are missing it. - message=format_lazy( - "{} {}", - SimpleLazyObject(lambda: prefix % params), - SimpleLazyObject(lambda: error.message % error_params), - ), - code=code, - params={**error_params, **params}, - ) - return ValidationError( - [prefix_validation_error(e, prefix, code, params) for e in error.error_list] - ) - - class ArrayField(CheckFieldDefaultMixin, Field): empty_strings_allowed = False default_error_messages = { @@ -293,55 +266,44 @@ def _rhs_not_none_values(self, rhs): yield True -# @ArrayField.register_lookup -# class ArrayContains(ArrayRHSMixin, lookups.DataContains): -# pass - - -# @ArrayField.register_lookup -# class ArrayContainedBy(ArrayRHSMixin, lookups.ContainedBy): -# pass - - # @ArrayField.register_lookup # class ArrayExact(ArrayRHSMixin, Exact): -# pass +# pass -# @ArrayField.register_lookup -# class ArrayOverlap(ArrayRHSMixin, lookups.Overlap): -# pass +@ArrayField.register_lookup +class ArrayLenTransform(Transform): + lookup_name = "len" + output_field = IntegerField() - -# @ArrayField.register_lookup -# class ArrayLenTransform(Transform): -# lookup_name = "len" -# output_field = IntegerField() - -# def as_sql(self, compiler, connection): -# lhs, params = compiler.compile(self.lhs) -# # Distinguish NULL and empty arrays -# return ( -# "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE " -# "coalesce(array_length(%(lhs)s, 1), 0) END" -# ) % {"lhs": lhs}, params * 2 + def as_sql(self, compiler, connection): + lhs, params = compiler.compile(self.lhs) + # Distinguish NULL and empty arrays + return ( + ( + "" # "CASE WHEN %(lhs)s IS NULL THEN NULL ELSE " + # "coalesce(array_length(%(lhs)s, 1), 0) END" + ) + % {}, + params * 2, + ) -# @ArrayField.register_lookup -# class ArrayInLookup(In): -# def get_prep_lookup(self): -# values = super().get_prep_lookup() -# if hasattr(values, "resolve_expression"): -# return values -# # In.process_rhs() expects values to be hashable, so convert lists -# # to tuples. -# prepared_values = [] -# for value in values: -# if hasattr(value, "resolve_expression"): -# prepared_values.append(value) -# else: -# prepared_values.append(tuple(value)) -# return prepared_values +@ArrayField.register_lookup +class ArrayInLookup(In): + def get_prep_lookup(self): + values = super().get_prep_lookup() + if hasattr(values, "resolve_expression"): + return values + # In.process_rhs() expects values to be hashable, so convert lists + # to tuples. + prepared_values = [] + for value in values: + if hasattr(value, "resolve_expression"): + prepared_values.append(value) + else: + prepared_values.append(tuple(value)) + return prepared_values class IndexTransform(Transform): @@ -388,6 +350,5 @@ def __init__(self, start, end): self.start = start self.end = end - -# def __call__(self, *args, **kwargs): -# return SliceTransform(self.start, self.end, *args, **kwargs) + def __call__(self, *args, **kwargs): + return SliceTransform(self.start, self.end, *args, **kwargs) diff --git a/django_mongodb/forms/__init__.py b/django_mongodb/forms/__init__.py new file mode 100644 index 00000000..d7ab09b9 --- /dev/null +++ b/django_mongodb/forms/__init__.py @@ -0,0 +1 @@ +from .array import * # NOQA: F403 diff --git a/django_mongodb/forms/array.py b/django_mongodb/forms/array.py new file mode 100644 index 00000000..cd45dff4 --- /dev/null +++ b/django_mongodb/forms/array.py @@ -0,0 +1,245 @@ +import copy +from itertools import chain + +from django import forms +from django.contrib.postgres.validators import ( + ArrayMaxLengthValidator, + ArrayMinLengthValidator, +) +from django.core.exceptions import ValidationError +from django.utils.translation import gettext_lazy as _ + +from ..utils import prefix_validation_error + + +class SimpleArrayField(forms.CharField): + default_error_messages = { + "item_invalid": _("Item %(nth)s in the array did not validate:"), + } + + def __init__(self, base_field, *, delimiter=",", max_length=None, min_length=None, **kwargs): + self.base_field = base_field + self.delimiter = delimiter + super().__init__(**kwargs) + if min_length is not None: + self.min_length = min_length + self.validators.append(ArrayMinLengthValidator(int(min_length))) + if max_length is not None: + self.max_length = max_length + self.validators.append(ArrayMaxLengthValidator(int(max_length))) + + def clean(self, value): + value = super().clean(value) + return [self.base_field.clean(val) for val in value] + + def prepare_value(self, value): + if isinstance(value, list): + return self.delimiter.join(str(self.base_field.prepare_value(v)) for v in value) + return value + + def to_python(self, value): + if isinstance(value, list): + items = value + elif value: + items = value.split(self.delimiter) + else: + items = [] + errors = [] + values = [] + for index, item in enumerate(items): + try: + values.append(self.base_field.to_python(item)) + except ValidationError as error: + errors.append( + prefix_validation_error( + error, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) + if errors: + raise ValidationError(errors) + return values + + def validate(self, value): + super().validate(value) + errors = [] + for index, item in enumerate(value): + try: + self.base_field.validate(item) + except ValidationError as error: + errors.append( + prefix_validation_error( + error, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) + if errors: + raise ValidationError(errors) + + def run_validators(self, value): + super().run_validators(value) + errors = [] + for index, item in enumerate(value): + try: + self.base_field.run_validators(item) + except ValidationError as error: + errors.append( + prefix_validation_error( + error, + prefix=self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) + if errors: + raise ValidationError(errors) + + def has_changed(self, initial, data): + try: + value = self.to_python(data) + except ValidationError: + pass + else: + if initial in self.empty_values and value in self.empty_values: + return False + return super().has_changed(initial, data) + + +class SplitArrayWidget(forms.Widget): + template_name = "mongodb/widgets/split_array.html" + + def __init__(self, widget, size, **kwargs): + self.widget = widget() if isinstance(widget, type) else widget + self.size = size + super().__init__(**kwargs) + + @property + def is_hidden(self): + return self.widget.is_hidden + + def value_from_datadict(self, data, files, name): + return [ + self.widget.value_from_datadict(data, files, f"{name}_{index}") + for index in range(self.size) + ] + + def value_omitted_from_data(self, data, files, name): + return all( + self.widget.value_omitted_from_data(data, files, f"{name}_{index}") + for index in range(self.size) + ) + + def id_for_label(self, id_): + # See the comment for RadioSelect.id_for_label() + if id_: + id_ += "_0" + return id_ + + def get_context(self, name, value, attrs=None): + attrs = {} if attrs is None else attrs + context = super().get_context(name, value, attrs) + if self.is_localized: + self.widget.is_localized = self.is_localized + value = value or [] + context["widget"]["subwidgets"] = [] + final_attrs = self.build_attrs(attrs) + id_ = final_attrs.get("id") + for i in range(max(len(value), self.size)): + try: + widget_value = value[i] + except IndexError: + widget_value = None + if id_: + final_attrs = {**final_attrs, "id": f"{id_}_{i}"} + context["widget"]["subwidgets"].append( + self.widget.get_context(name + "_%s" % i, widget_value, final_attrs)["widget"] + ) + return context + + @property + def media(self): + return self.widget.media + + def __deepcopy__(self, memo): + obj = super().__deepcopy__(memo) + obj.widget = copy.deepcopy(self.widget) + return obj + + @property + def needs_multipart_form(self): + return self.widget.needs_multipart_form + + +class SplitArrayField(forms.Field): + default_error_messages = { + "item_invalid": _("Item %(nth)s in the array did not validate:"), + } + + def __init__(self, base_field, size, *, remove_trailing_nulls=False, **kwargs): + self.base_field = base_field + self.size = size + self.remove_trailing_nulls = remove_trailing_nulls + widget = SplitArrayWidget(widget=base_field.widget, size=size) + kwargs.setdefault("widget", widget) + super().__init__(**kwargs) + + def _remove_trailing_nulls(self, values): + index = None + if self.remove_trailing_nulls: + for i, value in reversed(list(enumerate(values))): + if value in self.base_field.empty_values: + index = i + else: + break + if index is not None: + values = values[:index] + return values, index + + def to_python(self, value): + value = super().to_python(value) + return [self.base_field.to_python(item) for item in value] + + def clean(self, value): + cleaned_data = [] + errors = [] + if not any(value) and self.required: + raise ValidationError(self.error_messages["required"]) + max_size = max(self.size, len(value)) + for index in range(max_size): + item = value[index] + try: + cleaned_data.append(self.base_field.clean(item)) + except ValidationError as error: + errors.append( + prefix_validation_error( + error, + self.error_messages["item_invalid"], + code="item_invalid", + params={"nth": index + 1}, + ) + ) + cleaned_data.append(None) + else: + errors.append(None) + cleaned_data, null_index = self._remove_trailing_nulls(cleaned_data) + if null_index is not None: + errors = errors[:null_index] + errors = list(filter(None, errors)) + if errors: + raise ValidationError(list(chain.from_iterable(errors))) + return cleaned_data + + def has_changed(self, initial, data): + try: + data = self.to_python(data) + except ValidationError: + pass + else: + data, _ = self._remove_trailing_nulls(data) + if initial in self.empty_values and data in self.empty_values: + return False + return super().has_changed(initial, data) diff --git a/django_mongodb/jinja2/mongodb/widgets/split_array.html b/django_mongodb/jinja2/mongodb/widgets/split_array.html new file mode 100644 index 00000000..32fda826 --- /dev/null +++ b/django_mongodb/jinja2/mongodb/widgets/split_array.html @@ -0,0 +1 @@ +{% include 'django/forms/widgets/multiwidget.html' %} diff --git a/django_mongodb/templates/mongodb/widgets/split_array.html b/django_mongodb/templates/mongodb/widgets/split_array.html new file mode 100644 index 00000000..32fda826 --- /dev/null +++ b/django_mongodb/templates/mongodb/widgets/split_array.html @@ -0,0 +1 @@ +{% include 'django/forms/widgets/multiwidget.html' %} diff --git a/django_mongodb/utils.py b/django_mongodb/utils.py index b4d87cc7..e862211e 100644 --- a/django_mongodb/utils.py +++ b/django_mongodb/utils.py @@ -3,8 +3,10 @@ import django from django.conf import settings -from django.core.exceptions import ImproperlyConfigured +from django.core.exceptions import ImproperlyConfigured, ValidationError from django.db.backends.utils import logger +from django.utils.functional import SimpleLazyObject +from django.utils.text import format_lazy from django.utils.version import get_version_tuple @@ -25,6 +27,32 @@ def check_django_compatability(): ) +def prefix_validation_error(error, prefix, code, params): + """ + Prefix a validation error message while maintaining the existing + validation data structure. + """ + if error.error_list == [error]: + error_params = error.params or {} + return ValidationError( + # We can't simply concatenate messages since they might require + # their associated parameters to be expressed correctly which + # is not something `format_lazy` does. For example, proxied + # ngettext calls require a count parameter and are converted + # to an empty string if they are missing it. + message=format_lazy( + "{} {}", + SimpleLazyObject(lambda: prefix % params), + SimpleLazyObject(lambda: error.message % error_params), + ), + code=code, + params={**error_params, **params}, + ) + return ValidationError( + [prefix_validation_error(e, prefix, code, params) for e in error.error_list] + ) + + def set_wrapped_methods(cls): """Initialize the wrapped methods on cls.""" if hasattr(cls, "logging_wrapper"): diff --git a/tests/model_fields_/array_default_migrations/0001_initial.py b/tests/model_fields_/array_default_migrations/0001_initial.py new file mode 100644 index 00000000..ae7aa89b --- /dev/null +++ b/tests/model_fields_/array_default_migrations/0001_initial.py @@ -0,0 +1,29 @@ +import django.contrib.postgres.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [] + + operations = [ + migrations.CreateModel( + name="IntegerArrayDefaultModel", + fields=[ + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "field", + django.contrib.postgres.fields.ArrayField(models.IntegerField(), size=None), + ), + ], + options={}, + bases=(models.Model,), + ), + ] diff --git a/tests/model_fields_/array_default_migrations/0002_integerarraymodel_field_2.py b/tests/model_fields_/array_default_migrations/0002_integerarraymodel_field_2.py new file mode 100644 index 00000000..ab1f06b5 --- /dev/null +++ b/tests/model_fields_/array_default_migrations/0002_integerarraymodel_field_2.py @@ -0,0 +1,19 @@ +import django.contrib.postgres.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("postgres_tests", "0001_initial"), + ] + + operations = [ + migrations.AddField( + model_name="integerarraydefaultmodel", + name="field_2", + field=django.contrib.postgres.fields.ArrayField( + models.IntegerField(), default=[], size=None + ), + preserve_default=False, + ), + ] diff --git a/tests/model_fields_/array_default_migrations/__init__.py b/tests/model_fields_/array_default_migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/model_fields_/array_index_migrations/0001_initial.py b/tests/model_fields_/array_index_migrations/0001_initial.py new file mode 100644 index 00000000..84667cbb --- /dev/null +++ b/tests/model_fields_/array_index_migrations/0001_initial.py @@ -0,0 +1,36 @@ +import django.contrib.postgres.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [] + + operations = [ + migrations.CreateModel( + name="CharTextArrayIndexModel", + fields=[ + ( + "id", + models.AutoField( + verbose_name="ID", + serialize=False, + auto_created=True, + primary_key=True, + ), + ), + ( + "char", + django.contrib.postgres.fields.ArrayField( + models.CharField(max_length=10), db_index=True, size=100 + ), + ), + ("char2", models.CharField(max_length=11, db_index=True)), + ( + "text", + django.contrib.postgres.fields.ArrayField(models.TextField(), db_index=True), + ), + ], + options={}, + bases=(models.Model,), + ), + ] diff --git a/tests/model_fields_/array_index_migrations/__init__.py b/tests/model_fields_/array_index_migrations/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/model_fields_/test_arrayfield.py b/tests/model_fields_/test_arrayfield.py index b000a212..0af9518a 100644 --- a/tests/model_fields_/test_arrayfield.py +++ b/tests/model_fields_/test_arrayfield.py @@ -10,17 +10,19 @@ from django.core.exceptions import FieldError from django.core.management import call_command from django.db import IntegrityError, connection, models -from django.db.models.expressions import Exists, OuterRef, RawSQL, Value -from django.db.models.functions import Cast, JSONObject, Upper -from django.test import ( # , PostgreSQLWidgetTestCase +from django.db.models.expressions import Exists, OuterRef, Value +from django.test import ( SimpleTestCase, TestCase, TransactionTestCase, override_settings, - skipUnlessDBFeature, ) -from django.test.utils import isolate_apps +from django.test.utils import isolate_apps, modify_settings from django.utils import timezone +from forms_tests.widget_tests.base import WidgetTest + +from django_mongodb.fields import ArrayField +from django_mongodb.forms import SimpleArrayField, SplitArrayField, SplitArrayWidget from .models import ( ArrayEnumModel, @@ -34,19 +36,6 @@ Tag, ) -try: - from django.contrib.postgres.aggregates import ArrayAgg - from django.contrib.postgres.expressions import ArraySubquery - from django.contrib.postgres.fields import ArrayField - from django.contrib.postgres.fields.array import IndexTransform, SliceTransform - from django.contrib.postgres.forms import ( - SimpleArrayField, - SplitArrayField, - SplitArrayWidget, - ) -except ImportError: - pass - @isolate_apps("model_fields_") class BasicTests(SimpleTestCase): @@ -318,18 +307,6 @@ def test_in_as_F_object(self): self.objs[:4], ) - def test_contained_by(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__contained_by=[1, 2]), - self.objs[:2], - ) - - def test_contained_by_including_F_object(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__contained_by=[models.F("order"), 2]), - self.objs[:3], - ) - def test_contains(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__contains=[2]), @@ -363,90 +340,8 @@ def test_icontains(self): self.assertSequenceEqual(CharArrayModel.objects.filter(field__icontains="foo"), [instance]) def test_contains_charfield(self): - # Regression for #22907 self.assertSequenceEqual(CharArrayModel.objects.filter(field__contains=["text"]), []) - def test_contained_by_charfield(self): - self.assertSequenceEqual(CharArrayModel.objects.filter(field__contained_by=["text"]), []) - - def test_overlap_charfield(self): - self.assertSequenceEqual(CharArrayModel.objects.filter(field__overlap=["text"]), []) - - def test_overlap_charfield_including_expression(self): - obj_1 = CharArrayModel.objects.create(field=["TEXT", "lower text"]) - obj_2 = CharArrayModel.objects.create(field=["lower text", "TEXT"]) - CharArrayModel.objects.create(field=["lower text", "text"]) - self.assertSequenceEqual( - CharArrayModel.objects.filter( - field__overlap=[ - Upper(Value("text")), - "other", - ] - ), - [obj_1, obj_2], - ) - - def test_overlap_values(self): - qs = NullableIntegerArrayModel.objects.filter(order__lt=3) - self.assertCountEqual( - NullableIntegerArrayModel.objects.filter( - field__overlap=qs.values_list("field"), - ), - self.objs[:3], - ) - self.assertCountEqual( - NullableIntegerArrayModel.objects.filter( - field__overlap=qs.values("field"), - ), - self.objs[:3], - ) - - def test_lookups_autofield_array(self): - qs = ( - NullableIntegerArrayModel.objects.filter( - field__0__isnull=False, - ) - .values("field__0") - .annotate( - arrayagg=ArrayAgg("id"), - ) - .order_by("field__0") - ) - tests = ( - ("contained_by", [self.objs[1].pk, self.objs[2].pk, 0], [2]), - ("contains", [self.objs[2].pk], [2]), - ("exact", [self.objs[3].pk], [20]), - ("overlap", [self.objs[1].pk, self.objs[3].pk], [2, 20]), - ) - for lookup, value, expected in tests: - with self.subTest(lookup=lookup): - self.assertSequenceEqual( - qs.filter( - **{"arrayagg__" + lookup: value}, - ).values_list("field__0", flat=True), - expected, - ) - - @skipUnlessDBFeature("allows_group_by_select_index") - def test_group_by_order_by_select_index(self): - with self.assertNumQueries(1) as ctx: - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter( - field__0__isnull=False, - ) - .values("field__0") - .annotate(arrayagg=ArrayAgg("id")) - .order_by("field__0"), - [ - {"field__0": 1, "arrayagg": [self.objs[0].pk]}, - {"field__0": 2, "arrayagg": [self.objs[1].pk, self.objs[2].pk]}, - {"field__0": 20, "arrayagg": [self.objs[3].pk]}, - ], - ) - sql = ctx[0]["sql"] - self.assertIn("GROUP BY 2", sql) - self.assertIn("ORDER BY 2", sql) - def test_index(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__0=2), self.objs[1:3] @@ -468,18 +363,6 @@ def test_index_used_on_nested_data(self): NestedIntegerArrayModel.objects.filter(field__0=[1, 2]), [instance] ) - def test_index_transform_expression(self): - expr = RawSQL("string_to_array(%s, ';')", ["1;2"]) - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter( - field__0=Cast( - IndexTransform(1, models.IntegerField, expr), - output_field=models.IntegerField(), - ), - ), - self.objs[:1], - ) - def test_index_annotation(self): qs = NullableIntegerArrayModel.objects.annotate(second=models.F("field__1")) self.assertCountEqual( @@ -487,12 +370,6 @@ def test_index_annotation(self): [None, None, None, 3, 30], ) - def test_overlap(self): - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__overlap=[1, 2]), - self.objs[0:3], - ) - def test_len(self): self.assertSequenceEqual( NullableIntegerArrayModel.objects.filter(field__len__lte=2), self.objs[0:3] @@ -540,13 +417,6 @@ def test_slice_nested(self): NestedIntegerArrayModel.objects.filter(field__0__0_1=[1]), [instance] ) - def test_slice_transform_expression(self): - expr = RawSQL("string_to_array(%s, ';')", ["9;2;3"]) - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.filter(field__0_2=SliceTransform(2, 3, expr)), - self.objs[2:3], - ) - def test_slice_annotation(self): qs = NullableIntegerArrayModel.objects.annotate( first_two=models.F("field__0_2"), @@ -602,74 +472,6 @@ def test_grouping_by_annotations_with_array_field_param(self): 1, ) - def test_filter_by_array_subquery(self): - inner_qs = NullableIntegerArrayModel.objects.filter( - field__len=models.OuterRef("field__len"), - ).values("field") - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.alias( - same_sized_fields=ArraySubquery(inner_qs), - ).filter(same_sized_fields__len__gt=1), - self.objs[0:2], - ) - - def test_annotated_array_subquery(self): - inner_qs = NullableIntegerArrayModel.objects.exclude(pk=models.OuterRef("pk")).values( - "order" - ) - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.annotate( - sibling_ids=ArraySubquery(inner_qs), - ) - .get(order=1) - .sibling_ids, - [2, 3, 4, 5], - ) - - def test_group_by_with_annotated_array_subquery(self): - inner_qs = NullableIntegerArrayModel.objects.exclude(pk=models.OuterRef("pk")).values( - "order" - ) - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.annotate( - sibling_ids=ArraySubquery(inner_qs), - sibling_count=models.Max("sibling_ids__len"), - ).values_list("sibling_count", flat=True), - [len(self.objs) - 1] * len(self.objs), - ) - - def test_annotated_ordered_array_subquery(self): - inner_qs = NullableIntegerArrayModel.objects.order_by("-order").values("order") - self.assertSequenceEqual( - NullableIntegerArrayModel.objects.annotate( - ids=ArraySubquery(inner_qs), - ) - .first() - .ids, - [5, 4, 3, 2, 1], - ) - - def test_annotated_array_subquery_with_json_objects(self): - inner_qs = NullableIntegerArrayModel.objects.exclude(pk=models.OuterRef("pk")).values( - json=JSONObject(order="order", field="field") - ) - siblings_json = ( - NullableIntegerArrayModel.objects.annotate( - siblings_json=ArraySubquery(inner_qs), - ) - .values_list("siblings_json", flat=True) - .get(order=1) - ) - self.assertSequenceEqual( - siblings_json, - [ - {"field": [2], "order": 2}, - {"field": [2, 3], "order": 3}, - {"field": [20, 30, 40], "order": 4}, - {"field": None, "order": 5}, - ], - ) - class TestDateTimeExactQuerying(TestCase): @classmethod @@ -840,7 +642,7 @@ def test_deconstruct_args(self): def test_subclass_deconstruct(self): field = ArrayField(models.IntegerField()) name, path, args, kwargs = field.deconstruct() - self.assertEqual(path, "django.contrib.postgres.fields.ArrayField") + self.assertEqual(path, "django_mongodb.fields.ArrayField") field = ArrayFieldSubclass() name, path, args, kwargs = field.deconstruct() @@ -1138,6 +940,8 @@ def test_has_changed_empty(self): self.assertIs(field.has_changed([], ""), False) +# To locate the widget's template. +@modify_settings(INSTALLED_APPS={"append": "django_mongodb"}) class TestSplitFormField(SimpleTestCase): def test_valid(self): class SplitForm(forms.Form): @@ -1296,7 +1100,9 @@ class Meta: self.assertIs(form.has_changed(), expected_result) -class TestSplitFormWidget(SimpleTestCase): +# To locate the widget's template. +@modify_settings(INSTALLED_APPS={"append": "django_mongodb"}) +class TestSplitFormWidget(WidgetTest, SimpleTestCase): def test_get_context(self): self.assertEqual( SplitArrayWidget(forms.TextInput(), size=2).get_context("name", ["val1", "val2"]), @@ -1307,7 +1113,7 @@ def test_get_context(self): "required": False, "value": "['val1', 'val2']", "attrs": {}, - "template_name": "postgres/widgets/split_array.html", + "template_name": "mongodb/widgets/split_array.html", "subwidgets": [ { "name": "name_0",