-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add helper to refine IntegrityErrors
When we normally catch an IntegrityError from Django, it can be hard to determine exactly what the cause of the error was. Sometimes we assume that the error was of one type, where the error actually had another cause. This utility will make it easy for us to tell the difference between different types of IntegrityError, allowing us to be more specific about which errors we handle. Co-authored-by: Samuel Searles-Bryant <[email protected]>
- Loading branch information
Showing
3 changed files
with
416 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
from __future__ import annotations | ||
|
||
import abc | ||
import contextlib | ||
import dataclasses | ||
import re | ||
from collections.abc import Iterator, Mapping | ||
|
||
import psycopg2 | ||
from django import db as django_db | ||
|
||
|
||
@contextlib.contextmanager | ||
def refine_integrity_error(rules: Mapping[_Rule, Exception]) -> Iterator[None]: | ||
""" | ||
Convert a generic IntegrityError into a more specific exception. | ||
The conversion is based on a mapping of rules to exceptions. | ||
""" | ||
try: | ||
yield | ||
except django_db.IntegrityError as e: | ||
for rule, refined_error in rules.items(): | ||
if rule.is_match(e): | ||
raise refined_error from e | ||
raise | ||
|
||
|
||
class _Rule(abc.ABC): | ||
@abc.abstractmethod | ||
def is_match(self, error: django_db.IntegrityError) -> bool: | ||
... | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class Named(_Rule): | ||
""" | ||
A constraint identified by its name. | ||
""" | ||
|
||
name: str | ||
|
||
def is_match(self, error: django_db.IntegrityError) -> bool: | ||
if not isinstance(error.__cause__, psycopg2.errors.IntegrityError): | ||
return False | ||
|
||
return error.__cause__.diag.constraint_name == self.name | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class Unique(_Rule): | ||
""" | ||
A unique constraint defined by a model and a set of fields. | ||
""" | ||
|
||
model: django_db.models.Model | ||
fields: tuple[str] | ||
|
||
_pattern = re.compile(r"Key \((?P<fields>.+)\)=\(.*\) already exists.") | ||
|
||
def is_match(self, error: django_db.IntegrityError) -> bool: | ||
if not isinstance(error.__cause__, psycopg2.errors.UniqueViolation): | ||
return False | ||
|
||
match = self._pattern.match(error.__cause__.diag.message_detail) | ||
if match is None: | ||
return False | ||
|
||
return ( | ||
tuple(match.group("fields").split(", ")) == self.fields | ||
and error.__cause__.diag.table_name == self.model._meta.db_table | ||
) | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class PrimaryKey(_Rule): | ||
""" | ||
A unique constraint on the primary key of a model. | ||
""" | ||
|
||
model: django_db.models.Model | ||
|
||
_pattern = re.compile(r"Key \((?P<fields>.+)\)=\(.*\) already exists.") | ||
|
||
def is_match(self, error: django_db.IntegrityError) -> bool: | ||
if not isinstance(error.__cause__, psycopg2.errors.UniqueViolation): | ||
return False | ||
|
||
match = self._pattern.match(error.__cause__.diag.message_detail) | ||
if match is None: | ||
return False | ||
|
||
# We assume that the model has a primary key, | ||
# given that we're looking for a primary key constraint. | ||
assert self.model._meta.pk is not None | ||
return ( | ||
tuple(match.group("fields").split(", ")) == (self.model._meta.pk.name,) | ||
and error.__cause__.diag.table_name == self.model._meta.db_table | ||
) | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class NotNull(_Rule): | ||
""" | ||
A not-null constraint on a Model's field. | ||
""" | ||
|
||
model: django_db.models.Model | ||
field: str | ||
|
||
def is_match(self, error: django_db.IntegrityError) -> bool: | ||
if not isinstance(error.__cause__, psycopg2.errors.NotNullViolation): | ||
return False | ||
|
||
return ( | ||
error.__cause__.diag.column_name == self.field | ||
and error.__cause__.diag.table_name == self.model._meta.db_table | ||
) | ||
|
||
|
||
@dataclasses.dataclass(frozen=True) | ||
class ForeignKey(_Rule): | ||
""" | ||
A foreign key constraint on a Model's field. | ||
""" | ||
|
||
model: django_db.models.Model | ||
field: str | ||
|
||
_detail_pattern = re.compile( | ||
r"Key \((?P<field>.+)\)=\((?P<value>.+)\) is not present in table" | ||
) | ||
|
||
def is_match(self, error: django_db.IntegrityError) -> bool: | ||
if not isinstance(error.__cause__, psycopg2.errors.ForeignKeyViolation): | ||
return False | ||
|
||
detail_match = self._detail_pattern.match(error.__cause__.diag.message_detail) | ||
if detail_match is None: | ||
return False | ||
|
||
return ( | ||
detail_match.group("field") == self.field | ||
and error.__cause__.diag.table_name == self.model._meta.db_table | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,230 @@ | ||
import pytest | ||
from django import db as django_db | ||
|
||
from django_integrity import constraints, conversion | ||
from tests.example_app import models as test_models | ||
|
||
|
||
class SimpleError(Exception): | ||
pass | ||
|
||
|
||
class TestRefineIntegrityError: | ||
def test_no_rules(self): | ||
# It is legal to call the context manager without any rules. | ||
with conversion.refine_integrity_error(rules={}): | ||
pass | ||
|
||
|
||
@pytest.mark.django_db | ||
class TestNamedConstraint: | ||
def test_error_refined(self): | ||
# Create a unique instance so that we can violate the constraint later. | ||
test_models.UniqueModel.objects.create(unique_field=42) | ||
|
||
rules = {conversion.Named(name="unique_model_unique_field_key"): SimpleError} | ||
|
||
# The original error should be transformed into our expected error. | ||
with pytest.raises(SimpleError): | ||
with conversion.refine_integrity_error(rules): | ||
test_models.UniqueModel.objects.create(unique_field=42) | ||
|
||
def test_rules_mismatch(self): | ||
# Create a unique instance so that we can violate the constraint later. | ||
test_models.UniqueModel.objects.create(unique_field=42) | ||
|
||
# No constraints match the error: | ||
rules = {conversion.Named(name="nonexistent_constraint"): SimpleError} | ||
|
||
# The original error should be raised. | ||
with pytest.raises(django_db.IntegrityError): | ||
with conversion.refine_integrity_error(rules): | ||
test_models.UniqueModel.objects.create(unique_field=42) | ||
|
||
|
||
@pytest.mark.django_db | ||
class TestUnique: | ||
def test_error_refined(self): | ||
# Create a unique instance so that we can violate the constraint later. | ||
test_models.UniqueModel.objects.create(unique_field=42) | ||
|
||
rules = { | ||
conversion.Unique( | ||
model=test_models.UniqueModel, fields=("unique_field",) | ||
): SimpleError | ||
} | ||
|
||
# The original error should be transformed into our expected error. | ||
with pytest.raises(SimpleError): | ||
with conversion.refine_integrity_error(rules): | ||
test_models.UniqueModel.objects.create(unique_field=42) | ||
|
||
def test_multiple_fields(self): | ||
# Create a unique instance so that we can violate the constraint later. | ||
test_models.UniqueTogetherModel.objects.create(field_1=1, field_2=2) | ||
|
||
rules = { | ||
conversion.Unique( | ||
model=test_models.UniqueTogetherModel, fields=("field_1", "field_2") | ||
): SimpleError | ||
} | ||
|
||
# The original error should be transformed into our expected error. | ||
with pytest.raises(SimpleError): | ||
with conversion.refine_integrity_error(rules): | ||
test_models.UniqueTogetherModel.objects.create(field_1=1, field_2=2) | ||
|
||
@pytest.mark.parametrize( | ||
"Model, field", | ||
( | ||
# Wrong model, despite matching field name. | ||
( | ||
test_models.AlternativeUniqueModel, | ||
"unique_field", | ||
), | ||
# Wrong field, despite matching model. | ||
( | ||
test_models.UniqueModel, | ||
"id", | ||
), | ||
), | ||
ids=("wrong_model", "wrong_field"), | ||
) | ||
def test_rules_mismatch(self, Model: conversion.Unique, field: str): | ||
# A rule that matches a similar looking, but different, unique constraint. | ||
# Create a unique instance so that we can violate the constraint later. | ||
test_models.UniqueModel.objects.create(unique_field=42) | ||
|
||
rules = {conversion.Unique(model=Model, fields=(field,)): SimpleError} | ||
|
||
# We shouldn't transform the error, because it didn't match the rule. | ||
with pytest.raises(django_db.IntegrityError): | ||
with conversion.refine_integrity_error(rules): | ||
test_models.UniqueModel.objects.create(unique_field=42) | ||
|
||
|
||
@pytest.mark.django_db | ||
class TestPrimaryKey: | ||
@pytest.mark.parametrize( | ||
"ModelClass", | ||
( | ||
test_models.PrimaryKeyModel, | ||
test_models.AlternativePrimaryKeyModel, | ||
), | ||
) | ||
def test_error_refined(self, ModelClass): | ||
""" | ||
The primary key of a model is extracted from the model. | ||
This test internally refers to the models primary key using "pk". | ||
"pk" is Django magic that refers to the primary key of the model. | ||
On PrimaryKeyModel, the primary key is "id". | ||
On AlternativePrimaryKeyModel, the primary key is "identity". | ||
""" | ||
# Create a unique instance so that we can violate the constraint later. | ||
existing_primary_key = ModelClass.objects.create().pk | ||
|
||
rules = {conversion.PrimaryKey(model=ModelClass): SimpleError} | ||
|
||
# The original error should be transformed into our expected error. | ||
with pytest.raises(SimpleError): | ||
with conversion.refine_integrity_error(rules): | ||
ModelClass.objects.create(pk=existing_primary_key) | ||
|
||
def test_rules_mismatch(self): | ||
# Create a unique instance so that we can violate the constraint later. | ||
existing_primary_key = test_models.PrimaryKeyModel.objects.create().pk | ||
|
||
# A similar rule, but for a different model with the same field name.. | ||
rules = {conversion.PrimaryKey(model=test_models.UniqueModel): SimpleError} | ||
|
||
# The original error should be raised. | ||
with pytest.raises(django_db.IntegrityError): | ||
with conversion.refine_integrity_error(rules): | ||
test_models.PrimaryKeyModel.objects.create(pk=existing_primary_key) | ||
|
||
|
||
@pytest.mark.django_db | ||
class TestNotNull: | ||
def test_error_refined(self): | ||
rules = { | ||
conversion.NotNull( | ||
model=test_models.UniqueModel, field="unique_field" | ||
): SimpleError | ||
} | ||
|
||
# The original error should be transformed into our expected error. | ||
with pytest.raises(SimpleError): | ||
with conversion.refine_integrity_error(rules): | ||
test_models.UniqueModel.objects.create(unique_field=None) | ||
|
||
def test_model_mismatch(self): | ||
# Same field, but different model. | ||
rules = { | ||
conversion.NotNull( | ||
model=test_models.AlternativeUniqueModel, field="unique_field" | ||
): SimpleError | ||
} | ||
|
||
with pytest.raises(django_db.IntegrityError): | ||
with conversion.refine_integrity_error(rules): | ||
test_models.UniqueModel.objects.create(unique_field=None) | ||
|
||
def test_field_mismatch(self): | ||
# Same model, but different field. | ||
rules = { | ||
conversion.NotNull( | ||
model=test_models.AlternativeUniqueModel, field="unique_field_2" | ||
): SimpleError | ||
} | ||
|
||
# The original error should be raised. | ||
with pytest.raises(django_db.IntegrityError): | ||
with conversion.refine_integrity_error(rules): | ||
test_models.AlternativeUniqueModel.objects.create( | ||
unique_field=None, | ||
unique_field_2=42, | ||
) | ||
|
||
|
||
@pytest.mark.django_db | ||
class TestForeignKey: | ||
def test_error_refined(self): | ||
rules = { | ||
conversion.ForeignKey( | ||
model=test_models.ForeignKeyModel, field="related_id" | ||
): SimpleError | ||
} | ||
constraints.set_all_immediate(using="default") | ||
|
||
# The original error should be transformed into our expected error. | ||
with pytest.raises(SimpleError): | ||
with conversion.refine_integrity_error(rules): | ||
# Create a ForeignKeyModel with a related_id that doesn't exist. | ||
test_models.ForeignKeyModel.objects.create(related_id=42) | ||
|
||
def test_source_mismatch(self): | ||
# The field name matches, but the source model is different. | ||
rules = { | ||
conversion.ForeignKey( | ||
model=test_models.ForeignKeyModel2, field="related_id" | ||
): SimpleError | ||
} | ||
constraints.set_all_immediate(using="default") | ||
|
||
with pytest.raises(django_db.IntegrityError): | ||
with conversion.refine_integrity_error(rules): | ||
test_models.ForeignKeyModel.objects.create(related_id=42) | ||
|
||
def test_field_mismatch(self): | ||
# The source model matches, but the field name is different. | ||
rules = { | ||
conversion.ForeignKey( | ||
model=test_models.ForeignKeyModel3, field="related_2_id" | ||
): SimpleError | ||
} | ||
constraints.set_all_immediate(using="default") | ||
|
||
with pytest.raises(django_db.IntegrityError): | ||
with conversion.refine_integrity_error(rules): | ||
test_models.ForeignKeyModel3.objects.create(related_1_id=42) |
Oops, something went wrong.