Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix interaction between PyroParam and torch.func.grad #3328

Merged
merged 9 commits into from
Feb 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 57 additions & 5 deletions pyro/nn/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from torch.distributions import constraints, transform_to

import pyro
import pyro.params.param_store
from pyro.ops.provenance import detach_provenance
from pyro.poutine.runtime import _PYRO_PARAM_STORE

Expand Down Expand Up @@ -474,7 +475,7 @@ def __getattr__(self, name):
if name in _pyro_params:
constraint, event_dim = _pyro_params[name]
unconstrained_value = getattr(self, name + "_unconstrained")
if self._pyro_context.active:
if self._pyro_context.active and not _is_module_local_param_enabled():
fullname = self._pyro_get_fullname(name)
if fullname in _PYRO_PARAM_STORE:
if (
Expand Down Expand Up @@ -503,6 +504,15 @@ def __getattr__(self, name):
_PYRO_PARAM_STORE._params[fullname] = unconstrained_value
_PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname
return pyro.param(fullname, event_dim=event_dim)
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake param statement to ensure any handlers of pyro.param are applied,
# even though we don't use the contents of the local parameter store
fullname = self._pyro_get_fullname(name)
constrained_value = transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value)
return pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: constrained_value
)(fullname, event_dim=event_dim, name=fullname)
eb8680 marked this conversation as resolved.
Show resolved Hide resolved
else: # Cannot determine supermodule and hence cannot compute fullname.
constrained_value = transform_to(constraint)(unconstrained_value)
constrained_value.unconstrained = weakref.ref(unconstrained_value)
Expand Down Expand Up @@ -534,8 +544,15 @@ def __getattr__(self, name):
if isinstance(result, torch.nn.Parameter) and not name.endswith(
"_unconstrained"
):
if self._pyro_context.active:
if self._pyro_context.active and not _is_module_local_param_enabled():
pyro.param(self._pyro_get_fullname(name), result)
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake param statement to ensure any handlers of pyro.param are applied,
# even though we don't use the contents of the local parameter store
fullname = self._pyro_get_fullname(name)
pyro.poutine.runtime.effectful(type="param")(lambda *_, **__: result)(
fullname, result, name=fullname
)

if isinstance(result, torch.nn.Module):
if isinstance(result, PyroModule):
Expand All @@ -546,8 +563,19 @@ def __getattr__(self, name):
)
else:
# Regular nn.Modules trigger pyro.module statements.
if self._pyro_context.active:
if self._pyro_context.active and not _is_module_local_param_enabled():
pyro.module(self._pyro_get_fullname(name), result)
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake module statement to ensure any handlers of pyro.module are applied,
# even though we don't use the contents of the local parameter store
fullname_module = self._pyro_get_fullname(name)
for param_name, param_value in result.named_parameters():
fullname_param = pyro.params.param_store.param_with_module_name(
fullname_module, param_name
)
pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: param_value
)(fullname_param, param_value, name=fullname_param)

return result

Expand All @@ -569,7 +597,7 @@ def __setattr__(self, name, value):
pass
constrained_value, constraint, event_dim = value
self._pyro_params[name] = constraint, event_dim
if self._pyro_context.active:
if self._pyro_context.active and not _is_module_local_param_enabled():
fullname = self._pyro_get_fullname(name)
pyro.param(
fullname,
Expand All @@ -584,6 +612,21 @@ def __setattr__(self, name, value):
unconstrained_value = torch.nn.Parameter(unconstrained_value)
_PYRO_PARAM_STORE._params[fullname] = unconstrained_value
_PYRO_PARAM_STORE._param_to_name[unconstrained_value] = fullname
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake param statement to ensure any handlers of pyro.param are applied,
# even though we don't use the contents of the local parameter store
fullname = self._pyro_get_fullname(name)
constrained_value = detach_provenance(
pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: constrained_value
)(
fullname,
constraint=constraint,
event_dim=event_dim,
name=fullname,
)
)
unconstrained_value = _unconstrain(constrained_value, constraint)
else: # Cannot determine supermodule and hence cannot compute fullname.
unconstrained_value = _unconstrain(constrained_value, constraint)
super().__setattr__(name + "_unconstrained", unconstrained_value)
Expand All @@ -595,14 +638,23 @@ def __setattr__(self, name, value):
delattr(self, name)
except AttributeError:
pass
if self._pyro_context.active:
if self._pyro_context.active and not _is_module_local_param_enabled():
fullname = self._pyro_get_fullname(name)
value = pyro.param(fullname, value)
if not isinstance(value, torch.nn.Parameter):
# Update PyroModule ---> ParamStore (type only; data is preserved).
value = torch.nn.Parameter(detach_provenance(value))
_PYRO_PARAM_STORE._params[fullname] = value
_PYRO_PARAM_STORE._param_to_name[value] = fullname
elif self._pyro_context.active and _is_module_local_param_enabled():
# fake param statement to ensure any handlers of pyro.param are applied,
# even though we don't use the contents of the local parameter store
fullname = self._pyro_get_fullname(name)
value = detach_provenance(
pyro.poutine.runtime.effectful(type="param")(
lambda *_, **__: value
)(fullname, value, constraint=constraints.real, name=fullname)
)
super().__setattr__(name, value)
return

Expand Down
81 changes: 80 additions & 1 deletion tests/nn/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from pyro.infer import SVI, Trace_ELBO
from pyro.nn.module import PyroModule, PyroParam, PyroSample, clear, to_pyro_module_
from pyro.optim import Adam
from tests.common import assert_equal
from tests.common import assert_equal, xfail_param


def test_svi_smoke():
Expand Down Expand Up @@ -765,3 +765,82 @@ def test_bayesian_gru():
assert output.shape == (seq_len, batch_size, hidden_size)
output2, _ = gru(input_)
assert not torch.allclose(output2, output)


@pytest.mark.parametrize(
"use_local_params",
[
True,
xfail_param(
False, reason="torch.func not compatible with global parameter store"
),
],
)
def test_functorch_pyroparam(use_local_params):
class ParamModule(PyroModule):
def __init__(self):
super().__init__()
self.a2 = PyroParam(torch.tensor(0.678), constraints.positive)

@PyroParam(constraint=constraints.real)
def a1(self):
return torch.tensor(0.456)

class Model(PyroModule):
def __init__(self):
super().__init__()
self.param_module = ParamModule()
self.b1 = PyroParam(torch.tensor(0.123), constraints.positive)
self.b3 = torch.nn.Parameter(torch.tensor(0.789))
self.c = torch.nn.Linear(1, 1)

@PyroParam(constraint=constraints.positive)
def b2(self):
return torch.tensor(1.234)

def forward(self, x, y):
return (
(self.param_module.a1 + self.param_module.a2) * x
+ self.b1
+ self.b2
+ self.b3
- self.c(y.unsqueeze(-1)).squeeze(-1)
) ** 2

with pyro.settings.context(module_local_params=use_local_params):
model = Model()
x, y = torch.tensor(1.3), torch.tensor(0.2)

with pyro.poutine.trace() as tr:
model(x, y)

params = dict(model.named_parameters())

# Check that all parameters appear in the trace for SVI compatibility
assert len(params) == len(
{
name: node
for name, node in tr.trace.nodes.items()
if node["type"] == "param"
}
)

grad_model = torch.func.grad(
lambda p, x, y: torch.func.functional_call(model, p, (x, y))
)
grad_params_func = grad_model(params, x, y)

gs = torch.autograd.grad(model(x, y), tuple(params.values()))
grad_params_autograd = dict(zip(params.keys(), gs))

assert len(grad_params_autograd) == len(grad_params_func) != 0
assert (
set(grad_params_autograd.keys())
== set(grad_params_func.keys())
== set(params.keys())
)
for k in grad_params_autograd.keys():
assert not torch.allclose(
grad_params_func[k], torch.zeros_like(grad_params_func[k])
), k
assert torch.allclose(grad_params_autograd[k], grad_params_func[k]), k
Loading