diff --git a/Makefile b/Makefile index 46bd0306c..20e6681c1 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,9 @@ ifeq (${FUNSOR_BACKEND}, torch) python examples/sensor.py --seed=0 --num-frames=2 -n 1 @echo PASS else ifeq (${FUNSOR_BACKEND}, jax) - pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi + pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi --ignore=test/test_distribution.py --ignore=test/test_distribution_generic.py + pytest -v -n auto test/test_distribution.py + pytest -v -n auto test/test_distribution_generic.py @echo PASS else # default backend diff --git a/funsor/distribution.py b/funsor/distribution.py index 3e41dc932..e49e4e1d3 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -134,23 +134,28 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): value = params.pop("value") assert all(isinstance(v, (Number, Tensor)) for v in params.values()) assert isinstance(value, Variable) and value.name in sampled_vars - inputs_, tensors = align_tensors(*params.values()) - inputs = OrderedDict(sample_inputs.items()) - inputs.update(inputs_) - sample_shape = tuple(v.size for v in sample_inputs.values()) - raw_dist = self.dist_class(**dict(zip(self._ast_fields[:-1], tensors))) + value_name = value.name + raw_dist, value_output, dim_to_name = self._get_raw_dist() + for d, name in zip(range(len(sample_inputs), 0, -1), sample_inputs.keys()): + dim_to_name[-d - len(raw_dist.batch_shape)] = name + + sample_shape = tuple(v.size for v in sample_inputs.values()) sample_args = (sample_shape,) if get_backend() == "torch" else (rng_key, sample_shape) if self.has_rsample: - raw_sample = raw_dist.rsample(*sample_args) + raw_value = raw_dist.rsample(*sample_args) else: - raw_sample = ops.detach(raw_dist.sample(*sample_args)) + raw_value = ops.detach(raw_dist.sample(*sample_args)) - result = funsor.delta.Delta(value.name, Tensor(raw_sample, inputs, value.output.dtype)) + funsor_value = to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name) + funsor_value = funsor_value.align( + tuple(sample_inputs) + tuple(inp for inp in self.inputs if inp in funsor_value.inputs)) + result = funsor.delta.Delta(value_name, funsor_value) if not self.has_rsample: # scaling of dice_factor by num samples should already be handled by Funsor.sample - raw_log_prob = raw_dist.log_prob(raw_sample) - dice_factor = Tensor(raw_log_prob - ops.detach(raw_log_prob), inputs) + raw_log_prob = raw_dist.log_prob(raw_value) + dice_factor = to_funsor(raw_log_prob - ops.detach(raw_log_prob), + output=self.output, dim_to_name=dim_to_name) result = result + dice_factor return result diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index dd780adab..e676ee015 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -175,14 +175,11 @@ def _infer_param_domain(cls, name, raw_shape): @to_funsor.register(dist.BinomialProbs) @to_funsor.register(dist.BinomialLogits) def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): - new_pyro_dist = _NumPyroWrapper_Binomial(probs=numpyro_dist.probs) + new_pyro_dist = _NumPyroWrapper_Binomial(total_count=numpyro_dist.total_count, probs=numpyro_dist.probs) return backenddist_to_funsor(Binomial, new_pyro_dist, output, dim_to_name) # noqa: F821 @to_funsor.register(dist.CategoricalProbs) -# XXX: in Pyro backend, we always convert pyro.distributions.Categorical -# to funsor.torch.distributions.Categorical -@to_funsor.register(dist.CategoricalLogits) def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): new_pyro_dist = _NumPyroWrapper_Categorical(probs=numpyro_dist.probs) return backenddist_to_funsor(Categorical, new_pyro_dist, output, dim_to_name) # noqa: F821 @@ -191,7 +188,7 @@ def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): @to_funsor.register(dist.MultinomialProbs) @to_funsor.register(dist.MultinomialLogits) def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): - new_pyro_dist = _NumPyroWrapper_Multinomial(probs=numpyro_dist.probs) + new_pyro_dist = _NumPyroWrapper_Multinomial(total_count=numpyro_dist.total_count, probs=numpyro_dist.probs) return backenddist_to_funsor(Multinomial, new_pyro_dist, output, dim_to_name) # noqa: F821 diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index d709de2ed..6daa1edd2 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -121,6 +121,11 @@ def _is_numeric_array(x): return True +@ops.isnan.register(array) +def _isnan(x): + return np.isnan(x) + + @ops.lgamma.register(array) def _lgamma(x): return gammaln(x) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 915935553..c38449e76 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -24,6 +24,7 @@ diagonal = Op("diagonal") einsum = Op("einsum") full_like = Op(np.full_like) +isnan = Op(np.isnan) prod = Op(np.prod) stack = Op("stack") sum = Op(np.sum) @@ -300,6 +301,7 @@ def unsqueeze(x, dim): 'finfo', 'full_like', 'is_numeric_array', + 'isnan', 'logaddexp', 'logsumexp', 'new_arange', diff --git a/funsor/testing.py b/funsor/testing.py index afc9ff8db..1e42ff1de 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +import importlib import itertools import numbers import operator @@ -265,6 +266,24 @@ def randn(*args): return np.array(np.random.randn(*shape)) +def random_scale_tril(*args): + if isinstance(args[0], tuple): + assert len(args) == 1 + shape = args[0] + else: + shape = args + + from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND + backend_dist = importlib.import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist + + if get_backend() == "torch": + data = randn(shape) + return backend_dist.transforms.transform_to(backend_dist.constraints.lower_cholesky)(data) + else: + data = randn(shape[:-2] + (shape[-1] * (shape[-1] + 1) // 2,)) + return backend_dist.biject_to(backend_dist.constraints.lower_cholesky)(data) + + def zeros(*args): if isinstance(args[0], tuple): assert len(args) == 1 diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index ba1798a5a..4731605a9 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools +import numbers from typing import Tuple, Union import pyro.distributions as dist @@ -40,7 +41,7 @@ from funsor.domains import Real, Reals import funsor.ops as ops from funsor.tensor import Tensor, dummy_numeric_array -from funsor.terms import Binary, Funsor, Variable, eager, to_funsor +from funsor.terms import Binary, Funsor, Variable, eager, to_data, to_funsor from funsor.util import methodof @@ -153,6 +154,19 @@ def _infer_param_domain(cls, name, raw_shape): return Real +########################################################### +# Converting distribution funsors to PyTorch distributions +########################################################### + +@to_data.register(Multinomial) # noqa: F821 +def multinomial_to_data(funsor_dist, name_to_dim=None): + probs = to_data(funsor_dist.probs, name_to_dim) + total_count = to_data(funsor_dist.total_count, name_to_dim) + if isinstance(total_count, numbers.Number) or len(total_count.shape) == 0: + return dist.Multinomial(int(total_count), probs=probs) + raise NotImplementedError("inhomogeneous total_count not supported") + + ############################################### # Converting PyTorch Distributions to funsors ############################################### diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index 9a77fe520..d466934f1 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -107,6 +107,11 @@ def _is_numeric_array(x): return True +@ops.isnan.register(torch.Tensor) +def _isnan(x): + return torch.isnan(x) + + @ops.lgamma.register(torch.Tensor) def _lgamma(x): return x.lgamma() diff --git a/test/test_distribution.py b/test/test_distribution.py index f789414b4..4d02de93f 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -19,7 +19,8 @@ from funsor.interpreter import interpretation, reinterpret from funsor.tensor import Einsum, Tensor, numeric_array, stack from funsor.terms import Independent, Variable, eager, lazy, to_funsor -from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_mvn, random_tensor, xfail_param +from funsor.testing import assert_close, check_funsor, rand, randint, randn, \ + random_mvn, random_scale_tril, random_tensor, xfail_param from funsor.util import get_backend pytestmark = pytest.mark.skipif(get_backend() == "numpy", @@ -472,15 +473,6 @@ def test_mvn_defaults(): assert dist.MultivariateNormal(loc, scale_tril) is dist.MultivariateNormal(loc, scale_tril, value) -def _random_scale_tril(shape): - if get_backend() == "torch": - data = randn(shape) - return backend_dist.transforms.transform_to(backend_dist.constraints.lower_cholesky)(data) - else: - data = randn(shape[:-2] + (shape[-1] * (shape[-1] + 1) // 2,)) - return backend_dist.biject_to(backend_dist.constraints.lower_cholesky)(data) - - @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) def test_mvn_density(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] @@ -493,7 +485,7 @@ def mvn(loc: Reals[3], scale_tril: Reals[3, 3], value: Reals[3]) -> Real: check_funsor(mvn, {'loc': Reals[3], 'scale_tril': Reals[3, 3], 'value': Reals[3]}, Real) loc = Tensor(randn(batch_shape + (3,)), inputs) - scale_tril = Tensor(_random_scale_tril(batch_shape + (3, 3)), inputs) + scale_tril = Tensor(random_scale_tril(batch_shape + (3, 3)), inputs) value = Tensor(randn(batch_shape + (3,)), inputs) expected = mvn(loc, scale_tril, value) check_funsor(expected, inputs, Real) @@ -509,7 +501,7 @@ def test_mvn_gaussian(batch_shape): inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) loc = Tensor(randn(batch_shape + (3,)), inputs) - scale_tril = Tensor(_random_scale_tril(batch_shape + (3, 3)), inputs) + scale_tril = Tensor(random_scale_tril(batch_shape + (3, 3)), inputs) value = Tensor(randn(batch_shape + (3,)), inputs) expected = dist.MultivariateNormal(loc, scale_tril, value) @@ -808,7 +800,7 @@ def test_mvn_sample(with_lazy, batch_shape, sample_inputs, event_shape): inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) loc = randn(batch_shape + event_shape) - scale_tril = _random_scale_tril(batch_shape + event_shape * 2) + scale_tril = random_scale_tril(batch_shape + event_shape * 2) funsor_dist_class = dist.MultivariateNormal params = (loc, scale_tril) @@ -893,7 +885,7 @@ def test_binomial_sample(with_lazy, batch_shape, sample_inputs): funsor_dist_class = dist.Binomial params = (total_count, probs) - _check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=2e-2, skip_grad=True, with_lazy=with_lazy) + _check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=5e-2, skip_grad=True, with_lazy=with_lazy) @pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')]) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py new file mode 100644 index 000000000..356f109e4 --- /dev/null +++ b/test/test_distribution_generic.py @@ -0,0 +1,329 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import os +from collections import OrderedDict +from importlib import import_module + +import numpy as np +import pytest + +import funsor +import funsor.ops as ops +from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND +from funsor.interpreter import interpretation +from funsor.terms import lazy, to_data, to_funsor +from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_scale_tril, xfail_if_not_implemented # noqa: F401,E501 +from funsor.util import get_backend + + +_ENABLE_MC_DIST_TESTS = int(os.environ.get("FUNSOR_ENABLE_MC_DIST_TESTS", 0)) + +pytestmark = pytest.mark.skipif(get_backend() == "numpy", + reason="numpy does not have distributions backend") +if get_backend() != "numpy": + dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) + backend_dist = dist.dist + + class _fakes: + """alias for accessing nonreparameterized distributions""" + def __getattribute__(self, attr): + if get_backend() == "torch": + return getattr(backend_dist.testing.fakes, attr) + elif get_backend() == "jax": + return getattr(dist, "_NumPyroWrapper_" + attr) + raise ValueError(attr) + + FAKES = _fakes() + + +################################################## +# Test cases +################################################## + +TEST_CASES = [] + + +class DistTestCase: + + def __init__(self, raw_dist, raw_params, expected_value_domain): + self.raw_dist = raw_dist + self.raw_params = raw_params + self.expected_value_domain = expected_value_domain + for name, raw_param in self.raw_params: + if get_backend() != "numpy": + # we need direct access to these tensors for gradient tests + setattr(self, name, eval(raw_param)) + TEST_CASES.append(self) + + def __str__(self): + return self.raw_dist + " " + str(self.raw_params) + + def __hash__(self): + return hash((self.raw_dist, self.raw_params, self.expected_value_domain)) + + +for batch_shape in [(), (5,), (2, 3)]: + + # BernoulliLogits + DistTestCase( + "backend_dist.Bernoulli(logits=case.logits)", + (("logits", f"rand({batch_shape})"),), + funsor.Real, + ) + + # BernoulliProbs + DistTestCase( + "backend_dist.Bernoulli(probs=case.probs)", + (("probs", f"rand({batch_shape})"),), + funsor.Real, + ) + + # Beta + DistTestCase( + "backend_dist.Beta(case.concentration1, case.concentration0)", + (("concentration1", f"ops.exp(randn({batch_shape}))"), ("concentration0", f"ops.exp(randn({batch_shape}))")), + funsor.Real, + ) + # NonreparameterizedBeta + DistTestCase( + "FAKES.NonreparameterizedBeta(case.concentration1, case.concentration0)", + (("concentration1", f"ops.exp(randn({batch_shape}))"), ("concentration0", f"ops.exp(randn({batch_shape}))")), + funsor.Real, + ) + + # Binomial + DistTestCase( + "backend_dist.Binomial(total_count=case.total_count, probs=case.probs)", + (("total_count", "randint(10, 12, ())" if get_backend() == "jax" else "5"), ("probs", f"rand({batch_shape})")), + funsor.Real, + ) + + # CategoricalLogits + for size in [2, 4]: + DistTestCase( + "backend_dist.Categorical(logits=case.logits)", + (("logits", f"rand({batch_shape + (size,)})"),), + funsor.Bint[size], + ) + + # CategoricalProbs + for size in [2, 4]: + DistTestCase( + "backend_dist.Categorical(probs=case.probs)", + (("probs", f"rand({batch_shape + (size,)})"),), + funsor.Bint[size], + ) + + # TODO figure out what this should be... + # # Delta + # for event_shape in [(),]: # (4,), (3, 2)]: + # TEST_CASES += [DistTestCase( + # "backend_dist.Delta(case.v, case.log_density)", + # (("v", f"rand({batch_shape + event_shape})"), ("log_density", f"rand({batch_shape})")), + # funsor.Real, # s[event_shape], + # )] + + # Dirichlet + for event_shape in [(1,), (4,)]: + DistTestCase( + "backend_dist.Dirichlet(case.concentration)", + (("concentration", f"rand({batch_shape + event_shape})"),), + funsor.Reals[event_shape], + ) + # NonreparameterizedDirichlet + DistTestCase( + "FAKES.NonreparameterizedDirichlet(case.concentration)", + (("concentration", f"rand({batch_shape + event_shape})"),), + funsor.Reals[event_shape], + ) + + # DirichletMultinomial + for event_shape in [(1,), (4,)]: + DistTestCase( + "backend_dist.DirichletMultinomial(case.concentration, case.total_count)", + (("concentration", f"rand({batch_shape + event_shape})"), ("total_count", "randint(10, 12, ())")), + funsor.Reals[event_shape], + ) + + # Gamma + DistTestCase( + "backend_dist.Gamma(case.concentration, case.rate)", + (("concentration", f"rand({batch_shape})"), ("rate", f"rand({batch_shape})")), + funsor.Real, + ) + # NonreparametrizedGamma + DistTestCase( + "FAKES.NonreparameterizedGamma(case.concentration, case.rate)", + (("concentration", f"rand({batch_shape})"), ("rate", f"rand({batch_shape})")), + funsor.Real, + ) + + # Multinomial + for event_shape in [(1,), (4,)]: + DistTestCase( + "backend_dist.Multinomial(case.total_count, probs=case.probs)", + (("total_count", "randint(5, 7, ())" if get_backend() == "jax" else "5"), + ("probs", f"rand({batch_shape + event_shape})")), + funsor.Reals[event_shape], + ) + + # MultivariateNormal + for event_shape in [(1,), (3,)]: + DistTestCase( + "backend_dist.MultivariateNormal(loc=case.loc, scale_tril=case.scale_tril)", + (("loc", f"randn({batch_shape + event_shape})"), ("scale_tril", f"random_scale_tril({batch_shape + event_shape * 2})")), # noqa: E501 + funsor.Reals[event_shape], + ) + + # Normal + DistTestCase( + "backend_dist.Normal(case.loc, case.scale)", + (("loc", f"randn({batch_shape})"), ("scale", f"rand({batch_shape})")), + funsor.Real, + ) + # NonreparameterizedNormal + DistTestCase( + "FAKES.NonreparameterizedNormal(case.loc, case.scale)", + (("loc", f"randn({batch_shape})"), ("scale", f"rand({batch_shape})")), + funsor.Real, + ) + + # Poisson + DistTestCase( + "backend_dist.Poisson(rate=case.rate)", + (("rate", f"rand({batch_shape})"),), + funsor.Real, + ) + + # VonMises + DistTestCase( + "backend_dist.VonMises(case.loc, case.concentration)", + (("loc", f"rand({batch_shape})"), ("concentration", f"rand({batch_shape})")), + funsor.Real, + ) + + +########################### +# Generic tests: +# High-level distribution testing strategy: sequence of increasingly semantically strong distribution-agnostic tests +# Conversion invertibility -> density type and value -> enumerate_support type and value -> samplers -> gradients +########################### + +def _default_dim_to_name(inputs_shape, event_inputs=None): + DIM_TO_NAME = tuple(map("_pyro_dim_{}".format, range(-100, 0))) + dim_to_name_list = DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME + dim_to_name = OrderedDict(zip( + range(-len(inputs_shape), 0), + dim_to_name_list[len(dim_to_name_list) - len(inputs_shape):])) + name_to_dim = OrderedDict((name, dim) for dim, name in dim_to_name.items()) + return dim_to_name, name_to_dim + + +@pytest.mark.parametrize("case", TEST_CASES, ids=str) +def test_generic_distribution_to_funsor(case): + + raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain + + dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) + with interpretation(lazy): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) + + assert isinstance(actual_dist, backend_dist.Distribution) + assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers + assert funsor_dist.inputs["value"] == expected_value_domain + for param_name in funsor_dist.params.keys(): + if param_name == "value": + continue + assert hasattr(raw_dist, param_name) + assert_close(getattr(actual_dist, param_name), getattr(raw_dist, param_name)) + + +@pytest.mark.parametrize("case", TEST_CASES, ids=str) +def test_generic_log_prob(case): + + raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain + + dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + expected_inputs = {name: funsor.Bint[raw_dist.batch_shape[dim]] for dim, name in dim_to_name.items()} + expected_inputs.update({"value": expected_value_domain}) + + check_funsor(funsor_dist, expected_inputs, funsor.Real) + + if get_backend() == "jax": + raw_value = raw_dist.sample(key=np.array([0, 0], dtype=np.uint32)) + else: + raw_value = raw_dist.sample() + expected_logprob = to_funsor(raw_dist.log_prob(raw_value), output=funsor.Real, dim_to_name=dim_to_name) + funsor_value = to_funsor(raw_value, output=expected_value_domain, dim_to_name=dim_to_name) + assert_close(funsor_dist(value=funsor_value), expected_logprob, rtol=1e-4) + + +@pytest.mark.parametrize("case", TEST_CASES, ids=str) +@pytest.mark.parametrize("expand", [False, True]) +def test_generic_enumerate_support(case, expand): + + raw_dist = eval(case.raw_dist) + + dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) + with interpretation(lazy): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + + assert getattr(raw_dist, "has_enumerate_support", False) == getattr(funsor_dist, "has_enumerate_support", False) + if getattr(funsor_dist, "has_enumerate_support", False): + name_to_dim["value"] = -1 if not name_to_dim else min(name_to_dim.values()) - 1 + with xfail_if_not_implemented("enumerate support not implemented"): + raw_support = raw_dist.enumerate_support(expand=expand) + funsor_support = funsor_dist.enumerate_support(expand=expand) + assert_close(to_data(funsor_support, name_to_dim=name_to_dim), raw_support) + + +@pytest.mark.parametrize("case", TEST_CASES, ids=str) +@pytest.mark.parametrize("sample_shape", [(), (2,), (4, 3)], ids=str) +def test_generic_sample(case, sample_shape): + + raw_dist = eval(case.raw_dist) + + dim_to_name, name_to_dim = _default_dim_to_name(sample_shape + raw_dist.batch_shape) + with interpretation(lazy): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + + sample_inputs = OrderedDict((dim_to_name[dim - len(raw_dist.batch_shape)], funsor.Bint[sample_shape[dim]]) + for dim in range(-len(sample_shape), 0)) + rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) + sample_value = funsor_dist.sample(frozenset(['value']), sample_inputs, rng_key=rng_key) + expected_inputs = OrderedDict(tuple(sample_inputs.items()) + tuple(funsor_dist.inputs.items())) + # TODO compare sample values on jax backend + check_funsor(sample_value, expected_inputs, funsor.Real) + + +@pytest.mark.parametrize("case", TEST_CASES, ids=str) +@pytest.mark.parametrize("statistic", [ + "mean", + "variance", + pytest.param("entropy", marks=[pytest.mark.skipif(get_backend() == "jax", reason="entropy not implemented")]) +]) +def test_generic_stats(case, statistic): + + raw_dist = eval(case.raw_dist) + + dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) + with interpretation(lazy): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + + with xfail_if_not_implemented(msg="entropy not implemented for some distributions"): + actual_stat = getattr(funsor_dist, statistic)() + + expected_stat_raw = getattr(raw_dist, statistic) + if statistic == "entropy": + expected_stat = to_funsor(expected_stat_raw(), output=funsor.Real, dim_to_name=dim_to_name) + else: + expected_stat = to_funsor(expected_stat_raw, output=case.expected_value_domain, dim_to_name=dim_to_name) + + check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) + if ops.isnan(expected_stat.data).all(): + pytest.xfail(reason="base stat returns nan") + else: + assert_close(to_data(actual_stat, name_to_dim), to_data(expected_stat, name_to_dim), rtol=1e-4)