diff --git a/docs/log_posteriors.md b/docs/log_posteriors.md index 1d46ada..df7d0bb 100644 --- a/docs/log_posteriors.md +++ b/docs/log_posteriors.md @@ -98,9 +98,12 @@ either $N$ or $n$ increaase. log_prior = diag_normal_log_prob(params, sd=1., normalize=False) mean_log_lik = Categorical(logits=logits).log_prob(batch['labels']).mean() mean_log_post = log_prior / num_data + mean_log_lik - return mean_log_post + return mean_log_post, torch.tensor([]) ``` + See [auxiliary information](#auxiliary-information) for why we return an + additional empty tensor. + The issue with running Bayesian methods (such as VI or SGHMC) on this mean log posterior function is that naive application will result in approximating the tempered posterior diff --git a/posteriors/vi/dense.py b/posteriors/vi/dense.py index fd6c42b..f451798 100644 --- a/posteriors/vi/dense.py +++ b/posteriors/vi/dense.py @@ -43,7 +43,7 @@ def build( temperature: Temperature to rescale (divide) log_posterior. n_samples: Number of samples to use for Monte Carlo estimate. stl: Whether to use the stick-the-landing estimator - from (Roeder et al](https://arxiv.org/abs/1703.09194). + from [Roeder et al](https://arxiv.org/abs/1703.09194). init_L: Initial lower triangular matrix $L$ satisfying $LL^T$ = $\\Sigma$. Returns: @@ -246,8 +246,9 @@ def nelbo( # Don't use vmap for single sample, since vmap doesn't work with lots of models if n_samples == 1: single_param = tree_map(lambda x: x[0], sampled_params_tree) + single_param_flat, _ = tree_ravel(single_param) log_p, aux = log_posterior(single_param, batch) - log_q = dist.log_prob(single_param) + log_q = dist.log_prob(single_param_flat) else: log_p, aux = vmap(log_posterior, (0, None), (0, 0))(sampled_params_tree, batch) diff --git a/posteriors/vi/diag.py b/posteriors/vi/diag.py index 7127f24..593f8da 100644 --- a/posteriors/vi/diag.py +++ b/posteriors/vi/diag.py @@ -42,7 +42,7 @@ def build( temperature: Temperature to rescale (divide) log_posterior. n_samples: Number of samples to use for Monte Carlo estimate. stl: Whether to use the stick-the-landing estimator - from (Roeder et al](https://arxiv.org/abs/1703.09194). + from [Roeder et al](https://arxiv.org/abs/1703.09194). init_log_sds: Initial log of the square-root diagonal of the covariance matrix of the variational distribution. Can be a tree matching params or scalar. diff --git a/tests/vi/test_dense.py b/tests/vi/test_dense.py index 43fbda5..2f2b509 100644 --- a/tests/vi/test_dense.py +++ b/tests/vi/test_dense.py @@ -42,7 +42,7 @@ def log_prob(p, b): assert bad_nelbo_100 > target_nelbo_100 -def _test_vi_dense(optimizer_cls, stl): +def _test_vi_dense(optimizer_cls, stl, n_vi_samps): torch.manual_seed(43) target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)} num_params = tree_size(target_mean) @@ -61,7 +61,7 @@ def log_prob(p, b): init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean) - optimizer = optimizer_cls(lr=7.5e-3) + optimizer = optimizer_cls(lr=1e-2) state = vi.dense.init(init_mean, optimizer) @@ -88,8 +88,7 @@ def log_prob(p, b): assert torch.isclose(nelbo_target, torch.tensor(0.0), atol=1e-6) assert nelbo_init > nelbo_target - n_steps = 500 - n_vi_samps = 5 + n_steps = 1000 transform = vi.dense.build( log_prob, @@ -145,7 +144,7 @@ def log_prob(p, b): # Test sample mean_copy = tree_map(lambda x: x.clone(), state.params) - samples = vi.dense.sample(state, (1000,)) + samples = vi.dense.sample(state, (5000,)) flat_samples = torch.vmap(lambda s: tree_ravel(s)[0])(samples) samples_cov = torch.cov(flat_samples.T) samples_mean = tree_map(lambda x: x.mean(dim=0), samples) @@ -158,16 +157,16 @@ def log_prob(p, b): def test_vi_dense_sgd(): - _test_vi_dense(torchopt.sgd, False) + _test_vi_dense(torchopt.sgd, False, 5) def test_vi_dense_adamw(): - _test_vi_dense(torchopt.adamw, False) + _test_vi_dense(torchopt.adamw, False, 1) def test_vi_dense_sgd_stl(): - _test_vi_dense(torchopt.sgd, True) + _test_vi_dense(torchopt.sgd, True, 1) def test_vi_dense_adamw_stl(): - _test_vi_dense(torchopt.adamw, True) + _test_vi_dense(torchopt.adamw, True, 5) diff --git a/tests/vi/test_diag.py b/tests/vi/test_diag.py index 6c2169b..ca14c28 100644 --- a/tests/vi/test_diag.py +++ b/tests/vi/test_diag.py @@ -37,7 +37,7 @@ def test_nelbo(): assert bad_nelbo_100 > target_nelbo_100 -def _test_vi_diag(optimizer_cls, stl): +def _test_vi_diag(optimizer_cls, stl, n_vi_samps): torch.manual_seed(42) target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)} target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean) @@ -78,7 +78,6 @@ def _test_vi_diag(optimizer_cls, stl): assert nelbo_init > nelbo_target n_steps = 500 - n_vi_samps = 5 transform = vi.diag.build( batch_normal_log_prob_spec, @@ -141,16 +140,16 @@ def _test_vi_diag(optimizer_cls, stl): def test_vi_diag_sgd(): - _test_vi_diag(torchopt.sgd, False) + _test_vi_diag(torchopt.sgd, False, 5) def test_vi_diag_adamw(): - _test_vi_diag(torchopt.adamw, False) + _test_vi_diag(torchopt.adamw, False, 1) def test_vi_diag_sgd_stl(): - _test_vi_diag(torchopt.sgd, True) + _test_vi_diag(torchopt.sgd, True, 1) def test_vi_diag_adamw_stl(): - _test_vi_diag(torchopt.adamw, True) + _test_vi_diag(torchopt.adamw, True, 5)