Skip to content

Commit

Permalink
Add helper to refine IntegrityErrors
Browse files Browse the repository at this point in the history
When we normally catch an IntegrityError from Django, it can be hard to
determine exactly what the cause of the error was. Sometimes we assume
that the error was of one type, where the error actually had
another cause.

This utility will make it easy for us to tell the difference between
different types of IntegrityError, allowing us to be more specific about
which errors we handle.

Co-authored-by: Samuel Searles-Bryant <[email protected]>
  • Loading branch information
meshy and samueljsb committed Feb 13, 2024
1 parent 5048a43 commit 80677a9
Show file tree
Hide file tree
Showing 3 changed files with 416 additions and 1 deletion.
145 changes: 145 additions & 0 deletions src/django_integrity/conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
from __future__ import annotations

import abc
import contextlib
import dataclasses
import re
from collections.abc import Iterator, Mapping

import psycopg2
from django import db as django_db


@contextlib.contextmanager
def refine_integrity_error(rules: Mapping[_Rule, Exception]) -> Iterator[None]:
"""
Convert a generic IntegrityError into a more specific exception.
The conversion is based on a mapping of rules to exceptions.
"""
try:
yield
except django_db.IntegrityError as e:
for rule, refined_error in rules.items():
if rule.is_match(e):
raise refined_error from e
raise


class _Rule(abc.ABC):
@abc.abstractmethod
def is_match(self, error: django_db.IntegrityError) -> bool:
...


@dataclasses.dataclass(frozen=True)
class Named(_Rule):
"""
A constraint identified by its name.
"""

name: str

def is_match(self, error: django_db.IntegrityError) -> bool:
if not isinstance(error.__cause__, psycopg2.errors.IntegrityError):
return False

return error.__cause__.diag.constraint_name == self.name


@dataclasses.dataclass(frozen=True)
class Unique(_Rule):
"""
A unique constraint defined by a model and a set of fields.
"""

model: django_db.models.Model
fields: tuple[str]

_pattern = re.compile(r"Key \((?P<fields>.+)\)=\(.*\) already exists.")

def is_match(self, error: django_db.IntegrityError) -> bool:
if not isinstance(error.__cause__, psycopg2.errors.UniqueViolation):
return False

match = self._pattern.match(error.__cause__.diag.message_detail)
if match is None:
return False

return (
tuple(match.group("fields").split(", ")) == self.fields
and error.__cause__.diag.table_name == self.model._meta.db_table
)


@dataclasses.dataclass(frozen=True)
class PrimaryKey(_Rule):
"""
A unique constraint on the primary key of a model.
"""

model: django_db.models.Model

_pattern = re.compile(r"Key \((?P<fields>.+)\)=\(.*\) already exists.")

def is_match(self, error: django_db.IntegrityError) -> bool:
if not isinstance(error.__cause__, psycopg2.errors.UniqueViolation):
return False

match = self._pattern.match(error.__cause__.diag.message_detail)
if match is None:
return False

# We assume that the model has a primary key,
# given that we're looking for a primary key constraint.
assert self.model._meta.pk is not None
return (
tuple(match.group("fields").split(", ")) == (self.model._meta.pk.name,)
and error.__cause__.diag.table_name == self.model._meta.db_table
)


@dataclasses.dataclass(frozen=True)
class NotNull(_Rule):
"""
A not-null constraint on a Model's field.
"""

model: django_db.models.Model
field: str

def is_match(self, error: django_db.IntegrityError) -> bool:
if not isinstance(error.__cause__, psycopg2.errors.NotNullViolation):
return False

return (
error.__cause__.diag.column_name == self.field
and error.__cause__.diag.table_name == self.model._meta.db_table
)


@dataclasses.dataclass(frozen=True)
class ForeignKey(_Rule):
"""
A foreign key constraint on a Model's field.
"""

model: django_db.models.Model
field: str

_detail_pattern = re.compile(
r"Key \((?P<field>.+)\)=\((?P<value>.+)\) is not present in table"
)

def is_match(self, error: django_db.IntegrityError) -> bool:
if not isinstance(error.__cause__, psycopg2.errors.ForeignKeyViolation):
return False

detail_match = self._detail_pattern.match(error.__cause__.diag.message_detail)
if detail_match is None:
return False

return (
detail_match.group("field") == self.field
and error.__cause__.diag.table_name == self.model._meta.db_table
)
230 changes: 230 additions & 0 deletions tests/django_integrity/test_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
import pytest
from django import db as django_db

from django_integrity import constraints, conversion
from tests.example_app import models as test_models


class SimpleError(Exception):
pass


class TestRefineIntegrityError:
def test_no_rules(self):
# It is legal to call the context manager without any rules.
with conversion.refine_integrity_error(rules={}):
pass


@pytest.mark.django_db
class TestNamedConstraint:
def test_error_refined(self):
# Create a unique instance so that we can violate the constraint later.
test_models.UniqueModel.objects.create(unique_field=42)

rules = {conversion.Named(name="unique_model_unique_field_key"): SimpleError}

# The original error should be transformed into our expected error.
with pytest.raises(SimpleError):
with conversion.refine_integrity_error(rules):
test_models.UniqueModel.objects.create(unique_field=42)

def test_rules_mismatch(self):
# Create a unique instance so that we can violate the constraint later.
test_models.UniqueModel.objects.create(unique_field=42)

# No constraints match the error:
rules = {conversion.Named(name="nonexistent_constraint"): SimpleError}

# The original error should be raised.
with pytest.raises(django_db.IntegrityError):
with conversion.refine_integrity_error(rules):
test_models.UniqueModel.objects.create(unique_field=42)


@pytest.mark.django_db
class TestUnique:
def test_error_refined(self):
# Create a unique instance so that we can violate the constraint later.
test_models.UniqueModel.objects.create(unique_field=42)

rules = {
conversion.Unique(
model=test_models.UniqueModel, fields=("unique_field",)
): SimpleError
}

# The original error should be transformed into our expected error.
with pytest.raises(SimpleError):
with conversion.refine_integrity_error(rules):
test_models.UniqueModel.objects.create(unique_field=42)

def test_multiple_fields(self):
# Create a unique instance so that we can violate the constraint later.
test_models.UniqueTogetherModel.objects.create(field_1=1, field_2=2)

rules = {
conversion.Unique(
model=test_models.UniqueTogetherModel, fields=("field_1", "field_2")
): SimpleError
}

# The original error should be transformed into our expected error.
with pytest.raises(SimpleError):
with conversion.refine_integrity_error(rules):
test_models.UniqueTogetherModel.objects.create(field_1=1, field_2=2)

@pytest.mark.parametrize(
"Model, field",
(
# Wrong model, despite matching field name.
(
test_models.AlternativeUniqueModel,
"unique_field",
),
# Wrong field, despite matching model.
(
test_models.UniqueModel,
"id",
),
),
ids=("wrong_model", "wrong_field"),
)
def test_rules_mismatch(self, Model: conversion.Unique, field: str):
# A rule that matches a similar looking, but different, unique constraint.
# Create a unique instance so that we can violate the constraint later.
test_models.UniqueModel.objects.create(unique_field=42)

rules = {conversion.Unique(model=Model, fields=(field,)): SimpleError}

# We shouldn't transform the error, because it didn't match the rule.
with pytest.raises(django_db.IntegrityError):
with conversion.refine_integrity_error(rules):
test_models.UniqueModel.objects.create(unique_field=42)


@pytest.mark.django_db
class TestPrimaryKey:
@pytest.mark.parametrize(
"ModelClass",
(
test_models.PrimaryKeyModel,
test_models.AlternativePrimaryKeyModel,
),
)
def test_error_refined(self, ModelClass):
"""
The primary key of a model is extracted from the model.
This test internally refers to the models primary key using "pk".
"pk" is Django magic that refers to the primary key of the model.
On PrimaryKeyModel, the primary key is "id".
On AlternativePrimaryKeyModel, the primary key is "identity".
"""
# Create a unique instance so that we can violate the constraint later.
existing_primary_key = ModelClass.objects.create().pk

rules = {conversion.PrimaryKey(model=ModelClass): SimpleError}

# The original error should be transformed into our expected error.
with pytest.raises(SimpleError):
with conversion.refine_integrity_error(rules):
ModelClass.objects.create(pk=existing_primary_key)

def test_rules_mismatch(self):
# Create a unique instance so that we can violate the constraint later.
existing_primary_key = test_models.PrimaryKeyModel.objects.create().pk

# A similar rule, but for a different model with the same field name..
rules = {conversion.PrimaryKey(model=test_models.UniqueModel): SimpleError}

# The original error should be raised.
with pytest.raises(django_db.IntegrityError):
with conversion.refine_integrity_error(rules):
test_models.PrimaryKeyModel.objects.create(pk=existing_primary_key)


@pytest.mark.django_db
class TestNotNull:
def test_error_refined(self):
rules = {
conversion.NotNull(
model=test_models.UniqueModel, field="unique_field"
): SimpleError
}

# The original error should be transformed into our expected error.
with pytest.raises(SimpleError):
with conversion.refine_integrity_error(rules):
test_models.UniqueModel.objects.create(unique_field=None)

def test_model_mismatch(self):
# Same field, but different model.
rules = {
conversion.NotNull(
model=test_models.AlternativeUniqueModel, field="unique_field"
): SimpleError
}

with pytest.raises(django_db.IntegrityError):
with conversion.refine_integrity_error(rules):
test_models.UniqueModel.objects.create(unique_field=None)

def test_field_mismatch(self):
# Same model, but different field.
rules = {
conversion.NotNull(
model=test_models.AlternativeUniqueModel, field="unique_field_2"
): SimpleError
}

# The original error should be raised.
with pytest.raises(django_db.IntegrityError):
with conversion.refine_integrity_error(rules):
test_models.AlternativeUniqueModel.objects.create(
unique_field=None,
unique_field_2=42,
)


@pytest.mark.django_db
class TestForeignKey:
def test_error_refined(self):
rules = {
conversion.ForeignKey(
model=test_models.ForeignKeyModel, field="related_id"
): SimpleError
}
constraints.set_all_immediate(using="default")

# The original error should be transformed into our expected error.
with pytest.raises(SimpleError):
with conversion.refine_integrity_error(rules):
# Create a ForeignKeyModel with a related_id that doesn't exist.
test_models.ForeignKeyModel.objects.create(related_id=42)

def test_source_mismatch(self):
# The field name matches, but the source model is different.
rules = {
conversion.ForeignKey(
model=test_models.ForeignKeyModel2, field="related_id"
): SimpleError
}
constraints.set_all_immediate(using="default")

with pytest.raises(django_db.IntegrityError):
with conversion.refine_integrity_error(rules):
test_models.ForeignKeyModel.objects.create(related_id=42)

def test_field_mismatch(self):
# The source model matches, but the field name is different.
rules = {
conversion.ForeignKey(
model=test_models.ForeignKeyModel3, field="related_2_id"
): SimpleError
}
constraints.set_all_immediate(using="default")

with pytest.raises(django_db.IntegrityError):
with conversion.refine_integrity_error(rules):
test_models.ForeignKeyModel3.objects.create(related_1_id=42)
Loading

0 comments on commit 80677a9

Please sign in to comment.