Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Automate distribution testing #389

Merged
merged 40 commits into from
Nov 11, 2020
Merged
Changes from 1 commit
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0f904b9
add generic stat methods to distribution
eb8680 Oct 29, 2020
3ad5c79
start some generic distribution test functions
eb8680 Oct 30, 2020
4fbc801
more cleanup of _check_sample
eb8680 Oct 30, 2020
ecad0ee
move generic harness into separate file for now
eb8680 Oct 30, 2020
0c2c0fb
work up to actual tests
eb8680 Oct 30, 2020
56ace75
lint
eb8680 Oct 30, 2020
bace658
nits
eb8680 Oct 30, 2020
c38d560
more nits
eb8680 Oct 30, 2020
040f417
make case.raw_dist a string and eval it
eb8680 Oct 30, 2020
4f40aaf
nit
eb8680 Oct 30, 2020
fb75a15
add a bunch of test cases, most passing
eb8680 Oct 30, 2020
e6dcabd
rename file
eb8680 Oct 30, 2020
8be24c6
gradient test at least runs for pytorch backend...
eb8680 Oct 30, 2020
50f0343
break up sample test into smoke and stats
eb8680 Oct 30, 2020
6f69e62
remove with_lazy
eb8680 Oct 30, 2020
a06744e
get basic jax tests passing
eb8680 Oct 31, 2020
e01de62
Merge branch 'master' into distribution-test-harness
eb8680 Oct 31, 2020
5a8fd42
alphabetize test cases and fix incorrect mvn spec
eb8680 Oct 31, 2020
127b12c
break up stat test
eb8680 Oct 31, 2020
80750b0
binomial test case
eb8680 Oct 31, 2020
3ceadfd
add more test cases
eb8680 Oct 31, 2020
151ca2b
add dirichletmultinomial test case for parity
eb8680 Oct 31, 2020
06f6d06
add CategoricalLogits test case
eb8680 Oct 31, 2020
d1e8af5
custom conversion for multinomial
eb8680 Oct 31, 2020
39ced3a
use get_raw_dist in sample, add appropriate xfails
eb8680 Oct 31, 2020
35d283e
fix error in gradient test
eb8680 Oct 31, 2020
34919a4
simplify binomial test case, remove delta test cases, weaken mc grads…
eb8680 Oct 31, 2020
bc0a518
fix bug in jax binomial conversion
eb8680 Nov 3, 2020
81f8a37
add ops.isnan and fix jax multinomial conversion
eb8680 Nov 3, 2020
bdff5d4
compatible nomial test cases for both backends
eb8680 Nov 3, 2020
3d2ff83
deterministic tests pass for both backends
eb8680 Nov 3, 2020
f46b22b
add environment variable to disable Monte Carlo tests
eb8680 Nov 4, 2020
e35ad5c
deduplicate random_scale_tril
eb8680 Nov 4, 2020
44e7783
tweak tolerance on binomial
eb8680 Nov 4, 2020
4de34b6
Merge branch 'master' into distribution-test-harness
eb8680 Nov 5, 2020
4f32686
try splitting up make test command to unbreak travis
eb8680 Nov 5, 2020
a0dfc66
nit
eb8680 Nov 5, 2020
ce6465a
automatically append DistTestCase instances to TEST_CASE list
eb8680 Nov 5, 2020
c408c71
Remove sample/gradient tests from this PR
eb8680 Nov 11, 2020
7227e19
lint
eb8680 Nov 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
start some generic distribution test functions
eb8680 committed Oct 30, 2020
commit 3ad5c7997e61f70a252c46b3d4fc45cd42d3e68c
55 changes: 53 additions & 2 deletions test/test_distribution.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@

import functools
import math
from collections import OrderedDict
from collections import OrderedDict, namedtuple
from importlib import import_module

import numpy as np
@@ -728,7 +728,7 @@ def _get_stat_diff(funsor_dist_class, sample_inputs, inputs, num_samples, statis
return diff.sum(), diff


def _check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=1e-2,
def _check_sample(raw_dist, sample_shape=(), atol=1e-2,
num_samples=100000, statistic="mean", skip_grad=False, with_lazy=None):
"""utility that compares a Monte Carlo estimate of a distribution mean with the true mean"""
samples_per_dim = int(num_samples ** (1./max(1, len(sample_inputs))))
@@ -766,6 +766,57 @@ def _check_sample(funsor_dist_class, params, sample_inputs, inputs, atol=1e-2,
_get_stat_diff_fn(params)


def _check_distribution_to_funsor(raw_dist, expected_value_domain):

# TODO specify dim_to_name
funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name)
actual_dist = to_data(funsor_dist, name_to_dim=name_to_dim)

assert isinstance(actual_dist, backend_dist.Distribution)
assert type(raw_dist) == type(actual_dist)
assert funsor_dist.inputs["value"] == expected_value_domain
for param_name in funsor_dist.params.keys():
if param_name == "value":
continue
assert hasattr(raw_dist, param_name)
assert_close(getattr(actual_dist, param_name), getattr(raw_dist, param_name))


def _check_log_prob(raw_dist, expected_value_domain):

# TODO specify dim_to_name
funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name)
expected_inputs = {name: raw_dist.batch_shape[dim] for dim, name in dim_to_name.items()}
expected_inputs.update({"value": expected_value_domain})

check_funsor(funsor_dist, expected_inputs, funsor.Real)

# TODO handle JAX rng here
expected_logprob = to_funsor(raw_dist.log_prob(raw_dist.sample()), output=funsor.Real, dim_to_name=dim_to_name)
assert_close(funsor_dist(value=value), expected_logprob)


def _check_enumerate_support(raw_dist, expand=False):

# TODO specify dim_to_name
funsor_dist = to_funsor(raw_dist, output=funsor.Real, dim_to_name=dim_to_name)

assert getattr(raw_dist, "has_enumerate_support", False) == funsor_dist.has_enumerate_support
if funsor_dist.has_enumerate_support:
raw_support = raw_dist.enumerate_support(expand=expand)
funsor_support = funsor_dist.enumerate_support(expand=expand)
assert_equal(to_data(funsor_support, name_to_dim=name_to_dim), raw_support)


# High-level distribution testing strategy: a fixed sequence of increasingly semantically strong distribution-agnostic tests
# conversion invertibility -> density type and value -> enumerate_support type and value -> statistic types and values -> samplers -> gradients
DistTestCase = namedtuple("DistTestCase", ["raw_dist", "expected_value_domain"])

TEST_CASES = [
DistTestCase(raw_dist=...),
]


@pytest.mark.parametrize('sample_inputs', [(), ('ii',), ('ii', 'jj'), ('ii', 'jj', 'kk')])
@pytest.mark.parametrize('batch_shape', [(), (5,), (2, 3)], ids=str)
@pytest.mark.parametrize('reparametrized', [True, False])