Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes cooperative multiple inheritance with __post_init__ #834

Merged
merged 1 commit into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 15 additions & 17 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def __new__(

# Add support for `eqx.field(converter=...)` when using `__post_init__`.
# (Scenario (c) above. Scenarios (a) and (b) are handled later.)
if has_dataclass_init and hasattr(cls, "__post_init__"):
if has_dataclass_init and "__post_init__" in cls.__dict__:
post_init = cls.__post_init__

@ft.wraps(post_init) # pyright: ignore
Expand All @@ -293,29 +293,23 @@ def __post_init__(self, *args, **kwargs):
# We want to only convert once, at the top level.
#
# This check is basically testing whether or not the function we're in
# now (`cls.__post_init__`) is at the top level
# (`self.__class__.__post_init__`). If we are, do conversion. If we're
# not, it's presumably because someone is calling us via `super()` in
# the middle of their own `__post_init__`. No conversion then; their own
# version of this wrapper will do it at the appropriate time instead.
#
# One small foible: we write `cls.__post_init__`, rather than just
# `__post_init__`, to refer to this function. This allows someone else
# to also monkey-patch `cls.__post_init__` if they wish, and this won't
# remove conversion. (Conversion is a at-the-top-level thing, not a
# this-particular-function thing.)
# now (`cls`) is at the top level (`self.__class__`). If we are, do
# conversion. If we're not, it's presumably because someone is calling
# us via `super()` in the middle of their own `__post_init__`. No
# conversion then; their own version of this wrapper will do it at the
# appropriate time instead.
#
# This top-level business means that this is very nearly the same as
# doing conversion in `_ModuleMeta.__call__`. The differences are that
# (a) that wouldn't allow us to convert fields before the user-provided
# `__post_init__`, and (b) it allows other libraries (i.e. jaxtyping)
# to later monkey-patch `__init__`, and we have our converter run before
# their own monkey-patched-in code.
if self.__class__.__post_init__ is cls.__post_init__:
if self.__class__ is _make_initable_wrapper(cls):
# Convert all fields currently available.
_convert_fields(self, init=True)
post_init(self, *args, **kwargs) # pyright: ignore
if self.__class__.__post_init__ is cls.__post_init__:
if self.__class__ is _make_initable_wrapper(cls):
# Convert all the fields filled in by `__post_init__` as well.
_convert_fields(self, init=False)

Expand Down Expand Up @@ -377,7 +371,7 @@ def __init__(self, *args, **kwargs):
__tracebackhide__ = True
init(self, *args, **kwargs)
# Same `if` trick as with `__post_init__`.
if self.__class__.__init__ is cls.__init__:
if self.__class__ is _make_initable_wrapper(cls):
_convert_fields(self, init=True)
_convert_fields(self, init=False)

Expand Down Expand Up @@ -566,8 +560,7 @@ def __call__(cls, *args, **kwargs):
# else it's handled in __setattr__, but that isn't called here.
# [Step 1] Modules are immutable -- except during construction. So defreeze
# before init.
post_init = getattr(cls, "__post_init__", None)
initable_cls = _make_initable(cls, cls.__init__, post_init, wraps=False)
initable_cls = _make_initable_wrapper(cls)
# [Step 2] Instantiate the class as normal.
self = super(_ActualModuleMeta, initable_cls).__call__(*args, **kwargs)
assert not _is_abstract(cls)
Expand Down Expand Up @@ -792,6 +785,11 @@ def __call__(self, ...):
break


def _make_initable_wrapper(cls: _ActualModuleMeta) -> _ActualModuleMeta:
post_init = getattr(cls, "__post_init__", None)
return _make_initable(cls, cls.__init__, post_init, wraps=False)


@ft.lru_cache(maxsize=128)
def _make_initable(
cls: _ActualModuleMeta, init, post_init, wraps: bool
Expand Down
35 changes: 35 additions & 0 deletions tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,3 +1169,38 @@ class InvalidArr(eqx.Module):
match="A JAX array is being set as static!",
):
InvalidArr((), jnp.ones(10))


# https://github.com/patrick-kidger/equinox/issues/832
def test_cooperative_multiple_inheritance():
called_a = False
called_b = False
called_d = False

class A(eqx.Module):
def __post_init__(self) -> None:
nonlocal called_a
called_a = True

class B(A):
def __post_init__(self) -> None:
nonlocal called_b
called_b = True
super().__post_init__()

class C(A):
pass

class D(C, A):
def __post_init__(self) -> None:
nonlocal called_d
called_d = True
super().__post_init__()

class E(D, B):
pass

E()
assert called_a
assert called_b
assert called_d
Loading