diff --git a/django_mongodb/fields/embedded_model.py b/django_mongodb/fields/embedded_model.py index 64670cf1..6b6fc899 100644 --- a/django_mongodb/fields/embedded_model.py +++ b/django_mongodb/fields/embedded_model.py @@ -2,6 +2,8 @@ from django.db.models.fields.related import lazy_related_operation from django.db.models.lookups import Transform +from .. import forms + class EmbeddedModelField(models.Field): """Field that stores a model instance.""" @@ -123,6 +125,16 @@ def validate(self, value, model_instance): attname = field.attname field.validate(getattr(value, attname), model_instance) + def formfield(self, **kwargs): + return super().formfield( + **{ + "form_class": forms.EmbeddedModelFormField, + "model": self.embedded_model, + "name": self.name, + **kwargs, + } + ) + class KeyTransform(Transform): def __init__(self, key_name, *args, **kwargs): diff --git a/django_mongodb/forms.py b/django_mongodb/forms.py new file mode 100644 index 00000000..0f78f868 --- /dev/null +++ b/django_mongodb/forms.py @@ -0,0 +1,61 @@ +from django import forms +from django.forms.models import modelform_factory +from django.utils.safestring import mark_safe +from django.utils.translation import gettext_lazy as _ + + +class EmbeddedModelWidget(forms.MultiWidget): + def __init__(self, field_names, *args, **kwargs): + self.field_names = field_names + super().__init__(*args, **kwargs) + # The default widget names are "_0", "_1", etc. Use the field names + # instead since that's how they'll be rendered by the model form. + self.widgets_names = ["-" + name for name in field_names] + + def decompress(self, value): + if value is None: + return [] + # Get the data from `value` (a model) for each field. + return [getattr(value, name) for name in self.field_names] + + +class EmbeddedModelBoundField(forms.BoundField): + def __str__(self): + """Render the model form as the representation for this field.""" + form = self.field.model_form_cls(instance=self.value(), **self.field.form_kwargs) + return mark_safe(f"{form.as_div()}") # noqa: S308 + + +class EmbeddedModelFormField(forms.MultiValueField): + default_error_messages = { + "invalid": _("Enter a list of values."), + "incomplete": _("Enter all required values."), + } + + def __init__(self, model, name, *args, **kwargs): + form_kwargs = {} + # The field must be prefixed with the name of the field. + form_kwargs["prefix"] = name + self.form_kwargs = form_kwargs + self.model_form_cls = modelform_factory(model, fields="__all__") + self.model_form = self.model_form_cls(**form_kwargs) + self.field_names = list(self.model_form.fields.keys()) + fields = self.model_form.fields.values() + widgets = [field.widget for field in fields] + widget = EmbeddedModelWidget(self.field_names, widgets) + super().__init__(*args, fields=fields, widget=widget, require_all_fields=False, **kwargs) + + def compress(self, data_dict): + if not data_dict: + return None + values = dict(zip(self.field_names, data_dict, strict=False)) + return self.model_form._meta.model(**values) + + def get_bound_field(self, form, field_name): + return EmbeddedModelBoundField(form, self, field_name) + + def bound_data(self, data, initial): + if self.disabled: + return initial + # The bound data must be transformed into a model instance. + return self.compress(data) diff --git a/tests/model_forms_/__init__.py b/tests/model_forms_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/model_forms_/forms.py b/tests/model_forms_/forms.py new file mode 100644 index 00000000..7bfed3fb --- /dev/null +++ b/tests/model_forms_/forms.py @@ -0,0 +1,9 @@ +from django import forms + +from .models import Author + + +class AuthorForm(forms.ModelForm): + class Meta: + fields = "__all__" + model = Author diff --git a/tests/model_forms_/models.py b/tests/model_forms_/models.py new file mode 100644 index 00000000..ef169756 --- /dev/null +++ b/tests/model_forms_/models.py @@ -0,0 +1,22 @@ +from django.db import models + +from django_mongodb.fields import EmbeddedModelField + + +class Address(models.Model): + po_box = models.CharField(max_length=50, blank=True, verbose_name="PO Box") + city = models.CharField(max_length=20) + state = models.CharField(max_length=2) + zip_code = models.IntegerField() + + +class Author(models.Model): + name = models.CharField(max_length=10) + age = models.IntegerField() + address = EmbeddedModelField(Address) + billing_address = EmbeddedModelField(Address, blank=True, null=True) + + +class Book(models.Model): + name = models.CharField(max_length=100) + author = EmbeddedModelField(Author) diff --git a/tests/model_forms_/test_embedded_model.py b/tests/model_forms_/test_embedded_model.py new file mode 100644 index 00000000..240f8c6d --- /dev/null +++ b/tests/model_forms_/test_embedded_model.py @@ -0,0 +1,130 @@ +from django.test import TestCase + +from .forms import AuthorForm +from .models import Address, Author + + +class ModelFormTests(TestCase): + def test_update(self): + author = Author.objects.create( + name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "New York City", + "address-state": "NY", + "address-zip_code": "10001", + } + form = AuthorForm(data, instance=author) + self.assertTrue(form.is_valid()) + form.save() + author.refresh_from_db() + self.assertEqual(author.age, 51) + self.assertEqual(author.address.city, "New York City") + + def test_some_missing_data(self): + author = Author.objects.create( + name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "New York City", + "address-state": "NY", + "address-zip_code": "", + } + form = AuthorForm(data, instance=author) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors["address"], ["Enter all required values."]) + + def test_invalid_field_data(self): + """A field's data (state) is too long.""" + author = Author.objects.create( + name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "New York City", + "address-state": "TOO LONG", + "address-zip_code": "", + } + form = AuthorForm(data, instance=author) + self.assertFalse(form.is_valid()) + self.assertEqual( + form.errors["address"], + [ + "Ensure this value has at most 2 characters (it has 8).", + "Enter all required values.", + ], + ) + + def test_all_missing_data(self): + author = Author.objects.create( + name="Bob", age=50, address=Address(city="NYC", state="NY", zip_code="10001") + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "", + "address-state": "", + "address-zip_code": "", + } + form = AuthorForm(data, instance=author) + self.assertFalse(form.is_valid()) + self.assertEqual(form.errors["address"], ["This field is required."]) + + def test_nullable_field(self): + """A nullable EmbeddedModelField is removed if all fields are empty.""" + author = Author.objects.create( + name="Bob", + age=50, + address=Address(city="NYC", state="NY", zip_code="10001"), + billing_address=Address(city="NYC", state="NY", zip_code="10001"), + ) + data = { + "name": "Bob", + "age": 51, + "address-po_box": "", + "address-city": "New York City", + "address-state": "NY", + "address-zip_code": "10001", + "billing_address-po_box": "", + "billing_address-city": "", + "billing_address-state": "", + "billing_address-zip_code": "", + } + form = AuthorForm(data, instance=author) + self.assertTrue(form.is_valid()) + form.save() + author.refresh_from_db() + self.assertIsNone(author.billing_address) + + def test_rendering(self): + form = AuthorForm() + self.assertHTMLEqual( + str(form.fields["address"].get_bound_field(form, "address")), + """ +