diff --git a/src/django_integrity/conversion.py b/src/django_integrity/conversion.py new file mode 100644 index 0000000..89193e7 --- /dev/null +++ b/src/django_integrity/conversion.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import abc +import contextlib +import dataclasses +import re +from collections.abc import Iterator, Mapping + +import psycopg2 +from django import db as django_db + + +@contextlib.contextmanager +def refine_integrity_error(rules: Mapping[_Rule, Exception]) -> Iterator[None]: + """ + Convert a generic IntegrityError into a more specific exception. + + The conversion is based on a mapping of rules to exceptions. + """ + try: + yield + except django_db.IntegrityError as e: + for rule, refined_error in rules.items(): + if rule.is_match(e): + raise refined_error from e + raise + + +class _Rule(abc.ABC): + @abc.abstractmethod + def is_match(self, error: django_db.IntegrityError) -> bool: + ... + + +@dataclasses.dataclass(frozen=True) +class Named(_Rule): + """ + A constraint identified by its name. + """ + + name: str + + def is_match(self, error: django_db.IntegrityError) -> bool: + if not isinstance(error.__cause__, psycopg2.errors.IntegrityError): + return False + + return error.__cause__.diag.constraint_name == self.name + + +@dataclasses.dataclass(frozen=True) +class Unique(_Rule): + """ + A unique constraint defined by a model and a set of fields. + """ + + model: django_db.models.Model + fields: tuple[str] + + _pattern = re.compile(r"Key \((?P.+)\)=\(.*\) already exists.") + + def is_match(self, error: django_db.IntegrityError) -> bool: + if not isinstance(error.__cause__, psycopg2.errors.UniqueViolation): + return False + + match = self._pattern.match(error.__cause__.diag.message_detail) + if match is None: + return False + + return ( + tuple(match.group("fields").split(", ")) == self.fields + and error.__cause__.diag.table_name == self.model._meta.db_table + ) + + +@dataclasses.dataclass(frozen=True) +class PrimaryKey(_Rule): + """ + A unique constraint on the primary key of a model. + """ + + model: django_db.models.Model + + _pattern = re.compile(r"Key \((?P.+)\)=\(.*\) already exists.") + + def is_match(self, error: django_db.IntegrityError) -> bool: + if not isinstance(error.__cause__, psycopg2.errors.UniqueViolation): + return False + + match = self._pattern.match(error.__cause__.diag.message_detail) + if match is None: + return False + + # We assume that the model has a primary key, + # given that we're looking for a primary key constraint. + assert self.model._meta.pk is not None + return ( + tuple(match.group("fields").split(", ")) == (self.model._meta.pk.name,) + and error.__cause__.diag.table_name == self.model._meta.db_table + ) + + +@dataclasses.dataclass(frozen=True) +class NotNull(_Rule): + """ + A not-null constraint on a Model's field. + """ + + model: django_db.models.Model + field: str + + def is_match(self, error: django_db.IntegrityError) -> bool: + if not isinstance(error.__cause__, psycopg2.errors.NotNullViolation): + return False + + return ( + error.__cause__.diag.column_name == self.field + and error.__cause__.diag.table_name == self.model._meta.db_table + ) + + +@dataclasses.dataclass(frozen=True) +class ForeignKey(_Rule): + """ + A foreign key constraint on a Model's field. + """ + + model: django_db.models.Model + field: str + + _detail_pattern = re.compile( + r"Key \((?P.+)\)=\((?P.+)\) is not present in table" + ) + + def is_match(self, error: django_db.IntegrityError) -> bool: + if not isinstance(error.__cause__, psycopg2.errors.ForeignKeyViolation): + return False + + detail_match = self._detail_pattern.match(error.__cause__.diag.message_detail) + if detail_match is None: + return False + + return ( + detail_match.group("field") == self.field + and error.__cause__.diag.table_name == self.model._meta.db_table + ) diff --git a/tests/django_integrity/test_conversion.py b/tests/django_integrity/test_conversion.py new file mode 100644 index 0000000..1fcd8b3 --- /dev/null +++ b/tests/django_integrity/test_conversion.py @@ -0,0 +1,230 @@ +import pytest +from django import db as django_db + +from django_integrity import constraints, conversion +from tests.example_app import models as test_models + + +class SimpleError(Exception): + pass + + +class TestRefineIntegrityError: + def test_no_rules(self): + # It is legal to call the context manager without any rules. + with conversion.refine_integrity_error(rules={}): + pass + + +@pytest.mark.django_db +class TestNamedConstraint: + def test_error_refined(self): + # Create a unique instance so that we can violate the constraint later. + test_models.UniqueModel.objects.create(unique_field=42) + + rules = {conversion.Named(name="unique_model_unique_field_key"): SimpleError} + + # The original error should be transformed into our expected error. + with pytest.raises(SimpleError): + with conversion.refine_integrity_error(rules): + test_models.UniqueModel.objects.create(unique_field=42) + + def test_rules_mismatch(self): + # Create a unique instance so that we can violate the constraint later. + test_models.UniqueModel.objects.create(unique_field=42) + + # No constraints match the error: + rules = {conversion.Named(name="nonexistent_constraint"): SimpleError} + + # The original error should be raised. + with pytest.raises(django_db.IntegrityError): + with conversion.refine_integrity_error(rules): + test_models.UniqueModel.objects.create(unique_field=42) + + +@pytest.mark.django_db +class TestUnique: + def test_error_refined(self): + # Create a unique instance so that we can violate the constraint later. + test_models.UniqueModel.objects.create(unique_field=42) + + rules = { + conversion.Unique( + model=test_models.UniqueModel, fields=("unique_field",) + ): SimpleError + } + + # The original error should be transformed into our expected error. + with pytest.raises(SimpleError): + with conversion.refine_integrity_error(rules): + test_models.UniqueModel.objects.create(unique_field=42) + + def test_multiple_fields(self): + # Create a unique instance so that we can violate the constraint later. + test_models.UniqueTogetherModel.objects.create(field_1=1, field_2=2) + + rules = { + conversion.Unique( + model=test_models.UniqueTogetherModel, fields=("field_1", "field_2") + ): SimpleError + } + + # The original error should be transformed into our expected error. + with pytest.raises(SimpleError): + with conversion.refine_integrity_error(rules): + test_models.UniqueTogetherModel.objects.create(field_1=1, field_2=2) + + @pytest.mark.parametrize( + "Model, field", + ( + # Wrong model, despite matching field name. + ( + test_models.AlternativeUniqueModel, + "unique_field", + ), + # Wrong field, despite matching model. + ( + test_models.UniqueModel, + "id", + ), + ), + ids=("wrong_model", "wrong_field"), + ) + def test_rules_mismatch(self, Model: conversion.Unique, field: str): + # A rule that matches a similar looking, but different, unique constraint. + # Create a unique instance so that we can violate the constraint later. + test_models.UniqueModel.objects.create(unique_field=42) + + rules = {conversion.Unique(model=Model, fields=(field,)): SimpleError} + + # We shouldn't transform the error, because it didn't match the rule. + with pytest.raises(django_db.IntegrityError): + with conversion.refine_integrity_error(rules): + test_models.UniqueModel.objects.create(unique_field=42) + + +@pytest.mark.django_db +class TestPrimaryKey: + @pytest.mark.parametrize( + "ModelClass", + ( + test_models.PrimaryKeyModel, + test_models.AlternativePrimaryKeyModel, + ), + ) + def test_error_refined(self, ModelClass): + """ + The primary key of a model is extracted from the model. + + This test internally refers to the models primary key using "pk". + "pk" is Django magic that refers to the primary key of the model. + On PrimaryKeyModel, the primary key is "id". + On AlternativePrimaryKeyModel, the primary key is "identity". + """ + # Create a unique instance so that we can violate the constraint later. + existing_primary_key = ModelClass.objects.create().pk + + rules = {conversion.PrimaryKey(model=ModelClass): SimpleError} + + # The original error should be transformed into our expected error. + with pytest.raises(SimpleError): + with conversion.refine_integrity_error(rules): + ModelClass.objects.create(pk=existing_primary_key) + + def test_rules_mismatch(self): + # Create a unique instance so that we can violate the constraint later. + existing_primary_key = test_models.PrimaryKeyModel.objects.create().pk + + # A similar rule, but for a different model with the same field name.. + rules = {conversion.PrimaryKey(model=test_models.UniqueModel): SimpleError} + + # The original error should be raised. + with pytest.raises(django_db.IntegrityError): + with conversion.refine_integrity_error(rules): + test_models.PrimaryKeyModel.objects.create(pk=existing_primary_key) + + +@pytest.mark.django_db +class TestNotNull: + def test_error_refined(self): + rules = { + conversion.NotNull( + model=test_models.UniqueModel, field="unique_field" + ): SimpleError + } + + # The original error should be transformed into our expected error. + with pytest.raises(SimpleError): + with conversion.refine_integrity_error(rules): + test_models.UniqueModel.objects.create(unique_field=None) + + def test_model_mismatch(self): + # Same field, but different model. + rules = { + conversion.NotNull( + model=test_models.AlternativeUniqueModel, field="unique_field" + ): SimpleError + } + + with pytest.raises(django_db.IntegrityError): + with conversion.refine_integrity_error(rules): + test_models.UniqueModel.objects.create(unique_field=None) + + def test_field_mismatch(self): + # Same model, but different field. + rules = { + conversion.NotNull( + model=test_models.AlternativeUniqueModel, field="unique_field_2" + ): SimpleError + } + + # The original error should be raised. + with pytest.raises(django_db.IntegrityError): + with conversion.refine_integrity_error(rules): + test_models.AlternativeUniqueModel.objects.create( + unique_field=None, + unique_field_2=42, + ) + + +@pytest.mark.django_db +class TestForeignKey: + def test_error_refined(self): + rules = { + conversion.ForeignKey( + model=test_models.ForeignKeyModel, field="related_id" + ): SimpleError + } + constraints.set_all_immediate(using="default") + + # The original error should be transformed into our expected error. + with pytest.raises(SimpleError): + with conversion.refine_integrity_error(rules): + # Create a ForeignKeyModel with a related_id that doesn't exist. + test_models.ForeignKeyModel.objects.create(related_id=42) + + def test_source_mismatch(self): + # The field name matches, but the source model is different. + rules = { + conversion.ForeignKey( + model=test_models.ForeignKeyModel2, field="related_id" + ): SimpleError + } + constraints.set_all_immediate(using="default") + + with pytest.raises(django_db.IntegrityError): + with conversion.refine_integrity_error(rules): + test_models.ForeignKeyModel.objects.create(related_id=42) + + def test_field_mismatch(self): + # The source model matches, but the field name is different. + rules = { + conversion.ForeignKey( + model=test_models.ForeignKeyModel3, field="related_2_id" + ): SimpleError + } + constraints.set_all_immediate(using="default") + + with pytest.raises(django_db.IntegrityError): + with conversion.refine_integrity_error(rules): + test_models.ForeignKeyModel3.objects.create(related_1_id=42) diff --git a/tests/example_app/models.py b/tests/example_app/models.py index a66fa1a..c4d379b 100644 --- a/tests/example_app/models.py +++ b/tests/example_app/models.py @@ -2,13 +2,35 @@ class PrimaryKeyModel(models.Model): - pass + # Serialize must not be on because our tests try to create instances with clashing IDs. + id = models.BigAutoField(primary_key=True, serialize=False) + + +class AlternativePrimaryKeyModel(models.Model): + # Serialize must not be on because our tests try to create instances with clashing IDs. + identity = models.BigAutoField(primary_key=True, serialize=False) class ForeignKeyModel(models.Model): related = models.ForeignKey(PrimaryKeyModel, on_delete=models.CASCADE) +class ForeignKeyModel2(models.Model): + related = models.ForeignKey(PrimaryKeyModel, on_delete=models.CASCADE) + + +class ForeignKeyModel3(models.Model): + related_1 = models.ForeignKey( + AlternativePrimaryKeyModel, on_delete=models.CASCADE, related_name="+" + ) + related_2 = models.ForeignKey( + AlternativePrimaryKeyModel, + on_delete=models.CASCADE, + related_name="+", + null=True, + ) + + class UniqueModel(models.Model): unique_field = models.IntegerField() @@ -20,3 +42,21 @@ class Meta: deferrable=models.Deferrable.IMMEDIATE, ), ) + + +class AlternativeUniqueModel(models.Model): + unique_field = models.IntegerField(unique=True) + unique_field_2 = models.IntegerField(unique=True) + + +class UniqueTogetherModel(models.Model): + field_1 = models.IntegerField() + field_2 = models.IntegerField() + + class Meta: + constraints = ( + models.UniqueConstraint( + fields=["field_1", "field_2"], + name="unique_together_model_field_1_field_2_key", + ), + )