From 1df75fcd6dce97b13b9eb2a01422606379c91055 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 7 Dec 2024 20:07:29 -0500 Subject: [PATCH 1/2] add ObjectIdField Co-authored-by: Tim Graham --- .github/workflows/runtests.py | 1 + django_mongodb/fields/__init__.py | 3 +- django_mongodb/fields/auto.py | 14 +-- django_mongodb/fields/objectid.py | 54 +++++++++ django_mongodb/forms/__init__.py | 3 + django_mongodb/forms/fields.py | 27 +++++ docs/source/fields.rst | 13 +++ docs/source/forms.rst | 13 +++ docs/source/index.rst | 2 + tests/forms_tests_/__init__.py | 0 tests/forms_tests_/test_objectidfield.py | 33 ++++++ tests/model_fields_/models.py | 15 +++ tests/model_fields_/test_objectidfield.py | 130 ++++++++++++++++++++++ tests/queries_/models.py | 39 +++++++ tests/queries_/test_objectid.py | 111 ++++++++++++++++++ 15 files changed, 445 insertions(+), 13 deletions(-) create mode 100644 django_mongodb/fields/objectid.py create mode 100644 django_mongodb/forms/__init__.py create mode 100644 django_mongodb/forms/fields.py create mode 100644 docs/source/fields.rst create mode 100644 docs/source/forms.rst create mode 100644 tests/forms_tests_/__init__.py create mode 100644 tests/forms_tests_/test_objectidfield.py create mode 100644 tests/model_fields_/models.py create mode 100644 tests/model_fields_/test_objectidfield.py create mode 100644 tests/queries_/test_objectid.py diff --git a/.github/workflows/runtests.py b/.github/workflows/runtests.py index 41d12812..8bc424e1 100755 --- a/.github/workflows/runtests.py +++ b/.github/workflows/runtests.py @@ -57,6 +57,7 @@ "force_insert_update", "foreign_object", "forms_tests", + "forms_tests_", "from_db_value", "generic_inline_admin", "generic_relations", diff --git a/django_mongodb/fields/__init__.py b/django_mongodb/fields/__init__.py index d558e0fe..9eb2518d 100644 --- a/django_mongodb/fields/__init__.py +++ b/django_mongodb/fields/__init__.py @@ -1,8 +1,9 @@ from .auto import ObjectIdAutoField from .duration import register_duration_field from .json import register_json_field +from .objectid import ObjectIdField -__all__ = ["register_fields", "ObjectIdAutoField"] +__all__ = ["register_fields", "ObjectIdAutoField", "ObjectIdField"] def register_fields(): diff --git a/django_mongodb/fields/auto.py b/django_mongodb/fields/auto.py index cc9ebda9..8dd535ba 100644 --- a/django_mongodb/fields/auto.py +++ b/django_mongodb/fields/auto.py @@ -2,15 +2,11 @@ from django.core import exceptions from django.db.models.fields import AutoField from django.utils.functional import cached_property -from django.utils.translation import gettext_lazy as _ +from .objectid import ObjectIdMixin -class ObjectIdAutoField(AutoField): - default_error_messages = { - "invalid": _("“%(value)s” value must be an Object Id."), - } - description = _("Object Id") +class ObjectIdAutoField(ObjectIdMixin, AutoField): def __init__(self, *args, **kwargs): kwargs["db_column"] = "_id" super().__init__(*args, **kwargs) @@ -42,12 +38,6 @@ def get_prep_value(self, value): def get_internal_type(self): return "ObjectIdAutoField" - def db_type(self, connection): - return "objectId" - - def rel_db_type(self, connection): - return "objectId" - def to_python(self, value): if value is None or isinstance(value, int): return value diff --git a/django_mongodb/fields/objectid.py b/django_mongodb/fields/objectid.py new file mode 100644 index 00000000..b60ed6fb --- /dev/null +++ b/django_mongodb/fields/objectid.py @@ -0,0 +1,54 @@ +from bson import ObjectId, errors +from django.core import exceptions +from django.db.models.fields import Field +from django.utils.translation import gettext_lazy as _ + +from django_mongodb import forms + + +class ObjectIdMixin: + default_error_messages = { + "invalid": _("“%(value)s” is not a valid Object Id."), + } + description = _("Object Id") + + def db_type(self, connection): + return "objectId" + + def rel_db_type(self, connection): + return "objectId" + + def get_prep_value(self, value): + value = super().get_prep_value(value) + return self.to_python(value) + + def to_python(self, value): + if value is None: + return value + try: + return ObjectId(value) + except (errors.InvalidId, TypeError): + raise exceptions.ValidationError( + self.error_messages["invalid"], + code="invalid", + params={"value": value}, + ) from None + + def formfield(self, **kwargs): + return super().formfield( + **{ + "form_class": forms.ObjectIdField, + **kwargs, + } + ) + + +class ObjectIdField(ObjectIdMixin, Field): + def deconstruct(self): + name, path, args, kwargs = super().deconstruct() + if path.startswith("django_mongodb.fields.objectid"): + path = path.replace("django_mongodb.fields.objectid", "django_mongodb.fields") + return name, path, args, kwargs + + def get_internal_type(self): + return "ObjectIdField" diff --git a/django_mongodb/forms/__init__.py b/django_mongodb/forms/__init__.py new file mode 100644 index 00000000..9009a3ee --- /dev/null +++ b/django_mongodb/forms/__init__.py @@ -0,0 +1,3 @@ +from .fields import ObjectIdField + +__all__ = ["ObjectIdField"] diff --git a/django_mongodb/forms/fields.py b/django_mongodb/forms/fields.py new file mode 100644 index 00000000..bb6f40c9 --- /dev/null +++ b/django_mongodb/forms/fields.py @@ -0,0 +1,27 @@ +from bson import ObjectId +from bson.errors import InvalidId +from django.core.exceptions import ValidationError +from django.forms import Field +from django.utils.translation import gettext_lazy as _ + + +class ObjectIdField(Field): + default_error_messages = { + "invalid": _("Enter a valid Object Id."), + } + + def prepare_value(self, value): + if isinstance(value, ObjectId): + return str(value) + return value + + def to_python(self, value): + value = super().to_python(value) + if value in self.empty_values: + return None + if not isinstance(value, ObjectId): + try: + value = ObjectId(value) + except InvalidId: + raise ValidationError(self.error_messages["invalid"], code="invalid") from None + return value diff --git a/docs/source/fields.rst b/docs/source/fields.rst new file mode 100644 index 00000000..0bb0feb3 --- /dev/null +++ b/docs/source/fields.rst @@ -0,0 +1,13 @@ +Model field reference +===================== + +.. module:: django_mongodb.fields + +Some MongoDB-specific fields are available in ``django_mongodb.fields``. + +``ObjectIdField`` +----------------- + +.. class:: ObjectIdField + +Stores an :class:`~bson.objectid.ObjectId`. diff --git a/docs/source/forms.rst b/docs/source/forms.rst new file mode 100644 index 00000000..01b4f2f9 --- /dev/null +++ b/docs/source/forms.rst @@ -0,0 +1,13 @@ +Forms API reference +=================== + +.. module:: django_mongodb.forms + +Some MongoDB-specific fields are available in ``django_mongodb.forms``. + +``ObjectIdField`` +----------------- + +.. class:: ObjectIdField + +Stores an :class:`~bson.objectid.ObjectId`. diff --git a/docs/source/index.rst b/docs/source/index.rst index e9a2ace6..8df60944 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -5,7 +5,9 @@ django-mongodb 5.0.x documentation :maxdepth: 1 :caption: Contents: + fields querysets + forms Indices and tables ================== diff --git a/tests/forms_tests_/__init__.py b/tests/forms_tests_/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/forms_tests_/test_objectidfield.py b/tests/forms_tests_/test_objectidfield.py new file mode 100644 index 00000000..fcbd1035 --- /dev/null +++ b/tests/forms_tests_/test_objectidfield.py @@ -0,0 +1,33 @@ +from bson import ObjectId +from django.core.exceptions import ValidationError +from django.test import SimpleTestCase + +from django_mongodb.forms.fields import ObjectIdField + + +class ObjectIdFieldTests(SimpleTestCase): + def test_clean(self): + field = ObjectIdField() + value = field.clean("675747ec45260945758d76bc") + self.assertEqual(value, ObjectId("675747ec45260945758d76bc")) + + def test_clean_objectid(self): + field = ObjectIdField() + value = field.clean(ObjectId("675747ec45260945758d76bc")) + self.assertEqual(value, ObjectId("675747ec45260945758d76bc")) + + def test_clean_empty_string(self): + field = ObjectIdField(required=False) + value = field.clean("") + self.assertEqual(value, None) + + def test_clean_invalid(self): + field = ObjectIdField() + with self.assertRaises(ValidationError) as cm: + field.clean("invalid") + self.assertEqual(cm.exception.messages[0], "Enter a valid Object Id.") + + def test_prepare_value(self): + field = ObjectIdField() + value = field.prepare_value(ObjectId("675747ec45260945758d76bc")) + self.assertEqual(value, "675747ec45260945758d76bc") diff --git a/tests/model_fields_/models.py b/tests/model_fields_/models.py new file mode 100644 index 00000000..10f21258 --- /dev/null +++ b/tests/model_fields_/models.py @@ -0,0 +1,15 @@ +from django.db import models + +from django_mongodb.fields import ObjectIdField + + +class ObjectIdModel(models.Model): + field = ObjectIdField() + + +class NullableObjectIdModel(models.Model): + field = ObjectIdField(blank=True, null=True) + + +class PrimaryKeyObjectIdModel(models.Model): + field = ObjectIdField(primary_key=True) diff --git a/tests/model_fields_/test_objectidfield.py b/tests/model_fields_/test_objectidfield.py new file mode 100644 index 00000000..c37a5b79 --- /dev/null +++ b/tests/model_fields_/test_objectidfield.py @@ -0,0 +1,130 @@ +import json + +from bson import ObjectId +from django.core import serializers +from django.core.exceptions import ValidationError +from django.test import SimpleTestCase, TestCase + +from django_mongodb import forms +from django_mongodb.fields import ObjectIdField + +from .models import NullableObjectIdModel, ObjectIdModel, PrimaryKeyObjectIdModel + + +class MethodTests(SimpleTestCase): + def test_deconstruct(self): + field = ObjectIdField() + name, path, args, kwargs = field.deconstruct() + self.assertEqual(path, "django_mongodb.fields.ObjectIdField") + self.assertEqual(args, []) + self.assertEqual(kwargs, {}) + + def test_formfield(self): + f = ObjectIdField().formfield() + self.assertIsInstance(f, forms.ObjectIdField) + + def test_get_internal_type(self): + f = ObjectIdField() + self.assertEqual(f.get_internal_type(), "ObjectIdField") + + def test_to_python_string(self): + value = "1" * 24 + self.assertEqual(ObjectIdField().to_python(value), ObjectId(value)) + + def test_to_python_objectid(self): + value = ObjectId("1" * 24) + self.assertEqual(ObjectIdField().to_python(value), value) + + def test_to_python_null(self): + self.assertIsNone(ObjectIdField().to_python(None)) + + def test_to_python_invalid_value(self): + f = ObjectIdField() + for invalid_value in ["None", "", {}, [], 123]: + with self.subTest(invalid_value=invalid_value): + msg = f"['“{invalid_value}” is not a valid Object Id.']" + with self.assertRaisesMessage(ValidationError, msg): + f.to_python(invalid_value) + + def test_get_prep_value_string(self): + value = "1" * 24 + self.assertEqual(ObjectIdField().get_prep_value(value), ObjectId(value)) + + def test_get_prep_value_objectid(self): + value = ObjectId("1" * 24) + self.assertEqual(ObjectIdField().get_prep_value(value), value) + + def test_get_prep_value_null(self): + self.assertIsNone(ObjectIdField().get_prep_value(None)) + + def test_get_prep_value_invalid_values(self): + f = ObjectIdField() + f.name = "test" + for invalid_value in ["None", "", {}, [], 123]: + with self.subTest(invalid_value=invalid_value): + msg = f"['“{invalid_value}” is not a valid Object Id.']" + with self.assertRaisesMessage(ValidationError, msg): + f.get_prep_value(invalid_value) + + +class SaveLoadTests(TestCase): + def test_objectid_instance(self): + instance = ObjectIdModel.objects.create(field=ObjectId()) + loaded = ObjectIdModel.objects.get() + self.assertEqual(loaded.field, instance.field) + + def test_str_instance(self): + ObjectIdModel.objects.create(field="6754ed8e584bc9ceaae3c072") + loaded = ObjectIdModel.objects.get() + self.assertEqual(loaded.field, ObjectId("6754ed8e584bc9ceaae3c072")) + + def test_null_handling(self): + NullableObjectIdModel.objects.create(field=None) + loaded = NullableObjectIdModel.objects.get() + self.assertIsNone(loaded.field) + + def test_pk_validated(self): + with self.assertRaisesMessage(ValidationError, "is not a valid Object Id."): + PrimaryKeyObjectIdModel.objects.get(pk={}) + + with self.assertRaisesMessage(ValidationError, "is not a valid Object Id."): + PrimaryKeyObjectIdModel.objects.get(pk=[]) + + def test_wrong_lookup_type(self): + with self.assertRaisesMessage(ValidationError, "is not a valid Object Id."): + ObjectIdModel.objects.get(field="not-a-objectid") + + with self.assertRaisesMessage(ValidationError, "is not a valid Object Id."): + ObjectIdModel.objects.create(field="not-a-objectid") + + +class SerializationTests(TestCase): + test_data = ( + '[{"fields": {"field": "6754ed8e584bc9ceaae3c072"}, "model": ' + '"model_fields_.objectidmodel", "pk": null}]' + ) + + def test_dumping(self): + instance = ObjectIdModel(field=ObjectId("6754ed8e584bc9ceaae3c072")) + data = serializers.serialize("json", [instance]) + self.assertEqual(json.loads(data), json.loads(self.test_data)) + + def test_loading(self): + instance = next(serializers.deserialize("json", self.test_data)).object + self.assertEqual(instance.field, ObjectId("6754ed8e584bc9ceaae3c072")) + + +class ValidationTests(TestCase): + def test_invalid_objectid(self): + field = ObjectIdField() + with self.assertRaises(ValidationError) as cm: + field.clean("550e8400", None) + self.assertEqual(cm.exception.code, "invalid") + self.assertEqual( + cm.exception.message % cm.exception.params, "“550e8400” is not a valid Object Id." + ) + + def test_objectid_instance_ok(self): + value = ObjectId() + field = ObjectIdField() + self.assertEqual(field.clean(value, None), value) diff --git a/tests/queries_/models.py b/tests/queries_/models.py index 61b93890..acf2bef2 100644 --- a/tests/queries_/models.py +++ b/tests/queries_/models.py @@ -1,5 +1,7 @@ from django.db import models +from django_mongodb.fields import ObjectIdAutoField, ObjectIdField + class Author(models.Model): name = models.CharField(max_length=10) @@ -14,3 +16,40 @@ class Book(models.Model): def __str__(self): return self.title + + +class Tag(models.Model): + name = models.CharField(max_length=10) + parent = models.ForeignKey( + "self", + models.SET_NULL, + blank=True, + null=True, + related_name="children", + ) + group_id = ObjectIdField(null=True) + + def __str__(self): + return self.name + + +class Order(models.Model): + id = ObjectIdAutoField(primary_key=True) + name = models.CharField(max_length=12, null=True, default="") + + class Meta: + ordering = ("pk",) + + def __str__(self): + return str(self.pk) + + +class OrderItem(models.Model): + order = models.ForeignKey(Order, models.CASCADE, related_name="items") + status = ObjectIdField(null=True) + + class Meta: + ordering = ("pk",) + + def __str__(self): + return str(self.pk) diff --git a/tests/queries_/test_objectid.py b/tests/queries_/test_objectid.py new file mode 100644 index 00000000..490d1b33 --- /dev/null +++ b/tests/queries_/test_objectid.py @@ -0,0 +1,111 @@ +from bson import ObjectId +from django.core.exceptions import ValidationError +from django.test import TestCase + +from .models import Order, OrderItem, Tag + + +class ObjectIdTests(TestCase): + @classmethod + def setUpTestData(cls): + cls.group_id_str_1 = "1" * 24 + cls.group_id_obj_1 = ObjectId(cls.group_id_str_1) + cls.group_id_str_2 = "2" * 24 + cls.group_id_obj_2 = ObjectId(cls.group_id_str_2) + + cls.t1 = Tag.objects.create(name="t1") + cls.t2 = Tag.objects.create(name="t2", parent=cls.t1) + cls.t3 = Tag.objects.create(name="t3", parent=cls.t1, group_id=cls.group_id_str_1) + cls.t4 = Tag.objects.create(name="t4", parent=cls.t3, group_id=cls.group_id_obj_2) + cls.t5 = Tag.objects.create(name="t5", parent=cls.t3) + + def test_filter_group_id_is_null_false(self): + """Filter objects where group_id is not null.""" + qs = Tag.objects.filter(group_id__isnull=False).order_by("name") + self.assertSequenceEqual(qs, [self.t3, self.t4]) + + def test_filter_group_id_is_null_true(self): + """Filter objects where group_id is null.""" + qs = Tag.objects.filter(group_id__isnull=True).order_by("name") + self.assertSequenceEqual(qs, [self.t1, self.t2, self.t5]) + + def test_filter_group_id_equal_str(self): + """Filter by group_id with a specific string value.""" + qs = Tag.objects.filter(group_id=self.group_id_str_1).order_by("name") + self.assertSequenceEqual(qs, [self.t3]) + + def test_filter_group_id_equal_obj(self): + """Filter by group_id with a specific ObjectId value.""" + qs = Tag.objects.filter(group_id=self.group_id_obj_1).order_by("name") + self.assertSequenceEqual(qs, [self.t3]) + + def test_filter_group_id_in_str_values(self): + """Filter by group_id with string values in a list.""" + ids = [self.group_id_str_1, self.group_id_str_2] + qs = Tag.objects.filter(group_id__in=ids).order_by("name") + self.assertSequenceEqual(qs, [self.t3, self.t4]) + + def test_filter_group_id_in_obj_values(self): + """Filter by group_id with ObjectId values in a list.""" + ids = [self.group_id_obj_1, self.group_id_obj_2] + qs = Tag.objects.filter(group_id__in=ids).order_by("name") + self.assertSequenceEqual(qs, [self.t3, self.t4]) + + def test_filter_group_id_equal_subquery(self): + """Filter by group_id using a subquery.""" + subquery = Tag.objects.filter(name="t3").values("group_id") + qs = Tag.objects.filter(group_id__in=subquery).order_by("name") + self.assertSequenceEqual(qs, [self.t3]) + + def test_filter_group_id_in_subquery(self): + """Filter by group_id using a subquery with multiple values.""" + subquery = Tag.objects.filter(name__in=["t3", "t4"]).values("group_id") + qs = Tag.objects.filter(group_id__in=subquery).order_by("name") + self.assertSequenceEqual(qs, [self.t3, self.t4]) + + def test_filter_parent_by_children_values_str(self): + """Query to select parents of children with specific string group_id.""" + child_ids = Tag.objects.filter(group_id=self.group_id_str_1).values_list("id", flat=True) + parent_qs = Tag.objects.filter(children__id__in=child_ids).distinct().order_by("name") + self.assertSequenceEqual(parent_qs, [self.t1]) + + def test_filter_parent_by_children_values_obj(self): + """Query to select parents of children with specific ObjectId group_id.""" + child_ids = Tag.objects.filter(group_id=self.group_id_obj_1).values_list("id", flat=True) + parent_qs = Tag.objects.filter(children__id__in=child_ids).distinct().order_by("name") + self.assertSequenceEqual(parent_qs, [self.t1]) + + def test_filter_group_id_union_with_str(self): + """Combine queries using union with string values.""" + qs_a = Tag.objects.filter(group_id=self.group_id_str_1) + qs_b = Tag.objects.filter(group_id=self.group_id_str_2) + union_qs = qs_a.union(qs_b).order_by("name") + self.assertSequenceEqual(union_qs, [self.t3, self.t4]) + + def test_filter_group_id_union_with_obj(self): + """Combine queries using union with ObjectId values.""" + qs_a = Tag.objects.filter(group_id=self.group_id_obj_1) + qs_b = Tag.objects.filter(group_id=self.group_id_obj_2) + union_qs = qs_a.union(qs_b).order_by("name") + self.assertSequenceEqual(union_qs, [self.t3, self.t4]) + + def test_filter_invalid_object_id(self): + msg = "“value1” is not a valid Object Id.'" + with self.assertRaisesMessage(ValidationError, msg): + Tag.objects.filter(group_id="value1") + + def test_values_in_subquery(self): + # If a values() queryset is used, then the given values will be used + # instead of forcing use of the relation's field. + o1 = Order.objects.create() + o2 = Order.objects.create() + oi1 = OrderItem.objects.create(order=o1, status=None) + oi1.status = oi1.pk + oi1.save() + OrderItem.objects.create(order=o2, status=None) + # The query below should match o1 as it has related order_item with + # id == status. + self.assertSequenceEqual( + Order.objects.filter(items__in=OrderItem.objects.values_list("status")), + [o1], + ) From 9ec493f1d5ee726d13914937662ae16dbfc5d680 Mon Sep 17 00:00:00 2001 From: Tim Graham Date: Mon, 9 Dec 2024 16:25:22 -0500 Subject: [PATCH 2/2] reclassify expected failures as skips Updating the affected models to use ObjectIdField breaks other tests. --- django_mongodb/features.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/django_mongodb/features.py b/django_mongodb/features.py index 1dbdbf79..ea7f69bd 100644 --- a/django_mongodb/features.py +++ b/django_mongodb/features.py @@ -77,10 +77,6 @@ class DatabaseFeatures(BaseDatabaseFeatures): # Connection creation doesn't follow the usual Django API. "backends.tests.ThreadTests.test_pass_connection_between_threads", "backends.tests.ThreadTests.test_default_connection_thread_local", - # ObjectId type mismatch in a subquery: - # https://github.com/mongodb-labs/django-mongodb/issues/161 - "queries.tests.RelatedLookupTypeTests.test_values_queryset_lookup", - "queries.tests.ValuesSubqueryTests.test_values_in_subquery", # Object of type ObjectId is not JSON serializable. "auth_tests.test_views.LoginTest.test_login_session_without_hash_session_key", # GenericRelation.value_to_string() assumes integer pk. @@ -225,6 +221,8 @@ def django_test_expected_failures(self): "expressions.tests.BasicExpressionsTests.test_nested_subquery_outer_ref_with_autofield", "model_fields.test_foreignkey.ForeignKeyTests.test_to_python", "queries.test_qs_combinators.QuerySetSetOperationTests.test_order_raises_on_non_selected_column", + "queries.tests.RelatedLookupTypeTests.test_values_queryset_lookup", + "queries.tests.ValuesSubqueryTests.test_values_in_subquery", }, "Cannot use QuerySet.delete() when querying across multiple collections on MongoDB.": { "admin_changelist.tests.ChangeListTests.test_distinct_for_many_to_many_at_second_level_in_search_fields",