From b8c57e7d55dd3751c079bcbe0e2eee93d787b416 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 8 Jan 2021 21:48:17 -0600 Subject: [PATCH 1/2] support expand for jax backend --- funsor/distribution.py | 2 +- funsor/jax/distributions.py | 2 ++ test/test_distribution_generic.py | 14 ++++++++++---- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index 8830c2f16..d8563be33 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -372,7 +372,7 @@ def expandeddist_to_funsor(backend_dist, output=None, dim_to_name=None): if name == "value": continue raw_param = to_data(funsor_param, name_to_dim=name_to_dim) - raw_expanded_params[name] = raw_param.expand(backend_dist.batch_shape + funsor_param.shape) + raw_expanded_params[name] = ops.expand(raw_param, backend_dist.batch_shape + funsor_param.shape) raw_expanded_dist = type(backend_dist.base_dist)(**raw_expanded_params) return to_funsor(raw_expanded_dist, output, dim_to_name) diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index e0e067e90..d9713a5a7 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -30,6 +30,7 @@ eager_mvn, eager_normal, eager_plate_multinomial, + expandeddist_to_funsor, indepdist_to_funsor, make_dist, maskeddist_to_funsor, @@ -211,6 +212,7 @@ def deltadist_to_data(funsor_dist, name_to_dim=None): dist.TransformedDistribution.has_rsample = property(lambda self: self.base_dist.has_rsample) dist.TransformedDistribution.rsample = dist.TransformedDistribution.sample +to_funsor.register(dist.ExpandedDistribution)(expandeddist_to_funsor) to_funsor.register(dist.Independent)(indepdist_to_funsor) if hasattr(dist, "MaskedDistribution"): to_funsor.register(dist.MaskedDistribution)(maskeddist_to_funsor) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 92f85dc29..aebadf880 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -40,6 +40,12 @@ def __getattribute__(self, attr): FAKES = _fakes() +if get_backend() == "jax": + _expanded_dist_path = "backend_dist.ExpandedDistribution" +elif get_backend() == "torch": + _expanded_dist_path = "backend_dist.torch_distribution.ExpandedDistribution" + + def normalize_with_subs(cls, *args): """ This interpretation is like normalize, except it also evaluates Subs eagerly. @@ -431,10 +437,9 @@ def __hash__(self): for extra_shape in [(), (3,), (2, 3)]: # Poisson DistTestCase( - f"backend_dist.torch_distribution.ExpandedDistribution(backend_dist.Poisson(rate=case.rate), {extra_shape + batch_shape})", # noqa: E501 + _expanded_dist_path + f"(backend_dist.Poisson(rate=case.rate), {extra_shape + batch_shape})", # noqa: E501 (("rate", f"rand({batch_shape})"),), funsor.Real, - xfail_reason="ExpandedDistribution only exists in torch backend" if get_backend() != "torch" else "", ) @@ -460,7 +465,8 @@ def test_generic_distribution_to_funsor(case): HIGHER_ORDER_DISTS = [ backend_dist.Independent, backend_dist.TransformedDistribution, - ] + ([backend_dist.torch_distribution.ExpandedDistribution] if get_backend() == "torch" else []) + ] + ([backend_dist.torch_distribution.ExpandedDistribution] if get_backend() == "torch" + else [backend_dist.ExpandedDistribution]) with xfail_if_not_found(): raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain @@ -484,7 +490,7 @@ def test_generic_distribution_to_funsor(case): assert isinstance(actual_dist, backend_dist.Distribution) assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers - if get_backend() == "torch" and "ExpandedDistribution" in case.raw_dist: + if "ExpandedDistribution" in case.raw_dist: assert orig_raw_dist.batch_shape == actual_dist.batch_shape return From 6fc81da3ad7007fc88bb47e884801b15af00fadc Mon Sep 17 00:00:00 2001 From: Du Phan Date: Sat, 9 Jan 2021 22:54:05 -0600 Subject: [PATCH 2/2] add expanded_dist_path for numpy --- test/test_distribution_generic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index aebadf880..f7bee7f94 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -44,6 +44,8 @@ def __getattribute__(self, attr): _expanded_dist_path = "backend_dist.ExpandedDistribution" elif get_backend() == "torch": _expanded_dist_path = "backend_dist.torch_distribution.ExpandedDistribution" +else: + _expanded_dist_path = "" def normalize_with_subs(cls, *args):