Skip to content

Commit

Permalink
Fix VI with tree_map
Browse files Browse the repository at this point in the history
  • Loading branch information
SamDuffield committed Aug 2, 2024
1 parent f8ed216 commit 701550b
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions posteriors/vi/diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,9 @@ def nelbo(

# Don't use vmap for single sample, since vmap doesn't work with lots of models
if n_samples == 1:
log_p, aux = log_posterior(sampled_params[0], batch)
log_q = diag_normal_log_prob(sampled_params[0], mean, sd_diag)
single_param = tree_map(lambda x: x[0], sampled_params)
log_p, aux = log_posterior(single_param, batch)
log_q = diag_normal_log_prob(single_param, mean, sd_diag)
else:
log_p, aux = vmap(log_posterior, (0, None), (0, 0))(sampled_params, batch)
log_q = vmap(diag_normal_log_prob, (0, None, None))(
Expand Down

0 comments on commit 701550b

Please sign in to comment.