diff --git a/funsor/distribution.py b/funsor/distribution.py index 049c24580..8830c2f16 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -359,6 +359,25 @@ def indepdist_to_funsor(backend_dist, output=None, dim_to_name=None): return result +def expandeddist_to_funsor(backend_dist, output=None, dim_to_name=None): + + funsor_base_dist = to_funsor(backend_dist.base_dist, output=output, dim_to_name=dim_to_name) + if not dim_to_name: + assert not backend_dist.batch_shape + return funsor_base_dist + + name_to_dim = {name: dim for dim, name in dim_to_name.items()} + raw_expanded_params = {} + for name, funsor_param in funsor_base_dist.params.items(): + if name == "value": + continue + raw_param = to_data(funsor_param, name_to_dim=name_to_dim) + raw_expanded_params[name] = raw_param.expand(backend_dist.batch_shape + funsor_param.shape) + + raw_expanded_dist = type(backend_dist.base_dist)(**raw_expanded_params) + return to_funsor(raw_expanded_dist, output, dim_to_name) + + def maskeddist_to_funsor(backend_dist, output=None, dim_to_name=None): mask = to_funsor(ops.astype(backend_dist._mask, 'float32'), output=output, dim_to_name=dim_to_name) funsor_base_dist = to_funsor(backend_dist.base_dist, output=output, dim_to_name=dim_to_name) diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index 974e1b5da..4f281f6df 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -7,7 +7,7 @@ import pyro.distributions as dist import pyro.distributions.testing.fakes as fakes -from pyro.distributions.torch_distribution import MaskedDistribution +from pyro.distributions.torch_distribution import ExpandedDistribution, MaskedDistribution import torch from funsor.cnf import Contraction @@ -34,6 +34,7 @@ eager_mvn, eager_normal, eager_plate_multinomial, + expandeddist_to_funsor, indepdist_to_funsor, make_dist, maskeddist_to_funsor, @@ -296,6 +297,7 @@ def composetransform_to_funsor(tfm, output=None, dim_to_name=None, real_inputs=N return expr +to_funsor.register(ExpandedDistribution)(expandeddist_to_funsor) to_funsor.register(torch.distributions.Independent)(indepdist_to_funsor) to_funsor.register(MaskedDistribution)(maskeddist_to_funsor) to_funsor.register(torch.distributions.TransformedDistribution)(transformeddist_to_funsor) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 50a00e237..92f85dc29 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -427,6 +427,16 @@ def __hash__(self): xfail_reason="to_funsor/to_data conversion is not yet reversible", ) + # ExpandedDistribution + for extra_shape in [(), (3,), (2, 3)]: + # Poisson + DistTestCase( + f"backend_dist.torch_distribution.ExpandedDistribution(backend_dist.Poisson(rate=case.rate), {extra_shape + batch_shape})", # noqa: E501 + (("rate", f"rand({batch_shape})"),), + funsor.Real, + xfail_reason="ExpandedDistribution only exists in torch backend" if get_backend() != "torch" else "", + ) + ########################### # Generic tests: @@ -447,6 +457,11 @@ def _default_dim_to_name(inputs_shape, event_inputs=None): @pytest.mark.parametrize("case", TEST_CASES, ids=str) def test_generic_distribution_to_funsor(case): + HIGHER_ORDER_DISTS = [ + backend_dist.Independent, + backend_dist.TransformedDistribution, + ] + ([backend_dist.torch_distribution.ExpandedDistribution] if get_backend() == "torch" else []) + with xfail_if_not_found(): raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain @@ -462,12 +477,16 @@ def test_generic_distribution_to_funsor(case): actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) assert isinstance(actual_dist, backend_dist.Distribution) - assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers - while isinstance(raw_dist, backend_dist.Independent) or type(raw_dist) == backend_dist.TransformedDistribution: + orig_raw_dist = raw_dist + while type(raw_dist) in HIGHER_ORDER_DISTS: raw_dist = raw_dist.base_dist - actual_dist = actual_dist.base_dist + actual_dist = actual_dist.base_dist if type(actual_dist) in HIGHER_ORDER_DISTS else actual_dist assert isinstance(actual_dist, backend_dist.Distribution) - assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers + assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers + + if get_backend() == "torch" and "ExpandedDistribution" in case.raw_dist: + assert orig_raw_dist.batch_shape == actual_dist.batch_shape + return for param_name, _ in case.raw_params: assert hasattr(raw_dist, param_name)