From f5daed6a0562306a9281b28d9e274c42c2ab2016 Mon Sep 17 00:00:00 2001 From: Charlie Denton Date: Fri, 9 Feb 2024 21:57:23 +0000 Subject: [PATCH] Add helper to refine IntegrityErrors When we normally catch an IntegrityError from Django, it can be hard to determine exactly what the cause of the error was. Sometimes we assume that the error was of one type, where the error actually had another cause. This utility will make it easy for us to tell the difference between different types of IntegrityError, allowing us to be more specific about which errors we handle. --- src/django_integrity/conversion.py | 145 ++++++++++++++ tests/django_integrity/test_conversion.py | 230 ++++++++++++++++++++++ tests/example_app/models.py | 42 +++- 3 files changed, 416 insertions(+), 1 deletion(-) create mode 100644 src/django_integrity/conversion.py create mode 100644 tests/django_integrity/test_conversion.py diff --git a/src/django_integrity/conversion.py b/src/django_integrity/conversion.py new file mode 100644 index 0000000..89193e7 --- /dev/null +++ b/src/django_integrity/conversion.py @@ -0,0 +1,145 @@ +from __future__ import annotations + +import abc +import contextlib +import dataclasses +import re +from collections.abc import Iterator, Mapping + +import psycopg2 +from django import db as django_db + + +@contextlib.contextmanager +def refine_integrity_error(rules: Mapping[_Rule, Exception]) -> Iterator[None]: + """ + Convert a generic IntegrityError into a more specific exception. + + The conversion is based on a mapping of rules to exceptions. + """ + try: + yield + except django_db.IntegrityError as e: + for rule, refined_error in rules.items(): + if rule.is_match(e): + raise refined_error from e + raise + + +class _Rule(abc.ABC): + @abc.abstractmethod + def is_match(self, error: django_db.IntegrityError) -> bool: + ... + + +@dataclasses.dataclass(frozen=True) +class Named(_Rule): + """ + A constraint identified by its name. + """ + + name: str + + def is_match(self, error: django_db.IntegrityError) -> bool: + if not isinstance(error.__cause__, psycopg2.errors.IntegrityError): + return False + + return error.__cause__.diag.constraint_name == self.name + + +@dataclasses.dataclass(frozen=True) +class Unique(_Rule): + """ + A unique constraint defined by a model and a set of fields. + """ + + model: django_db.models.Model + fields: tuple[str] + + _pattern = re.compile(r"Key \((?P.+)\)=\(.*\) already exists.") + + def is_match(self, error: django_db.IntegrityError) -> bool: + if not isinstance(error.__cause__, psycopg2.errors.UniqueViolation): + return False + + match = self._pattern.match(error.__cause__.diag.message_detail) + if match is None: + return False + + return ( + tuple(match.group("fields").split(", ")) == self.fields + and error.__cause__.diag.table_name == self.model._meta.db_table + ) + + +@dataclasses.dataclass(frozen=True) +class PrimaryKey(_Rule): + """ + A unique constraint on the primary key of a model. + """ + + model: django_db.models.Model + + _pattern = re.compile(r"Key \((?P.+)\)=\(.*\) already exists.") + + def is_match(self, error: django_db.IntegrityError) -> bool: + if not isinstance(error.__cause__, psycopg2.errors.UniqueViolation): + return False + + match = self._pattern.match(error.__cause__.diag.message_detail) + if match is None: + return False + + # We assume that the model has a primary key, + # given that we're looking for a primary key constraint. + assert self.model._meta.pk is not None + return ( + tuple(match.group("fields").split(", ")) == (self.model._meta.pk.name,) + and error.__cause__.diag.table_name == self.model._meta.db_table + ) + + +@dataclasses.dataclass(frozen=True) +class NotNull(_Rule): + """ + A not-null constraint on a Model's field. + """ + + model: django_db.models.Model + field: str + + def is_match(self, error: django_db.IntegrityError) -> bool: + if not isinstance(error.__cause__, psycopg2.errors.NotNullViolation): + return False + + return ( + error.__cause__.diag.column_name == self.field + and error.__cause__.diag.table_name == self.model._meta.db_table + ) + + +@dataclasses.dataclass(frozen=True) +class ForeignKey(_Rule): + """ + A foreign key constraint on a Model's field. + """ + + model: django_db.models.Model + field: str + + _detail_pattern = re.compile( + r"Key \((?P.+)\)=\((?P.+)\) is not present in table" + ) + + def is_match(self, error: django_db.IntegrityError) -> bool: + if not isinstance(error.__cause__, psycopg2.errors.ForeignKeyViolation): + return False + + detail_match = self._detail_pattern.match(error.__cause__.diag.message_detail) + if detail_match is None: + return False + + return ( + detail_match.group("field") == self.field + and error.__cause__.diag.table_name == self.model._meta.db_table + ) diff --git a/tests/django_integrity/test_conversion.py b/tests/django_integrity/test_conversion.py new file mode 100644 index 0000000..1fcd8b3 --- /dev/null +++ b/tests/django_integrity/test_conversion.py @@ -0,0 +1,230 @@ +import pytest +from django import db as django_db + +from django_integrity import constraints, conversion +from tests.example_app import models as test_models + + +class SimpleError(Exception): + pass + + +class TestRefineIntegrityError: + def test_no_rules(self): + # It is legal to call the context manager without any rules. + with conversion.refine_integrity_error(rules={}): + pass + + +@pytest.mark.django_db +class TestNamedConstraint: + def test_error_refined(self): + # Create a unique instance so that we can violate the constraint later. + test_models.UniqueModel.objects.create(unique_field=42) + + rules = {conversion.Named(name="unique_model_unique_field_key"): SimpleError} + + # The original error should be transformed into our expected error. + with pytest.raises(SimpleError): + with conversion.refine_integrity_error(rules): + test_models.UniqueModel.objects.create(unique_field=42) + + def test_rules_mismatch(self): + # Create a unique instance so that we can violate the constraint later. + test_models.UniqueModel.objects.create(unique_field=42) + + # No constraints match the error: + rules = {conversion.Named(name="nonexistent_constraint"): SimpleError} + + # The original error should be raised. + with pytest.raises(django_db.IntegrityError): + with conversion.refine_integrity_error(rules): + test_models.UniqueModel.objects.create(unique_field=42) + + +@pytest.mark.django_db +class TestUnique: + def test_error_refined(self): + # Create a unique instance so that we can violate the constraint later. + test_models.UniqueModel.objects.create(unique_field=42) + + rules = { + conversion.Unique( + model=test_models.UniqueModel, fields=("unique_field",) + ): SimpleError + } + + # The original error should be transformed into our expected error. + with pytest.raises(SimpleError): + with conversion.refine_integrity_error(rules): + test_models.UniqueModel.objects.create(unique_field=42) + + def test_multiple_fields(self): + # Create a unique instance so that we can violate the constraint later. + test_models.UniqueTogetherModel.objects.create(field_1=1, field_2=2) + + rules = { + conversion.Unique( + model=test_models.UniqueTogetherModel, fields=("field_1", "field_2") + ): SimpleError + } + + # The original error should be transformed into our expected error. + with pytest.raises(SimpleError): + with conversion.refine_integrity_error(rules): + test_models.UniqueTogetherModel.objects.create(field_1=1, field_2=2) + + @pytest.mark.parametrize( + "Model, field", + ( + # Wrong model, despite matching field name. + ( + test_models.AlternativeUniqueModel, + "unique_field", + ), + # Wrong field, despite matching model. + ( + test_models.UniqueModel, + "id", + ), + ), + ids=("wrong_model", "wrong_field"), + ) + def test_rules_mismatch(self, Model: conversion.Unique, field: str): + # A rule that matches a similar looking, but different, unique constraint. + # Create a unique instance so that we can violate the constraint later. + test_models.UniqueModel.objects.create(unique_field=42) + + rules = {conversion.Unique(model=Model, fields=(field,)): SimpleError} + + # We shouldn't transform the error, because it didn't match the rule. + with pytest.raises(django_db.IntegrityError): + with conversion.refine_integrity_error(rules): + test_models.UniqueModel.objects.create(unique_field=42) + + +@pytest.mark.django_db +class TestPrimaryKey: + @pytest.mark.parametrize( + "ModelClass", + ( + test_models.PrimaryKeyModel, + test_models.AlternativePrimaryKeyModel, + ), + ) + def test_error_refined(self, ModelClass): + """ + The primary key of a model is extracted from the model. + + This test internally refers to the models primary key using "pk". + "pk" is Django magic that refers to the primary key of the model. + On PrimaryKeyModel, the primary key is "id". + On AlternativePrimaryKeyModel, the primary key is "identity". + """ + # Create a unique instance so that we can violate the constraint later. + existing_primary_key = ModelClass.objects.create().pk + + rules = {conversion.PrimaryKey(model=ModelClass): SimpleError} + + # The original error should be transformed into our expected error. + with pytest.raises(SimpleError): + with conversion.refine_integrity_error(rules): + ModelClass.objects.create(pk=existing_primary_key) + + def test_rules_mismatch(self): + # Create a unique instance so that we can violate the constraint later. + existing_primary_key = test_models.PrimaryKeyModel.objects.create().pk + + # A similar rule, but for a different model with the same field name.. + rules = {conversion.PrimaryKey(model=test_models.UniqueModel): SimpleError} + + # The original error should be raised. + with pytest.raises(django_db.IntegrityError): + with conversion.refine_integrity_error(rules): + test_models.PrimaryKeyModel.objects.create(pk=existing_primary_key) + + +@pytest.mark.django_db +class TestNotNull: + def test_error_refined(self): + rules = { + conversion.NotNull( + model=test_models.UniqueModel, field="unique_field" + ): SimpleError + } + + # The original error should be transformed into our expected error. + with pytest.raises(SimpleError): + with conversion.refine_integrity_error(rules): + test_models.UniqueModel.objects.create(unique_field=None) + + def test_model_mismatch(self): + # Same field, but different model. + rules = { + conversion.NotNull( + model=test_models.AlternativeUniqueModel, field="unique_field" + ): SimpleError + } + + with pytest.raises(django_db.IntegrityError): + with conversion.refine_integrity_error(rules): + test_models.UniqueModel.objects.create(unique_field=None) + + def test_field_mismatch(self): + # Same model, but different field. + rules = { + conversion.NotNull( + model=test_models.AlternativeUniqueModel, field="unique_field_2" + ): SimpleError + } + + # The original error should be raised. + with pytest.raises(django_db.IntegrityError): + with conversion.refine_integrity_error(rules): + test_models.AlternativeUniqueModel.objects.create( + unique_field=None, + unique_field_2=42, + ) + + +@pytest.mark.django_db +class TestForeignKey: + def test_error_refined(self): + rules = { + conversion.ForeignKey( + model=test_models.ForeignKeyModel, field="related_id" + ): SimpleError + } + constraints.set_all_immediate(using="default") + + # The original error should be transformed into our expected error. + with pytest.raises(SimpleError): + with conversion.refine_integrity_error(rules): + # Create a ForeignKeyModel with a related_id that doesn't exist. + test_models.ForeignKeyModel.objects.create(related_id=42) + + def test_source_mismatch(self): + # The field name matches, but the source model is different. + rules = { + conversion.ForeignKey( + model=test_models.ForeignKeyModel2, field="related_id" + ): SimpleError + } + constraints.set_all_immediate(using="default") + + with pytest.raises(django_db.IntegrityError): + with conversion.refine_integrity_error(rules): + test_models.ForeignKeyModel.objects.create(related_id=42) + + def test_field_mismatch(self): + # The source model matches, but the field name is different. + rules = { + conversion.ForeignKey( + model=test_models.ForeignKeyModel3, field="related_2_id" + ): SimpleError + } + constraints.set_all_immediate(using="default") + + with pytest.raises(django_db.IntegrityError): + with conversion.refine_integrity_error(rules): + test_models.ForeignKeyModel3.objects.create(related_1_id=42) diff --git a/tests/example_app/models.py b/tests/example_app/models.py index a66fa1a..c4d379b 100644 --- a/tests/example_app/models.py +++ b/tests/example_app/models.py @@ -2,13 +2,35 @@ class PrimaryKeyModel(models.Model): - pass + # Serialize must not be on because our tests try to create instances with clashing IDs. + id = models.BigAutoField(primary_key=True, serialize=False) + + +class AlternativePrimaryKeyModel(models.Model): + # Serialize must not be on because our tests try to create instances with clashing IDs. + identity = models.BigAutoField(primary_key=True, serialize=False) class ForeignKeyModel(models.Model): related = models.ForeignKey(PrimaryKeyModel, on_delete=models.CASCADE) +class ForeignKeyModel2(models.Model): + related = models.ForeignKey(PrimaryKeyModel, on_delete=models.CASCADE) + + +class ForeignKeyModel3(models.Model): + related_1 = models.ForeignKey( + AlternativePrimaryKeyModel, on_delete=models.CASCADE, related_name="+" + ) + related_2 = models.ForeignKey( + AlternativePrimaryKeyModel, + on_delete=models.CASCADE, + related_name="+", + null=True, + ) + + class UniqueModel(models.Model): unique_field = models.IntegerField() @@ -20,3 +42,21 @@ class Meta: deferrable=models.Deferrable.IMMEDIATE, ), ) + + +class AlternativeUniqueModel(models.Model): + unique_field = models.IntegerField(unique=True) + unique_field_2 = models.IntegerField(unique=True) + + +class UniqueTogetherModel(models.Model): + field_1 = models.IntegerField() + field_2 = models.IntegerField() + + class Meta: + constraints = ( + models.UniqueConstraint( + fields=["field_1", "field_2"], + name="unique_together_model_field_1_field_2_key", + ), + )