From 0f904b9505899f409fe7bcf44a1c077df28814d2 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 29 Oct 2020 17:32:06 -0400 Subject: [PATCH 01/38] add generic stat methods to distribution --- funsor/distribution.py | 42 +++++++++++++++++++++++++++++++-------- test/test_distribution.py | 23 ++++++++++----------- 2 files changed, 44 insertions(+), 21 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index 98a98c0be..51fbd0a68 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -105,6 +105,22 @@ def eager_log_prob(cls, *params): data = cls.dist_class(**params).log_prob(value) return Tensor(data, inputs) + def _get_raw_dist(self): + """ + Internal method for working with underlying distribution attributes + """ + if isinstance(self.value, Variable): + value_name = self.value.name + else: + raise NotImplementedError("cannot get raw dist for {}".format(self)) + # arbitrary name-dim mapping, since we're converting back to a funsor anyway + name_to_dim = {name: -dim-1 for dim, (name, domain) in enumerate(self.inputs.items()) + if isinstance(domain.dtype, int) and name != value_name} + raw_dist = to_data(self, name_to_dim=name_to_dim) + dim_to_name = {dim: name for name, dim in name_to_dim.items()} + # also return value output, dim_to_name for converting results back to funsor + return raw_dist, self.value.output, dim_to_name + @property def has_rsample(self): return getattr(self.dist_class, "has_rsample", False) @@ -139,16 +155,26 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): return result def enumerate_support(self, expand=False): - if not self.has_enumerate_support or not isinstance(self.value, Variable): - raise ValueError("cannot enumerate support of {}".format(repr(self))) - # arbitrary name-dim mapping, since we're converting back to a funsor anyway - name_to_dim = {name: -dim-1 for dim, (name, domain) in enumerate(self.inputs.items()) - if isinstance(domain.dtype, int) and name != self.value.name} - raw_dist = to_data(self, name_to_dim=name_to_dim) + assert self.has_enumerate_support and isinstance(self.value, Variable) + raw_dist, value_output, dim_to_name = self._get_raw_dist() raw_value = raw_dist.enumerate_support(expand=expand) - dim_to_name = {dim: name for name, dim in name_to_dim.items()} dim_to_name[min(dim_to_name.keys(), default=0)-1] = self.value.name - return to_funsor(raw_value, output=self.value.output, dim_to_name=dim_to_name) + return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name) + + def entropy(self): + raw_dist, value_output, dim_to_name = self._get_raw_dist() + raw_value = raw_dist.entropy() + return to_funsor(raw_value, output=self.output, dim_to_name=dim_to_name) + + def mean(self): + raw_dist, value_output, dim_to_name = self._get_raw_dist() + raw_value = raw_dist.mean + return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name) + + def variance(self): + raw_dist, value_output, dim_to_name = self._get_raw_dist() + raw_value = raw_dist.variance + return to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name) def __getattribute__(self, attr): if attr in type(self)._ast_fields and attr != 'name': diff --git a/test/test_distribution.py b/test/test_distribution.py index 0e7489da5..f789414b4 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -17,7 +17,7 @@ from funsor.domains import Bint, Real, Reals from funsor.integrate import Integrate from funsor.interpreter import interpretation, reinterpret -from funsor.tensor import Einsum, Tensor, align_tensors, numeric_array, stack +from funsor.tensor import Einsum, Tensor, numeric_array, stack from funsor.terms import Independent, Variable, eager, lazy, to_funsor from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_mvn, random_tensor, xfail_param from funsor.util import get_backend @@ -701,29 +701,26 @@ def _get_stat_diff(funsor_dist_class, sample_inputs, inputs, num_samples, statis check_funsor(sample_value, expected_inputs, Real) if sample_inputs: - - actual_mean = Integrate( - sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) - ).reduce(ops.add, frozenset(sample_inputs)) - - inputs, tensors = align_tensors(*list(funsor_dist.params.values())[:-1]) - raw_dist = funsor_dist.dist_class(**dict(zip(funsor_dist._ast_fields[:-1], tensors))) - expected_mean = Tensor(raw_dist.mean, inputs) - if statistic == "mean": - actual_stat, expected_stat = actual_mean, expected_mean + actual_stat = Integrate( + sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) + expected_stat = funsor_dist.mean() elif statistic == "variance": + actual_mean = Integrate( + sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) actual_stat = Integrate( sample_value, (Variable('value', funsor_dist.inputs['value']) - actual_mean) ** 2, frozenset(['value']) ).reduce(ops.add, frozenset(sample_inputs)) - expected_stat = Tensor(raw_dist.variance, inputs) + expected_stat = funsor_dist.variance() elif statistic == "entropy": actual_stat = -Integrate( sample_value, funsor_dist, frozenset(['value']) ).reduce(ops.add, frozenset(sample_inputs)) - expected_stat = Tensor(raw_dist.entropy(), inputs) + expected_stat = funsor_dist.entropy() else: raise ValueError("invalid test statistic") From 3ad5c7997e61f70a252c46b3d4fc45cd42d3e68c Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 29 Oct 2020 21:56:48 -0400 Subject: [PATCH 02/38] start some generic distribution test functions --- test/test_distribution.py | 55 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 2 deletions(-) diff --git a/test/test_distribution.py b/test/test_distribution.py index f789414b4..c6715c247 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -3,7 +3,7 @@ import functools import math -from collections import OrderedDict +from collections import OrderedDict, namedtuple from importlib import import_module import numpy as np @@ -728,7 +728,7 @@ def _get_stat_diff(funsor_dist_class, sample_inputs, inputs, num_samples, statis return diff.sum(), diff -def _check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=1e-2, +def _check_sample(raw_dist, sample_shape=(), atol=1e-2, num_samples=100000, statistic="mean", skip_grad=False, with_lazy=None): """utility that compares a Monte Carlo estimate of a distribution mean with the true mean""" samples_per_dim = int(num_samples ** (1./max(1, len(sample_inputs)))) @@ -766,6 +766,57 @@ def _check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=1e-2, _get_stat_diff_fn(params) +def _check_distribution_to_funsor(raw_dist, expected_value_domain): + + # TODO specify dim_to_name + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) + + assert isinstance(actual_dist, backend_dist.Distribution) + assert type(raw_dist) == type(actual_dist) + assert funsor_dist.inputs["value"] == expected_value_domain + for param_name in funsor_dist.params.keys(): + if param_name == "value": + continue + assert hasattr(raw_dist, param_name) + assert_close(getattr(actual_dist, param_name), getattr(raw_dist, param_name)) + + +def _check_log_prob(raw_dist, expected_value_domain): + + # TODO specify dim_to_name + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + expected_inputs = {name: raw_dist.batch_shape[dim] for dim, name in dim_to_name.items()} + expected_inputs.update({"value": expected_value_domain}) + + check_funsor(funsor_dist, expected_inputs, funsor.Real) + + # TODO handle JAX rng here + expected_logprob = to_funsor(raw_dist.log_prob(raw_dist.sample()), output=funsor.Real, dim_to_name=dim_to_name) + assert_close(funsor_dist(value=value), expected_logprob) + + +def _check_enumerate_support(raw_dist, expand=False): + + # TODO specify dim_to_name + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + + assert getattr(raw_dist, "has_enumerate_support", False) == funsor_dist.has_enumerate_support + if funsor_dist.has_enumerate_support: + raw_support = raw_dist.enumerate_support(expand=expand) + funsor_support = funsor_dist.enumerate_support(expand=expand) + assert_equal(to_data(funsor_support, name_to_dim=name_to_dim), raw_support) + + +# High-level distribution testing strategy: a fixed sequence of increasingly semantically strong distribution-agnostic tests +# conversion invertibility -> density type and value -> enumerate_support type and value -> statistic types and values -> samplers -> gradients +DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain"]) + +TEST_CASES = [ + DistTestCase(raw_dist=...), +] + + @pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')]) @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) @pytest.mark.parametrize('reparametrized', [True, False]) From 4fbc801c2c182df3cc503ab402dfca021aa38c3f Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 29 Oct 2020 22:54:29 -0400 Subject: [PATCH 03/38] more cleanup of _check_sample --- test/test_distribution.py | 148 ++++++++++++++++++++------------------ 1 file changed, 78 insertions(+), 70 deletions(-) diff --git a/test/test_distribution.py b/test/test_distribution.py index c6715c247..f6eacd2e5 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -685,90 +685,95 @@ def von_mises(loc: Real, concentration: Real, value: Real) -> Real: assert_close(actual, expected) -def _get_stat_diff(funsor_dist_class, sample_inputs, inputs, num_samples, statistic, with_lazy, params): - params = [Tensor(p, inputs) for p in params] - if isinstance(with_lazy, bool): - with interpretation(lazy if with_lazy else eager): - funsor_dist = funsor_dist_class(*params) - else: - funsor_dist = funsor_dist_class(*params) +def default_dim_to_name(inputs_shape, event_inputs=None): + DIM_TO_NAME = tuple(map("_pyro_dim_{}".format, range(-100, 0))) + NAME_TO_DIM = dict(zip(DIM_TO_NAME, range(-100, 0))) + + dim_to_name_list = TESTS_DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME + dim_to_name = OrderedDict(zip( + range(-len(inputs_shape), 0), + dim_to_name_list[len(dim_to_name_list) - len(inputs_shape):])) + name_to_dim = OrderedDict((name, dim) for dim, name in dim_to_name.items()) + return name_to_dim + +def _get_stat(raw_dist, sample_shape, statistic, with_lazy): + dim_to_name, name_to_dim = default_dim_to_name(sample_shape + raw_dist.batch_shape) + with interpretation(lazy if with_lazy else eager): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + + sample_inputs = ... # TODO compute sample_inputs from dim_to_name rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) sample_value = funsor_dist.sample(frozenset(['value']), sample_inputs, rng_key=rng_key) - expected_inputs = OrderedDict( - tuple(sample_inputs.items()) + tuple(inputs.items()) + (('value', funsor_dist.inputs['value']),) - ) + expected_inputs = OrderedDict(tuple(sample_inputs.items()) + tuple(funsor_dist.inputs.items())) check_funsor(sample_value, expected_inputs, Real) - if sample_inputs: - if statistic == "mean": - actual_stat = Integrate( - sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) - ).reduce(ops.add, frozenset(sample_inputs)) - expected_stat = funsor_dist.mean() - elif statistic == "variance": - actual_mean = Integrate( - sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) - ).reduce(ops.add, frozenset(sample_inputs)) - actual_stat = Integrate( - sample_value, - (Variable('value', funsor_dist.inputs['value']) - actual_mean) ** 2, - frozenset(['value']) - ).reduce(ops.add, frozenset(sample_inputs)) - expected_stat = funsor_dist.variance() - elif statistic == "entropy": - actual_stat = -Integrate( - sample_value, funsor_dist, frozenset(['value']) - ).reduce(ops.add, frozenset(sample_inputs)) - expected_stat = funsor_dist.entropy() - else: - raise ValueError("invalid test statistic") - - diff = actual_stat.reduce(ops.add).data - expected_stat.reduce(ops.add).data - return diff.sum(), diff - - -def _check_sample(raw_dist, sample_shape=(), atol=1e-2, - num_samples=100000, statistic="mean", skip_grad=False, with_lazy=None): - """utility that compares a Monte Carlo estimate of a distribution mean with the true mean""" - samples_per_dim = int(num_samples ** (1./max(1, len(sample_inputs)))) - sample_inputs = OrderedDict((k, Bint[samples_per_dim]) for k in sample_inputs) - _get_stat_diff_fn = functools.partial( - _get_stat_diff, funsor_dist_class, sample_inputs, inputs, num_samples, statistic, with_lazy) + if statistic == "mean": + actual_stat = Integrate( + sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) + expected_stat = funsor_dist.mean() + elif statistic == "variance": + actual_mean = Integrate( + sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) + actual_stat = Integrate( + sample_value, + (Variable('value', funsor_dist.inputs['value']) - actual_mean) ** 2, + frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) + expected_stat = funsor_dist.variance() + elif statistic == "entropy": + actual_stat = -Integrate( + sample_value, funsor_dist, frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) + expected_stat = funsor_dist.entropy() + else: + raise ValueError("invalid test statistic") + + return actual_stat.reduce(ops.add), expected_stat.reduce(ops.add) + + +def _check_sample_grads(raw_dist, sample_shape=(), atol=1e-2, statistic="mean", with_lazy=False): + + def _get_stat_diff_fn(raw_dist): + actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic, with_lazy) + return to_data((actual_stat - expected_stat).sum()) if get_backend() == "torch": import torch + # TODO compute params here for param in params: param.requires_grad_() - res = _get_stat_diff_fn(params) - if sample_inputs: - diff_sum, diff = res - assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) - if not skip_grad: - diff_grads = torch.autograd.grad(diff_sum, params, allow_unused=True) - for diff_grad in diff_grads: - assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) + diff = _get_stat_diff_fn(raw_dist) + assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) + diff_grads = torch.autograd.grad(diff, params, allow_unused=True) + for diff_grad in diff_grads: + assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) + elif get_backend() == "jax": import jax - if sample_inputs: - if skip_grad: - _, diff = _get_stat_diff_fn(params) - assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) - else: - (_, diff), diff_grads = jax.value_and_grad(_get_stat_diff_fn, has_aux=True)(params) - assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) - for diff_grad in diff_grads: - assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) - else: - _get_stat_diff_fn(params) + # TODO compute gradient wrt distribution instance PyTree + diff, diff_grads = jax.value_and_grad(lambda *args: _get_stat_diff_fn(*args).sum(), has_aux=True)(params) + assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) + for diff_grad in diff_grads: + assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) + + +def _check_sample(raw_dist, sample_shape=(), atol=1e-2, statistic="mean", with_lazy=False): + + actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic, with_lazy) + check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) + if sample_inputs: + assert_close(actual_stat, expected_stat, atol=atol, rtol=None) def _check_distribution_to_funsor(raw_dist, expected_value_domain): - # TODO specify dim_to_name + dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) @@ -784,21 +789,24 @@ def _check_distribution_to_funsor(raw_dist, expected_value_domain): def _check_log_prob(raw_dist, expected_value_domain): - # TODO specify dim_to_name + dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) - expected_inputs = {name: raw_dist.batch_shape[dim] for dim, name in dim_to_name.items()} + expected_inputs = {name: funsor.Bint[raw_dist.batch_shape[dim]] for dim, name in dim_to_name.items()} expected_inputs.update({"value": expected_value_domain}) check_funsor(funsor_dist, expected_inputs, funsor.Real) - # TODO handle JAX rng here - expected_logprob = to_funsor(raw_dist.log_prob(raw_dist.sample()), output=funsor.Real, dim_to_name=dim_to_name) + if get_backend() == "jax": + raw_value = raw_dist.sample(rng_key=np.array([0, 0], dtype=np.uint32)) + else: + raw_value = raw_dist.sample() + expected_logprob = to_funsor(raw_dist.log_prob(raw_value), output=funsor.Real, dim_to_name=dim_to_name) assert_close(funsor_dist(value=value), expected_logprob) def _check_enumerate_support(raw_dist, expand=False): - # TODO specify dim_to_name + dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) assert getattr(raw_dist, "has_enumerate_support", False) == funsor_dist.has_enumerate_support From ecad0ee5467e9618a5fcca37985776940cabc32f Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 29 Oct 2020 23:03:45 -0400 Subject: [PATCH 04/38] move generic harness into separate file for now --- test/test_distribution.py | 189 +++++++++++++----------------------- test/test_distribution_2.py | 181 ++++++++++++++++++++++++++++++++++ 2 files changed, 246 insertions(+), 124 deletions(-) create mode 100644 test/test_distribution_2.py diff --git a/test/test_distribution.py b/test/test_distribution.py index f6eacd2e5..f789414b4 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -3,7 +3,7 @@ import functools import math -from collections import OrderedDict, namedtuple +from collections import OrderedDict from importlib import import_module import numpy as np @@ -685,144 +685,85 @@ def von_mises(loc: Real, concentration: Real, value: Real) -> Real: assert_close(actual, expected) -def default_dim_to_name(inputs_shape, event_inputs=None): - DIM_TO_NAME = tuple(map("_pyro_dim_{}".format, range(-100, 0))) - NAME_TO_DIM = dict(zip(DIM_TO_NAME, range(-100, 0))) - - dim_to_name_list = TESTS_DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME - dim_to_name = OrderedDict(zip( - range(-len(inputs_shape), 0), - dim_to_name_list[len(dim_to_name_list) - len(inputs_shape):])) - name_to_dim = OrderedDict((name, dim) for dim, name in dim_to_name.items()) - return name_to_dim - - -def _get_stat(raw_dist, sample_shape, statistic, with_lazy): - dim_to_name, name_to_dim = default_dim_to_name(sample_shape + raw_dist.batch_shape) - with interpretation(lazy if with_lazy else eager): - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) +def _get_stat_diff(funsor_dist_class, sample_inputs, inputs, num_samples, statistic, with_lazy, params): + params = [Tensor(p, inputs) for p in params] + if isinstance(with_lazy, bool): + with interpretation(lazy if with_lazy else eager): + funsor_dist = funsor_dist_class(*params) + else: + funsor_dist = funsor_dist_class(*params) - sample_inputs = ... # TODO compute sample_inputs from dim_to_name rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) sample_value = funsor_dist.sample(frozenset(['value']), sample_inputs, rng_key=rng_key) - expected_inputs = OrderedDict(tuple(sample_inputs.items()) + tuple(funsor_dist.inputs.items())) + expected_inputs = OrderedDict( + tuple(sample_inputs.items()) + tuple(inputs.items()) + (('value', funsor_dist.inputs['value']),) + ) check_funsor(sample_value, expected_inputs, Real) - if statistic == "mean": - actual_stat = Integrate( - sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) - ).reduce(ops.add, frozenset(sample_inputs)) - expected_stat = funsor_dist.mean() - elif statistic == "variance": - actual_mean = Integrate( - sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) - ).reduce(ops.add, frozenset(sample_inputs)) - actual_stat = Integrate( - sample_value, - (Variable('value', funsor_dist.inputs['value']) - actual_mean) ** 2, - frozenset(['value']) - ).reduce(ops.add, frozenset(sample_inputs)) - expected_stat = funsor_dist.variance() - elif statistic == "entropy": - actual_stat = -Integrate( - sample_value, funsor_dist, frozenset(['value']) - ).reduce(ops.add, frozenset(sample_inputs)) - expected_stat = funsor_dist.entropy() - else: - raise ValueError("invalid test statistic") - - return actual_stat.reduce(ops.add), expected_stat.reduce(ops.add) - - -def _check_sample_grads(raw_dist, sample_shape=(), atol=1e-2, statistic="mean", with_lazy=False): - - def _get_stat_diff_fn(raw_dist): - actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic, with_lazy) - return to_data((actual_stat - expected_stat).sum()) + if sample_inputs: + if statistic == "mean": + actual_stat = Integrate( + sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) + expected_stat = funsor_dist.mean() + elif statistic == "variance": + actual_mean = Integrate( + sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) + actual_stat = Integrate( + sample_value, + (Variable('value', funsor_dist.inputs['value']) - actual_mean) ** 2, + frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) + expected_stat = funsor_dist.variance() + elif statistic == "entropy": + actual_stat = -Integrate( + sample_value, funsor_dist, frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) + expected_stat = funsor_dist.entropy() + else: + raise ValueError("invalid test statistic") + + diff = actual_stat.reduce(ops.add).data - expected_stat.reduce(ops.add).data + return diff.sum(), diff + + +def _check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=1e-2, + num_samples=100000, statistic="mean", skip_grad=False, with_lazy=None): + """utility that compares a Monte Carlo estimate of a distribution mean with the true mean""" + samples_per_dim = int(num_samples ** (1./max(1, len(sample_inputs)))) + sample_inputs = OrderedDict((k, Bint[samples_per_dim]) for k in sample_inputs) + _get_stat_diff_fn = functools.partial( + _get_stat_diff, funsor_dist_class, sample_inputs, inputs, num_samples, statistic, with_lazy) if get_backend() == "torch": import torch - # TODO compute params here for param in params: param.requires_grad_() - diff = _get_stat_diff_fn(raw_dist) - assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) - diff_grads = torch.autograd.grad(diff, params, allow_unused=True) - for diff_grad in diff_grads: - assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) - + res = _get_stat_diff_fn(params) + if sample_inputs: + diff_sum, diff = res + assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) + if not skip_grad: + diff_grads = torch.autograd.grad(diff_sum, params, allow_unused=True) + for diff_grad in diff_grads: + assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) elif get_backend() == "jax": import jax - # TODO compute gradient wrt distribution instance PyTree - diff, diff_grads = jax.value_and_grad(lambda *args: _get_stat_diff_fn(*args).sum(), has_aux=True)(params) - assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) - for diff_grad in diff_grads: - assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) - - -def _check_sample(raw_dist, sample_shape=(), atol=1e-2, statistic="mean", with_lazy=False): - - actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic, with_lazy) - check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) - if sample_inputs: - assert_close(actual_stat, expected_stat, atol=atol, rtol=None) - - -def _check_distribution_to_funsor(raw_dist, expected_value_domain): - - dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) - actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) - - assert isinstance(actual_dist, backend_dist.Distribution) - assert type(raw_dist) == type(actual_dist) - assert funsor_dist.inputs["value"] == expected_value_domain - for param_name in funsor_dist.params.keys(): - if param_name == "value": - continue - assert hasattr(raw_dist, param_name) - assert_close(getattr(actual_dist, param_name), getattr(raw_dist, param_name)) - - -def _check_log_prob(raw_dist, expected_value_domain): - - dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) - expected_inputs = {name: funsor.Bint[raw_dist.batch_shape[dim]] for dim, name in dim_to_name.items()} - expected_inputs.update({"value": expected_value_domain}) - - check_funsor(funsor_dist, expected_inputs, funsor.Real) - - if get_backend() == "jax": - raw_value = raw_dist.sample(rng_key=np.array([0, 0], dtype=np.uint32)) - else: - raw_value = raw_dist.sample() - expected_logprob = to_funsor(raw_dist.log_prob(raw_value), output=funsor.Real, dim_to_name=dim_to_name) - assert_close(funsor_dist(value=value), expected_logprob) - - -def _check_enumerate_support(raw_dist, expand=False): - - dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) - - assert getattr(raw_dist, "has_enumerate_support", False) == funsor_dist.has_enumerate_support - if funsor_dist.has_enumerate_support: - raw_support = raw_dist.enumerate_support(expand=expand) - funsor_support = funsor_dist.enumerate_support(expand=expand) - assert_equal(to_data(funsor_support, name_to_dim=name_to_dim), raw_support) - - -# High-level distribution testing strategy: a fixed sequence of increasingly semantically strong distribution-agnostic tests -# conversion invertibility -> density type and value -> enumerate_support type and value -> statistic types and values -> samplers -> gradients -DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain"]) - -TEST_CASES = [ - DistTestCase(raw_dist=...), -] + if sample_inputs: + if skip_grad: + _, diff = _get_stat_diff_fn(params) + assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) + else: + (_, diff), diff_grads = jax.value_and_grad(_get_stat_diff_fn, has_aux=True)(params) + assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) + for diff_grad in diff_grads: + assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) + else: + _get_stat_diff_fn(params) @pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')]) diff --git a/test/test_distribution_2.py b/test/test_distribution_2.py new file mode 100644 index 000000000..bb1ec7aef --- /dev/null +++ b/test/test_distribution_2.py @@ -0,0 +1,181 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import functools +import math +from collections import OrderedDict, namedtuple +from importlib import import_module + +import numpy as np +import pytest + +import funsor +import funsor.ops as ops +from funsor.cnf import Contraction, GaussianMixture +from funsor.delta import Delta +from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND +from funsor.domains import Bint, Real, Reals +from funsor.integrate import Integrate +from funsor.interpreter import interpretation, reinterpret +from funsor.tensor import Einsum, Tensor, numeric_array, stack +from funsor.terms import Independent, Variable, eager, lazy, to_funsor +from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_mvn, random_tensor, xfail_param +from funsor.util import get_backend + +pytestmark = pytest.mark.skipif(get_backend() == "numpy", + reason="numpy does not have distributions backend") +if get_backend() != "numpy": + dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) + backend_dist = dist.dist + + +def _skip_for_numpyro_version(version="0.2.4"): + if get_backend() == "jax": + import numpyro + + if numpyro.__version__ <= version: + return True + + return False + + +def default_dim_to_name(inputs_shape, event_inputs=None): + DIM_TO_NAME = tuple(map("_pyro_dim_{}".format, range(-100, 0))) + NAME_TO_DIM = dict(zip(DIM_TO_NAME, range(-100, 0))) + + dim_to_name_list = TESTS_DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME + dim_to_name = OrderedDict(zip( + range(-len(inputs_shape), 0), + dim_to_name_list[len(dim_to_name_list) - len(inputs_shape):])) + name_to_dim = OrderedDict((name, dim) for dim, name in dim_to_name.items()) + return name_to_dim + + +def _get_stat(raw_dist, sample_shape, statistic, with_lazy): + dim_to_name, name_to_dim = default_dim_to_name(sample_shape + raw_dist.batch_shape) + with interpretation(lazy if with_lazy else eager): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + + sample_inputs = ... # TODO compute sample_inputs from dim_to_name + rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) + sample_value = funsor_dist.sample(frozenset(['value']), sample_inputs, rng_key=rng_key) + expected_inputs = OrderedDict(tuple(sample_inputs.items()) + tuple(funsor_dist.inputs.items())) + check_funsor(sample_value, expected_inputs, Real) + + if statistic == "mean": + actual_stat = Integrate( + sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) + expected_stat = funsor_dist.mean() + elif statistic == "variance": + actual_mean = Integrate( + sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) + actual_stat = Integrate( + sample_value, + (Variable('value', funsor_dist.inputs['value']) - actual_mean) ** 2, + frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) + expected_stat = funsor_dist.variance() + elif statistic == "entropy": + actual_stat = -Integrate( + sample_value, funsor_dist, frozenset(['value']) + ).reduce(ops.add, frozenset(sample_inputs)) + expected_stat = funsor_dist.entropy() + else: + raise ValueError("invalid test statistic") + + return actual_stat.reduce(ops.add), expected_stat.reduce(ops.add) + + +def _check_sample_grads(raw_dist, sample_shape=(), atol=1e-2, statistic="mean", with_lazy=False): + + def _get_stat_diff_fn(raw_dist): + actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic, with_lazy) + return to_data((actual_stat - expected_stat).sum()) + + if get_backend() == "torch": + import torch + + # TODO compute params here + for param in params: + param.requires_grad_() + + diff = _get_stat_diff_fn(raw_dist) + assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) + diff_grads = torch.autograd.grad(diff, params, allow_unused=True) + for diff_grad in diff_grads: + assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) + + elif get_backend() == "jax": + import jax + + # TODO compute gradient wrt distribution instance PyTree + diff, diff_grads = jax.value_and_grad(lambda *args: _get_stat_diff_fn(*args).sum(), has_aux=True)(params) + assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) + for diff_grad in diff_grads: + assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) + + +def _check_sample(raw_dist, sample_shape=(), atol=1e-2, statistic="mean", with_lazy=False): + + actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic, with_lazy) + check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) + if sample_inputs: + assert_close(actual_stat, expected_stat, atol=atol, rtol=None) + + +def _check_distribution_to_funsor(raw_dist, expected_value_domain): + + dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) + + assert isinstance(actual_dist, backend_dist.Distribution) + assert type(raw_dist) == type(actual_dist) + assert funsor_dist.inputs["value"] == expected_value_domain + for param_name in funsor_dist.params.keys(): + if param_name == "value": + continue + assert hasattr(raw_dist, param_name) + assert_close(getattr(actual_dist, param_name), getattr(raw_dist, param_name)) + + +def _check_log_prob(raw_dist, expected_value_domain): + + dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + expected_inputs = {name: funsor.Bint[raw_dist.batch_shape[dim]] for dim, name in dim_to_name.items()} + expected_inputs.update({"value": expected_value_domain}) + + check_funsor(funsor_dist, expected_inputs, funsor.Real) + + if get_backend() == "jax": + raw_value = raw_dist.sample(rng_key=np.array([0, 0], dtype=np.uint32)) + else: + raw_value = raw_dist.sample() + expected_logprob = to_funsor(raw_dist.log_prob(raw_value), output=funsor.Real, dim_to_name=dim_to_name) + assert_close(funsor_dist(value=value), expected_logprob) + + +def _check_enumerate_support(raw_dist, expand=False): + + dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + + assert getattr(raw_dist, "has_enumerate_support", False) == funsor_dist.has_enumerate_support + if funsor_dist.has_enumerate_support: + raw_support = raw_dist.enumerate_support(expand=expand) + funsor_support = funsor_dist.enumerate_support(expand=expand) + assert_equal(to_data(funsor_support, name_to_dim=name_to_dim), raw_support) + + +# High-level distribution testing strategy: a fixed sequence of increasingly semantically strong distribution-agnostic tests +# conversion invertibility -> density type and value -> enumerate_support type and value -> statistic types and values -> samplers -> gradients +DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain"]) + +TEST_CASES = [ + DistTestCase(raw_dist=...), +] + + From 0c2c0fb33dc075c92afa7b92d4a38e186c337a3c Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 29 Oct 2020 23:16:07 -0400 Subject: [PATCH 05/38] work up to actual tests --- test/test_distribution_2.py | 109 +++++++++++++++++++++++------------- 1 file changed, 69 insertions(+), 40 deletions(-) diff --git a/test/test_distribution_2.py b/test/test_distribution_2.py index bb1ec7aef..15da6cd94 100644 --- a/test/test_distribution_2.py +++ b/test/test_distribution_2.py @@ -1,6 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +# High-level distribution testing strategy: a fixed sequence of increasingly semantically strong distribution-agnostic tests +# conversion invertibility -> density type and value -> enumerate_support type and value -> statistic types and values -> samplers -> gradients + import functools import math from collections import OrderedDict, namedtuple @@ -88,44 +91,25 @@ def _get_stat(raw_dist, sample_shape, statistic, with_lazy): return actual_stat.reduce(ops.add), expected_stat.reduce(ops.add) -def _check_sample_grads(raw_dist, sample_shape=(), atol=1e-2, statistic="mean", with_lazy=False): - - def _get_stat_diff_fn(raw_dist): - actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic, with_lazy) - return to_data((actual_stat - expected_stat).sum()) - - if get_backend() == "torch": - import torch - - # TODO compute params here - for param in params: - param.requires_grad_() - - diff = _get_stat_diff_fn(raw_dist) - assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) - diff_grads = torch.autograd.grad(diff, params, allow_unused=True) - for diff_grad in diff_grads: - assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) - - elif get_backend() == "jax": - import jax +################################################## +# Test cases +################################################## - # TODO compute gradient wrt distribution instance PyTree - diff, diff_grads = jax.value_and_grad(lambda *args: _get_stat_diff_fn(*args).sum(), has_aux=True)(params) - assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) - for diff_grad in diff_grads: - assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) +DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain"]) +TEST_CASES = [ + DistTestCase(raw_dist=...), +] -def _check_sample(raw_dist, sample_shape=(), atol=1e-2, statistic="mean", with_lazy=False): - actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic, with_lazy) - check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) - if sample_inputs: - assert_close(actual_stat, expected_stat, atol=atol, rtol=None) +########################### +# Generic tests +########################### +@pytest.mark.parametrize("case", TEST_CASES) +def test_generic_distribution_to_funsor(case): -def _check_distribution_to_funsor(raw_dist, expected_value_domain): + raw_dist, expected_value_domain = case.raw_dist, case.expected_value_domain dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) @@ -141,7 +125,10 @@ def _check_distribution_to_funsor(raw_dist, expected_value_domain): assert_close(getattr(actual_dist, param_name), getattr(raw_dist, param_name)) -def _check_log_prob(raw_dist, expected_value_domain): +@pytest.mark.parametrize("case", TEST_CASES) +def test_generic_log_prob(case): + + raw_dist, expected_value_domain = case.raw_dist, case.expected_value_domain dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) @@ -158,7 +145,11 @@ def _check_log_prob(raw_dist, expected_value_domain): assert_close(funsor_dist(value=value), expected_logprob) -def _check_enumerate_support(raw_dist, expand=False): +@pytest.mark.parametrize("case", TEST_CASES) +@pytest.mark.parametrize("expand", [False, True]) +def test_generic_enumerate_support(case, expand): + + raw_dist = case.raw_dist dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) @@ -170,12 +161,50 @@ def _check_enumerate_support(raw_dist, expand=False): assert_equal(to_data(funsor_support, name_to_dim=name_to_dim), raw_support) -# High-level distribution testing strategy: a fixed sequence of increasingly semantically strong distribution-agnostic tests -# conversion invertibility -> density type and value -> enumerate_support type and value -> statistic types and values -> samplers -> gradients -DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain"]) +@pytest.mark.parametrize("case", TEST_CASES) +@pytest.mark.parametrize("statistic", ["mean", "variance", "entropy"]) +@pytest.mark.parametrize("with_lazy", [True, False]) +def test_generic_sample(case, statistic, with_lazy): + + raw_dist, sample_shape = case.raw_dist, case.sample_shape + + actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic, with_lazy) + check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) + if sample_shape: + assert_close(actual_stat, expected_stat, atol=atol, rtol=None) -TEST_CASES = [ - DistTestCase(raw_dist=...), -] +@pytest.mark.parametrize("case", TEST_CASES) +@pytest.mark.parametrize("statistic", ["mean", "variance", "entropy"]) +@pytest.mark.parametrize("with_lazy", [True, False]) +def test_generic_sample_grads(case, statistic, with_lazy): + + raw_dist, sample_shape = case.raw_dist, case.sample_shape + + atol = 1e-2 + + def _get_stat_diff_fn(raw_dist): + actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic, with_lazy) + return to_data((actual_stat - expected_stat).sum()) + + if get_backend() == "torch": + import torch + # TODO compute params here + for param in params: + param.requires_grad_() + + diff = _get_stat_diff_fn(raw_dist) + assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) + diff_grads = torch.autograd.grad(diff, params, allow_unused=True) + for diff_grad in diff_grads: + assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) + + elif get_backend() == "jax": + import jax + + # TODO compute gradient wrt distribution instance PyTree + diff, diff_grads = jax.value_and_grad(lambda *args: _get_stat_diff_fn(*args).sum(), has_aux=True)(params) + assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) + for diff_grad in diff_grads: + assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) From 56ace75c93ccf64bb3b1dfdd93d1dadebc38c994 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 29 Oct 2020 23:24:04 -0400 Subject: [PATCH 06/38] lint --- test/test_distribution_2.py | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/test/test_distribution_2.py b/test/test_distribution_2.py index 15da6cd94..1bf916347 100644 --- a/test/test_distribution_2.py +++ b/test/test_distribution_2.py @@ -1,11 +1,9 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -# High-level distribution testing strategy: a fixed sequence of increasingly semantically strong distribution-agnostic tests -# conversion invertibility -> density type and value -> enumerate_support type and value -> statistic types and values -> samplers -> gradients +# High-level distribution testing strategy: sequence of increasingly semantically strong distribution-agnostic tests +# conversion invertibility -> density type and value -> enumerate_support type and value -> samplers -> gradients -import functools -import math from collections import OrderedDict, namedtuple from importlib import import_module @@ -14,15 +12,11 @@ import funsor import funsor.ops as ops -from funsor.cnf import Contraction, GaussianMixture -from funsor.delta import Delta from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND -from funsor.domains import Bint, Real, Reals from funsor.integrate import Integrate -from funsor.interpreter import interpretation, reinterpret -from funsor.tensor import Einsum, Tensor, numeric_array, stack -from funsor.terms import Independent, Variable, eager, lazy, to_funsor -from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_mvn, random_tensor, xfail_param +from funsor.interpreter import interpretation +from funsor.terms import Variable, eager, lazy, to_data, to_funsor +from funsor.testing import assert_close, check_funsor, rand, randint, randn from funsor.util import get_backend pytestmark = pytest.mark.skipif(get_backend() == "numpy", @@ -44,9 +38,7 @@ def _skip_for_numpyro_version(version="0.2.4"): def default_dim_to_name(inputs_shape, event_inputs=None): DIM_TO_NAME = tuple(map("_pyro_dim_{}".format, range(-100, 0))) - NAME_TO_DIM = dict(zip(DIM_TO_NAME, range(-100, 0))) - - dim_to_name_list = TESTS_DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME + dim_to_name_list = DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME dim_to_name = OrderedDict(zip( range(-len(inputs_shape), 0), dim_to_name_list[len(dim_to_name_list) - len(inputs_shape):])) @@ -63,7 +55,7 @@ def _get_stat(raw_dist, sample_shape, statistic, with_lazy): rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) sample_value = funsor_dist.sample(frozenset(['value']), sample_inputs, rng_key=rng_key) expected_inputs = OrderedDict(tuple(sample_inputs.items()) + tuple(funsor_dist.inputs.items())) - check_funsor(sample_value, expected_inputs, Real) + check_funsor(sample_value, expected_inputs, funsor.Real) if statistic == "mean": actual_stat = Integrate( @@ -95,6 +87,8 @@ def _get_stat(raw_dist, sample_shape, statistic, with_lazy): # Test cases ################################################## +# TODO how to make this work with multiple pytest workers? + DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain"]) TEST_CASES = [ @@ -114,7 +108,7 @@ def test_generic_distribution_to_funsor(case): dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) - + assert isinstance(actual_dist, backend_dist.Distribution) assert type(raw_dist) == type(actual_dist) assert funsor_dist.inputs["value"] == expected_value_domain @@ -142,7 +136,8 @@ def test_generic_log_prob(case): else: raw_value = raw_dist.sample() expected_logprob = to_funsor(raw_dist.log_prob(raw_value), output=funsor.Real, dim_to_name=dim_to_name) - assert_close(funsor_dist(value=value), expected_logprob) + funsor_value = to_funsor(raw_value, output=expected_value_domain, dim_to_name=dim_to_name) + assert_close(funsor_dist(value=funsor_value), expected_logprob) @pytest.mark.parametrize("case", TEST_CASES) @@ -158,7 +153,7 @@ def test_generic_enumerate_support(case, expand): if funsor_dist.has_enumerate_support: raw_support = raw_dist.enumerate_support(expand=expand) funsor_support = funsor_dist.enumerate_support(expand=expand) - assert_equal(to_data(funsor_support, name_to_dim=name_to_dim), raw_support) + assert_close(to_data(funsor_support, name_to_dim=name_to_dim), raw_support) @pytest.mark.parametrize("case", TEST_CASES) @@ -168,12 +163,15 @@ def test_generic_sample(case, statistic, with_lazy): raw_dist, sample_shape = case.raw_dist, case.sample_shape + atol = 1e-2 + actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic, with_lazy) check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) if sample_shape: assert_close(actual_stat, expected_stat, atol=atol, rtol=None) +@pytest.mark.skipif(True, reason="broken") @pytest.mark.parametrize("case", TEST_CASES) @pytest.mark.parametrize("statistic", ["mean", "variance", "entropy"]) @pytest.mark.parametrize("with_lazy", [True, False]) From bace6583bcd2f9ec886866b5950e24027634a276 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 29 Oct 2020 23:26:36 -0400 Subject: [PATCH 07/38] nits --- test/test_distribution_2.py | 52 ++++++++++++++----------------------- 1 file changed, 20 insertions(+), 32 deletions(-) diff --git a/test/test_distribution_2.py b/test/test_distribution_2.py index 1bf916347..03a78b38b 100644 --- a/test/test_distribution_2.py +++ b/test/test_distribution_2.py @@ -26,17 +26,24 @@ backend_dist = dist.dist -def _skip_for_numpyro_version(version="0.2.4"): - if get_backend() == "jax": - import numpyro +################################################## +# Test cases +################################################## + +# TODO how to make this work with multiple pytest workers? + +DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain"]) - if numpyro.__version__ <= version: - return True +TEST_CASES = [ + DistTestCase(raw_dist=...), +] - return False +########################### +# Generic tests +########################### -def default_dim_to_name(inputs_shape, event_inputs=None): +def _default_dim_to_name(inputs_shape, event_inputs=None): DIM_TO_NAME = tuple(map("_pyro_dim_{}".format, range(-100, 0))) dim_to_name_list = DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME dim_to_name = OrderedDict(zip( @@ -47,7 +54,7 @@ def default_dim_to_name(inputs_shape, event_inputs=None): def _get_stat(raw_dist, sample_shape, statistic, with_lazy): - dim_to_name, name_to_dim = default_dim_to_name(sample_shape + raw_dist.batch_shape) + dim_to_name, name_to_dim = _default_dim_to_name(sample_shape + raw_dist.batch_shape) with interpretation(lazy if with_lazy else eager): funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) @@ -57,11 +64,11 @@ def _get_stat(raw_dist, sample_shape, statistic, with_lazy): expected_inputs = OrderedDict(tuple(sample_inputs.items()) + tuple(funsor_dist.inputs.items())) check_funsor(sample_value, expected_inputs, funsor.Real) + expected_stat = getattr(funsor_dist, statistic)() if statistic == "mean": actual_stat = Integrate( sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) ).reduce(ops.add, frozenset(sample_inputs)) - expected_stat = funsor_dist.mean() elif statistic == "variance": actual_mean = Integrate( sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) @@ -71,41 +78,22 @@ def _get_stat(raw_dist, sample_shape, statistic, with_lazy): (Variable('value', funsor_dist.inputs['value']) - actual_mean) ** 2, frozenset(['value']) ).reduce(ops.add, frozenset(sample_inputs)) - expected_stat = funsor_dist.variance() elif statistic == "entropy": actual_stat = -Integrate( sample_value, funsor_dist, frozenset(['value']) ).reduce(ops.add, frozenset(sample_inputs)) - expected_stat = funsor_dist.entropy() else: - raise ValueError("invalid test statistic") + raise ValueError("invalid test statistic: {}".format(statistic)) return actual_stat.reduce(ops.add), expected_stat.reduce(ops.add) -################################################## -# Test cases -################################################## - -# TODO how to make this work with multiple pytest workers? - -DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain"]) - -TEST_CASES = [ - DistTestCase(raw_dist=...), -] - - -########################### -# Generic tests -########################### - @pytest.mark.parametrize("case", TEST_CASES) def test_generic_distribution_to_funsor(case): raw_dist, expected_value_domain = case.raw_dist, case.expected_value_domain - dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) + dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) @@ -124,7 +112,7 @@ def test_generic_log_prob(case): raw_dist, expected_value_domain = case.raw_dist, case.expected_value_domain - dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) + dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) expected_inputs = {name: funsor.Bint[raw_dist.batch_shape[dim]] for dim, name in dim_to_name.items()} expected_inputs.update({"value": expected_value_domain}) @@ -146,7 +134,7 @@ def test_generic_enumerate_support(case, expand): raw_dist = case.raw_dist - dim_to_name, name_to_dim = default_dim_to_name(raw_dist.batch_shape) + dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) assert getattr(raw_dist, "has_enumerate_support", False) == funsor_dist.has_enumerate_support From c38d560ba8bac09cf23831433f8665253c32569a Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 29 Oct 2020 23:30:39 -0400 Subject: [PATCH 08/38] more nits --- test/test_distribution_2.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/test/test_distribution_2.py b/test/test_distribution_2.py index 03a78b38b..b35865e5d 100644 --- a/test/test_distribution_2.py +++ b/test/test_distribution_2.py @@ -1,9 +1,6 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -# High-level distribution testing strategy: sequence of increasingly semantically strong distribution-agnostic tests -# conversion invertibility -> density type and value -> enumerate_support type and value -> samplers -> gradients - from collections import OrderedDict, namedtuple from importlib import import_module @@ -40,7 +37,9 @@ ########################### -# Generic tests +# Generic tests: +# High-level distribution testing strategy: sequence of increasingly semantically strong distribution-agnostic tests +# Conversion invertibility -> density type and value -> enumerate_support type and value -> samplers -> gradients ########################### def _default_dim_to_name(inputs_shape, event_inputs=None): From 040f4177388831522965cb156a1d7fb07cffedeb Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 29 Oct 2020 23:41:43 -0400 Subject: [PATCH 09/38] make case.raw_dist a string and eval it --- test/test_distribution_2.py | 24 +++++++++++++++--------- 1 file changed, 15 insertions(+), 9 deletions(-) diff --git a/test/test_distribution_2.py b/test/test_distribution_2.py index b35865e5d..b5f9b2a03 100644 --- a/test/test_distribution_2.py +++ b/test/test_distribution_2.py @@ -29,11 +29,17 @@ # TODO how to make this work with multiple pytest workers? -DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain"]) +DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain", "sample_shape"]) -TEST_CASES = [ - DistTestCase(raw_dist=...), -] +TEST_CASES = [] + +for batch_shape in [(), (5,), (2, 3)]: + TEST_CASES += [ + DistTestCase( + f"backend_dist.Normal(rand({batch_shape}), rand({batch_shape}))", + funsor.Real + ), + ] ########################### @@ -90,7 +96,7 @@ def _get_stat(raw_dist, sample_shape, statistic, with_lazy): @pytest.mark.parametrize("case", TEST_CASES) def test_generic_distribution_to_funsor(case): - raw_dist, expected_value_domain = case.raw_dist, case.expected_value_domain + raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) @@ -109,7 +115,7 @@ def test_generic_distribution_to_funsor(case): @pytest.mark.parametrize("case", TEST_CASES) def test_generic_log_prob(case): - raw_dist, expected_value_domain = case.raw_dist, case.expected_value_domain + raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) @@ -131,7 +137,7 @@ def test_generic_log_prob(case): @pytest.mark.parametrize("expand", [False, True]) def test_generic_enumerate_support(case, expand): - raw_dist = case.raw_dist + raw_dist = eval(case.raw_dist) dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) @@ -148,7 +154,7 @@ def test_generic_enumerate_support(case, expand): @pytest.mark.parametrize("with_lazy", [True, False]) def test_generic_sample(case, statistic, with_lazy): - raw_dist, sample_shape = case.raw_dist, case.sample_shape + raw_dist, sample_shape = eval(case.raw_dist), case.sample_shape atol = 1e-2 @@ -164,7 +170,7 @@ def test_generic_sample(case, statistic, with_lazy): @pytest.mark.parametrize("with_lazy", [True, False]) def test_generic_sample_grads(case, statistic, with_lazy): - raw_dist, sample_shape = case.raw_dist, case.sample_shape + raw_dist, sample_shape = eval(case.raw_dist), case.sample_shape atol = 1e-2 From 4f40aaf14a3986f9654b18046426f409c3cbc6c6 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 29 Oct 2020 23:43:59 -0400 Subject: [PATCH 10/38] nit --- test/test_distribution_2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_distribution_2.py b/test/test_distribution_2.py index b5f9b2a03..374cee177 100644 --- a/test/test_distribution_2.py +++ b/test/test_distribution_2.py @@ -27,7 +27,7 @@ # Test cases ################################################## -# TODO how to make this work with multiple pytest workers? +# TODO separate sample_shape from DistTestCase? DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain", "sample_shape"]) From fb75a152bb9700397fb622f91e4cef059b8f23fe Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 30 Oct 2020 13:16:34 -0400 Subject: [PATCH 11/38] add a bunch of test cases, most passing --- test/test_distribution_2.py | 126 +++++++++++++++++++++++++++--------- 1 file changed, 97 insertions(+), 29 deletions(-) diff --git a/test/test_distribution_2.py b/test/test_distribution_2.py index 374cee177..33c84affb 100644 --- a/test/test_distribution_2.py +++ b/test/test_distribution_2.py @@ -13,7 +13,7 @@ from funsor.integrate import Integrate from funsor.interpreter import interpretation from funsor.terms import Variable, eager, lazy, to_data, to_funsor -from funsor.testing import assert_close, check_funsor, rand, randint, randn +from funsor.testing import assert_close, check_funsor, rand, randint, randn # noqa: F401 from funsor.util import get_backend pytestmark = pytest.mark.skipif(get_backend() == "numpy", @@ -29,17 +29,68 @@ # TODO separate sample_shape from DistTestCase? -DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain", "sample_shape"]) +DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain"]) TEST_CASES = [] for batch_shape in [(), (5,), (2, 3)]: - TEST_CASES += [ - DistTestCase( - f"backend_dist.Normal(rand({batch_shape}), rand({batch_shape}))", - funsor.Real - ), - ] + + # Normal + TEST_CASES += [DistTestCase( + f"backend_dist.Normal(randn({batch_shape}), rand({batch_shape}))", + funsor.Real, + )] + # NonreparametrizedNormal + TEST_CASES += [DistTestCase( + f"backend_dist.testing.fakes.NonreparameterizedNormal(rand({batch_shape}), rand({batch_shape}))", + funsor.Real, + )] + + # Beta + TEST_CASES += [DistTestCase( + f"backend_dist.Beta(ops.exp(randn({batch_shape})), ops.exp(randn({batch_shape})))", + funsor.Real, + )] + # NonreparametrizedBeta + TEST_CASES += [DistTestCase( + f"backend_dist.testing.fakes.NonreparameterizedBeta(ops.exp(randn({batch_shape})), ops.exp(randn({batch_shape})))", # noqa: E501 + funsor.Real, + )] + + # Gamma + TEST_CASES += [DistTestCase( + f"backend_dist.Gamma(rand({batch_shape}), rand({batch_shape}))", + funsor.Real, + )] + # NonreparametrizedGamma + TEST_CASES += [DistTestCase( + f"backend_dist.testing.fakes.NonreparameterizedGamma(rand({batch_shape}), rand({batch_shape}))", + funsor.Real, + )] + + # Dirichlet + for event_shape in [(1,), (4,), (5,)]: + TEST_CASES += [DistTestCase( + f"backend_dist.Dirichlet(rand({batch_shape + event_shape}))", + funsor.Reals[event_shape], + )] + TEST_CASES += [DistTestCase( + f"backend_dist.testing.fakes.NonreparameterizedDirichlet(rand({batch_shape + event_shape}))", + funsor.Reals[event_shape], + )] + + # MultivariateNormal + for event_shape in [(1,), (3,)]: + TEST_CASES += [DistTestCase( + f"backend_dist.MultivariateNormal(randn({batch_shape + event_shape}), random_scale_tril({batch_shape + event_shape * 2}))", # noqa: E501 + funsor.Reals[event_shape], + )] + + # BernoulliLogits + TEST_CASES += [DistTestCase( + f"backend_dist.Bernoulli(logits=rand({batch_shape}))", + funsor.Real, + )] ########################### @@ -48,6 +99,19 @@ # Conversion invertibility -> density type and value -> enumerate_support type and value -> samplers -> gradients ########################### +def case_id(case): + return str(case.raw_dist) + + +def random_scale_tril(shape): + if get_backend() == "torch": + data = randn(shape) + return backend_dist.transforms.transform_to(backend_dist.constraints.lower_cholesky)(data) + else: + data = randn(shape[:-2] + (shape[-1] * (shape[-1] + 1) // 2,)) + return backend_dist.biject_to(backend_dist.constraints.lower_cholesky)(data) + + def _default_dim_to_name(inputs_shape, event_inputs=None): DIM_TO_NAME = tuple(map("_pyro_dim_{}".format, range(-100, 0))) dim_to_name_list = DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME @@ -55,15 +119,16 @@ def _default_dim_to_name(inputs_shape, event_inputs=None): range(-len(inputs_shape), 0), dim_to_name_list[len(dim_to_name_list) - len(inputs_shape):])) name_to_dim = OrderedDict((name, dim) for dim, name in dim_to_name.items()) - return name_to_dim + return dim_to_name, name_to_dim -def _get_stat(raw_dist, sample_shape, statistic, with_lazy): +def _get_stat(raw_dist, sample_shape, statistic): dim_to_name, name_to_dim = _default_dim_to_name(sample_shape + raw_dist.batch_shape) - with interpretation(lazy if with_lazy else eager): + with interpretation(lazy): funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) - sample_inputs = ... # TODO compute sample_inputs from dim_to_name + sample_inputs = OrderedDict((dim_to_name[dim - len(raw_dist.batch_shape)], funsor.Bint[sample_shape[dim]]) + for dim in range(-len(sample_shape), 0)) rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) sample_value = funsor_dist.sample(frozenset(['value']), sample_inputs, rng_key=rng_key) expected_inputs = OrderedDict(tuple(sample_inputs.items()) + tuple(funsor_dist.inputs.items())) @@ -93,13 +158,15 @@ def _get_stat(raw_dist, sample_shape, statistic, with_lazy): return actual_stat.reduce(ops.add), expected_stat.reduce(ops.add) -@pytest.mark.parametrize("case", TEST_CASES) -def test_generic_distribution_to_funsor(case): +@pytest.mark.parametrize("case", TEST_CASES, ids=case_id) +@pytest.mark.parametrize("with_lazy", [True, False]) +def test_generic_distribution_to_funsor(case, with_lazy): raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + with interpretation(lazy if with_lazy else eager): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) assert isinstance(actual_dist, backend_dist.Distribution) @@ -112,7 +179,7 @@ def test_generic_distribution_to_funsor(case): assert_close(getattr(actual_dist, param_name), getattr(raw_dist, param_name)) -@pytest.mark.parametrize("case", TEST_CASES) +@pytest.mark.parametrize("case", TEST_CASES, ids=case_id) def test_generic_log_prob(case): raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain @@ -133,41 +200,42 @@ def test_generic_log_prob(case): assert_close(funsor_dist(value=funsor_value), expected_logprob) -@pytest.mark.parametrize("case", TEST_CASES) +@pytest.mark.parametrize("case", TEST_CASES, ids=case_id) @pytest.mark.parametrize("expand", [False, True]) def test_generic_enumerate_support(case, expand): raw_dist = eval(case.raw_dist) dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + with interpretation(lazy): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) - assert getattr(raw_dist, "has_enumerate_support", False) == funsor_dist.has_enumerate_support - if funsor_dist.has_enumerate_support: + assert getattr(raw_dist, "has_enumerate_support", False) == getattr(funsor_dist, "has_enumerate_support", False) + if getattr(funsor_dist, "has_enumerate_support", False): + name_to_dim["value"] = -1 if not name_to_dim else min(name_to_dim.values()) - 1 raw_support = raw_dist.enumerate_support(expand=expand) funsor_support = funsor_dist.enumerate_support(expand=expand) assert_close(to_data(funsor_support, name_to_dim=name_to_dim), raw_support) -@pytest.mark.parametrize("case", TEST_CASES) +@pytest.mark.parametrize("case", TEST_CASES, ids=case_id) @pytest.mark.parametrize("statistic", ["mean", "variance", "entropy"]) -@pytest.mark.parametrize("with_lazy", [True, False]) -def test_generic_sample(case, statistic, with_lazy): +@pytest.mark.parametrize("sample_shape", [(), (200000,), (400, 400)]) +def test_generic_sample(case, statistic, sample_shape): - raw_dist, sample_shape = eval(case.raw_dist), case.sample_shape + raw_dist = eval(case.raw_dist) atol = 1e-2 - actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic, with_lazy) + actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic) check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) if sample_shape: assert_close(actual_stat, expected_stat, atol=atol, rtol=None) -@pytest.mark.skipif(True, reason="broken") -@pytest.mark.parametrize("case", TEST_CASES) +@pytest.mark.skipif(True, reason="not working yet") +@pytest.mark.parametrize("case", TEST_CASES, ids=case_id) @pytest.mark.parametrize("statistic", ["mean", "variance", "entropy"]) -@pytest.mark.parametrize("with_lazy", [True, False]) def test_generic_sample_grads(case, statistic, with_lazy): raw_dist, sample_shape = eval(case.raw_dist), case.sample_shape @@ -175,7 +243,7 @@ def test_generic_sample_grads(case, statistic, with_lazy): atol = 1e-2 def _get_stat_diff_fn(raw_dist): - actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic, with_lazy) + actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic) return to_data((actual_stat - expected_stat).sum()) if get_backend() == "torch": From e6dcabd1fa82dd757442077d3bab7a7420a3a749 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 30 Oct 2020 13:24:41 -0400 Subject: [PATCH 12/38] rename file --- test/{test_distribution_2.py => test_distribution_generic.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename test/{test_distribution_2.py => test_distribution_generic.py} (100%) diff --git a/test/test_distribution_2.py b/test/test_distribution_generic.py similarity index 100% rename from test/test_distribution_2.py rename to test/test_distribution_generic.py From 8be24c6af13425b674c9862322e5f74f9462211b Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 30 Oct 2020 14:28:05 -0400 Subject: [PATCH 13/38] gradient test at least runs for pytorch backend... --- funsor/testing.py | 19 +++++++ test/test_distribution_generic.py | 82 +++++++++++++++++-------------- 2 files changed, 65 insertions(+), 36 deletions(-) diff --git a/funsor/testing.py b/funsor/testing.py index afc9ff8db..1e42ff1de 100644 --- a/funsor/testing.py +++ b/funsor/testing.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +import importlib import itertools import numbers import operator @@ -265,6 +266,24 @@ def randn(*args): return np.array(np.random.randn(*shape)) +def random_scale_tril(*args): + if isinstance(args[0], tuple): + assert len(args) == 1 + shape = args[0] + else: + shape = args + + from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND + backend_dist = importlib.import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist + + if get_backend() == "torch": + data = randn(shape) + return backend_dist.transforms.transform_to(backend_dist.constraints.lower_cholesky)(data) + else: + data = randn(shape[:-2] + (shape[-1] * (shape[-1] + 1) // 2,)) + return backend_dist.biject_to(backend_dist.constraints.lower_cholesky)(data) + + def zeros(*args): if isinstance(args[0], tuple): assert len(args) == 1 diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 33c84affb..5b932cc89 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -1,7 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import OrderedDict, namedtuple +from collections import OrderedDict from importlib import import_module import numpy as np @@ -13,7 +13,7 @@ from funsor.integrate import Integrate from funsor.interpreter import interpretation from funsor.terms import Variable, eager, lazy, to_data, to_funsor -from funsor.testing import assert_close, check_funsor, rand, randint, randn # noqa: F401 +from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_scale_tril # noqa: F401 from funsor.util import get_backend pytestmark = pytest.mark.skipif(get_backend() == "numpy", @@ -27,9 +27,21 @@ # Test cases ################################################## -# TODO separate sample_shape from DistTestCase? +class DistTestCase: + + def __init__(self, raw_dist, raw_params, expected_value_domain): + self.raw_dist = raw_dist + self.raw_params = raw_params + self.expected_value_domain = expected_value_domain + for name, raw_param in self.raw_params: + setattr(self, name, eval(raw_param)) + + def __str__(self): + return self.raw_dist + + def __hash__(self): + return hash((self.raw_dist, self.raw_params, self.expected_value_domain)) -DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain"]) TEST_CASES = [] @@ -37,58 +49,68 @@ # Normal TEST_CASES += [DistTestCase( - f"backend_dist.Normal(randn({batch_shape}), rand({batch_shape}))", + "backend_dist.Normal(case.loc, case.scale)", + (("loc", f"randn({batch_shape})"), ("scale", f"rand({batch_shape})")), funsor.Real, )] # NonreparametrizedNormal TEST_CASES += [DistTestCase( - f"backend_dist.testing.fakes.NonreparameterizedNormal(rand({batch_shape}), rand({batch_shape}))", + "backend_dist.testing.fakes.NonreparameterizedNormal(case.loc, case.scale)", + (("loc", f"randn({batch_shape})"), ("scale", f"rand({batch_shape})")), funsor.Real, )] # Beta TEST_CASES += [DistTestCase( - f"backend_dist.Beta(ops.exp(randn({batch_shape})), ops.exp(randn({batch_shape})))", + "backend_dist.Beta(case.concentration1, case.concentration0)", + (("concentration1", f"ops.exp(randn({batch_shape}))"), ("concentration0", f"ops.exp(randn({batch_shape}))")), funsor.Real, )] # NonreparametrizedBeta TEST_CASES += [DistTestCase( - f"backend_dist.testing.fakes.NonreparameterizedBeta(ops.exp(randn({batch_shape})), ops.exp(randn({batch_shape})))", # noqa: E501 + "backend_dist.testing.fakes.NonreparameterizedBeta(case.concentration1, case.concentration0)", + (("concentration1", f"ops.exp(randn({batch_shape}))"), ("concentration0", f"ops.exp(randn({batch_shape}))")), funsor.Real, )] # Gamma TEST_CASES += [DistTestCase( - f"backend_dist.Gamma(rand({batch_shape}), rand({batch_shape}))", + "backend_dist.Gamma(case.concentration, case.rate)", + (("concentration", f"rand({batch_shape})"), ("rate", f"rand({batch_shape})")), funsor.Real, )] # NonreparametrizedGamma TEST_CASES += [DistTestCase( - f"backend_dist.testing.fakes.NonreparameterizedGamma(rand({batch_shape}), rand({batch_shape}))", + "backend_dist.testing.fakes.NonreparameterizedGamma(case.concentration, case.rate)", + (("concentration", f"rand({batch_shape})"), ("rate", f"rand({batch_shape})")), funsor.Real, )] # Dirichlet for event_shape in [(1,), (4,), (5,)]: TEST_CASES += [DistTestCase( - f"backend_dist.Dirichlet(rand({batch_shape + event_shape}))", + "backend_dist.Dirichlet(case.concentration)", + (("concentration", f"rand({batch_shape + event_shape})"),), funsor.Reals[event_shape], )] TEST_CASES += [DistTestCase( - f"backend_dist.testing.fakes.NonreparameterizedDirichlet(rand({batch_shape + event_shape}))", + "backend_dist.testing.fakes.NonreparameterizedDirichlet(case.concentration)", + (("concentration", f"rand({batch_shape + event_shape})"),), funsor.Reals[event_shape], )] # MultivariateNormal for event_shape in [(1,), (3,)]: TEST_CASES += [DistTestCase( - f"backend_dist.MultivariateNormal(randn({batch_shape + event_shape}), random_scale_tril({batch_shape + event_shape * 2}))", # noqa: E501 + "backend_dist.MultivariateNormal(case.loc, case.scale_tril)", + (("loc", f"randn({batch_shape + event_shape})"), ("scale_tril", f"random_scale_tril({batch_shape + event_shape * 2})")), # noqa: E501 funsor.Reals[event_shape], )] # BernoulliLogits TEST_CASES += [DistTestCase( - f"backend_dist.Bernoulli(logits=rand({batch_shape}))", + "backend_dist.Bernoulli(logits=case.logits)", + (("logits", f"rand({batch_shape})"),), funsor.Real, )] @@ -99,19 +121,6 @@ # Conversion invertibility -> density type and value -> enumerate_support type and value -> samplers -> gradients ########################### -def case_id(case): - return str(case.raw_dist) - - -def random_scale_tril(shape): - if get_backend() == "torch": - data = randn(shape) - return backend_dist.transforms.transform_to(backend_dist.constraints.lower_cholesky)(data) - else: - data = randn(shape[:-2] + (shape[-1] * (shape[-1] + 1) // 2,)) - return backend_dist.biject_to(backend_dist.constraints.lower_cholesky)(data) - - def _default_dim_to_name(inputs_shape, event_inputs=None): DIM_TO_NAME = tuple(map("_pyro_dim_{}".format, range(-100, 0))) dim_to_name_list = DIM_TO_NAME + event_inputs if event_inputs else DIM_TO_NAME @@ -158,7 +167,7 @@ def _get_stat(raw_dist, sample_shape, statistic): return actual_stat.reduce(ops.add), expected_stat.reduce(ops.add) -@pytest.mark.parametrize("case", TEST_CASES, ids=case_id) +@pytest.mark.parametrize("case", TEST_CASES, ids=str) @pytest.mark.parametrize("with_lazy", [True, False]) def test_generic_distribution_to_funsor(case, with_lazy): @@ -179,7 +188,7 @@ def test_generic_distribution_to_funsor(case, with_lazy): assert_close(getattr(actual_dist, param_name), getattr(raw_dist, param_name)) -@pytest.mark.parametrize("case", TEST_CASES, ids=case_id) +@pytest.mark.parametrize("case", TEST_CASES, ids=str) def test_generic_log_prob(case): raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain @@ -200,7 +209,7 @@ def test_generic_log_prob(case): assert_close(funsor_dist(value=funsor_value), expected_logprob) -@pytest.mark.parametrize("case", TEST_CASES, ids=case_id) +@pytest.mark.parametrize("case", TEST_CASES, ids=str) @pytest.mark.parametrize("expand", [False, True]) def test_generic_enumerate_support(case, expand): @@ -218,7 +227,7 @@ def test_generic_enumerate_support(case, expand): assert_close(to_data(funsor_support, name_to_dim=name_to_dim), raw_support) -@pytest.mark.parametrize("case", TEST_CASES, ids=case_id) +@pytest.mark.parametrize("case", TEST_CASES, ids=str) @pytest.mark.parametrize("statistic", ["mean", "variance", "entropy"]) @pytest.mark.parametrize("sample_shape", [(), (200000,), (400, 400)]) def test_generic_sample(case, statistic, sample_shape): @@ -233,12 +242,13 @@ def test_generic_sample(case, statistic, sample_shape): assert_close(actual_stat, expected_stat, atol=atol, rtol=None) -@pytest.mark.skipif(True, reason="not working yet") -@pytest.mark.parametrize("case", TEST_CASES, ids=case_id) +@pytest.mark.skipif(get_backend() != "torch", reason="not working yet") +@pytest.mark.parametrize("case", TEST_CASES, ids=str) @pytest.mark.parametrize("statistic", ["mean", "variance", "entropy"]) -def test_generic_sample_grads(case, statistic, with_lazy): +@pytest.mark.parametrize("sample_shape", [(200000,), (400, 400)]) +def test_generic_sample_grads(case, statistic, sample_shape): - raw_dist, sample_shape = eval(case.raw_dist), case.sample_shape + raw_dist = eval(case.raw_dist) atol = 1e-2 @@ -249,7 +259,7 @@ def _get_stat_diff_fn(raw_dist): if get_backend() == "torch": import torch - # TODO compute params here + params = tuple(getattr(case, param) for param, _ in case.raw_params) for param in params: param.requires_grad_() From 50f03438786effc37363cbd4bb3f5fd39716babf Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 30 Oct 2020 14:44:59 -0400 Subject: [PATCH 14/38] break up sample test into smoke and stats --- test/test_distribution_generic.py | 30 ++++++++++++++++++++++++------ 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 5b932cc89..32b7560d6 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -164,7 +164,7 @@ def _get_stat(raw_dist, sample_shape, statistic): else: raise ValueError("invalid test statistic: {}".format(statistic)) - return actual_stat.reduce(ops.add), expected_stat.reduce(ops.add) + return actual_stat, expected_stat @pytest.mark.parametrize("case", TEST_CASES, ids=str) @@ -229,24 +229,42 @@ def test_generic_enumerate_support(case, expand): @pytest.mark.parametrize("case", TEST_CASES, ids=str) @pytest.mark.parametrize("statistic", ["mean", "variance", "entropy"]) -@pytest.mark.parametrize("sample_shape", [(), (200000,), (400, 400)]) +@pytest.mark.parametrize("sample_shape", [(), (2,), (4, 3)], ids=str) def test_generic_sample(case, statistic, sample_shape): raw_dist = eval(case.raw_dist) + dim_to_name, name_to_dim = _default_dim_to_name(sample_shape + raw_dist.batch_shape) + with interpretation(lazy): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + + sample_inputs = OrderedDict((dim_to_name[dim - len(raw_dist.batch_shape)], funsor.Bint[sample_shape[dim]]) + for dim in range(-len(sample_shape), 0)) + rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) + sample_value = funsor_dist.sample(frozenset(['value']), sample_inputs, rng_key=rng_key) + expected_inputs = OrderedDict(tuple(sample_inputs.items()) + tuple(funsor_dist.inputs.items())) + check_funsor(sample_value, expected_inputs, funsor.Real) + + +@pytest.mark.parametrize("case", TEST_CASES, ids=str) +@pytest.mark.parametrize("statistic", ["mean", "variance", "entropy"]) +@pytest.mark.parametrize("sample_shape", [(), (200000,), (400, 400)], ids=str) +def test_generic_stats(case, statistic, sample_shape): + + raw_dist = eval(case.raw_dist) + atol = 1e-2 actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic) check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) - if sample_shape: - assert_close(actual_stat, expected_stat, atol=atol, rtol=None) + assert_close(actual_stat.reduce(ops.add), expected_stat.reduce(ops.add), atol=atol, rtol=None) @pytest.mark.skipif(get_backend() != "torch", reason="not working yet") @pytest.mark.parametrize("case", TEST_CASES, ids=str) @pytest.mark.parametrize("statistic", ["mean", "variance", "entropy"]) -@pytest.mark.parametrize("sample_shape", [(200000,), (400, 400)]) -def test_generic_sample_grads(case, statistic, sample_shape): +@pytest.mark.parametrize("sample_shape", [(200000,), (400, 400)], ids=str) +def test_generic_grads(case, statistic, sample_shape): raw_dist = eval(case.raw_dist) From 6f69e62b2be58676691e959de8a36fe5da52c788 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 30 Oct 2020 14:47:48 -0400 Subject: [PATCH 15/38] remove with_lazy --- test/test_distribution_generic.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 32b7560d6..fbae5f64d 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -168,13 +168,12 @@ def _get_stat(raw_dist, sample_shape, statistic): @pytest.mark.parametrize("case", TEST_CASES, ids=str) -@pytest.mark.parametrize("with_lazy", [True, False]) -def test_generic_distribution_to_funsor(case, with_lazy): +def test_generic_distribution_to_funsor(case): raw_dist, expected_value_domain = eval(case.raw_dist), case.expected_value_domain dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) - with interpretation(lazy if with_lazy else eager): + with interpretation(lazy): funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) From a06744ef904cded2c91e3588cde7c20ee061bbce Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 30 Oct 2020 21:35:29 -0400 Subject: [PATCH 16/38] get basic jax tests passing --- test/test_distribution_generic.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index fbae5f64d..df656bcfe 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -12,7 +12,7 @@ from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND from funsor.integrate import Integrate from funsor.interpreter import interpretation -from funsor.terms import Variable, eager, lazy, to_data, to_funsor +from funsor.terms import Variable, lazy, to_data, to_funsor from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_scale_tril # noqa: F401 from funsor.util import get_backend @@ -22,6 +22,16 @@ dist = import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]) backend_dist = dist.dist + class _fakes: + def __getattribute__(self, attr): + if get_backend() == "torch": + return getattr(backend_dist.testing.fakes, attr) + elif get_backend() == "jax": + return getattr(dist, "_NumPyroWrapper_" + attr) + raise ValueError(attr) + + FAKES = _fakes() + ################################################## # Test cases @@ -34,6 +44,7 @@ def __init__(self, raw_dist, raw_params, expected_value_domain): self.raw_params = raw_params self.expected_value_domain = expected_value_domain for name, raw_param in self.raw_params: + # we need direct access to these tensors for gradient tests setattr(self, name, eval(raw_param)) def __str__(self): @@ -53,9 +64,9 @@ def __hash__(self): (("loc", f"randn({batch_shape})"), ("scale", f"rand({batch_shape})")), funsor.Real, )] - # NonreparametrizedNormal + # NonreparameterizedNormal TEST_CASES += [DistTestCase( - "backend_dist.testing.fakes.NonreparameterizedNormal(case.loc, case.scale)", + "FAKES.NonreparameterizedNormal(case.loc, case.scale)", (("loc", f"randn({batch_shape})"), ("scale", f"rand({batch_shape})")), funsor.Real, )] @@ -66,9 +77,9 @@ def __hash__(self): (("concentration1", f"ops.exp(randn({batch_shape}))"), ("concentration0", f"ops.exp(randn({batch_shape}))")), funsor.Real, )] - # NonreparametrizedBeta + # NonreparameterizedBeta TEST_CASES += [DistTestCase( - "backend_dist.testing.fakes.NonreparameterizedBeta(case.concentration1, case.concentration0)", + "FAKES.NonreparameterizedBeta(case.concentration1, case.concentration0)", (("concentration1", f"ops.exp(randn({batch_shape}))"), ("concentration0", f"ops.exp(randn({batch_shape}))")), funsor.Real, )] @@ -81,7 +92,7 @@ def __hash__(self): )] # NonreparametrizedGamma TEST_CASES += [DistTestCase( - "backend_dist.testing.fakes.NonreparameterizedGamma(case.concentration, case.rate)", + "FAKES.NonreparameterizedGamma(case.concentration, case.rate)", (("concentration", f"rand({batch_shape})"), ("rate", f"rand({batch_shape})")), funsor.Real, )] @@ -93,8 +104,9 @@ def __hash__(self): (("concentration", f"rand({batch_shape + event_shape})"),), funsor.Reals[event_shape], )] + # NonreparameterizedDirichlet TEST_CASES += [DistTestCase( - "backend_dist.testing.fakes.NonreparameterizedDirichlet(case.concentration)", + "FAKES.NonreparameterizedDirichlet(case.concentration)", (("concentration", f"rand({batch_shape + event_shape})"),), funsor.Reals[event_shape], )] @@ -200,7 +212,7 @@ def test_generic_log_prob(case): check_funsor(funsor_dist, expected_inputs, funsor.Real) if get_backend() == "jax": - raw_value = raw_dist.sample(rng_key=np.array([0, 0], dtype=np.uint32)) + raw_value = raw_dist.sample(key=np.array([0, 0], dtype=np.uint32)) else: raw_value = raw_dist.sample() expected_logprob = to_funsor(raw_dist.log_prob(raw_value), output=funsor.Real, dim_to_name=dim_to_name) From 5a8fd4281235095b1705c26ff37f970889f5906f Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 30 Oct 2020 22:34:42 -0400 Subject: [PATCH 17/38] alphabetize test cases and fix incorrect mvn spec --- test/test_distribution_generic.py | 59 ++++++++++++++++--------------- 1 file changed, 30 insertions(+), 29 deletions(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index df656bcfe..85ec001af 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -23,6 +23,7 @@ backend_dist = dist.dist class _fakes: + """alias for accessing nonreparameterized distributions""" def __getattribute__(self, attr): if get_backend() == "torch": return getattr(backend_dist.testing.fakes, attr) @@ -48,7 +49,7 @@ def __init__(self, raw_dist, raw_params, expected_value_domain): setattr(self, name, eval(raw_param)) def __str__(self): - return self.raw_dist + return self.raw_dist + " " + str(self.raw_params) def __hash__(self): return hash((self.raw_dist, self.raw_params, self.expected_value_domain)) @@ -58,16 +59,10 @@ def __hash__(self): for batch_shape in [(), (5,), (2, 3)]: - # Normal - TEST_CASES += [DistTestCase( - "backend_dist.Normal(case.loc, case.scale)", - (("loc", f"randn({batch_shape})"), ("scale", f"rand({batch_shape})")), - funsor.Real, - )] - # NonreparameterizedNormal + # BernoulliLogits TEST_CASES += [DistTestCase( - "FAKES.NonreparameterizedNormal(case.loc, case.scale)", - (("loc", f"randn({batch_shape})"), ("scale", f"rand({batch_shape})")), + "backend_dist.Bernoulli(logits=case.logits)", + (("logits", f"rand({batch_shape})"),), funsor.Real, )] @@ -84,21 +79,8 @@ def __hash__(self): funsor.Real, )] - # Gamma - TEST_CASES += [DistTestCase( - "backend_dist.Gamma(case.concentration, case.rate)", - (("concentration", f"rand({batch_shape})"), ("rate", f"rand({batch_shape})")), - funsor.Real, - )] - # NonreparametrizedGamma - TEST_CASES += [DistTestCase( - "FAKES.NonreparameterizedGamma(case.concentration, case.rate)", - (("concentration", f"rand({batch_shape})"), ("rate", f"rand({batch_shape})")), - funsor.Real, - )] - # Dirichlet - for event_shape in [(1,), (4,), (5,)]: + for event_shape in [(1,), (4,)]: TEST_CASES += [DistTestCase( "backend_dist.Dirichlet(case.concentration)", (("concentration", f"rand({batch_shape + event_shape})"),), @@ -111,18 +93,37 @@ def __hash__(self): funsor.Reals[event_shape], )] + # Gamma + TEST_CASES += [DistTestCase( + "backend_dist.Gamma(case.concentration, case.rate)", + (("concentration", f"rand({batch_shape})"), ("rate", f"rand({batch_shape})")), + funsor.Real, + )] + # NonreparametrizedGamma + TEST_CASES += [DistTestCase( + "FAKES.NonreparameterizedGamma(case.concentration, case.rate)", + (("concentration", f"rand({batch_shape})"), ("rate", f"rand({batch_shape})")), + funsor.Real, + )] + # MultivariateNormal for event_shape in [(1,), (3,)]: TEST_CASES += [DistTestCase( - "backend_dist.MultivariateNormal(case.loc, case.scale_tril)", + "backend_dist.MultivariateNormal(loc=case.loc, scale_tril=case.scale_tril)", (("loc", f"randn({batch_shape + event_shape})"), ("scale_tril", f"random_scale_tril({batch_shape + event_shape * 2})")), # noqa: E501 funsor.Reals[event_shape], )] - # BernoulliLogits + # Normal TEST_CASES += [DistTestCase( - "backend_dist.Bernoulli(logits=case.logits)", - (("logits", f"rand({batch_shape})"),), + "backend_dist.Normal(case.loc, case.scale)", + (("loc", f"randn({batch_shape})"), ("scale", f"rand({batch_shape})")), + funsor.Real, + )] + # NonreparameterizedNormal + TEST_CASES += [DistTestCase( + "FAKES.NonreparameterizedNormal(case.loc, case.scale)", + (("loc", f"randn({batch_shape})"), ("scale", f"rand({batch_shape})")), funsor.Real, )] @@ -190,7 +191,7 @@ def test_generic_distribution_to_funsor(case): actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim) assert isinstance(actual_dist, backend_dist.Distribution) - assert type(raw_dist) == type(actual_dist) + assert issubclass(type(actual_dist), type(raw_dist)) # subclass to handle wrappers assert funsor_dist.inputs["value"] == expected_value_domain for param_name in funsor_dist.params.keys(): if param_name == "value": From 127b12c71a55820a3727625c578f75f743178613 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 30 Oct 2020 22:55:36 -0400 Subject: [PATCH 18/38] break up stat test --- test/test_distribution_generic.py | 66 +++++++++++++++++++++++++++---- 1 file changed, 58 insertions(+), 8 deletions(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 85ec001af..88d210e77 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -13,7 +13,7 @@ from funsor.integrate import Integrate from funsor.interpreter import interpretation from funsor.terms import Variable, lazy, to_data, to_funsor -from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_scale_tril # noqa: F401 +from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_scale_tril, xfail_if_not_implemented # noqa: F401,E501 from funsor.util import get_backend pytestmark = pytest.mark.skipif(get_backend() == "numpy", @@ -66,6 +66,13 @@ def __hash__(self): funsor.Real, )] + # BernoulliProbs + TEST_CASES += [DistTestCase( + "backend_dist.Bernoulli(probs=case.probs)", + (("probs", f"rand({batch_shape})"),), + funsor.Real, + )] + # Beta TEST_CASES += [DistTestCase( "backend_dist.Beta(case.concentration1, case.concentration0)", @@ -127,6 +134,13 @@ def __hash__(self): funsor.Real, )] + # Poisson + TEST_CASES += [DistTestCase( + "backend_dist.Poisson(rate=case.rate)", + (("rate", f"rand({batch_shape})"),), + funsor.Real, + )] + ########################### # Generic tests: @@ -240,9 +254,8 @@ def test_generic_enumerate_support(case, expand): @pytest.mark.parametrize("case", TEST_CASES, ids=str) -@pytest.mark.parametrize("statistic", ["mean", "variance", "entropy"]) @pytest.mark.parametrize("sample_shape", [(), (2,), (4, 3)], ids=str) -def test_generic_sample(case, statistic, sample_shape): +def test_generic_sample(case, sample_shape): raw_dist = eval(case.raw_dist) @@ -259,23 +272,59 @@ def test_generic_sample(case, statistic, sample_shape): @pytest.mark.parametrize("case", TEST_CASES, ids=str) -@pytest.mark.parametrize("statistic", ["mean", "variance", "entropy"]) +@pytest.mark.parametrize("statistic", [ + "mean", + "variance", + pytest.param("entropy", marks=[pytest.mark.skipif(get_backend() == "jax", reason="entropy not implemented")]) +]) +def test_generic_stats_smoke(case, statistic): + + raw_dist = eval(case.raw_dist) + + dim_to_name, name_to_dim = _default_dim_to_name(raw_dist.batch_shape) + with interpretation(lazy): + funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) + + with xfail_if_not_implemented(msg="entropy not implemented for some distributions"): + actual_stat = getattr(funsor_dist, statistic)() + + expected_stat_raw = getattr(raw_dist, statistic) + if statistic == "entropy": + expected_stat = to_funsor(expected_stat_raw(), output=funsor.Real, dim_to_name=dim_to_name) + else: + expected_stat = to_funsor(expected_stat_raw, output=case.expected_value_domain, dim_to_name=dim_to_name) + + check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) + assert_close(to_data(actual_stat, name_to_dim), to_data(expected_stat, name_to_dim)) + + +@pytest.mark.parametrize("case", TEST_CASES, ids=str) @pytest.mark.parametrize("sample_shape", [(), (200000,), (400, 400)], ids=str) -def test_generic_stats(case, statistic, sample_shape): +@pytest.mark.parametrize("statistic", [ + "mean", + "variance", + pytest.param("entropy", marks=[pytest.mark.skipif(get_backend() == "jax", reason="entropy not implemented")]) +]) +def test_generic_stats_sample(case, statistic, sample_shape): raw_dist = eval(case.raw_dist) atol = 1e-2 + with xfail_if_not_implemented(msg="entropy not implemented for some distributions"): + actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic) - actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic) check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) assert_close(actual_stat.reduce(ops.add), expected_stat.reduce(ops.add), atol=atol, rtol=None) @pytest.mark.skipif(get_backend() != "torch", reason="not working yet") @pytest.mark.parametrize("case", TEST_CASES, ids=str) -@pytest.mark.parametrize("statistic", ["mean", "variance", "entropy"]) @pytest.mark.parametrize("sample_shape", [(200000,), (400, 400)], ids=str) +@pytest.mark.parametrize("statistic", [ + "mean", + "variance", + pytest.param("entropy", marks=[pytest.mark.skipif(get_backend() == "jax", reason="entropy not implemented")]) +]) def test_generic_grads(case, statistic, sample_shape): raw_dist = eval(case.raw_dist) @@ -283,7 +332,8 @@ def test_generic_grads(case, statistic, sample_shape): atol = 1e-2 def _get_stat_diff_fn(raw_dist): - actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic) + with xfail_if_not_implemented(msg="entropy not implemented for some distributions"): + actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic) return to_data((actual_stat - expected_stat).sum()) if get_backend() == "torch": From 80750b02a8ee8061ddbdffd3a34c0e2e728da350 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 30 Oct 2020 23:04:24 -0400 Subject: [PATCH 19/38] binomial test case --- test/test_distribution_generic.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 88d210e77..2a1d89b15 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -86,6 +86,13 @@ def __hash__(self): funsor.Real, )] + # Binomial + TEST_CASES += [DistTestCase( + "backend_dist.Binomial(total_count=case.total_count, probs=case.probs)", + (("total_count", f"ops.astype(randint(0, 10, {batch_shape}), 'float')"), ("probs", f"rand({batch_shape})")), + funsor.Real, + )] + # Dirichlet for event_shape in [(1,), (4,)]: TEST_CASES += [DistTestCase( @@ -248,9 +255,10 @@ def test_generic_enumerate_support(case, expand): assert getattr(raw_dist, "has_enumerate_support", False) == getattr(funsor_dist, "has_enumerate_support", False) if getattr(funsor_dist, "has_enumerate_support", False): name_to_dim["value"] = -1 if not name_to_dim else min(name_to_dim.values()) - 1 - raw_support = raw_dist.enumerate_support(expand=expand) - funsor_support = funsor_dist.enumerate_support(expand=expand) - assert_close(to_data(funsor_support, name_to_dim=name_to_dim), raw_support) + with xfail_if_not_implemented("enumerate support not implemented"): + raw_support = raw_dist.enumerate_support(expand=expand) + funsor_support = funsor_dist.enumerate_support(expand=expand) + assert_close(to_data(funsor_support, name_to_dim=name_to_dim), raw_support) @pytest.mark.parametrize("case", TEST_CASES, ids=str) From 3ceadfdcf85053e842079843d730e316be31d50e Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 30 Oct 2020 23:18:35 -0400 Subject: [PATCH 20/38] add more test cases --- test/test_distribution_generic.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 2a1d89b15..60c7e0cc1 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -93,6 +93,22 @@ def __hash__(self): funsor.Real, )] + # CategoricalProbs + for size in [2, 4]: + TEST_CASES += [DistTestCase( + "backend_dist.Categorical(probs=case.probs)", + (("probs", f"rand({batch_shape + (size,)})"),), + funsor.Bint[size], + )] + + # Delta + for event_shape in [(), (4,), (3, 2)]: + TEST_CASES += [DistTestCase( + "backend_dist.Delta(case.v, case.log_density)", + (("v", f"rand({batch_shape + event_shape})"), ("log_density", f"rand({batch_shape})")), + funsor.Reals[event_shape], + )] + # Dirichlet for event_shape in [(1,), (4,)]: TEST_CASES += [DistTestCase( @@ -120,6 +136,14 @@ def __hash__(self): funsor.Real, )] + # Multinomial + for event_shape in [(1,), (4,)]: + TEST_CASES += [DistTestCase( + "backend_dist.Multinomial(case.total_count, probs=case.probs)", + (("total_count", "5"), ("probs", f"rand({batch_shape + event_shape})")), + funsor.Reals[event_shape], + )] + # MultivariateNormal for event_shape in [(1,), (3,)]: TEST_CASES += [DistTestCase( @@ -148,6 +172,13 @@ def __hash__(self): funsor.Real, )] + # VonMises + TEST_CASES += [DistTestCase( + "backend_dist.VonMises(case.loc, case.concentration)", + (("loc", f"rand({batch_shape})"), ("concentration", f"rand({batch_shape})")), + funsor.Real, + )] + ########################### # Generic tests: From 151ca2ba7b77ee3e838f2551c5990920a46549a1 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 30 Oct 2020 23:24:53 -0400 Subject: [PATCH 21/38] add dirichletmultinomial test case for parity --- test/test_distribution_generic.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 60c7e0cc1..7a27e229b 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -123,6 +123,14 @@ def __hash__(self): funsor.Reals[event_shape], )] + # DirichletMultinomial + for event_shape in [(1,), (4,)]: + TEST_CASES += [DistTestCase( + "backend_dist.DirichletMultinomial(case.concentration, case.total_count)", + (("concentration", f"rand({batch_shape + event_shape})"), ("total_count", "10")), + funsor.Reals[event_shape], + )] + # Gamma TEST_CASES += [DistTestCase( "backend_dist.Gamma(case.concentration, case.rate)", From 06f6d06bb231946981e6564e727f2dc18905f7d9 Mon Sep 17 00:00:00 2001 From: Eli Date: Fri, 30 Oct 2020 23:27:34 -0400 Subject: [PATCH 22/38] add CategoricalLogits test case --- test/test_distribution_generic.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 7a27e229b..53d1f20a2 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -93,6 +93,14 @@ def __hash__(self): funsor.Real, )] + # CategoricalLogits + for size in [2, 4]: + TEST_CASES += [DistTestCase( + "backend_dist.Categorical(logits=case.logits)", + (("logits", f"rand({batch_shape + (size,)})"),), + funsor.Bint[size], + )] + # CategoricalProbs for size in [2, 4]: TEST_CASES += [DistTestCase( From d1e8af5ada0a255d403161e660c5641a1694c4e3 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 31 Oct 2020 15:34:41 -0400 Subject: [PATCH 23/38] custom conversion for multinomial --- funsor/torch/distributions.py | 16 +++++++++++++++- test/test_distribution_generic.py | 2 +- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index c2918eb97..cf6f94950 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import functools +import numbers from typing import Tuple, Union import pyro.distributions as dist @@ -41,7 +42,7 @@ from funsor.domains import Real, Reals import funsor.ops as ops from funsor.tensor import Tensor, dummy_numeric_array -from funsor.terms import Binary, Funsor, Variable, eager, to_funsor +from funsor.terms import Binary, Funsor, Variable, eager, to_data, to_funsor from funsor.util import methodof @@ -154,6 +155,19 @@ def _infer_param_domain(cls, name, raw_shape): return Real +########################################################### +# Converting distribution funsors to PyTorch distributions +########################################################### + +@to_data.register(Multinomial) +def multinomial_to_data(funsor_dist, name_to_dim=None): + probs = to_data(funsor_dist.probs, name_to_dim) + total_count = to_data(funsor_dist.total_count, name_to_dim) + if isinstance(total_count, numbers.Number) or len(total_count.shape) == 0: + return dist.Multinomial(int(total_count), probs=probs) + raise NotImplementedError("inhomogeneous total_count not supported") + + ############################################### # Converting PyTorch Distributions to funsors ############################################### diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 53d1f20a2..8fb9e63f4 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -104,7 +104,7 @@ def __hash__(self): # CategoricalProbs for size in [2, 4]: TEST_CASES += [DistTestCase( - "backend_dist.Categorical(probs=case.probs)", + "backend_dist.Categorical(probs=case.probs / case.probs.sum(-1, True))", (("probs", f"rand({batch_shape + (size,)})"),), funsor.Bint[size], )] From 39ced3afbc704a083a878801435d0a942be725eb Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 31 Oct 2020 16:07:52 -0400 Subject: [PATCH 24/38] use get_raw_dist in sample, add appropriate xfails --- funsor/distribution.py | 25 +++++++++++++++---------- funsor/torch/distributions.py | 2 +- test/test_distribution_generic.py | 16 ++++++++++++---- 3 files changed, 28 insertions(+), 15 deletions(-) diff --git a/funsor/distribution.py b/funsor/distribution.py index 51fbd0a68..44cdf65c7 100644 --- a/funsor/distribution.py +++ b/funsor/distribution.py @@ -134,23 +134,28 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None): value = params.pop("value") assert all(isinstance(v, (Number, Tensor)) for v in params.values()) assert isinstance(value, Variable) and value.name in sampled_vars - inputs_, tensors = align_tensors(*params.values()) - inputs = OrderedDict(sample_inputs.items()) - inputs.update(inputs_) - sample_shape = tuple(v.size for v in sample_inputs.values()) - raw_dist = self.dist_class(**dict(zip(self._ast_fields[:-1], tensors))) + value_name = value.name + raw_dist, value_output, dim_to_name = self._get_raw_dist() + for d, name in zip(range(len(sample_inputs), 0, -1), sample_inputs.keys()): + dim_to_name[-d - len(raw_dist.batch_shape)] = name + + sample_shape = tuple(v.size for v in sample_inputs.values()) sample_args = (sample_shape,) if get_backend() == "torch" else (rng_key, sample_shape) if self.has_rsample: - raw_sample = raw_dist.rsample(*sample_args) + raw_value = raw_dist.rsample(*sample_args) else: - raw_sample = ops.detach(raw_dist.sample(*sample_args)) + raw_value = ops.detach(raw_dist.sample(*sample_args)) - result = funsor.delta.Delta(value.name, Tensor(raw_sample, inputs, value.output.dtype)) + funsor_value = to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name) + funsor_value = funsor_value.align( + tuple(sample_inputs) + tuple(inp for inp in self.inputs if inp in funsor_value.inputs)) + result = funsor.delta.Delta(value_name, funsor_value) if not self.has_rsample: # scaling of dice_factor by num samples should already be handled by Funsor.sample - raw_log_prob = raw_dist.log_prob(raw_sample) - dice_factor = Tensor(raw_log_prob - ops.detach(raw_log_prob), inputs) + raw_log_prob = raw_dist.log_prob(raw_value) + dice_factor = to_funsor(raw_log_prob - ops.detach(raw_log_prob), + output=self.output, dim_to_name=dim_to_name) result = result + dice_factor return result diff --git a/funsor/torch/distributions.py b/funsor/torch/distributions.py index cf6f94950..eb661ce77 100644 --- a/funsor/torch/distributions.py +++ b/funsor/torch/distributions.py @@ -159,7 +159,7 @@ def _infer_param_domain(cls, name, raw_shape): # Converting distribution funsors to PyTorch distributions ########################################################### -@to_data.register(Multinomial) +@to_data.register(Multinomial) # noqa: F821 def multinomial_to_data(funsor_dist, name_to_dim=None): probs = to_data(funsor_dist.probs, name_to_dim) total_count = to_data(funsor_dist.total_count, name_to_dim) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 8fb9e63f4..b366adfd3 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -104,7 +104,7 @@ def __hash__(self): # CategoricalProbs for size in [2, 4]: TEST_CASES += [DistTestCase( - "backend_dist.Categorical(probs=case.probs / case.probs.sum(-1, True))", + "backend_dist.Categorical(probs=case.probs)", (("probs", f"rand({batch_shape + (size,)})"),), funsor.Bint[size], )] @@ -286,7 +286,7 @@ def test_generic_log_prob(case): raw_value = raw_dist.sample() expected_logprob = to_funsor(raw_dist.log_prob(raw_value), output=funsor.Real, dim_to_name=dim_to_name) funsor_value = to_funsor(raw_value, output=expected_value_domain, dim_to_name=dim_to_name) - assert_close(funsor_dist(value=funsor_value), expected_logprob) + assert_close(funsor_dist(value=funsor_value), expected_logprob, rtol=1e-4) @pytest.mark.parametrize("case", TEST_CASES, ids=str) @@ -350,7 +350,10 @@ def test_generic_stats_smoke(case, statistic): expected_stat = to_funsor(expected_stat_raw, output=case.expected_value_domain, dim_to_name=dim_to_name) check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) - assert_close(to_data(actual_stat, name_to_dim), to_data(expected_stat, name_to_dim)) + if expected_stat.data.isnan().all(): + pytest.xfail(reason="base stat returns nan") + else: + assert_close(to_data(actual_stat, name_to_dim), to_data(expected_stat, name_to_dim), rtol=1e-4) @pytest.mark.parametrize("case", TEST_CASES, ids=str) @@ -369,7 +372,10 @@ def test_generic_stats_sample(case, statistic, sample_shape): actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic) check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) - assert_close(actual_stat.reduce(ops.add), expected_stat.reduce(ops.add), atol=atol, rtol=None) + if expected_stat.data.isnan().all(): + pytest.xfail(reason="base stat returns nan") + else: + assert_close(actual_stat.reduce(ops.add), expected_stat.reduce(ops.add), atol=atol, rtol=None) @pytest.mark.skipif(get_backend() != "torch", reason="not working yet") @@ -389,6 +395,8 @@ def test_generic_grads(case, statistic, sample_shape): def _get_stat_diff_fn(raw_dist): with xfail_if_not_implemented(msg="entropy not implemented for some distributions"): actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic) + if expected_stat.data.isnan().all(): + pytest.xfail(reason="base stat returns nan") return to_data((actual_stat - expected_stat).sum()) if get_backend() == "torch": From 35d283e7f1f79ae47745cd2774f5061e85a17e27 Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 31 Oct 2020 16:27:18 -0400 Subject: [PATCH 25/38] fix error in gradient test --- test/test_distribution_generic.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index b366adfd3..0b9e81f9c 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -397,12 +397,13 @@ def _get_stat_diff_fn(raw_dist): actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic) if expected_stat.data.isnan().all(): pytest.xfail(reason="base stat returns nan") - return to_data((actual_stat - expected_stat).sum()) + return to_data((actual_stat - expected_stat).reduce(ops.add).sum()) if get_backend() == "torch": import torch params = tuple(getattr(case, param) for param, _ in case.raw_params) + params = tuple(param for param in params if isinstance(param, torch.Tensor)) for param in params: param.requires_grad_() From 34919a4e25df6553c4104203c465641b447b5f3f Mon Sep 17 00:00:00 2001 From: Eli Date: Sat, 31 Oct 2020 16:59:19 -0400 Subject: [PATCH 26/38] simplify binomial test case, remove delta test cases, weaken mc grads test --- test/test_distribution_generic.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 0b9e81f9c..c016ff04b 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -89,7 +89,7 @@ def __hash__(self): # Binomial TEST_CASES += [DistTestCase( "backend_dist.Binomial(total_count=case.total_count, probs=case.probs)", - (("total_count", f"ops.astype(randint(0, 10, {batch_shape}), 'float')"), ("probs", f"rand({batch_shape})")), + (("total_count", f"10"), ("probs", f"rand({batch_shape})")), funsor.Real, )] @@ -109,13 +109,14 @@ def __hash__(self): funsor.Bint[size], )] - # Delta - for event_shape in [(), (4,), (3, 2)]: - TEST_CASES += [DistTestCase( - "backend_dist.Delta(case.v, case.log_density)", - (("v", f"rand({batch_shape + event_shape})"), ("log_density", f"rand({batch_shape})")), - funsor.Reals[event_shape], - )] + # TODO figure out what this should be... + # # Delta + # for event_shape in [(),]: # (4,), (3, 2)]: + # TEST_CASES += [DistTestCase( + # "backend_dist.Delta(case.v, case.log_density)", + # (("v", f"rand({batch_shape + event_shape})"), ("log_density", f"rand({batch_shape})")), + # funsor.Real, # s[event_shape], + # )] # Dirichlet for event_shape in [(1,), (4,)]: @@ -390,7 +391,7 @@ def test_generic_grads(case, statistic, sample_shape): raw_dist = eval(case.raw_dist) - atol = 1e-2 + atol = 1e-2 if raw_dist.has_rsample else 1e-1 def _get_stat_diff_fn(raw_dist): with xfail_if_not_implemented(msg="entropy not implemented for some distributions"): From bc0a518dd09fff2a143bc57cb2ab25b6d0aea3e7 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 3 Nov 2020 14:51:31 -0500 Subject: [PATCH 27/38] fix bug in jax binomial conversion --- funsor/jax/distributions.py | 2 +- test/test_distribution_generic.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index 70b6b2aec..b10af0e2f 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -178,7 +178,7 @@ def _infer_param_domain(cls, name, raw_shape): @to_funsor.register(dist.BinomialProbs) @to_funsor.register(dist.BinomialLogits) def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): - new_pyro_dist = _NumPyroWrapper_Binomial(probs=numpyro_dist.probs) + new_pyro_dist = _NumPyroWrapper_Binomial(total_count=numpyro_dist.total_count, probs=numpyro_dist.probs) return backenddist_to_funsor(new_pyro_dist, output, dim_to_name) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index c016ff04b..b25ed237e 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -387,7 +387,7 @@ def test_generic_stats_sample(case, statistic, sample_shape): "variance", pytest.param("entropy", marks=[pytest.mark.skipif(get_backend() == "jax", reason="entropy not implemented")]) ]) -def test_generic_grads(case, statistic, sample_shape): +def test_generic_grads_sample(case, statistic, sample_shape): raw_dist = eval(case.raw_dist) From 81f8a37c35a5e63a81d647a19555bb6308148e3a Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 3 Nov 2020 15:39:50 -0500 Subject: [PATCH 28/38] add ops.isnan and fix jax multinomial conversion --- funsor/jax/distributions.py | 2 +- funsor/jax/ops.py | 5 +++++ funsor/ops/array.py | 2 ++ funsor/torch/ops.py | 5 +++++ test/test_distribution_generic.py | 8 ++++---- 5 files changed, 17 insertions(+), 5 deletions(-) diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index b10af0e2f..d040aa3e9 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -194,7 +194,7 @@ def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): @to_funsor.register(dist.MultinomialProbs) @to_funsor.register(dist.MultinomialLogits) def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): - new_pyro_dist = _NumPyroWrapper_Multinomial(probs=numpyro_dist.probs) + new_pyro_dist = _NumPyroWrapper_Multinomial(total_count=numpyro_dist.total_count, probs=numpyro_dist.probs) return backenddist_to_funsor(new_pyro_dist, output, dim_to_name) diff --git a/funsor/jax/ops.py b/funsor/jax/ops.py index d709de2ed..6daa1edd2 100644 --- a/funsor/jax/ops.py +++ b/funsor/jax/ops.py @@ -121,6 +121,11 @@ def _is_numeric_array(x): return True +@ops.isnan.register(array) +def _isnan(x): + return np.isnan(x) + + @ops.lgamma.register(array) def _lgamma(x): return gammaln(x) diff --git a/funsor/ops/array.py b/funsor/ops/array.py index 915935553..c38449e76 100644 --- a/funsor/ops/array.py +++ b/funsor/ops/array.py @@ -24,6 +24,7 @@ diagonal = Op("diagonal") einsum = Op("einsum") full_like = Op(np.full_like) +isnan = Op(np.isnan) prod = Op(np.prod) stack = Op("stack") sum = Op(np.sum) @@ -300,6 +301,7 @@ def unsqueeze(x, dim): 'finfo', 'full_like', 'is_numeric_array', + 'isnan', 'logaddexp', 'logsumexp', 'new_arange', diff --git a/funsor/torch/ops.py b/funsor/torch/ops.py index 9a77fe520..d466934f1 100644 --- a/funsor/torch/ops.py +++ b/funsor/torch/ops.py @@ -107,6 +107,11 @@ def _is_numeric_array(x): return True +@ops.isnan.register(torch.Tensor) +def _isnan(x): + return torch.isnan(x) + + @ops.lgamma.register(torch.Tensor) def _lgamma(x): return x.lgamma() diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index b25ed237e..b1321a046 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -89,7 +89,7 @@ def __hash__(self): # Binomial TEST_CASES += [DistTestCase( "backend_dist.Binomial(total_count=case.total_count, probs=case.probs)", - (("total_count", f"10"), ("probs", f"rand({batch_shape})")), + (("total_count", "10"), ("probs", f"rand({batch_shape})")), funsor.Real, )] @@ -351,7 +351,7 @@ def test_generic_stats_smoke(case, statistic): expected_stat = to_funsor(expected_stat_raw, output=case.expected_value_domain, dim_to_name=dim_to_name) check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) - if expected_stat.data.isnan().all(): + if ops.isnan(expected_stat.data).all(): pytest.xfail(reason="base stat returns nan") else: assert_close(to_data(actual_stat, name_to_dim), to_data(expected_stat, name_to_dim), rtol=1e-4) @@ -373,7 +373,7 @@ def test_generic_stats_sample(case, statistic, sample_shape): actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic) check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) - if expected_stat.data.isnan().all(): + if ops.isnan(expected_stat.data).all(): pytest.xfail(reason="base stat returns nan") else: assert_close(actual_stat.reduce(ops.add), expected_stat.reduce(ops.add), atol=atol, rtol=None) @@ -396,7 +396,7 @@ def test_generic_grads_sample(case, statistic, sample_shape): def _get_stat_diff_fn(raw_dist): with xfail_if_not_implemented(msg="entropy not implemented for some distributions"): actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic) - if expected_stat.data.isnan().all(): + if ops.isnan(expected_stat.data).all(): pytest.xfail(reason="base stat returns nan") return to_data((actual_stat - expected_stat).reduce(ops.add).sum()) From bdff5d45143feae4b1d2fd3203c1eb08af0b81cf Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 3 Nov 2020 18:52:00 -0500 Subject: [PATCH 29/38] compatible nomial test cases for both backends --- test/test_distribution_generic.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index b1321a046..f04dfa8dd 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -89,7 +89,7 @@ def __hash__(self): # Binomial TEST_CASES += [DistTestCase( "backend_dist.Binomial(total_count=case.total_count, probs=case.probs)", - (("total_count", "10"), ("probs", f"rand({batch_shape})")), + (("total_count", "randint(10, 12, ())" if get_backend() == "jax" else "5"), ("probs", f"rand({batch_shape})")), funsor.Real, )] @@ -136,7 +136,7 @@ def __hash__(self): for event_shape in [(1,), (4,)]: TEST_CASES += [DistTestCase( "backend_dist.DirichletMultinomial(case.concentration, case.total_count)", - (("concentration", f"rand({batch_shape + event_shape})"), ("total_count", "10")), + (("concentration", f"rand({batch_shape + event_shape})"), ("total_count", "randint(10, 12, ())")), funsor.Reals[event_shape], )] @@ -157,7 +157,8 @@ def __hash__(self): for event_shape in [(1,), (4,)]: TEST_CASES += [DistTestCase( "backend_dist.Multinomial(case.total_count, probs=case.probs)", - (("total_count", "5"), ("probs", f"rand({batch_shape + event_shape})")), + (("total_count", "randint(5, 7, ())" if get_backend() == "jax" else "5"), + ("probs", f"rand({batch_shape + event_shape})")), funsor.Reals[event_shape], )] From 3d2ff83c533d7fd59889637f861ad17e382d83f3 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 3 Nov 2020 18:57:18 -0500 Subject: [PATCH 30/38] deterministic tests pass for both backends --- funsor/jax/distributions.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/funsor/jax/distributions.py b/funsor/jax/distributions.py index d040aa3e9..9557d7972 100644 --- a/funsor/jax/distributions.py +++ b/funsor/jax/distributions.py @@ -183,9 +183,6 @@ def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): @to_funsor.register(dist.CategoricalProbs) -# XXX: in Pyro backend, we always convert pyro.distributions.Categorical -# to funsor.torch.distributions.Categorical -@to_funsor.register(dist.CategoricalLogits) def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None): new_pyro_dist = _NumPyroWrapper_Categorical(probs=numpyro_dist.probs) return backenddist_to_funsor(new_pyro_dist, output, dim_to_name) From f46b22b252a4ffbf377e73ee5d044e557eb50265 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 3 Nov 2020 19:17:57 -0500 Subject: [PATCH 31/38] add environment variable to disable Monte Carlo tests --- test/test_distribution_generic.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index f04dfa8dd..1445264dc 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -1,6 +1,7 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +import os from collections import OrderedDict from importlib import import_module @@ -16,6 +17,9 @@ from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_scale_tril, xfail_if_not_implemented # noqa: F401,E501 from funsor.util import get_backend + +_ENABLE_MC_DIST_TESTS = int(os.environ.get("FUNSOR_ENABLE_MC_DIST_TESTS", 0)) + pytestmark = pytest.mark.skipif(get_backend() == "numpy", reason="numpy does not have distributions backend") if get_backend() != "numpy": @@ -45,8 +49,9 @@ def __init__(self, raw_dist, raw_params, expected_value_domain): self.raw_params = raw_params self.expected_value_domain = expected_value_domain for name, raw_param in self.raw_params: - # we need direct access to these tensors for gradient tests - setattr(self, name, eval(raw_param)) + if get_backend() != "numpy": + # we need direct access to these tensors for gradient tests + setattr(self, name, eval(raw_param)) def __str__(self): return self.raw_dist + " " + str(self.raw_params) @@ -325,6 +330,7 @@ def test_generic_sample(case, sample_shape): rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) sample_value = funsor_dist.sample(frozenset(['value']), sample_inputs, rng_key=rng_key) expected_inputs = OrderedDict(tuple(sample_inputs.items()) + tuple(funsor_dist.inputs.items())) + # TODO compare sample values on jax backend check_funsor(sample_value, expected_inputs, funsor.Real) @@ -334,7 +340,7 @@ def test_generic_sample(case, sample_shape): "variance", pytest.param("entropy", marks=[pytest.mark.skipif(get_backend() == "jax", reason="entropy not implemented")]) ]) -def test_generic_stats_smoke(case, statistic): +def test_generic_stats(case, statistic): raw_dist = eval(case.raw_dist) @@ -358,8 +364,9 @@ def test_generic_stats_smoke(case, statistic): assert_close(to_data(actual_stat, name_to_dim), to_data(expected_stat, name_to_dim), rtol=1e-4) +@pytest.mark.skipif(not _ENABLE_MC_DIST_TESTS, reason="slow and finicky Monte Carlo tests disabled by default") @pytest.mark.parametrize("case", TEST_CASES, ids=str) -@pytest.mark.parametrize("sample_shape", [(), (200000,), (400, 400)], ids=str) +@pytest.mark.parametrize("sample_shape", [(200000,), (400, 400)], ids=str) @pytest.mark.parametrize("statistic", [ "mean", "variance", @@ -380,7 +387,8 @@ def test_generic_stats_sample(case, statistic, sample_shape): assert_close(actual_stat.reduce(ops.add), expected_stat.reduce(ops.add), atol=atol, rtol=None) -@pytest.mark.skipif(get_backend() != "torch", reason="not working yet") +@pytest.mark.skipif(not _ENABLE_MC_DIST_TESTS, reason="slow and finicky Monte Carlo tests disabled by default") +@pytest.mark.skipif(get_backend() != "torch", reason="not working yet on jax") @pytest.mark.parametrize("case", TEST_CASES, ids=str) @pytest.mark.parametrize("sample_shape", [(200000,), (400, 400)], ids=str) @pytest.mark.parametrize("statistic", [ From e35ad5ca01996b5534ce2a38d8a9158b2fef3577 Mon Sep 17 00:00:00 2001 From: Eli Date: Tue, 3 Nov 2020 19:27:22 -0500 Subject: [PATCH 32/38] deduplicate random_scale_tril --- test/test_distribution.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/test/test_distribution.py b/test/test_distribution.py index f789414b4..0f673957e 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -19,7 +19,8 @@ from funsor.interpreter import interpretation, reinterpret from funsor.tensor import Einsum, Tensor, numeric_array, stack from funsor.terms import Independent, Variable, eager, lazy, to_funsor -from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_mvn, random_tensor, xfail_param +from funsor.testing import assert_close, check_funsor, rand, randint, randn, \ + random_mvn, random_scale_tril, random_tensor, xfail_param from funsor.util import get_backend pytestmark = pytest.mark.skipif(get_backend() == "numpy", @@ -472,15 +473,6 @@ def test_mvn_defaults(): assert dist.MultivariateNormal(loc, scale_tril) is dist.MultivariateNormal(loc, scale_tril, value) -def _random_scale_tril(shape): - if get_backend() == "torch": - data = randn(shape) - return backend_dist.transforms.transform_to(backend_dist.constraints.lower_cholesky)(data) - else: - data = randn(shape[:-2] + (shape[-1] * (shape[-1] + 1) // 2,)) - return backend_dist.biject_to(backend_dist.constraints.lower_cholesky)(data) - - @pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str) def test_mvn_density(batch_shape): batch_dims = ('i', 'j', 'k')[:len(batch_shape)] @@ -493,7 +485,7 @@ def mvn(loc: Reals[3], scale_tril: Reals[3, 3], value: Reals[3]) -> Real: check_funsor(mvn, {'loc': Reals[3], 'scale_tril': Reals[3, 3], 'value': Reals[3]}, Real) loc = Tensor(randn(batch_shape + (3,)), inputs) - scale_tril = Tensor(_random_scale_tril(batch_shape + (3, 3)), inputs) + scale_tril = Tensor(random_scale_tril(batch_shape + (3, 3)), inputs) value = Tensor(randn(batch_shape + (3,)), inputs) expected = mvn(loc, scale_tril, value) check_funsor(expected, inputs, Real) @@ -509,7 +501,7 @@ def test_mvn_gaussian(batch_shape): inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) loc = Tensor(randn(batch_shape + (3,)), inputs) - scale_tril = Tensor(_random_scale_tril(batch_shape + (3, 3)), inputs) + scale_tril = Tensor(random_scale_tril(batch_shape + (3, 3)), inputs) value = Tensor(randn(batch_shape + (3,)), inputs) expected = dist.MultivariateNormal(loc, scale_tril, value) @@ -808,7 +800,7 @@ def test_mvn_sample(with_lazy, batch_shape, sample_inputs, event_shape): inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape)) loc = randn(batch_shape + event_shape) - scale_tril = _random_scale_tril(batch_shape + event_shape * 2) + scale_tril = random_scale_tril(batch_shape + event_shape * 2) funsor_dist_class = dist.MultivariateNormal params = (loc, scale_tril) From 44e77831bbfc9e3b55adfe936bc539e0d3428bb0 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 4 Nov 2020 00:44:57 -0500 Subject: [PATCH 33/38] tweak tolerance on binomial --- test/test_distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_distribution.py b/test/test_distribution.py index 0f673957e..4d02de93f 100644 --- a/test/test_distribution.py +++ b/test/test_distribution.py @@ -885,7 +885,7 @@ def test_binomial_sample(with_lazy, batch_shape, sample_inputs): funsor_dist_class = dist.Binomial params = (total_count, probs) - _check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=2e-2, skip_grad=True, with_lazy=with_lazy) + _check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=5e-2, skip_grad=True, with_lazy=with_lazy) @pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')]) From 4f3268614bf689739ad71039a3eb7002f632c195 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 5 Nov 2020 13:49:03 -0500 Subject: [PATCH 34/38] try splitting up make test command to unbreak travis --- Makefile | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 46bd0306c..9359e11d4 100644 --- a/Makefile +++ b/Makefile @@ -20,7 +20,9 @@ format: FORCE test: lint FORCE ifeq (${FUNSOR_BACKEND}, torch) - pytest -v -n auto test/ + pytest -v -n auto test/ --ignore=test/test_distribution.py --ignore=test/test_distribution_generic.py + pytest -v -n auto test/test_distribution.py + pytest -v -n auto test/test_distribution_generic.py FUNSOR_DEBUG=1 pytest -v test/test_gaussian.py FUNSOR_USE_TCO=1 pytest -v test/test_terms.py FUNSOR_USE_TCO=1 pytest -v test/test_einsum.py @@ -43,7 +45,9 @@ ifeq (${FUNSOR_BACKEND}, torch) python examples/sensor.py --seed=0 --num-frames=2 -n 1 @echo PASS else ifeq (${FUNSOR_BACKEND}, jax) - pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi + pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi --ignore=test/test_distribution.py --ignore=test/test_distribution_generic.py + pytest -v -n auto test/test_distribution.py + pytest -v -n auto test/test_distribution_generic.py @echo PASS else # default backend From a0dfc669bf9a7cbb9b3c517d40ff2619534361f2 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 5 Nov 2020 13:49:42 -0500 Subject: [PATCH 35/38] nit --- Makefile | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 9359e11d4..20e6681c1 100644 --- a/Makefile +++ b/Makefile @@ -20,9 +20,7 @@ format: FORCE test: lint FORCE ifeq (${FUNSOR_BACKEND}, torch) - pytest -v -n auto test/ --ignore=test/test_distribution.py --ignore=test/test_distribution_generic.py - pytest -v -n auto test/test_distribution.py - pytest -v -n auto test/test_distribution_generic.py + pytest -v -n auto test/ FUNSOR_DEBUG=1 pytest -v test/test_gaussian.py FUNSOR_USE_TCO=1 pytest -v test/test_terms.py FUNSOR_USE_TCO=1 pytest -v test/test_einsum.py From ce6465a60f03609be8f98fd26a31adcc80390eb7 Mon Sep 17 00:00:00 2001 From: Eli Date: Thu, 5 Nov 2020 13:56:33 -0500 Subject: [PATCH 36/38] automatically append DistTestCase instances to TEST_CASE list --- test/test_distribution_generic.py | 78 ++++++++++++++++--------------- 1 file changed, 40 insertions(+), 38 deletions(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 1445264dc..0aa1a8690 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -42,6 +42,9 @@ def __getattribute__(self, attr): # Test cases ################################################## +TEST_CASES = [] + + class DistTestCase: def __init__(self, raw_dist, raw_params, expected_value_domain): @@ -52,6 +55,7 @@ def __init__(self, raw_dist, raw_params, expected_value_domain): if get_backend() != "numpy": # we need direct access to these tensors for gradient tests setattr(self, name, eval(raw_param)) + TEST_CASES.append(self) def __str__(self): return self.raw_dist + " " + str(self.raw_params) @@ -60,59 +64,57 @@ def __hash__(self): return hash((self.raw_dist, self.raw_params, self.expected_value_domain)) -TEST_CASES = [] - for batch_shape in [(), (5,), (2, 3)]: # BernoulliLogits - TEST_CASES += [DistTestCase( + DistTestCase( "backend_dist.Bernoulli(logits=case.logits)", (("logits", f"rand({batch_shape})"),), funsor.Real, - )] + ) # BernoulliProbs - TEST_CASES += [DistTestCase( + DistTestCase( "backend_dist.Bernoulli(probs=case.probs)", (("probs", f"rand({batch_shape})"),), funsor.Real, - )] + ) # Beta - TEST_CASES += [DistTestCase( + DistTestCase( "backend_dist.Beta(case.concentration1, case.concentration0)", (("concentration1", f"ops.exp(randn({batch_shape}))"), ("concentration0", f"ops.exp(randn({batch_shape}))")), funsor.Real, - )] + ) # NonreparameterizedBeta - TEST_CASES += [DistTestCase( + DistTestCase( "FAKES.NonreparameterizedBeta(case.concentration1, case.concentration0)", (("concentration1", f"ops.exp(randn({batch_shape}))"), ("concentration0", f"ops.exp(randn({batch_shape}))")), funsor.Real, - )] + ) # Binomial - TEST_CASES += [DistTestCase( + DistTestCase( "backend_dist.Binomial(total_count=case.total_count, probs=case.probs)", (("total_count", "randint(10, 12, ())" if get_backend() == "jax" else "5"), ("probs", f"rand({batch_shape})")), funsor.Real, - )] + ) # CategoricalLogits for size in [2, 4]: - TEST_CASES += [DistTestCase( + DistTestCase( "backend_dist.Categorical(logits=case.logits)", (("logits", f"rand({batch_shape + (size,)})"),), funsor.Bint[size], - )] + ) # CategoricalProbs for size in [2, 4]: - TEST_CASES += [DistTestCase( + DistTestCase( "backend_dist.Categorical(probs=case.probs)", (("probs", f"rand({batch_shape + (size,)})"),), funsor.Bint[size], - )] + ) # TODO figure out what this should be... # # Delta @@ -125,82 +127,82 @@ def __hash__(self): # Dirichlet for event_shape in [(1,), (4,)]: - TEST_CASES += [DistTestCase( + DistTestCase( "backend_dist.Dirichlet(case.concentration)", (("concentration", f"rand({batch_shape + event_shape})"),), funsor.Reals[event_shape], - )] + ) # NonreparameterizedDirichlet - TEST_CASES += [DistTestCase( + DistTestCase( "FAKES.NonreparameterizedDirichlet(case.concentration)", (("concentration", f"rand({batch_shape + event_shape})"),), funsor.Reals[event_shape], - )] + ) # DirichletMultinomial for event_shape in [(1,), (4,)]: - TEST_CASES += [DistTestCase( + DistTestCase( "backend_dist.DirichletMultinomial(case.concentration, case.total_count)", (("concentration", f"rand({batch_shape + event_shape})"), ("total_count", "randint(10, 12, ())")), funsor.Reals[event_shape], - )] + ) # Gamma - TEST_CASES += [DistTestCase( + DistTestCase( "backend_dist.Gamma(case.concentration, case.rate)", (("concentration", f"rand({batch_shape})"), ("rate", f"rand({batch_shape})")), funsor.Real, - )] + ) # NonreparametrizedGamma - TEST_CASES += [DistTestCase( + DistTestCase( "FAKES.NonreparameterizedGamma(case.concentration, case.rate)", (("concentration", f"rand({batch_shape})"), ("rate", f"rand({batch_shape})")), funsor.Real, - )] + ) # Multinomial for event_shape in [(1,), (4,)]: - TEST_CASES += [DistTestCase( + DistTestCase( "backend_dist.Multinomial(case.total_count, probs=case.probs)", (("total_count", "randint(5, 7, ())" if get_backend() == "jax" else "5"), ("probs", f"rand({batch_shape + event_shape})")), funsor.Reals[event_shape], - )] + ) # MultivariateNormal for event_shape in [(1,), (3,)]: - TEST_CASES += [DistTestCase( + DistTestCase( "backend_dist.MultivariateNormal(loc=case.loc, scale_tril=case.scale_tril)", (("loc", f"randn({batch_shape + event_shape})"), ("scale_tril", f"random_scale_tril({batch_shape + event_shape * 2})")), # noqa: E501 funsor.Reals[event_shape], - )] + ) # Normal - TEST_CASES += [DistTestCase( + DistTestCase( "backend_dist.Normal(case.loc, case.scale)", (("loc", f"randn({batch_shape})"), ("scale", f"rand({batch_shape})")), funsor.Real, - )] + ) # NonreparameterizedNormal - TEST_CASES += [DistTestCase( + DistTestCase( "FAKES.NonreparameterizedNormal(case.loc, case.scale)", (("loc", f"randn({batch_shape})"), ("scale", f"rand({batch_shape})")), funsor.Real, - )] + ) # Poisson - TEST_CASES += [DistTestCase( + DistTestCase( "backend_dist.Poisson(rate=case.rate)", (("rate", f"rand({batch_shape})"),), funsor.Real, - )] + ) # VonMises - TEST_CASES += [DistTestCase( + DistTestCase( "backend_dist.VonMises(case.loc, case.concentration)", (("loc", f"rand({batch_shape})"), ("concentration", f"rand({batch_shape})")), funsor.Real, - )] + ) ########################### From c408c7164c7f28d93b9d7dab489be00dc5e25fb6 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 11 Nov 2020 12:13:00 -0500 Subject: [PATCH 37/38] Remove sample/gradient tests from this PR --- test/test_distribution_generic.py | 105 ------------------------------ 1 file changed, 105 deletions(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 0aa1a8690..644ee6fca 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -221,42 +221,6 @@ def _default_dim_to_name(inputs_shape, event_inputs=None): return dim_to_name, name_to_dim -def _get_stat(raw_dist, sample_shape, statistic): - dim_to_name, name_to_dim = _default_dim_to_name(sample_shape + raw_dist.batch_shape) - with interpretation(lazy): - funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name) - - sample_inputs = OrderedDict((dim_to_name[dim - len(raw_dist.batch_shape)], funsor.Bint[sample_shape[dim]]) - for dim in range(-len(sample_shape), 0)) - rng_key = None if get_backend() == "torch" else np.array([0, 0], dtype=np.uint32) - sample_value = funsor_dist.sample(frozenset(['value']), sample_inputs, rng_key=rng_key) - expected_inputs = OrderedDict(tuple(sample_inputs.items()) + tuple(funsor_dist.inputs.items())) - check_funsor(sample_value, expected_inputs, funsor.Real) - - expected_stat = getattr(funsor_dist, statistic)() - if statistic == "mean": - actual_stat = Integrate( - sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) - ).reduce(ops.add, frozenset(sample_inputs)) - elif statistic == "variance": - actual_mean = Integrate( - sample_value, Variable('value', funsor_dist.inputs['value']), frozenset(['value']) - ).reduce(ops.add, frozenset(sample_inputs)) - actual_stat = Integrate( - sample_value, - (Variable('value', funsor_dist.inputs['value']) - actual_mean) ** 2, - frozenset(['value']) - ).reduce(ops.add, frozenset(sample_inputs)) - elif statistic == "entropy": - actual_stat = -Integrate( - sample_value, funsor_dist, frozenset(['value']) - ).reduce(ops.add, frozenset(sample_inputs)) - else: - raise ValueError("invalid test statistic: {}".format(statistic)) - - return actual_stat, expected_stat - - @pytest.mark.parametrize("case", TEST_CASES, ids=str) def test_generic_distribution_to_funsor(case): @@ -364,72 +328,3 @@ def test_generic_stats(case, statistic): pytest.xfail(reason="base stat returns nan") else: assert_close(to_data(actual_stat, name_to_dim), to_data(expected_stat, name_to_dim), rtol=1e-4) - - -@pytest.mark.skipif(not _ENABLE_MC_DIST_TESTS, reason="slow and finicky Monte Carlo tests disabled by default") -@pytest.mark.parametrize("case", TEST_CASES, ids=str) -@pytest.mark.parametrize("sample_shape", [(200000,), (400, 400)], ids=str) -@pytest.mark.parametrize("statistic", [ - "mean", - "variance", - pytest.param("entropy", marks=[pytest.mark.skipif(get_backend() == "jax", reason="entropy not implemented")]) -]) -def test_generic_stats_sample(case, statistic, sample_shape): - - raw_dist = eval(case.raw_dist) - - atol = 1e-2 - with xfail_if_not_implemented(msg="entropy not implemented for some distributions"): - actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic) - - check_funsor(actual_stat, expected_stat.inputs, expected_stat.output) - if ops.isnan(expected_stat.data).all(): - pytest.xfail(reason="base stat returns nan") - else: - assert_close(actual_stat.reduce(ops.add), expected_stat.reduce(ops.add), atol=atol, rtol=None) - - -@pytest.mark.skipif(not _ENABLE_MC_DIST_TESTS, reason="slow and finicky Monte Carlo tests disabled by default") -@pytest.mark.skipif(get_backend() != "torch", reason="not working yet on jax") -@pytest.mark.parametrize("case", TEST_CASES, ids=str) -@pytest.mark.parametrize("sample_shape", [(200000,), (400, 400)], ids=str) -@pytest.mark.parametrize("statistic", [ - "mean", - "variance", - pytest.param("entropy", marks=[pytest.mark.skipif(get_backend() == "jax", reason="entropy not implemented")]) -]) -def test_generic_grads_sample(case, statistic, sample_shape): - - raw_dist = eval(case.raw_dist) - - atol = 1e-2 if raw_dist.has_rsample else 1e-1 - - def _get_stat_diff_fn(raw_dist): - with xfail_if_not_implemented(msg="entropy not implemented for some distributions"): - actual_stat, expected_stat = _get_stat(raw_dist, sample_shape, statistic) - if ops.isnan(expected_stat.data).all(): - pytest.xfail(reason="base stat returns nan") - return to_data((actual_stat - expected_stat).reduce(ops.add).sum()) - - if get_backend() == "torch": - import torch - - params = tuple(getattr(case, param) for param, _ in case.raw_params) - params = tuple(param for param in params if isinstance(param, torch.Tensor)) - for param in params: - param.requires_grad_() - - diff = _get_stat_diff_fn(raw_dist) - assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) - diff_grads = torch.autograd.grad(diff, params, allow_unused=True) - for diff_grad in diff_grads: - assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) - - elif get_backend() == "jax": - import jax - - # TODO compute gradient wrt distribution instance PyTree - diff, diff_grads = jax.value_and_grad(lambda *args: _get_stat_diff_fn(*args).sum(), has_aux=True)(params) - assert_close(diff, ops.new_zeros(diff, diff.shape), atol=atol, rtol=None) - for diff_grad in diff_grads: - assert_close(diff_grad, ops.new_zeros(diff_grad, diff_grad.shape), atol=atol, rtol=None) From 7227e19dd62c67ae113ad5658585d46b65bac1c6 Mon Sep 17 00:00:00 2001 From: Eli Date: Wed, 11 Nov 2020 12:13:45 -0500 Subject: [PATCH 38/38] lint --- test/test_distribution_generic.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/test_distribution_generic.py b/test/test_distribution_generic.py index 644ee6fca..356f109e4 100644 --- a/test/test_distribution_generic.py +++ b/test/test_distribution_generic.py @@ -11,9 +11,8 @@ import funsor import funsor.ops as ops from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND -from funsor.integrate import Integrate from funsor.interpreter import interpretation -from funsor.terms import Variable, lazy, to_data, to_funsor +from funsor.terms import lazy, to_data, to_funsor from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_scale_tril, xfail_if_not_implemented # noqa: F401,E501 from funsor.util import get_backend