From b4e07bccc0955d607b4ae48c69f2428f21a72a97 Mon Sep 17 00:00:00 2001 From: Charlie Denton Date: Fri, 9 Feb 2024 20:37:33 +0000 Subject: [PATCH] Add helper methods for deferrable constraints These new tools give us control over constraints within the scope of a transaction. It can be helpful to ensure constraints are immediately checked so that we can catch violations in the flow of code, rather than when we're committing a transaction. --- src/django_integrity/constraints.py | 126 ++++++++++++++++++ tests/django_integrity/test_constraints.py | 147 +++++++++++++++++++++ tests/example_app/models.py | 13 ++ 3 files changed, 286 insertions(+) diff --git a/src/django_integrity/constraints.py b/src/django_integrity/constraints.py index 2bd2a44..c8c9320 100644 --- a/src/django_integrity/constraints.py +++ b/src/django_integrity/constraints.py @@ -1,4 +1,130 @@ +import contextlib +from collections.abc import Iterator, Sequence + from django import db as django_db +from psycopg2 import sql + + +# Note [Deferrable constraints] +# ----------------------------- +# Only some types of PostgreSQL constraint can be DEFERRED, and +# they may be deferred if they are created with the DEFERRABLE option. +# +# These types of constraints can be DEFERRABLE: +# - UNIQUE +# - PRIMARY KEY +# - REFERENCES (foreign key) +# - EXCLUDE +# +# These types of constraints can never be DEFERRABLE: +# - CHECK +# - NOT NULL +# +# By default, Django makes foreign key constraints DEFERRABLE INITIALLY DEFERRED, +# so they are checked at the end of the transaction, +# rather than when the statement is executed. +# +# All other constraints are IMMEDIATE (and not DEFERRABLE) by default. +# This can be changed by passing the `deferrable` argument to the constraint. +# +# Further reading: +# - https://www.postgresql.org/docs/current/sql-set-constraints.html +# - https://www.postgresql.org/docs/current/sql-createtable.html +# - https://docs.djangoproject.com/en/5.0/ref/models/constraints/#deferrable + + +@contextlib.contextmanager +def immediate(names: Sequence[str], *, using: str) -> Iterator[None]: + """ + Temporarily set named DEFERRABLE constraints to IMMEDIATE. + + This is useful for catching constraint violations as soon as they occur, + rather than at the end of the transaction. + + This is especially useful for foreign key constraints in Django, + which are DEFERRED by default. + + We presume that any provided constraints were previously DEFERRED, + and we restore them to that state after the context manager exits. + + To be sure that the constraints are restored to DEFERRED + even if an exception is raised, we use a savepoint. + + This could be expensive if used in a loop because on every iteration we would + create and close (or roll back) a savepoint, and set and unset the constraint state. + + # See Note [Deferrable constraints] + """ + set_immediate(names, using=using) + try: + with django_db.transaction.atomic(using=using): + yield + finally: + set_deferred(names, using=using) + + +def set_all_immediate(*, using: str) -> None: + """ + Set all constraints to IMMEDIATE for the remainder of the transaction. + + # See Note [Deferrable constraints] + """ + if django_db.transaction.get_autocommit(using): + raise NotInTransaction + + with django_db.connections[using].cursor() as cursor: + cursor.execute("SET CONSTRAINTS ALL IMMEDIATE") + + +def set_immediate(names: Sequence[str], *, using: str) -> None: + """ + Set particular constraints to IMMEDIATE for the remainder of the transaction. + + # See Note [Deferrable constraints] + """ + if django_db.transaction.get_autocommit(using): + raise NotInTransaction + + if not names: + return + + query = sql.SQL("SET CONSTRAINTS {names} IMMEDIATE").format( + names=sql.SQL(", ").join(sql.Identifier(name) for name in names) + ) + + with django_db.connections[using].cursor() as cursor: + cursor.execute(query) + + +def set_deferred(names: Sequence[str], *, using: str) -> None: + """ + Set particular constraints to DEFERRED for the remainder of the transaction. + + # See Note [Deferrable constraints] + """ + if django_db.transaction.get_autocommit(using): + raise NotInTransaction + + if not names: + return + + query = sql.SQL("SET CONSTRAINTS {names} DEFERRED").format( + names=sql.SQL(", ").join(sql.Identifier(name) for name in names) + ) + + with django_db.connections[using].cursor() as cursor: + cursor.execute(query) + + +class NotInTransaction(Exception): + """ + Raised when we try to change the state of constraints outside of a transaction. + + It doesn't make sense to change the state of constraints outside of a transaction, + because the change of state would only last for the remainder of the transaction. + + See https://www.postgresql.org/docs/current/sql-set-constraints.html + """ def foreign_key_constraint_name( diff --git a/tests/django_integrity/test_constraints.py b/tests/django_integrity/test_constraints.py index afcd1c4..6978c2a 100644 --- a/tests/django_integrity/test_constraints.py +++ b/tests/django_integrity/test_constraints.py @@ -1,4 +1,5 @@ import pytest +from django import db as django_db from django.core import exceptions from django_integrity import constraints @@ -36,3 +37,149 @@ def test_wrong_field_name(self) -> None: field_name="does_not_exist", using="default", ) + + +class TestSetAllImmediate: + @pytest.mark.django_db + def test_all_constraints_set(self): + constraints.set_all_immediate(using="default") + + with pytest.raises(django_db.IntegrityError): + # The ForeignKey constraint should be enforced immediately. + test_models.ForeignKeyModel.objects.create(related_id=42) + + # lint-ignore NoTransactionLessIntegrationTests + @pytest.mark.django_db(transaction=True) + def test_not_in_transaction(self): + # Fail if we're not in a transaction. + with pytest.raises(constraints.NotInTransaction): + constraints.set_all_immediate(using="default") + + +class TestSetImmediate: + @pytest.mark.django_db + def test_set(self): + constraint_name = constraints.foreign_key_constraint_name( + model=test_models.ForeignKeyModel, + field_name="related_id", + using="default", + ) + + constraints.set_immediate(names=(constraint_name,), using="default") + + # An error should be raised immediately. + with pytest.raises(django_db.IntegrityError): + test_models.ForeignKeyModel.objects.create(related_id=42) + + @pytest.mark.django_db + def test_not_set(self): + # No constraint name is passed, so no constraints should be set to immediate. + constraints.set_immediate(names=(), using="default") + + # No error should be raised. + test_models.ForeignKeyModel.objects.create(related_id=42) + + # We catch the error here to prevent the test from failing in shutdown. + with pytest.raises(django_db.IntegrityError): + constraints.set_all_immediate(using="default") + + # lint-ignore NoTransactionLessIntegrationTests + @pytest.mark.django_db(transaction=True) + def test_not_in_transaction(self): + # Fail if we're not in a transaction. + with pytest.raises(constraints.NotInTransaction): + constraints.set_immediate(names=(), using="default") + + +class TestSetDeferred: + @pytest.mark.django_db + def test_not_set(self): + test_models.UniqueModel.objects.create(unique_field=42) + + # We pass no names, so no constraints should be set to deferred. + constraints.set_deferred(names=(), using="default") + + # This constraint defaults to IMMEDIATE, + # so an error should be raised immediately. + with pytest.raises(django_db.IntegrityError): + test_models.UniqueModel.objects.create(unique_field=42) + + @pytest.mark.django_db + def test_set(self): + test_models.UniqueModel.objects.create(unique_field=42) + + # We defer the constraint... + constraint_name = "unique_model_unique_field_key" + constraints.set_deferred(names=(constraint_name,), using="default") + + # ... so no error should be raised. + test_models.UniqueModel.objects.create(unique_field=42) + + # We catch the error here to prevent the test from failing in shutdown. + with pytest.raises(django_db.IntegrityError): + constraints.set_all_immediate(using="default") + + # lint-ignore NoTransactionLessIntegrationTests + @pytest.mark.django_db(transaction=True) + def test_not_in_transaction(self): + # Fail if we're not in a transaction. + with pytest.raises(constraints.NotInTransaction): + constraints.set_deferred(names=(), using="default") + + +class TestImmediate: + @pytest.mark.django_db + def test_constraint_not_enforced(self): + """Constraints are not changed when not explicitly enforced.""" + # Call the context manager without any constraint names. + with constraints.immediate((), using="default"): + # Create an instance that violates a deferred constraint. + # No error should be raised. + test_models.ForeignKeyModel.objects.create(related_id=42) + + # We catch the error here to prevent the test from failing in shutdown. + with pytest.raises(django_db.IntegrityError): + constraints.set_all_immediate(using="default") + + @pytest.mark.django_db + def test_constraint_enforced(self): + """Constraints are enforced when explicitly enforced.""" + constraint_name = constraints.foreign_key_constraint_name( + model=test_models.ForeignKeyModel, + field_name="related_id", + using="default", + ) + + # An error should be raised immediately. + with pytest.raises(django_db.IntegrityError): + with constraints.immediate((constraint_name,), using="default"): + # Create an instance that violates a deferred constraint. + test_models.ForeignKeyModel.objects.create(related_id=42) + + @pytest.mark.django_db + def test_deferral_restored(self): + """Constraints are restored to DEFERRED after the context manager.""" + constraint_name = constraints.foreign_key_constraint_name( + model=test_models.ForeignKeyModel, + field_name="related_id", + using="default", + ) + + with constraints.immediate((constraint_name,), using="default"): + pass + + # Create an instance that violates a deferred constraint. + # No error should be raised, because the constraint should be deferred again. + test_models.ForeignKeyModel.objects.create(related_id=42) + + # We catch the error here to prevent the test from failing in shutdown. + with pytest.raises(django_db.IntegrityError): + constraints.set_all_immediate(using="default") + + # lint-ignore NoTransactionLessIntegrationTests + @pytest.mark.django_db(transaction=True) + def test_not_in_transaction(self): + # Fail if we're not in a transaction. + with pytest.raises(constraints.NotInTransaction): + with constraints.immediate((), using="default"): + pass diff --git a/tests/example_app/models.py b/tests/example_app/models.py index ef99773..a66fa1a 100644 --- a/tests/example_app/models.py +++ b/tests/example_app/models.py @@ -7,3 +7,16 @@ class PrimaryKeyModel(models.Model): class ForeignKeyModel(models.Model): related = models.ForeignKey(PrimaryKeyModel, on_delete=models.CASCADE) + + +class UniqueModel(models.Model): + unique_field = models.IntegerField() + + class Meta: + constraints = ( + models.UniqueConstraint( + fields=["unique_field"], + name="unique_model_unique_field_key", + deferrable=models.Deferrable.IMMEDIATE, + ), + )