From 785bcfec3737c522ff378e43e223896028a20665 Mon Sep 17 00:00:00 2001 From: Emanuel Lupi Date: Sat, 23 Nov 2024 19:02:59 -0300 Subject: [PATCH] Allows int as ObjectIdField --- django_mongodb/fields/auto.py | 33 --------------- django_mongodb/fields/objectid.py | 50 ++++++++++++++--------- tests/model_fields_/test_objectidfield.py | 4 +- 3 files changed, 32 insertions(+), 55 deletions(-) diff --git a/django_mongodb/fields/auto.py b/django_mongodb/fields/auto.py index 8dd535ba..46c3237d 100644 --- a/django_mongodb/fields/auto.py +++ b/django_mongodb/fields/auto.py @@ -1,5 +1,3 @@ -from bson import ObjectId, errors -from django.core import exceptions from django.db.models.fields import AutoField from django.utils.functional import cached_property @@ -19,40 +17,9 @@ def deconstruct(self): path = path.replace("django_mongodb.fields.auto", "django_mongodb.fields") return name, path, args, kwargs - def get_prep_value(self, value): - if value is None: - return None - # Accept int for compatibility with Django's test suite which has many - # instances of manually assigned integer IDs, as well as for things - # like settings.SITE_ID which has a system check requiring an integer. - if isinstance(value, (ObjectId | int)): - return value - try: - return ObjectId(value) - except errors.InvalidId as e: - # A manually assigned integer ID? - if isinstance(value, str) and value.isdigit(): - return int(value) - raise ValueError(f"Field '{self.name}' expected an ObjectId but got {value!r}.") from e - def get_internal_type(self): return "ObjectIdAutoField" - def to_python(self, value): - if value is None or isinstance(value, int): - return value - try: - return ObjectId(value) - except errors.InvalidId: - try: - return int(value) - except ValueError: - raise exceptions.ValidationError( - self.error_messages["invalid"], - code="invalid", - params={"value": value}, - ) from None - @cached_property def validators(self): # Avoid IntegerField validators inherited from AutoField. diff --git a/django_mongodb/fields/objectid.py b/django_mongodb/fields/objectid.py index 9d8d5a16..3e0a59f1 100644 --- a/django_mongodb/fields/objectid.py +++ b/django_mongodb/fields/objectid.py @@ -1,5 +1,4 @@ from bson import ObjectId, errors -from bson.errors import InvalidId from django.core import exceptions from django.db.models.fields import Field from django.utils.translation import gettext_lazy as _ @@ -14,33 +13,44 @@ class ObjectIdMixin: def db_type(self, connection): return "objectId" + def get_prep_value(self, value): + if value is None: + return None + # Accept int for compatibility with Django's test suite which has many + # instances of manually assigned integer IDs, as well as for things + # like settings.SITE_ID which has a system check requiring an integer. + if isinstance(value, (ObjectId | int)): + return value + try: + return ObjectId(value) + except (errors.InvalidId, TypeError) as e: + # A manually assigned integer ID? + if isinstance(value, str) and value.isdigit(): + return int(value) + raise ValueError(f"Field '{self.name}' expected an ObjectId but got {value!r}.") from e + def rel_db_type(self, connection): return "objectId" - -class ObjectIdField(ObjectIdMixin, Field): - def get_internal_type(self): - return "ObjectIdField" - def to_python(self, value): - if value is None: + if value is None or isinstance(value, int): return value try: return ObjectId(value) - except (TypeError, InvalidId): - raise exceptions.ValidationError( - self.error_messages["invalid"], - code="invalid", - params={"value": value}, - ) from None + except (errors.InvalidId, TypeError): + try: + return int(value) + except (ValueError, TypeError): + raise exceptions.ValidationError( + self.error_messages["invalid"], + code="invalid", + params={"value": value}, + ) from None - def get_prep_value(self, value): - if value is None: - return None - try: - return ObjectId(value) - except (errors.InvalidId, TypeError) as e: - raise ValueError(f"Field '{self.name}' expected an ObjectId but got {value!r}.") from e + +class ObjectIdField(ObjectIdMixin, Field): + def get_internal_type(self): + return "ObjectIdField" def deconstruct(self): name, path, args, kwargs = super().deconstruct() diff --git a/tests/model_fields_/test_objectidfield.py b/tests/model_fields_/test_objectidfield.py index e11406b5..5e7e83f4 100644 --- a/tests/model_fields_/test_objectidfield.py +++ b/tests/model_fields_/test_objectidfield.py @@ -27,14 +27,14 @@ def test_to_python(self): self.assertIsNone(f.to_python(None)) def test_to_python_invalid_value(self): - for invalid_value in [3, "None", {}, []]: + for invalid_value in ["None", {}, []]: with self.subTest(invalid_value=invalid_value): msg = f"['“{invalid_value}” value must be an Object Id.']" with self.assertRaisesMessage(exceptions.ValidationError, msg): ObjectIdField().to_python(invalid_value) def test_get_prep_value_invalud_values(self): - for invalid_value in [3, "None", {}, []]: + for invalid_value in ["None", {}, []]: with self.subTest(invalid_value=invalid_value): msg = f"Field '{None}' expected an ObjectId but got {invalid_value!r}." with self.assertRaisesMessage(ValueError, msg):