Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support ExpandedDistribution for jax backend #419

Merged
merged 2 commits into from
Jan 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def expandeddist_to_funsor(backend_dist, output=None, dim_to_name=None):
if name == "value":
continue
raw_param = to_data(funsor_param, name_to_dim=name_to_dim)
raw_expanded_params[name] = raw_param.expand(backend_dist.batch_shape + funsor_param.shape)
raw_expanded_params[name] = ops.expand(raw_param, backend_dist.batch_shape + funsor_param.shape)

raw_expanded_dist = type(backend_dist.base_dist)(**raw_expanded_params)
return to_funsor(raw_expanded_dist, output, dim_to_name)
Expand Down
2 changes: 2 additions & 0 deletions funsor/jax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
eager_mvn,
eager_normal,
eager_plate_multinomial,
expandeddist_to_funsor,
indepdist_to_funsor,
make_dist,
maskeddist_to_funsor,
Expand Down Expand Up @@ -211,6 +212,7 @@ def deltadist_to_data(funsor_dist, name_to_dim=None):
dist.TransformedDistribution.has_rsample = property(lambda self: self.base_dist.has_rsample)
dist.TransformedDistribution.rsample = dist.TransformedDistribution.sample

to_funsor.register(dist.ExpandedDistribution)(expandeddist_to_funsor)
to_funsor.register(dist.Independent)(indepdist_to_funsor)
if hasattr(dist, "MaskedDistribution"):
to_funsor.register(dist.MaskedDistribution)(maskeddist_to_funsor)
Expand Down
16 changes: 12 additions & 4 deletions test/test_distribution_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ def __getattribute__(self, attr):
FAKES = _fakes()


if get_backend() == "jax":
_expanded_dist_path = "backend_dist.ExpandedDistribution"
elif get_backend() == "torch":
_expanded_dist_path = "backend_dist.torch_distribution.ExpandedDistribution"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like you need to define a dummy value of _expanded_dist_path when get_backend() == "numpy" to get test collection to work, even though the tests themselves are all skipped.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Eli!

else:
_expanded_dist_path = ""


def normalize_with_subs(cls, *args):
"""
This interpretation is like normalize, except it also evaluates Subs eagerly.
Expand Down Expand Up @@ -431,10 +439,9 @@ def __hash__(self):
for extra_shape in [(), (3,), (2, 3)]:
# Poisson
DistTestCase(
f"backend_dist.torch_distribution.ExpandedDistribution(backend_dist.Poisson(rate=case.rate), {extra_shape + batch_shape})", # noqa: E501
_expanded_dist_path + f"(backend_dist.Poisson(rate=case.rate), {extra_shape + batch_shape})", # noqa: E501
(("rate", f"rand({batch_shape})"),),
funsor.Real,
xfail_reason="ExpandedDistribution only exists in torch backend" if get_backend() != "torch" else "",
)


Expand All @@ -460,7 +467,8 @@ def test_generic_distribution_to_funsor(case):
HIGHER_ORDER_DISTS = [
backend_dist.Independent,
backend_dist.TransformedDistribution,
] + ([backend_dist.torch_distribution.ExpandedDistribution] if get_backend() == "torch" else [])
] + ([backend_dist.torch_distribution.ExpandedDistribution] if get_backend() == "torch"
else [backend_dist.ExpandedDistribution])

with xfail_if_not_found():
raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain
Expand All @@ -484,7 +492,7 @@ def test_generic_distribution_to_funsor(case):
assert isinstance(actual_dist, backend_dist.Distribution)
assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers

if get_backend() == "torch" and "ExpandedDistribution" in case.raw_dist:
if "ExpandedDistribution" in case.raw_dist:
assert orig_raw_dist.batch_shape == actual_dist.batch_shape
return

Expand Down