From 4a55960cdaf60426b3147e1e96c57c8a66fb8822 Mon Sep 17 00:00:00 2001 From: eb8680 Date: Fri, 16 Feb 2024 15:09:28 -0500 Subject: [PATCH] Fix interaction between PyroParam and torch.func.grad (#3328) * fix * lint * context * nit * lint * bump black version and lint * add failing case and lint * strengthen test and add extra fixes * last edge case --- pyro/nn/module.py | 62 ++++++++++++++++++++++++++++--- tests/nn/test_module.py | 81 ++++++++++++++++++++++++++++++++++++++++- 2 files changed, 137 insertions(+), 6 deletions(-) diff --git a/pyro/nn/module.py b/pyro/nn/module.py index cc38517ac4..323fe470a5 100644 --- a/pyro/nn/module.py +++ b/pyro/nn/module.py @@ -21,6 +21,7 @@ from torch.distributions import constraints, transform_to import pyro +import pyro.params.param_store from pyro.ops.provenance import detach_provenance from pyro.poutine.runtime import _PYRO_PARAM_STORE @@ -474,7 +475,7 @@ def __getattr__(self, name): if name in _pyro_params: constraint, event_dim = _pyro_params[name] unconstrained_value = getattr(self, name + "_unconstrained") - if self._pyro_context.active: + if self._pyro_context.active and not _is_module_local_param_enabled(): fullname = self._pyro_get_fullname(name) if fullname in _PYRO_PARAM_STORE: if ( @@ -503,6 +504,15 @@ def __getattr__(self, name): _PYRO_PARAM_STORE._params[fullname] = unconstrained_value _PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname return pyro.param(fullname, event_dim=event_dim) + elif self._pyro_context.active and _is_module_local_param_enabled(): + # fake param statement to ensure any handlers of pyro.param are applied, + # even though we don't use the contents of the local parameter store + fullname = self._pyro_get_fullname(name) + constrained_value = transform_to(constraint)(unconstrained_value) + constrained_value.unconstrained = weakref.ref(unconstrained_value) + return pyro.poutine.runtime.effectful(type="param")( + lambda *_, **__: constrained_value + )(fullname, event_dim=event_dim, name=fullname) else: # Cannot determine supermodule and hence cannot compute fullname. constrained_value = transform_to(constraint)(unconstrained_value) constrained_value.unconstrained = weakref.ref(unconstrained_value) @@ -534,8 +544,15 @@ def __getattr__(self, name): if isinstance(result, torch.nn.Parameter) and not name.endswith( "_unconstrained" ): - if self._pyro_context.active: + if self._pyro_context.active and not _is_module_local_param_enabled(): pyro.param(self._pyro_get_fullname(name), result) + elif self._pyro_context.active and _is_module_local_param_enabled(): + # fake param statement to ensure any handlers of pyro.param are applied, + # even though we don't use the contents of the local parameter store + fullname = self._pyro_get_fullname(name) + pyro.poutine.runtime.effectful(type="param")(lambda *_, **__: result)( + fullname, result, name=fullname + ) if isinstance(result, torch.nn.Module): if isinstance(result, PyroModule): @@ -546,8 +563,19 @@ def __getattr__(self, name): ) else: # Regular nn.Modules trigger pyro.module statements. - if self._pyro_context.active: + if self._pyro_context.active and not _is_module_local_param_enabled(): pyro.module(self._pyro_get_fullname(name), result) + elif self._pyro_context.active and _is_module_local_param_enabled(): + # fake module statement to ensure any handlers of pyro.module are applied, + # even though we don't use the contents of the local parameter store + fullname_module = self._pyro_get_fullname(name) + for param_name, param_value in result.named_parameters(): + fullname_param = pyro.params.param_store.param_with_module_name( + fullname_module, param_name + ) + pyro.poutine.runtime.effectful(type="param")( + lambda *_, **__: param_value + )(fullname_param, param_value, name=fullname_param) return result @@ -569,7 +597,7 @@ def __setattr__(self, name, value): pass constrained_value, constraint, event_dim = value self._pyro_params[name] = constraint, event_dim - if self._pyro_context.active: + if self._pyro_context.active and not _is_module_local_param_enabled(): fullname = self._pyro_get_fullname(name) pyro.param( fullname, @@ -584,6 +612,21 @@ def __setattr__(self, name, value): unconstrained_value = torch.nn.Parameter(unconstrained_value) _PYRO_PARAM_STORE._params[fullname] = unconstrained_value _PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname + elif self._pyro_context.active and _is_module_local_param_enabled(): + # fake param statement to ensure any handlers of pyro.param are applied, + # even though we don't use the contents of the local parameter store + fullname = self._pyro_get_fullname(name) + constrained_value = detach_provenance( + pyro.poutine.runtime.effectful(type="param")( + lambda *_, **__: constrained_value + )( + fullname, + constraint=constraint, + event_dim=event_dim, + name=fullname, + ) + ) + unconstrained_value = _unconstrain(constrained_value, constraint) else: # Cannot determine supermodule and hence cannot compute fullname. unconstrained_value = _unconstrain(constrained_value, constraint) super().__setattr__(name + "_unconstrained", unconstrained_value) @@ -595,7 +638,7 @@ def __setattr__(self, name, value): delattr(self, name) except AttributeError: pass - if self._pyro_context.active: + if self._pyro_context.active and not _is_module_local_param_enabled(): fullname = self._pyro_get_fullname(name) value = pyro.param(fullname, value) if not isinstance(value, torch.nn.Parameter): @@ -603,6 +646,15 @@ def __setattr__(self, name, value): value = torch.nn.Parameter(detach_provenance(value)) _PYRO_PARAM_STORE._params[fullname] = value _PYRO_PARAM_STORE._param_to_name[value] = fullname + elif self._pyro_context.active and _is_module_local_param_enabled(): + # fake param statement to ensure any handlers of pyro.param are applied, + # even though we don't use the contents of the local parameter store + fullname = self._pyro_get_fullname(name) + value = detach_provenance( + pyro.poutine.runtime.effectful(type="param")( + lambda *_, **__: value + )(fullname, value, constraint=constraints.real, name=fullname) + ) super().__setattr__(name, value) return diff --git a/tests/nn/test_module.py b/tests/nn/test_module.py index 90717ee180..67c4b98108 100644 --- a/tests/nn/test_module.py +++ b/tests/nn/test_module.py @@ -15,7 +15,7 @@ from pyro.infer import SVI, Trace_ELBO from pyro.nn.module import PyroModule, PyroParam, PyroSample, clear, to_pyro_module_ from pyro.optim import Adam -from tests.common import assert_equal +from tests.common import assert_equal, xfail_param def test_svi_smoke(): @@ -765,3 +765,82 @@ def test_bayesian_gru(): assert output.shape == (seq_len, batch_size, hidden_size) output2, _ = gru(input_) assert not torch.allclose(output2, output) + + +@pytest.mark.parametrize( + "use_local_params", + [ + True, + xfail_param( + False, reason="torch.func not compatible with global parameter store" + ), + ], +) +def test_functorch_pyroparam(use_local_params): + class ParamModule(PyroModule): + def __init__(self): + super().__init__() + self.a2 = PyroParam(torch.tensor(0.678), constraints.positive) + + @PyroParam(constraint=constraints.real) + def a1(self): + return torch.tensor(0.456) + + class Model(PyroModule): + def __init__(self): + super().__init__() + self.param_module = ParamModule() + self.b1 = PyroParam(torch.tensor(0.123), constraints.positive) + self.b3 = torch.nn.Parameter(torch.tensor(0.789)) + self.c = torch.nn.Linear(1, 1) + + @PyroParam(constraint=constraints.positive) + def b2(self): + return torch.tensor(1.234) + + def forward(self, x, y): + return ( + (self.param_module.a1 + self.param_module.a2) * x + + self.b1 + + self.b2 + + self.b3 + - self.c(y.unsqueeze(-1)).squeeze(-1) + ) ** 2 + + with pyro.settings.context(module_local_params=use_local_params): + model = Model() + x, y = torch.tensor(1.3), torch.tensor(0.2) + + with pyro.poutine.trace() as tr: + model(x, y) + + params = dict(model.named_parameters()) + + # Check that all parameters appear in the trace for SVI compatibility + assert len(params) == len( + { + name: node + for name, node in tr.trace.nodes.items() + if node["type"] == "param" + } + ) + + grad_model = torch.func.grad( + lambda p, x, y: torch.func.functional_call(model, p, (x, y)) + ) + grad_params_func = grad_model(params, x, y) + + gs = torch.autograd.grad(model(x, y), tuple(params.values())) + grad_params_autograd = dict(zip(params.keys(), gs)) + + assert len(grad_params_autograd) == len(grad_params_func) != 0 + assert ( + set(grad_params_autograd.keys()) + == set(grad_params_func.keys()) + == set(params.keys()) + ) + for k in grad_params_autograd.keys(): + assert not torch.allclose( + grad_params_func[k], torch.zeros_like(grad_params_func[k]) + ), k + assert torch.allclose(grad_params_autograd[k], grad_params_func[k]), k