diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 59be5aa8e..52dc74ff4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,6 +14,7 @@ Changelog Added ^^^^^ - Implement savepoints for transactions (#1816) +- Added type validation for foreign key fields to ensure type safety. Now raises `ValidationError` when assigning foreign key values with incorrect model types (#1792) Fixed ^^^^^ @@ -1498,4 +1499,4 @@ Docs/examples: await Tournament.filter( events__name__in=['1', '3'] - ).order_by('-events__participants__name').distinct() + ).order_by('-events__participants__name').distinct() \ No newline at end of file diff --git a/tests/fields/test_fk.py b/tests/fields/test_fk.py index 299e459d9..ff69cbbd2 100644 --- a/tests/fields/test_fk.py +++ b/tests/fields/test_fk.py @@ -1,10 +1,20 @@ from tests import testmodels from tortoise.contrib import test -from tortoise.exceptions import IntegrityError, NoValuesFetched, OperationalError +from tortoise.exceptions import ( + IntegrityError, + NoValuesFetched, + OperationalError, + ValidationError, +) from tortoise.queryset import QuerySet class TestForeignKeyField(test.TestCase): + def assertRaisesWrongTypeException(self, relation_name: str): + return self.assertRaisesRegex( + ValidationError, f"Invalid type for relationship field '{relation_name}'" + ) + async def test_empty(self): with self.assertRaises(IntegrityError): await testmodels.MinRelation.create() @@ -151,6 +161,11 @@ async def test_minimal__instantiated_create(self): tour = await testmodels.Tournament.create(name="Team1") await testmodels.MinRelation.create(tournament=tour) + async def test_minimal__instantiated_create_wrong_type(self): + author = await testmodels.Author.create(name="Author1") + with self.assertRaisesWrongTypeException("tournament"): + await testmodels.MinRelation.create(tournament=author) + async def test_minimal__instantiated_iterate(self): tour = await testmodels.Tournament.create(name="Team1") async for _ in tour.minrelations: @@ -229,3 +244,57 @@ async def test_event__offset(self): event2 = await testmodels.Event.create(name="Event2", tournament=tour) event3 = await testmodels.Event.create(name="Event3", tournament=tour) self.assertEqual(await tour.events.offset(1).order_by("name"), [event2, event3]) + + async def test_fk_correct_type_assignment(self): + tour1 = await testmodels.Tournament.create(name="Team1") + tour2 = await testmodels.Tournament.create(name="Team2") + event = await testmodels.Event(name="Event1", tournament=tour1) + + event.tournament = tour2 + await event.save() + self.assertEqual(event.tournament_id, tour2.id) + + async def test_fk_wrong_type_assignment(self): + tour = await testmodels.Tournament.create(name="Team1") + author = await testmodels.Author.create(name="Author") + rel = await testmodels.MinRelation.create(tournament=tour) + + with self.assertRaisesWrongTypeException("tournament"): + rel.tournament = author + + async def test_fk_none_assignment(self): + manager = await testmodels.Employee.create(name="Manager") + employee = await testmodels.Employee.create(name="Employee", manager=manager) + + employee.manager = None + await employee.save() + self.assertIsNone(employee.manager) + + async def test_fk_update_wrong_type(self): + tour = await testmodels.Tournament.create(name="Team1") + rel = await testmodels.MinRelation.create(tournament=tour) + author = await testmodels.Author.create(name="Author1") + + with self.assertRaisesWrongTypeException("tournament"): + await testmodels.MinRelation.filter(id=rel.id).update(tournament=author) + + async def test_fk_bulk_create_wrong_type(self): + author = await testmodels.Author.create(name="Author") + with self.assertRaisesWrongTypeException("tournament"): + await testmodels.MinRelation.bulk_create( + [testmodels.MinRelation(tournament=author) for _ in range(10)] + ) + + async def test_fk_bulk_update_wrong_type(self): + tour = await testmodels.Tournament.create(name="Team1") + await testmodels.MinRelation.bulk_create( + [testmodels.MinRelation(tournament=tour) for _ in range(1, 10)] + ) + author = await testmodels.Author.create(name="Author") + + with self.assertRaisesWrongTypeException("tournament"): + relations = await testmodels.MinRelation.all() + await testmodels.MinRelation.bulk_update( + [testmodels.MinRelation(id=rel.id, tournament=author) for rel in relations], + fields=["tournament"], + ) diff --git a/tortoise/models.py b/tortoise/models.py index 7581ff060..779ef1033 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -30,11 +30,13 @@ from tortoise.exceptions import ( ConfigurationError, DoesNotExist, + FieldError, IncompleteInstanceError, IntegrityError, ObjectDoesNotExistError, OperationalError, ParamsError, + ValidationError, ) from tortoise.expressions import Expression from tortoise.fields.base import Field @@ -685,6 +687,8 @@ def __setattr__(self, key, value) -> None: # set field value override async default function if hasattr(self, "_await_when_save"): self._await_when_save.pop(key, None) + if key in self._meta.fk_fields or key in self._meta.o2o_fields: + self._validate_relation_type(key, value) super().__setattr__(key, value) def _set_kwargs(self, kwargs: dict) -> Set[str]: @@ -806,6 +810,27 @@ def _set_pk_val(self, value: Any) -> None: Can be used as a field name when doing filtering e.g. ``.filter(pk=...)`` etc... """ + @classmethod + def _validate_relation_type(cls, field_key: str, value: Optional["Model"]) -> None: + if value is None: + return + + field = cls._meta.fields_map[field_key] + if not isinstance(field, (OneToOneFieldInstance, ForeignKeyFieldInstance)): + raise FieldError( + f"Field '{field_key}' must be a OneToOne or ForeignKey relation, " + f"got {type(field).__name__}" + ) + + expected_model = field.related_model + received_model = type(value) + if received_model is not expected_model: + raise ValidationError( + f"Invalid type for relationship field '{field_key}'. " + f"Expected model type '{expected_model.__name__}', but got '{received_model.__name__}'. " + "Make sure you're using the correct model class for this relationship." + ) + @classmethod async def _getbypk(cls: Type[MODEL], key: Any) -> MODEL: try: diff --git a/tortoise/queryset.py b/tortoise/queryset.py index 0911421e7..636221797 100644 --- a/tortoise/queryset.py +++ b/tortoise/queryset.py @@ -1204,6 +1204,7 @@ def _make_query(self) -> None: if field_object.pk: raise IntegrityError(f"Field {key} is PK and can not be updated") if isinstance(field_object, (ForeignKeyFieldInstance, OneToOneFieldInstance)): + self.model._validate_relation_type(key, value) fk_field: str = field_object.source_field # type: ignore db_field = self.model._meta.fields_map[fk_field].source_field value = executor.column_map[fk_field](