Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method for converting Pyro ExpandedDistribution to funsor #418

Merged
merged 1 commit into from
Jan 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,25 @@ def indepdist_to_funsor(backend_dist, output=None, dim_to_name=None):
return result


def expandeddist_to_funsor(backend_dist, output=None, dim_to_name=None):

funsor_base_dist = to_funsor(backend_dist.base_dist, output=output, dim_to_name=dim_to_name)
if not dim_to_name:
assert not backend_dist.batch_shape
return funsor_base_dist

name_to_dim = {name: dim for dim, name in dim_to_name.items()}
raw_expanded_params = {}
for name, funsor_param in funsor_base_dist.params.items():
if name == "value":
continue
raw_param = to_data(funsor_param, name_to_dim=name_to_dim)
raw_expanded_params[name] = raw_param.expand(backend_dist.batch_shape + funsor_param.shape)

raw_expanded_dist = type(backend_dist.base_dist)(**raw_expanded_params)
return to_funsor(raw_expanded_dist, output, dim_to_name)


def maskeddist_to_funsor(backend_dist, output=None, dim_to_name=None):
mask = to_funsor(ops.astype(backend_dist._mask, 'float32'), output=output, dim_to_name=dim_to_name)
funsor_base_dist = to_funsor(backend_dist.base_dist, output=output, dim_to_name=dim_to_name)
Expand Down
4 changes: 3 additions & 1 deletion funsor/torch/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pyro.distributions as dist
import pyro.distributions.testing.fakes as fakes
from pyro.distributions.torch_distribution import MaskedDistribution
from pyro.distributions.torch_distribution import ExpandedDistribution, MaskedDistribution
import torch

from funsor.cnf import Contraction
Expand All @@ -34,6 +34,7 @@
eager_mvn,
eager_normal,
eager_plate_multinomial,
expandeddist_to_funsor,
indepdist_to_funsor,
make_dist,
maskeddist_to_funsor,
Expand Down Expand Up @@ -296,6 +297,7 @@ def composetransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=N
return expr


to_funsor.register(ExpandedDistribution)(expandeddist_to_funsor)
to_funsor.register(torch.distributions.Independent)(indepdist_to_funsor)
to_funsor.register(MaskedDistribution)(maskeddist_to_funsor)
to_funsor.register(torch.distributions.TransformedDistribution)(transformeddist_to_funsor)
Expand Down
27 changes: 23 additions & 4 deletions test/test_distribution_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,16 @@ def __hash__(self):
xfail_reason="to_funsor/to_data conversion is not yet reversible",
)

# ExpandedDistribution
for extra_shape in [(), (3,), (2, 3)]:
# Poisson
DistTestCase(
f"backend_dist.torch_distribution.ExpandedDistribution(backend_dist.Poisson(rate=case.rate), {extra_shape + batch_shape})", # noqa: E501
(("rate", f"rand({batch_shape})"),),
funsor.Real,
xfail_reason="ExpandedDistribution only exists in torch backend" if get_backend() != "torch" else "",
)


###########################
# Generic tests:
Expand All @@ -447,6 +457,11 @@ def _default_dim_to_name(inputs_shape, event_inputs=None):
@pytest.mark.parametrize("case", TEST_CASES, ids=str)
def test_generic_distribution_to_funsor(case):

HIGHER_ORDER_DISTS = [
backend_dist.Independent,
backend_dist.TransformedDistribution,
] + ([backend_dist.torch_distribution.ExpandedDistribution] if get_backend() == "torch" else [])

with xfail_if_not_found():
raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain

Expand All @@ -462,12 +477,16 @@ def test_generic_distribution_to_funsor(case):
actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim)

assert isinstance(actual_dist, backend_dist.Distribution)
assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers
while isinstance(raw_dist, backend_dist.Independent) or type(raw_dist) == backend_dist.TransformedDistribution:
orig_raw_dist = raw_dist
while type(raw_dist) in HIGHER_ORDER_DISTS:
raw_dist = raw_dist.base_dist
actual_dist = actual_dist.base_dist
actual_dist = actual_dist.base_dist if type(actual_dist) in HIGHER_ORDER_DISTS else actual_dist
assert isinstance(actual_dist, backend_dist.Distribution)
assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers
assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers

if get_backend() == "torch" and "ExpandedDistribution" in case.raw_dist:
assert orig_raw_dist.batch_shape == actual_dist.batch_shape
return

for param_name, _ in case.raw_params:
assert hasattr(raw_dist, param_name)
Expand Down