Skip to content

Commit

Permalink
Add helper methods for deferrable constraints
Browse files Browse the repository at this point in the history
These new tools give us control over constraints within the scope of a
transaction.

It can be helpful to ensure constraints are immediately checked so that
we can catch violations in the flow of code, rather than when we're
committing a transaction.
  • Loading branch information
meshy committed Feb 12, 2024
1 parent 5ade696 commit b4e07bc
Show file tree
Hide file tree
Showing 3 changed files with 286 additions and 0 deletions.
126 changes: 126 additions & 0 deletions src/django_integrity/constraints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,130 @@
import contextlib
from collections.abc import Iterator, Sequence

from django import db as django_db
from psycopg2 import sql


# Note [Deferrable constraints]
# -----------------------------
# Only some types of PostgreSQL constraint can be DEFERRED, and
# they may be deferred if they are created with the DEFERRABLE option.
#
# These types of constraints can be DEFERRABLE:
# - UNIQUE
# - PRIMARY KEY
# - REFERENCES (foreign key)
# - EXCLUDE
#
# These types of constraints can never be DEFERRABLE:
# - CHECK
# - NOT NULL
#
# By default, Django makes foreign key constraints DEFERRABLE INITIALLY DEFERRED,
# so they are checked at the end of the transaction,
# rather than when the statement is executed.
#
# All other constraints are IMMEDIATE (and not DEFERRABLE) by default.
# This can be changed by passing the `deferrable` argument to the constraint.
#
# Further reading:
# - https://www.postgresql.org/docs/current/sql-set-constraints.html
# - https://www.postgresql.org/docs/current/sql-createtable.html
# - https://docs.djangoproject.com/en/5.0/ref/models/constraints/#deferrable


@contextlib.contextmanager
def immediate(names: Sequence[str], *, using: str) -> Iterator[None]:
"""
Temporarily set named DEFERRABLE constraints to IMMEDIATE.
This is useful for catching constraint violations as soon as they occur,
rather than at the end of the transaction.
This is especially useful for foreign key constraints in Django,
which are DEFERRED by default.
We presume that any provided constraints were previously DEFERRED,
and we restore them to that state after the context manager exits.
To be sure that the constraints are restored to DEFERRED
even if an exception is raised, we use a savepoint.
This could be expensive if used in a loop because on every iteration we would
create and close (or roll back) a savepoint, and set and unset the constraint state.
# See Note [Deferrable constraints]
"""
set_immediate(names, using=using)
try:
with django_db.transaction.atomic(using=using):
yield
finally:
set_deferred(names, using=using)


def set_all_immediate(*, using: str) -> None:
"""
Set all constraints to IMMEDIATE for the remainder of the transaction.
# See Note [Deferrable constraints]
"""
if django_db.transaction.get_autocommit(using):
raise NotInTransaction

with django_db.connections[using].cursor() as cursor:
cursor.execute("SET CONSTRAINTS ALL IMMEDIATE")


def set_immediate(names: Sequence[str], *, using: str) -> None:
"""
Set particular constraints to IMMEDIATE for the remainder of the transaction.
# See Note [Deferrable constraints]
"""
if django_db.transaction.get_autocommit(using):
raise NotInTransaction

if not names:
return

query = sql.SQL("SET CONSTRAINTS {names} IMMEDIATE").format(
names=sql.SQL(", ").join(sql.Identifier(name) for name in names)
)

with django_db.connections[using].cursor() as cursor:
cursor.execute(query)


def set_deferred(names: Sequence[str], *, using: str) -> None:
"""
Set particular constraints to DEFERRED for the remainder of the transaction.
# See Note [Deferrable constraints]
"""
if django_db.transaction.get_autocommit(using):
raise NotInTransaction

if not names:
return

query = sql.SQL("SET CONSTRAINTS {names} DEFERRED").format(
names=sql.SQL(", ").join(sql.Identifier(name) for name in names)
)

with django_db.connections[using].cursor() as cursor:
cursor.execute(query)


class NotInTransaction(Exception):
"""
Raised when we try to change the state of constraints outside of a transaction.
It doesn't make sense to change the state of constraints outside of a transaction,
because the change of state would only last for the remainder of the transaction.
See https://www.postgresql.org/docs/current/sql-set-constraints.html
"""


def foreign_key_constraint_name(
Expand Down
147 changes: 147 additions & 0 deletions tests/django_integrity/test_constraints.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from django import db as django_db
from django.core import exceptions

from django_integrity import constraints
Expand Down Expand Up @@ -36,3 +37,149 @@ def test_wrong_field_name(self) -> None:
field_name="does_not_exist",
using="default",
)


class TestSetAllImmediate:
@pytest.mark.django_db
def test_all_constraints_set(self):
constraints.set_all_immediate(using="default")

with pytest.raises(django_db.IntegrityError):
# The ForeignKey constraint should be enforced immediately.
test_models.ForeignKeyModel.objects.create(related_id=42)

# lint-ignore NoTransactionLessIntegrationTests
@pytest.mark.django_db(transaction=True)
def test_not_in_transaction(self):
# Fail if we're not in a transaction.
with pytest.raises(constraints.NotInTransaction):
constraints.set_all_immediate(using="default")


class TestSetImmediate:
@pytest.mark.django_db
def test_set(self):
constraint_name = constraints.foreign_key_constraint_name(
model=test_models.ForeignKeyModel,
field_name="related_id",
using="default",
)

constraints.set_immediate(names=(constraint_name,), using="default")

# An error should be raised immediately.
with pytest.raises(django_db.IntegrityError):
test_models.ForeignKeyModel.objects.create(related_id=42)

@pytest.mark.django_db
def test_not_set(self):
# No constraint name is passed, so no constraints should be set to immediate.
constraints.set_immediate(names=(), using="default")

# No error should be raised.
test_models.ForeignKeyModel.objects.create(related_id=42)

# We catch the error here to prevent the test from failing in shutdown.
with pytest.raises(django_db.IntegrityError):
constraints.set_all_immediate(using="default")

# lint-ignore NoTransactionLessIntegrationTests
@pytest.mark.django_db(transaction=True)
def test_not_in_transaction(self):
# Fail if we're not in a transaction.
with pytest.raises(constraints.NotInTransaction):
constraints.set_immediate(names=(), using="default")


class TestSetDeferred:
@pytest.mark.django_db
def test_not_set(self):
test_models.UniqueModel.objects.create(unique_field=42)

# We pass no names, so no constraints should be set to deferred.
constraints.set_deferred(names=(), using="default")

# This constraint defaults to IMMEDIATE,
# so an error should be raised immediately.
with pytest.raises(django_db.IntegrityError):
test_models.UniqueModel.objects.create(unique_field=42)

@pytest.mark.django_db
def test_set(self):
test_models.UniqueModel.objects.create(unique_field=42)

# We defer the constraint...
constraint_name = "unique_model_unique_field_key"
constraints.set_deferred(names=(constraint_name,), using="default")

# ... so no error should be raised.
test_models.UniqueModel.objects.create(unique_field=42)

# We catch the error here to prevent the test from failing in shutdown.
with pytest.raises(django_db.IntegrityError):
constraints.set_all_immediate(using="default")

# lint-ignore NoTransactionLessIntegrationTests
@pytest.mark.django_db(transaction=True)
def test_not_in_transaction(self):
# Fail if we're not in a transaction.
with pytest.raises(constraints.NotInTransaction):
constraints.set_deferred(names=(), using="default")


class TestImmediate:
@pytest.mark.django_db
def test_constraint_not_enforced(self):
"""Constraints are not changed when not explicitly enforced."""
# Call the context manager without any constraint names.
with constraints.immediate((), using="default"):
# Create an instance that violates a deferred constraint.
# No error should be raised.
test_models.ForeignKeyModel.objects.create(related_id=42)

# We catch the error here to prevent the test from failing in shutdown.
with pytest.raises(django_db.IntegrityError):
constraints.set_all_immediate(using="default")

@pytest.mark.django_db
def test_constraint_enforced(self):
"""Constraints are enforced when explicitly enforced."""
constraint_name = constraints.foreign_key_constraint_name(
model=test_models.ForeignKeyModel,
field_name="related_id",
using="default",
)

# An error should be raised immediately.
with pytest.raises(django_db.IntegrityError):
with constraints.immediate((constraint_name,), using="default"):
# Create an instance that violates a deferred constraint.
test_models.ForeignKeyModel.objects.create(related_id=42)

@pytest.mark.django_db
def test_deferral_restored(self):
"""Constraints are restored to DEFERRED after the context manager."""
constraint_name = constraints.foreign_key_constraint_name(
model=test_models.ForeignKeyModel,
field_name="related_id",
using="default",
)

with constraints.immediate((constraint_name,), using="default"):
pass

# Create an instance that violates a deferred constraint.
# No error should be raised, because the constraint should be deferred again.
test_models.ForeignKeyModel.objects.create(related_id=42)

# We catch the error here to prevent the test from failing in shutdown.
with pytest.raises(django_db.IntegrityError):
constraints.set_all_immediate(using="default")

# lint-ignore NoTransactionLessIntegrationTests
@pytest.mark.django_db(transaction=True)
def test_not_in_transaction(self):
# Fail if we're not in a transaction.
with pytest.raises(constraints.NotInTransaction):
with constraints.immediate((), using="default"):
pass
13 changes: 13 additions & 0 deletions tests/example_app/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,16 @@ class PrimaryKeyModel(models.Model):

class ForeignKeyModel(models.Model):
related = models.ForeignKey(PrimaryKeyModel, on_delete=models.CASCADE)


class UniqueModel(models.Model):
unique_field = models.IntegerField()

class Meta:
constraints = (
models.UniqueConstraint(
fields=["unique_field"],
name="unique_model_unique_field_key",
deferrable=models.Deferrable.IMMEDIATE,
),
)

0 comments on commit b4e07bc

Please sign in to comment.