Skip to content

Commit

Permalink
Fix interaction between PyroParam and torch.func.grad (#3328)
Browse files Browse the repository at this point in the history
* fix

* lint

* context

* nit

* lint

* bump black version and lint

* add failing case and lint

* strengthen test and add extra fixes

* last edge case
  • Loading branch information
eb8680 authored Feb 16, 2024
1 parent 800a484 commit 4a55960
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 6 deletions.
62 changes: 57 additions & 5 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch.distributions import constraints, transform_to

import pyro
import pyro.params.param_store
from pyro.ops.provenance import detach_provenance
from pyro.poutine.runtime import _PYRO_PARAM_STORE

Expand Down Expand Up @@ -474,7 +475,7 @@ def __getattr__(self, name):
if name in _pyro_params:
constraint, event_dim = _pyro_params[name]
unconstrained_value = getattr(self, name + "_unconstrained")
if self._pyro_context.active:
if self._pyro_context.active and not _is_module_local_param_enabled():
fullname = self._pyro_get_fullname(name)
if fullname in _PYRO_PARAM_STORE:
if (
Expand Down Expand Up @@ -503,6 +504,15 @@ def __getattr__(self, name):
_PYRO_PARAM_STORE._params[fullname] = unconstrained_value
_PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname
return pyro.param(fullname, event_dim=event_dim)
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake param statement to ensure any handlers of pyro.param are applied,
# even though we don't use the contents of the local parameter store
fullname = self._pyro_get_fullname(name)
constrained_value = transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value)
return pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: constrained_value
)(fullname, event_dim=event_dim, name=fullname)
else: # Cannot determine supermodule and hence cannot compute fullname.
constrained_value = transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value)
Expand Down Expand Up @@ -534,8 +544,15 @@ def __getattr__(self, name):
if isinstance(result, torch.nn.Parameter) and not name.endswith(
"_unconstrained"
):
if self._pyro_context.active:
if self._pyro_context.active and not _is_module_local_param_enabled():
pyro.param(self._pyro_get_fullname(name), result)
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake param statement to ensure any handlers of pyro.param are applied,
# even though we don't use the contents of the local parameter store
fullname = self._pyro_get_fullname(name)
pyro.poutine.runtime.effectful(type="param")(lambda *_, **__: result)(
fullname, result, name=fullname
)

if isinstance(result, torch.nn.Module):
if isinstance(result, PyroModule):
Expand All @@ -546,8 +563,19 @@ def __getattr__(self, name):
)
else:
# Regular nn.Modules trigger pyro.module statements.
if self._pyro_context.active:
if self._pyro_context.active and not _is_module_local_param_enabled():
pyro.module(self._pyro_get_fullname(name), result)
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake module statement to ensure any handlers of pyro.module are applied,
# even though we don't use the contents of the local parameter store
fullname_module = self._pyro_get_fullname(name)
for param_name, param_value in result.named_parameters():
fullname_param = pyro.params.param_store.param_with_module_name(
fullname_module, param_name
)
pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: param_value
)(fullname_param, param_value, name=fullname_param)

return result

Expand All @@ -569,7 +597,7 @@ def __setattr__(self, name, value):
pass
constrained_value, constraint, event_dim = value
self._pyro_params[name] = constraint, event_dim
if self._pyro_context.active:
if self._pyro_context.active and not _is_module_local_param_enabled():
fullname = self._pyro_get_fullname(name)
pyro.param(
fullname,
Expand All @@ -584,6 +612,21 @@ def __setattr__(self, name, value):
unconstrained_value = torch.nn.Parameter(unconstrained_value)
_PYRO_PARAM_STORE._params[fullname] = unconstrained_value
_PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake param statement to ensure any handlers of pyro.param are applied,
# even though we don't use the contents of the local parameter store
fullname = self._pyro_get_fullname(name)
constrained_value = detach_provenance(
pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: constrained_value
)(
fullname,
constraint=constraint,
event_dim=event_dim,
name=fullname,
)
)
unconstrained_value = _unconstrain(constrained_value, constraint)
else: # Cannot determine supermodule and hence cannot compute fullname.
unconstrained_value = _unconstrain(constrained_value, constraint)
super().__setattr__(name + "_unconstrained", unconstrained_value)
Expand All @@ -595,14 +638,23 @@ def __setattr__(self, name, value):
delattr(self, name)
except AttributeError:
pass
if self._pyro_context.active:
if self._pyro_context.active and not _is_module_local_param_enabled():
fullname = self._pyro_get_fullname(name)
value = pyro.param(fullname, value)
if not isinstance(value, torch.nn.Parameter):
# Update PyroModule ---> ParamStore (type only; data is preserved).
value = torch.nn.Parameter(detach_provenance(value))
_PYRO_PARAM_STORE._params[fullname] = value
_PYRO_PARAM_STORE._param_to_name[value] = fullname
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake param statement to ensure any handlers of pyro.param are applied,
# even though we don't use the contents of the local parameter store
fullname = self._pyro_get_fullname(name)
value = detach_provenance(
pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: value
)(fullname, value, constraint=constraints.real, name=fullname)
)
super().__setattr__(name, value)
return

Expand Down
81 changes: 80 additions & 1 deletion tests/nn/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pyro.infer import SVI, Trace_ELBO
from pyro.nn.module import PyroModule, PyroParam, PyroSample, clear, to_pyro_module_
from pyro.optim import Adam
from tests.common import assert_equal
from tests.common import assert_equal, xfail_param


def test_svi_smoke():
Expand Down Expand Up @@ -765,3 +765,82 @@ def test_bayesian_gru():
assert output.shape == (seq_len, batch_size, hidden_size)
output2, _ = gru(input_)
assert not torch.allclose(output2, output)


@pytest.mark.parametrize(
"use_local_params",
[
True,
xfail_param(
False, reason="torch.func not compatible with global parameter store"
),
],
)
def test_functorch_pyroparam(use_local_params):
class ParamModule(PyroModule):
def __init__(self):
super().__init__()
self.a2 = PyroParam(torch.tensor(0.678), constraints.positive)

@PyroParam(constraint=constraints.real)
def a1(self):
return torch.tensor(0.456)

class Model(PyroModule):
def __init__(self):
super().__init__()
self.param_module = ParamModule()
self.b1 = PyroParam(torch.tensor(0.123), constraints.positive)
self.b3 = torch.nn.Parameter(torch.tensor(0.789))
self.c = torch.nn.Linear(1, 1)

@PyroParam(constraint=constraints.positive)
def b2(self):
return torch.tensor(1.234)

def forward(self, x, y):
return (
(self.param_module.a1 + self.param_module.a2) * x
+ self.b1
+ self.b2
+ self.b3
- self.c(y.unsqueeze(-1)).squeeze(-1)
) ** 2

with pyro.settings.context(module_local_params=use_local_params):
model = Model()
x, y = torch.tensor(1.3), torch.tensor(0.2)

with pyro.poutine.trace() as tr:
model(x, y)

params = dict(model.named_parameters())

# Check that all parameters appear in the trace for SVI compatibility
assert len(params) == len(
{
name: node
for name, node in tr.trace.nodes.items()
if node["type"] == "param"
}
)

grad_model = torch.func.grad(
lambda p, x, y: torch.func.functional_call(model, p, (x, y))
)
grad_params_func = grad_model(params, x, y)

gs = torch.autograd.grad(model(x, y), tuple(params.values()))
grad_params_autograd = dict(zip(params.keys(), gs))

assert len(grad_params_autograd) == len(grad_params_func) != 0
assert (
set(grad_params_autograd.keys())
== set(grad_params_func.keys())
== set(params.keys())
)
for k in grad_params_autograd.keys():
assert not torch.allclose(
grad_params_func[k], torch.zeros_like(grad_params_func[k])
), k
assert torch.allclose(grad_params_autograd[k], grad_params_func[k]), k

0 comments on commit 4a55960

Please sign in to comment.