Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automate distribution testing #389

Merged
merged 40 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0f904b9
add generic stat methods to distribution
eb8680 Oct 29, 2020
3ad5c79
start some generic distribution test functions
eb8680 Oct 30, 2020
4fbc801
more cleanup of _check_sample
eb8680 Oct 30, 2020
ecad0ee
move generic harness into separate file for now
eb8680 Oct 30, 2020
0c2c0fb
work up to actual tests
eb8680 Oct 30, 2020
56ace75
lint
eb8680 Oct 30, 2020
bace658
nits
eb8680 Oct 30, 2020
c38d560
more nits
eb8680 Oct 30, 2020
040f417
make case.raw_dist a string and eval it
eb8680 Oct 30, 2020
4f40aaf
nit
eb8680 Oct 30, 2020
fb75a15
add a bunch of test cases, most passing
eb8680 Oct 30, 2020
e6dcabd
rename file
eb8680 Oct 30, 2020
8be24c6
gradient test at least runs for pytorch backend...
eb8680 Oct 30, 2020
50f0343
break up sample test into smoke and stats
eb8680 Oct 30, 2020
6f69e62
remove with_lazy
eb8680 Oct 30, 2020
a06744e
get basic jax tests passing
eb8680 Oct 31, 2020
e01de62
Merge branch 'master' into distribution-test-harness
eb8680 Oct 31, 2020
5a8fd42
alphabetize test cases and fix incorrect mvn spec
eb8680 Oct 31, 2020
127b12c
break up stat test
eb8680 Oct 31, 2020
80750b0
binomial test case
eb8680 Oct 31, 2020
3ceadfd
add more test cases
eb8680 Oct 31, 2020
151ca2b
add dirichletmultinomial test case for parity
eb8680 Oct 31, 2020
06f6d06
add CategoricalLogits test case
eb8680 Oct 31, 2020
d1e8af5
custom conversion for multinomial
eb8680 Oct 31, 2020
39ced3a
use get_raw_dist in sample, add appropriate xfails
eb8680 Oct 31, 2020
35d283e
fix error in gradient test
eb8680 Oct 31, 2020
34919a4
simplify binomial test case, remove delta test cases, weaken mc grads…
eb8680 Oct 31, 2020
bc0a518
fix bug in jax binomial conversion
eb8680 Nov 3, 2020
81f8a37
add ops.isnan and fix jax multinomial conversion
eb8680 Nov 3, 2020
bdff5d4
compatible nomial test cases for both backends
eb8680 Nov 3, 2020
3d2ff83
deterministic tests pass for both backends
eb8680 Nov 3, 2020
f46b22b
add environment variable to disable Monte Carlo tests
eb8680 Nov 4, 2020
e35ad5c
deduplicate random_scale_tril
eb8680 Nov 4, 2020
44e7783
tweak tolerance on binomial
eb8680 Nov 4, 2020
4de34b6
Merge branch 'master' into distribution-test-harness
eb8680 Nov 5, 2020
4f32686
try splitting up make test command to unbreak travis
eb8680 Nov 5, 2020
a0dfc66
nit
eb8680 Nov 5, 2020
ce6465a
automatically append DistTestCase instances to TEST_CASE list
eb8680 Nov 5, 2020
c408c71
Remove sample/gradient tests from this PR
eb8680 Nov 11, 2020
7227e19
lint
eb8680 Nov 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ ifeq (${FUNSOR_BACKEND}, torch)
python examples/sensor.py --seed=0 --num-frames=2 -n 1
@echo PASS
else ifeq (${FUNSOR_BACKEND}, jax)
pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi
pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi --ignore=test/test_distribution.py --ignore=test/test_distribution_generic.py
pytest -v -n auto test/test_distribution.py
pytest -v -n auto test/test_distribution_generic.py
@echo PASS
else
# default backend
Expand Down
25 changes: 15 additions & 10 deletions funsor/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,23 +134,28 @@ def unscaled_sample(self, sampled_vars, sample_inputs, rng_key=None):
value = params.pop("value")
assert all(isinstance(v, (Number, Tensor)) for v in params.values())
assert isinstance(value, Variable) and value.name in sampled_vars
inputs_, tensors = align_tensors(*params.values())
inputs = OrderedDict(sample_inputs.items())
inputs.update(inputs_)
sample_shape = tuple(v.size for v in sample_inputs.values())

raw_dist = self.dist_class(**dict(zip(self._ast_fields[:-1], tensors)))
value_name = value.name
raw_dist, value_output, dim_to_name = self._get_raw_dist()
for d, name in zip(range(len(sample_inputs), 0, -1), sample_inputs.keys()):
dim_to_name[-d - len(raw_dist.batch_shape)] = name

sample_shape = tuple(v.size for v in sample_inputs.values())
sample_args = (sample_shape,) if get_backend() == "torch" else (rng_key, sample_shape)
if self.has_rsample:
raw_sample = raw_dist.rsample(*sample_args)
raw_value = raw_dist.rsample(*sample_args)
else:
raw_sample = ops.detach(raw_dist.sample(*sample_args))
raw_value = ops.detach(raw_dist.sample(*sample_args))

result = funsor.delta.Delta(value.name, Tensor(raw_sample, inputs, value.output.dtype))
funsor_value = to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had to update funsor.distribution.Distribution.unscaled_sample to use to_funsor and to_data throughout to get some tests to pass.

funsor_value = funsor_value.align(
tuple(sample_inputs) + tuple(inp for inp in self.inputs if inp in funsor_value.inputs))
result = funsor.delta.Delta(value_name, funsor_value)
if not self.has_rsample:
# scaling of dice_factor by num samples should already be handled by Funsor.sample
raw_log_prob = raw_dist.log_prob(raw_sample)
dice_factor = Tensor(raw_log_prob - ops.detach(raw_log_prob), inputs)
raw_log_prob = raw_dist.log_prob(raw_value)
dice_factor = to_funsor(raw_log_prob - ops.detach(raw_log_prob),
output=self.output, dim_to_name=dim_to_name)
result = result + dice_factor
return result

Expand Down
7 changes: 2 additions & 5 deletions funsor/jax/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,11 @@ def _infer_param_domain(cls, name, raw_shape):
@to_funsor.register(dist.BinomialProbs)
@to_funsor.register(dist.BinomialLogits)
def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None):
new_pyro_dist = _NumPyroWrapper_Binomial(probs=numpyro_dist.probs)
new_pyro_dist = _NumPyroWrapper_Binomial(total_count=numpyro_dist.total_count, probs=numpyro_dist.probs)
return backenddist_to_funsor(Binomial, new_pyro_dist, output, dim_to_name) # noqa: F821


@to_funsor.register(dist.CategoricalProbs)
# XXX: in Pyro backend, we always convert pyro.distributions.Categorical
# to funsor.torch.distributions.Categorical
@to_funsor.register(dist.CategoricalLogits)
def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None):
new_pyro_dist = _NumPyroWrapper_Categorical(probs=numpyro_dist.probs)
return backenddist_to_funsor(Categorical, new_pyro_dist, output, dim_to_name) # noqa: F821
Expand All @@ -191,7 +188,7 @@ def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None):
@to_funsor.register(dist.MultinomialProbs)
@to_funsor.register(dist.MultinomialLogits)
def categorical_to_funsor(numpyro_dist, output=None, dim_to_name=None):
new_pyro_dist = _NumPyroWrapper_Multinomial(probs=numpyro_dist.probs)
new_pyro_dist = _NumPyroWrapper_Multinomial(total_count=numpyro_dist.total_count, probs=numpyro_dist.probs)
return backenddist_to_funsor(Multinomial, new_pyro_dist, output, dim_to_name) # noqa: F821


Expand Down
5 changes: 5 additions & 0 deletions funsor/jax/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,11 @@ def _is_numeric_array(x):
return True


@ops.isnan.register(array)
def _isnan(x):
return np.isnan(x)


@ops.lgamma.register(array)
def _lgamma(x):
return gammaln(x)
Expand Down
2 changes: 2 additions & 0 deletions funsor/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
diagonal = Op("diagonal")
einsum = Op("einsum")
full_like = Op(np.full_like)
isnan = Op(np.isnan)
prod = Op(np.prod)
stack = Op("stack")
sum = Op(np.sum)
Expand Down Expand Up @@ -300,6 +301,7 @@ def unsqueeze(x, dim):
'finfo',
'full_like',
'is_numeric_array',
'isnan',
'logaddexp',
'logsumexp',
'new_arange',
Expand Down
19 changes: 19 additions & 0 deletions funsor/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import contextlib
import importlib
import itertools
import numbers
import operator
Expand Down Expand Up @@ -265,6 +266,24 @@ def randn(*args):
return np.array(np.random.randn(*shape))


def random_scale_tril(*args):
if isinstance(args[0], tuple):
assert len(args) == 1
shape = args[0]
else:
shape = args

from funsor.distribution import BACKEND_TO_DISTRIBUTIONS_BACKEND
backend_dist = importlib.import_module(BACKEND_TO_DISTRIBUTIONS_BACKEND[get_backend()]).dist

if get_backend() == "torch":
data = randn(shape)
return backend_dist.transforms.transform_to(backend_dist.constraints.lower_cholesky)(data)
else:
data = randn(shape[:-2] + (shape[-1] * (shape[-1] + 1) // 2,))
return backend_dist.biject_to(backend_dist.constraints.lower_cholesky)(data)


def zeros(*args):
if isinstance(args[0], tuple):
assert len(args) == 1
Expand Down
16 changes: 15 additions & 1 deletion funsor/torch/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

import functools
import numbers
from typing import Tuple, Union

import pyro.distributions as dist
Expand Down Expand Up @@ -40,7 +41,7 @@
from funsor.domains import Real, Reals
import funsor.ops as ops
from funsor.tensor import Tensor, dummy_numeric_array
from funsor.terms import Binary, Funsor, Variable, eager, to_funsor
from funsor.terms import Binary, Funsor, Variable, eager, to_data, to_funsor
from funsor.util import methodof


Expand Down Expand Up @@ -153,6 +154,19 @@ def _infer_param_domain(cls, name, raw_shape):
return Real


###########################################################
# Converting distribution funsors to PyTorch distributions
###########################################################

@to_data.register(Multinomial) # noqa: F821
def multinomial_to_data(funsor_dist, name_to_dim=None):
probs = to_data(funsor_dist.probs, name_to_dim)
total_count = to_data(funsor_dist.total_count, name_to_dim)
if isinstance(total_count, numbers.Number) or len(total_count.shape) == 0:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to worry about int(total_count) thwarting PyTorch tracing? Should we preserve scalar Tensors?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this change because torch.distributions.Multinomial raises an error given a tensor total_count, even if it is a scalar tensor, and the generic conversion tests were failing. It's definitely an issue, but there are already lots of JIT issues with scalars and tuple shapes throughout the codebase (encountered while working on pyro.contrib.funsor), and my inclination with this PR and further work on distributions is to avoid special-casing to the greatest extent possible even if that means deferring to odd implementation details in the backends. I think a better fix in this instance would be to allow scalar tensor total_count upstream so this change could be reverted.

This is also a good reminder to add some generic JIT tests for distribution wrappers in a followup PR.

return dist.Multinomial(int(total_count), probs=probs)
raise NotImplementedError("inhomogeneous total_count not supported")


###############################################
# Converting PyTorch Distributions to funsors
###############################################
Expand Down
5 changes: 5 additions & 0 deletions funsor/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,11 @@ def _is_numeric_array(x):
return True


@ops.isnan.register(torch.Tensor)
def _isnan(x):
return torch.isnan(x)


@ops.lgamma.register(torch.Tensor)
def _lgamma(x):
return x.lgamma()
Expand Down
20 changes: 6 additions & 14 deletions test/test_distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from funsor.interpreter import interpretation, reinterpret
from funsor.tensor import Einsum, Tensor, numeric_array, stack
from funsor.terms import Independent, Variable, eager, lazy, to_funsor
from funsor.testing import assert_close, check_funsor, rand, randint, randn, random_mvn, random_tensor, xfail_param
from funsor.testing import assert_close, check_funsor, rand, randint, randn, \
random_mvn, random_scale_tril, random_tensor, xfail_param
from funsor.util import get_backend

pytestmark = pytest.mark.skipif(get_backend() == "numpy",
Expand Down Expand Up @@ -472,15 +473,6 @@ def test_mvn_defaults():
assert dist.MultivariateNormal(loc, scale_tril) is dist.MultivariateNormal(loc, scale_tril, value)


def _random_scale_tril(shape):
if get_backend() == "torch":
data = randn(shape)
return backend_dist.transforms.transform_to(backend_dist.constraints.lower_cholesky)(data)
else:
data = randn(shape[:-2] + (shape[-1] * (shape[-1] + 1) // 2,))
return backend_dist.biject_to(backend_dist.constraints.lower_cholesky)(data)


@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
def test_mvn_density(batch_shape):
batch_dims = ('i', 'j', 'k')[:len(batch_shape)]
Expand All @@ -493,7 +485,7 @@ def mvn(loc: Reals[3], scale_tril: Reals[3, 3], value: Reals[3]) -> Real:
check_funsor(mvn, {'loc': Reals[3], 'scale_tril': Reals[3, 3], 'value': Reals[3]}, Real)

loc = Tensor(randn(batch_shape + (3,)), inputs)
scale_tril = Tensor(_random_scale_tril(batch_shape + (3, 3)), inputs)
scale_tril = Tensor(random_scale_tril(batch_shape + (3, 3)), inputs)
value = Tensor(randn(batch_shape + (3,)), inputs)
expected = mvn(loc, scale_tril, value)
check_funsor(expected, inputs, Real)
Expand All @@ -509,7 +501,7 @@ def test_mvn_gaussian(batch_shape):
inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))

loc = Tensor(randn(batch_shape + (3,)), inputs)
scale_tril = Tensor(_random_scale_tril(batch_shape + (3, 3)), inputs)
scale_tril = Tensor(random_scale_tril(batch_shape + (3, 3)), inputs)
value = Tensor(randn(batch_shape + (3,)), inputs)

expected = dist.MultivariateNormal(loc, scale_tril, value)
Expand Down Expand Up @@ -808,7 +800,7 @@ def test_mvn_sample(with_lazy, batch_shape, sample_inputs, event_shape):
inputs = OrderedDict((k, Bint[v]) for k, v in zip(batch_dims, batch_shape))

loc = randn(batch_shape + event_shape)
scale_tril = _random_scale_tril(batch_shape + event_shape * 2)
scale_tril = random_scale_tril(batch_shape + event_shape * 2)
funsor_dist_class = dist.MultivariateNormal
params = (loc, scale_tril)

Expand Down Expand Up @@ -893,7 +885,7 @@ def test_binomial_sample(with_lazy, batch_shape, sample_inputs):
funsor_dist_class = dist.Binomial
params = (total_count, probs)

_check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=2e-2, skip_grad=True, with_lazy=with_lazy)
_check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=5e-2, skip_grad=True, with_lazy=with_lazy)


@pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')])
Expand Down
Loading