-
-
Notifications
You must be signed in to change notification settings - Fork 411
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
(WIP?) vectorized log_likelihood function for NumPyro #2390
base: main
Are you sure you want to change the base?
Conversation
Hey, I looked at the checks that failed, and they are failing because they can't even find test cases. I don't think that is related to the updated code at all? Let me know if I am missing something, though. |
@virajpandya could you try it out and see how timing compares to the ~80 mins from the latest release and setting You can install the arviz version of this PR with:
|
The pylint checks are failing. These are the specific errors:
For the |
You might try this in your terminal.
This will format the code according to the benchmark. Once done re-add the changes ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried testing this locally on a variation of the model in https://python.arviz.org/en/stable/getting_started/CreatingInferenceData.html#from-numpyro but with random y and sigma with 30k elements plus generating 2k posterior samples.
The version with vmap (after the fixes mentioned in the review) and the current version took basically the same time. The log_likelihood function itself in numpyro calls a soft_vmap
so there might not even be any difference between using vmap directly on our side or calling numpyro directly.
I did still crash my computer multiple times with both versions when I attempted running things in a loop to get some average timings which makes me suspect there are memory leaks somewhere in the process which might even be the reason of the slowness.
I am sorry but I don't think it makes sense to merge this before we can get reproducible models that take extremely long with the current version yet run fast with this vmap version
@@ -181,20 +181,26 @@ def sample_stats_to_xarray(self): | |||
@requires("posterior") | |||
@requires("model") | |||
def log_likelihood_to_xarray(self): | |||
"""Extract log likelihood from NumPyro posterior.""" | |||
"""Extract log likelihood from NumPyro posterior using vectorization.""" | |||
if not self.log_likelihood: | |||
return None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return None | |
return None | |
import jax |
|
||
# Vectorized log likelihood calculation using jax.vmap | ||
log_likelihood_dict = jax.vmap(lambda single_sample: | ||
self.numpyro.infer.log_likelihood(self.model, single_sample, *self._args, **self._kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.numpyro.infer.log_likelihood(self.model, single_sample, *self._args, **self._kwargs) | |
self.numpyro.infer.log_likelihood(self.model, single_sample, *self._args, batch_ndims=0, **self._kwargs) |
It doesn't work without this because batching is not taken care of directly in vmap but this function expects a batch dimension too and fails when it is not there (or seemingly changes with the different variables)
Description
Checklist
📚 Documentation preview 📚: https://arviz--2390.org.readthedocs.build/en/2390/