-
Notifications
You must be signed in to change notification settings - Fork 248
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add two ensemble sampling methods #1692
Conversation
@amifalk hello what is the status of this PR? are you waiting for feedback from us? i think we sort of lost track of this over the holiday break. |
@martinjankowiak Would you mind looking through it and making comments about what you feel is missing/needs to change? I believe all of the code lints with the exception of the batch_ravel_pytree function in At the core, these are gradient free methods that update the state of each chain by looking at the current state of the other chains, which I've implemented by storing an (n_chains, n_params) array (thus the need to use write a version of I'm also happy to reduce the scope of the PR if that will make it easier to review (the affine invariant ensemble sampler, AIES, is much less complex than the ensemble slice sampler, ESS). |
@amifalk thanks i'll try to review this weekend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so i'm not familiar with these algorithms in detail so it's a bit hard for me to follow all the details but the code is generally very clean and readable : )
i think the main thing that is missing is tests:
- you can put a
test_ensemble_util.py
intests/infer
- you should add at least one test that does inference on a conjugate model where you can e.g. compute analytic posterior means/variances; for example you can add your methods to a few of the following tests:
test_mcmc.py::test_unnormalized_normal_x64
,test_logistic_regression_x64
,test_beta_bernoulli_x64
- you have a fair number of different ways to initialize your kernels so you probably want to add some simple smoke tests or the like to
test_ensemble_mcmc.py
or similar. e.g. make sure that various combinations of init args do not error out. make simple checks about expected shapes of outputs and the like. you can also use thewith pytest.raises(ValueError, match="..expected message...")
context manager to check that some invalid initializations and the like are being caught as expected.
numpyro/infer/ensemble.py
Outdated
|
||
super().__init__(model, potential_fn, randomize_split, init_strategy) | ||
|
||
# XXX: this doesn't show because state_method='vectorized' shuts off diagnostics_str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fehiepsi any workarounds?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this chain_method
@amifalk? I don't think progress_bar works with the vectorized method there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it is, sorry for the typo. Can we adjust this line to allow the diagnostics_str for these ensemble methods? If I remove the prng_key check, it displays correctly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That line checks whether vectorized
method is used. If rng_key.ndim == 1, parallel is used. Otherwise, vectorized is used. Looking at your code, it seems that rng_key.ndim == 2 in both cases, which is strange to me. Could you double check the logic for both methods?
I think you can add an attribute to MCMCKernel to indicate whether it is an ensemble kernel and skip is_prng_key
check at that line if the kernel is an ensemble one.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch. Parallel sampling does not work with pmap since the chains need to talk to each other. To make that work, we would probably need to add an internal utility that does parallel sharding on the chains, kind of like what is requested in #1192. In the interest of not overcomplicating the API, I'll opt not to support it for now and make a note in the code, but if there was sufficient interest I think we could add a shard_chains
argument to EnsembleSampler
.
Thanks for the review @martinjankowiak! Documentation has been updated and tests have been added (a distribution test seems to be failing for reasons unrelated to this PR). If there are any more comments and/or if @fehiepsi has a better solution for |
numpyro/infer/ensemble.py
Outdated
|
||
super().__init__(model, potential_fn, randomize_split, init_strategy) | ||
|
||
# XXX: this doesn't show because state_method='vectorized' shuts off diagnostics_str |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this chain_method
@amifalk? I don't think progress_bar works with the vectorized method there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great to me! Could you add the classes to docs/source/mcmc file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks great thanks!
can you check how much additional time your tests are adding?
i don't know how slow the additions to test_beta_bernoulli_x64
and test_logistic_regression_x64
might be. if they're adding too much time we may want to add something like pytest.mark.skipif("CI" in os.environ, reason="reduce time for CI")
|
thanks @amifalk ! btw what's your interest in these algorithms? do you have a non-differentiable log density? because unless the problem is very multi-modal or otherwise difficult HMC should work pretty well if a gradient is available. also have you tried comparing these algorithms to |
@martinjankowiak We work with some likelihood-free models like LCA that we perform approximate inference on through simulation. We actually did try SA first, but it didn't seem robust to the noisy likelihood. |
* ensemble sampling draft * rewrite for loop as fori_loop * added efficiency comment for ESS GaussianMove * fix typo * fixed ravel for mixed dtype * add defaults * add support for potential_fn * AIES tests, warnings for AIES * AIES input validation * better docs, more input validation * ESS passing test cases * add tests for other files * linting * refactor ensemble_util * make test result less close to margin in CI, swap deprecated function * rename get_nondiagonal_indices, fix batch_ravel_pytree * print ensemble kernel diagnostics, smoke test parallel arg * fix docstring build * documentation * skip slow CI tests, unnest test if statements * fix doctest * doc rewrite * fix distribution test
As described in #1691
Currently, I've only implemented a subset of emcee and zeus moves, but it should be trivial to extend in the future. I also don't have support for potential_fn.
Should there be separate tests for these modules or should I try to work them in to existing ones? The pattern isn't terribly clean with existing tests because AIES and ESS can only be run with multiple chains.