Skip to content

Commit

Permalink
add method for converting Pyro ExpandedDistribution to funsor (#418)
Browse files Browse the repository at this point in the history
  • Loading branch information
eb8680 authored Jan 9, 2021
1 parent d3d2ce7 commit 180a946
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 5 deletions.
19 changes: 19 additions & 0 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,25 @@ def indepdist_to_funsor(backend_dist, output=None, dim_to_name=None):
return result


def expandeddist_to_funsor(backend_dist, output=None, dim_to_name=None):

funsor_base_dist = to_funsor(backend_dist.base_dist, output=output, dim_to_name=dim_to_name)
if not dim_to_name:
assert not backend_dist.batch_shape
return funsor_base_dist

name_to_dim = {name: dim for dim, name in dim_to_name.items()}
raw_expanded_params = {}
for name, funsor_param in funsor_base_dist.params.items():
if name == "value":
continue
raw_param = to_data(funsor_param, name_to_dim=name_to_dim)
raw_expanded_params[name] = raw_param.expand(backend_dist.batch_shape + funsor_param.shape)

raw_expanded_dist = type(backend_dist.base_dist)(**raw_expanded_params)
return to_funsor(raw_expanded_dist, output, dim_to_name)


def maskeddist_to_funsor(backend_dist, output=None, dim_to_name=None):
mask = to_funsor(ops.astype(backend_dist._mask, 'float32'), output=output, dim_to_name=dim_to_name)
funsor_base_dist = to_funsor(backend_dist.base_dist, output=output, dim_to_name=dim_to_name)
Expand Down
4 changes: 3 additions & 1 deletion funsor/torch/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import pyro.distributions as dist
import pyro.distributions.testing.fakes as fakes
from pyro.distributions.torch_distribution import MaskedDistribution
from pyro.distributions.torch_distribution import ExpandedDistribution, MaskedDistribution
import torch

from funsor.cnf import Contraction
Expand All @@ -34,6 +34,7 @@
eager_mvn,
eager_normal,
eager_plate_multinomial,
expandeddist_to_funsor,
indepdist_to_funsor,
make_dist,
maskeddist_to_funsor,
Expand Down Expand Up @@ -296,6 +297,7 @@ def composetransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=N
return expr


to_funsor.register(ExpandedDistribution)(expandeddist_to_funsor)
to_funsor.register(torch.distributions.Independent)(indepdist_to_funsor)
to_funsor.register(MaskedDistribution)(maskeddist_to_funsor)
to_funsor.register(torch.distributions.TransformedDistribution)(transformeddist_to_funsor)
Expand Down
27 changes: 23 additions & 4 deletions test/test_distribution_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,16 @@ def __hash__(self):
xfail_reason="to_funsor/to_data conversion is not yet reversible",
)

# ExpandedDistribution
for extra_shape in [(), (3,), (2, 3)]:
# Poisson
DistTestCase(
f"backend_dist.torch_distribution.ExpandedDistribution(backend_dist.Poisson(rate=case.rate), {extra_shape + batch_shape})", # noqa: E501
(("rate", f"rand({batch_shape})"),),
funsor.Real,
xfail_reason="ExpandedDistribution only exists in torch backend" if get_backend() != "torch" else "",
)


###########################
# Generic tests:
Expand All @@ -447,6 +457,11 @@ def _default_dim_to_name(inputs_shape, event_inputs=None):
@pytest.mark.parametrize("case", TEST_CASES, ids=str)
def test_generic_distribution_to_funsor(case):

HIGHER_ORDER_DISTS = [
backend_dist.Independent,
backend_dist.TransformedDistribution,
] + ([backend_dist.torch_distribution.ExpandedDistribution] if get_backend() == "torch" else [])

with xfail_if_not_found():
raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain

Expand All @@ -462,12 +477,16 @@ def test_generic_distribution_to_funsor(case):
actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim)

assert isinstance(actual_dist, backend_dist.Distribution)
assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers
while isinstance(raw_dist, backend_dist.Independent) or type(raw_dist) == backend_dist.TransformedDistribution:
orig_raw_dist = raw_dist
while type(raw_dist) in HIGHER_ORDER_DISTS:
raw_dist = raw_dist.base_dist
actual_dist = actual_dist.base_dist
actual_dist = actual_dist.base_dist if type(actual_dist) in HIGHER_ORDER_DISTS else actual_dist
assert isinstance(actual_dist, backend_dist.Distribution)
assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers
assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers

if get_backend() == "torch" and "ExpandedDistribution" in case.raw_dist:
assert orig_raw_dist.batch_shape == actual_dist.batch_shape
return

for param_name, _ in case.raw_params:
assert hasattr(raw_dist, param_name)
Expand Down

0 comments on commit 180a946

Please sign in to comment.