Skip to content

Commit

Permalink
Full VI test coverage (#116)
Browse files Browse the repository at this point in the history
* Fix Roeder citation

* Add aux to Going Bayesian doc

* n_vi_samps=1 bugfix + test
  • Loading branch information
SamDuffield authored Nov 5, 2024
1 parent 6af158b commit 7b432fb
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 19 deletions.
5 changes: 4 additions & 1 deletion docs/log_posteriors.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,12 @@ either $N$ or $n$ increaase.
log_prior = diag_normal_log_prob(params, sd=1., normalize=False)
mean_log_lik = Categorical(logits=logits).log_prob(batch['labels']).mean()
mean_log_post = log_prior / num_data + mean_log_lik
return mean_log_post
return mean_log_post, torch.tensor([])
```

See [auxiliary information](#auxiliary-information) for why we return an
additional empty tensor.

The issue with running Bayesian methods (such as VI or SGHMC) on this mean log posterior
function is that naive application will result in approximating the tempered posterior

Expand Down
5 changes: 3 additions & 2 deletions posteriors/vi/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def build(
temperature: Temperature to rescale (divide) log_posterior.
n_samples: Number of samples to use for Monte Carlo estimate.
stl: Whether to use the stick-the-landing estimator
from (Roeder et al](https://arxiv.org/abs/1703.09194).
from [Roeder et al](https://arxiv.org/abs/1703.09194).
init_L: Initial lower triangular matrix $L$ satisfying $LL^T$ = $\\Sigma$.
Returns:
Expand Down Expand Up @@ -246,8 +246,9 @@ def nelbo(
# Don't use vmap for single sample, since vmap doesn't work with lots of models
if n_samples == 1:
single_param = tree_map(lambda x: x[0], sampled_params_tree)
single_param_flat, _ = tree_ravel(single_param)
log_p, aux = log_posterior(single_param, batch)
log_q = dist.log_prob(single_param)
log_q = dist.log_prob(single_param_flat)

else:
log_p, aux = vmap(log_posterior, (0, None), (0, 0))(sampled_params_tree, batch)
Expand Down
2 changes: 1 addition & 1 deletion posteriors/vi/diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def build(
temperature: Temperature to rescale (divide) log_posterior.
n_samples: Number of samples to use for Monte Carlo estimate.
stl: Whether to use the stick-the-landing estimator
from (Roeder et al](https://arxiv.org/abs/1703.09194).
from [Roeder et al](https://arxiv.org/abs/1703.09194).
init_log_sds: Initial log of the square-root diagonal of the covariance matrix
of the variational distribution. Can be a tree matching params or scalar.
Expand Down
17 changes: 8 additions & 9 deletions tests/vi/test_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def log_prob(p, b):
assert bad_nelbo_100 > target_nelbo_100


def _test_vi_dense(optimizer_cls, stl):
def _test_vi_dense(optimizer_cls, stl, n_vi_samps):
torch.manual_seed(43)
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)}
num_params = tree_size(target_mean)
Expand All @@ -61,7 +61,7 @@ def log_prob(p, b):

init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean)

optimizer = optimizer_cls(lr=7.5e-3)
optimizer = optimizer_cls(lr=1e-2)

state = vi.dense.init(init_mean, optimizer)

Expand All @@ -88,8 +88,7 @@ def log_prob(p, b):
assert torch.isclose(nelbo_target, torch.tensor(0.0), atol=1e-6)
assert nelbo_init > nelbo_target

n_steps = 500
n_vi_samps = 5
n_steps = 1000

transform = vi.dense.build(
log_prob,
Expand Down Expand Up @@ -145,7 +144,7 @@ def log_prob(p, b):

# Test sample
mean_copy = tree_map(lambda x: x.clone(), state.params)
samples = vi.dense.sample(state, (1000,))
samples = vi.dense.sample(state, (5000,))
flat_samples = torch.vmap(lambda s: tree_ravel(s)[0])(samples)
samples_cov = torch.cov(flat_samples.T)
samples_mean = tree_map(lambda x: x.mean(dim=0), samples)
Expand All @@ -158,16 +157,16 @@ def log_prob(p, b):


def test_vi_dense_sgd():
_test_vi_dense(torchopt.sgd, False)
_test_vi_dense(torchopt.sgd, False, 5)


def test_vi_dense_adamw():
_test_vi_dense(torchopt.adamw, False)
_test_vi_dense(torchopt.adamw, False, 1)


def test_vi_dense_sgd_stl():
_test_vi_dense(torchopt.sgd, True)
_test_vi_dense(torchopt.sgd, True, 1)


def test_vi_dense_adamw_stl():
_test_vi_dense(torchopt.adamw, True)
_test_vi_dense(torchopt.adamw, True, 5)
11 changes: 5 additions & 6 deletions tests/vi/test_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_nelbo():
assert bad_nelbo_100 > target_nelbo_100


def _test_vi_diag(optimizer_cls, stl):
def _test_vi_diag(optimizer_cls, stl, n_vi_samps):
torch.manual_seed(42)
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)}
target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean)
Expand Down Expand Up @@ -78,7 +78,6 @@ def _test_vi_diag(optimizer_cls, stl):
assert nelbo_init > nelbo_target

n_steps = 500
n_vi_samps = 5

transform = vi.diag.build(
batch_normal_log_prob_spec,
Expand Down Expand Up @@ -141,16 +140,16 @@ def _test_vi_diag(optimizer_cls, stl):


def test_vi_diag_sgd():
_test_vi_diag(torchopt.sgd, False)
_test_vi_diag(torchopt.sgd, False, 5)


def test_vi_diag_adamw():
_test_vi_diag(torchopt.adamw, False)
_test_vi_diag(torchopt.adamw, False, 1)


def test_vi_diag_sgd_stl():
_test_vi_diag(torchopt.sgd, True)
_test_vi_diag(torchopt.sgd, True, 1)


def test_vi_diag_adamw_stl():
_test_vi_diag(torchopt.adamw, True)
_test_vi_diag(torchopt.adamw, True, 5)

0 comments on commit 7b432fb

Please sign in to comment.