-
Notifications
You must be signed in to change notification settings - Fork 21
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Automate distribution testing #389
Conversation
test/test_distribution_generic.py
Outdated
"variance", | ||
pytest.param("entropy", marks=[pytest.mark.skipif(get_backend() == "jax", reason="entropy not implemented")]) | ||
]) | ||
def test_generic_stats_sample(case, statistic, sample_shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test comparing Monte Carlo estimates of summary statistics with ground-truth values is slow (especially on the JAX backend) and finicky. I've disabled it by default but could also remove it entirely - I'm not sure failures are very informative.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed these tests from this PR.
test/test_distribution_generic.py
Outdated
"variance", | ||
pytest.param("entropy", marks=[pytest.mark.skipif(get_backend() == "jax", reason="entropy not implemented")]) | ||
]) | ||
def test_generic_grads_sample(case, statistic, sample_shape): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto for this test comparing Monte Carlo estimates of gradients of summary statistics wrt parameters versus gradients of ground-truth values - it's slow and finicky, and I'm not sure failures are very informative. It's disabled by default, but I am open to removing it entirely.
test/test_distribution_generic.py
Outdated
TEST_CASES += [DistTestCase( | ||
"backend_dist.Bernoulli(logits=case.logits)", | ||
(("logits", f"rand({batch_shape})"),), | ||
funsor.Real, | ||
)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this PR, adding a new distribution to funsor.distributions
will be as simple as adding it to the list of distributions to be wrapped in funsor/{backend}/distributions.py
and adding a new test case to this list.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding it to the list
Is there a .register()
command allowing users to dynamically extend the list in their own code? Similar to the dynamic registration mechanisms in kl_divergence.register()
or biject_to.register
? Even better, could we automatically register distributions in Funsor the first time they are encountered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a .register() command allowing users to dynamically extend the list in their own code?
Yes, funsor.distribution.make_dist
basically plays this role, especially after #391 - it takes a backend distribution class and (optionally) some parameter names as input and generates a new funsor.distribution.Distribution
with generic eager_subs
, unscaled_sample
, to_funsor
and to_data
patterns that should work for most use cases, provided the user has correctly implemented .arg_constraints
and .support
in their custom distribution.
|
||
# BernoulliLogits | ||
TEST_CASES += [DistTestCase( | ||
"backend_dist.Bernoulli(logits=case.logits)", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note I've chosen to encode test cases with strings of Python code. This should allow us to apply these generic tests even to quite complicated backend distribution expressions, e.g. TransformedDistribution(Normal(loc, scale).to_event(1), [TanhTransform,]).mask(mask)
|
||
result = funsor.delta.Delta(value.name, Tensor(raw_sample, inputs, value.output.dtype)) | ||
funsor_value = to_funsor(raw_value, output=value_output, dim_to_name=dim_to_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to update funsor.distribution.Distribution.unscaled_sample
to use to_funsor
and to_data
throughout to get some tests to pass.
@fehiepsi Travis seems to be hanging somewhere in the JAX tests, but I haven't been able to reproduce it locally and I can't figure out where from the logs, although I assume it's related to the distribution tests. Any idea what might be going on? |
One thing I'm noticing is that The new sampler tests in this PR ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...still reviewing test_distribution_generic.py. I think it will be easier to read after adding TEST_CASES.append(self)
to .__init__()
.
def multinomial_to_data(funsor_dist, name_to_dim=None): | ||
probs = to_data(funsor_dist.probs, name_to_dim) | ||
total_count = to_data(funsor_dist.total_count, name_to_dim) | ||
if isinstance(total_count, numbers.Number) or len(total_count.shape) == 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to worry about int(total_count)
thwarting PyTorch tracing? Should we preserve scalar Tensor
s?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made this change because torch.distributions.Multinomial
raises an error given a tensor total_count
, even if it is a scalar tensor, and the generic conversion tests were failing. It's definitely an issue, but there are already lots of JIT issues with scalars and tuple shapes throughout the codebase (encountered while working on pyro.contrib.funsor
), and my inclination with this PR and further work on distributions is to avoid special-casing to the greatest extent possible even if that means deferring to odd implementation details in the backends. I think a better fix in this instance would be to allow scalar tensor total_count
upstream so this change could be reverted.
This is also a good reminder to add some generic JIT tests for distribution wrappers in a followup PR.
@eb8680 I think the memory issue is due to "caching" mechanism of jax. It cached the compiled code so that the next time, a function can be executed fast given inputs with the same shape. I guess we can split pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi into smaller procedures, e.g. pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi --ignore=test/distribution.py
pytest -v -n auto test/distribution.py |
Would that also explain the run time differences? For comparison, the new sampler smoke tests in this PR were taking ~1s on a single core on my laptop for PyTorch vs ~100s for JAX |
How many tests there are? If there are 100 tests, then it is normal to me. |
Around that, yeah |
Addresses #386.
Blocked by #388.This PR adds a new file
test_distribution_generic.py
that refactors the distribution tests. Here there is one generic test for each distribution method, and all that is needed to test a new distribution is to add a recipe for constructing a random instance to a list of test cases. This change is necessary if we want to approach full coverage of the dozens of distributions in PyTorch, Pyro/NumPyro and TFP without thousands of lines of manually duplicated testing logic.I have also had to add a number of small fixes to get the new tests to pass, which I would argue is a sign of the value of these tests - even the first version here is exercising many small distribution API edge cases that would be hard to catch with the current approach. There are also a couple of new testing utilities (
ops.isnan
andtesting.random_scale_tril
).If the approach in this PR works, I will delete many of the one-off tests in
test_distribution.py
in a followup PR.Remaining tasks:
test_distribution.py
Triaged
Get all Monte Carlo tests to pass- I am skeptical of the viability and usefulness of doing this and have simply disabled and removed these tests for now.