diff --git a/posteriors/vi/diag.py b/posteriors/vi/diag.py index 7d5c883..7127f24 100644 --- a/posteriors/vi/diag.py +++ b/posteriors/vi/diag.py @@ -225,8 +225,9 @@ def nelbo( # Don't use vmap for single sample, since vmap doesn't work with lots of models if n_samples == 1: - log_p, aux = log_posterior(sampled_params[0], batch) - log_q = diag_normal_log_prob(sampled_params[0], mean, sd_diag) + single_param = tree_map(lambda x: x[0], sampled_params) + log_p, aux = log_posterior(single_param, batch) + log_q = diag_normal_log_prob(single_param, mean, sd_diag) else: log_p, aux = vmap(log_posterior, (0, None), (0, 0))(sampled_params, batch) log_q = vmap(diag_normal_log_prob, (0, None, None))(