diff --git a/docs/api/module/advanced_fields.md b/docs/api/module/advanced_fields.md index 2c18549a..e910e0a9 100644 --- a/docs/api/module/advanced_fields.md +++ b/docs/api/module/advanced_fields.md @@ -18,6 +18,69 @@ Equinox modules can be used as [abstract base classes](https://docs.python.org/3 selection: members: false +## Checking invariants + +Equinox extends dataclasses with a `__check_init__` method, which is automatically ran after initialisation. This can be used to check invariants like so: + +```python +class Positive(eqx.Module): + x: int + + def __check_init__(self): + if self.x <= 0: + raise ValueError("Oh no!") +``` + +This method has three key differences compared to the `__post_init__` provided by dataclasses: + +- It is not overridden by an `__init__` method of a subclass. In contrast, the following code has a silent bug: + + ```python + class Parent(eqx.Module): + x: int + + def __check_init__(self): + if self.x <= 0: + raise ValueError("Oh no!") + + class Child(Parent): + x_as_str: str + + def __init__(self, x): + self.x = x + self.x_as_str = str(x) + + Child(-1) # No error! + ``` + +- It is automatically called for parent classes; `super().__check_init__()` is not required: + + ```python + class Parent(eqx.Module): + def __check_init__(self): + print("Parent") + + class Child(Parent): + def __check_init__(self): + print("Child") + + Child() # prints out both Child and Parent + ``` + + As with the previous bullet point, this is to prevent child classes accidentally failing to check that the invariants of their parent hold. + +- Assignment is not allowed: + + ```python + class MyModule(eqx.Module): + foo: int + + def __check_init__(self): + self.foo = 1 # will raise an error + ``` + + This is to prevent `__check_init__` from doing anything too surprising: as the name suggests, it's meant to be used for checking invariants. + ## Creating wrapper modules ::: equinox.module_update_wrapper diff --git a/equinox/_module.py b/equinox/_module.py index 724826b7..4f4cfc54 100644 --- a/equinox/_module.py +++ b/equinox/_module.py @@ -196,6 +196,13 @@ def __call__(cls, *args, **kwargs): else: setattr(self, field.name, converter(getattr(self, field.name))) object.__setattr__(self, "__class__", cls) + for kls in cls.__mro__: + try: + check = kls.__dict__["__check_init__"] + except KeyError: + pass + else: + check(self) return self diff --git a/tests/test_module.py b/tests/test_module.py index 35c0e725..a235ddad 100644 --- a/tests/test_module.py +++ b/tests/test_module.py @@ -1,3 +1,4 @@ +import dataclasses import functools as ft from typing import Any @@ -285,3 +286,104 @@ class B(A, foo=True): pass assert called + + +def test_check_init(): + class FooException(Exception): + pass + + called_a = False + called_b = False + + class A(eqx.Module): + a: int + + def __check_init__(self): + nonlocal called_a + called_a = True + if self.a >= 0: + raise FooException + + class B(A): + def __check_init__(self): + nonlocal called_b + called_b = True + + class C(A): + pass + + assert not called_a + assert not called_b + A(-1) + assert called_a + assert not called_b + + called_a = False + with pytest.raises(FooException): + A(1) + assert called_a + assert not called_b + + called_a = False + B(-1) + assert called_a + assert called_b + + called_a = False + called_b = False + with pytest.raises(FooException): + B(1) + assert called_a + assert called_b # B.__check_init__ is called before A.__check_init__ + + called_a = False + called_b = False + C(-1) + assert called_a + assert not called_b + + called_a = False + with pytest.raises(FooException): + C(1) + assert called_a + assert not called_b + + +def test_check_init_order(): + called_a = False + called_b = False + called_c = False + + class A(eqx.Module): + def __check_init__(self): + nonlocal called_a + called_a = True + + class B(A): + def __check_init__(self): + nonlocal called_b + called_b = True + raise ValueError + + class C(B): + def __check_init__(self): # pyright: ignore + nonlocal called_c + called_c = True + + with pytest.raises(ValueError): + C() + + assert called_c + assert called_b + assert not called_a + + +def test_check_init_no_assignment(): + class A(eqx.Module): + x: int + + def __check_init__(self): + self.x = 4 + + with pytest.raises(dataclasses.FrozenInstanceError): + A(1)