diff --git a/haiku/_src/lift_test.py b/haiku/_src/lift_test.py index 1cdf3fcaa..4e8366209 100644 --- a/haiku/_src/lift_test.py +++ b/haiku/_src/lift_test.py @@ -170,14 +170,27 @@ def f(): ValueError, "must be used within the same call to init/apply"): f.init(None) - def test_transparent_lift(self): + @parameterized.parameters( + (True,), # passes, previous test + (False,), # fails + ) + def test_transparent_lift(self, inner_module_instantiated_inside_transform): class OuterModule(module.Module): def __call__(self, x): x += base.get_parameter("a", shape=[10, 10], init=jnp.zeros) - def inner_fn(x): - return InnerModule(name="inner")(x) + if inner_module_instantiated_inside_transform: + # These weights will end up being called "outer/inner", ok! + def inner_fn(x): + return InnerModule(name="inner")(x) + else: + inner_module = InnerModule(name="inner") + # These weights will end up being called "outer/outer/inner", + # causing the test to fail. I would be very surprised if this is + # desider behavior. + def inner_fn(x): + return inner_module(x) inner_transformed = transform.transform(inner_fn) inner_params = lift.transparent_lift(inner_transformed.init)(