Skip to content

Commit

Permalink
n_vi_samps=1 bugfix + test
Browse files Browse the repository at this point in the history
  • Loading branch information
SamDuffield committed Nov 5, 2024
1 parent e8c2501 commit 38cedbc
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 16 deletions.
3 changes: 2 additions & 1 deletion posteriors/vi/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,8 +246,9 @@ def nelbo(
# Don't use vmap for single sample, since vmap doesn't work with lots of models
if n_samples == 1:
single_param = tree_map(lambda x: x[0], sampled_params_tree)
single_param_flat, _ = tree_ravel(single_param)
log_p, aux = log_posterior(single_param, batch)
log_q = dist.log_prob(single_param)
log_q = dist.log_prob(single_param_flat)

else:
log_p, aux = vmap(log_posterior, (0, None), (0, 0))(sampled_params_tree, batch)
Expand Down
17 changes: 8 additions & 9 deletions tests/vi/test_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def log_prob(p, b):
assert bad_nelbo_100 > target_nelbo_100


def _test_vi_dense(optimizer_cls, stl):
def _test_vi_dense(optimizer_cls, stl, n_vi_samps):
torch.manual_seed(43)
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)}
num_params = tree_size(target_mean)
Expand All @@ -61,7 +61,7 @@ def log_prob(p, b):

init_mean = tree_map(lambda x: torch.zeros_like(x, requires_grad=True), target_mean)

optimizer = optimizer_cls(lr=7.5e-3)
optimizer = optimizer_cls(lr=1e-2)

state = vi.dense.init(init_mean, optimizer)

Expand All @@ -88,8 +88,7 @@ def log_prob(p, b):
assert torch.isclose(nelbo_target, torch.tensor(0.0), atol=1e-6)
assert nelbo_init > nelbo_target

n_steps = 500
n_vi_samps = 5
n_steps = 1000

transform = vi.dense.build(
log_prob,
Expand Down Expand Up @@ -145,7 +144,7 @@ def log_prob(p, b):

# Test sample
mean_copy = tree_map(lambda x: x.clone(), state.params)
samples = vi.dense.sample(state, (1000,))
samples = vi.dense.sample(state, (5000,))
flat_samples = torch.vmap(lambda s: tree_ravel(s)[0])(samples)
samples_cov = torch.cov(flat_samples.T)
samples_mean = tree_map(lambda x: x.mean(dim=0), samples)
Expand All @@ -158,16 +157,16 @@ def log_prob(p, b):


def test_vi_dense_sgd():
_test_vi_dense(torchopt.sgd, False)
_test_vi_dense(torchopt.sgd, False, 5)


def test_vi_dense_adamw():
_test_vi_dense(torchopt.adamw, False)
_test_vi_dense(torchopt.adamw, False, 1)


def test_vi_dense_sgd_stl():
_test_vi_dense(torchopt.sgd, True)
_test_vi_dense(torchopt.sgd, True, 1)


def test_vi_dense_adamw_stl():
_test_vi_dense(torchopt.adamw, True)
_test_vi_dense(torchopt.adamw, True, 5)
11 changes: 5 additions & 6 deletions tests/vi/test_diag.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def test_nelbo():
assert bad_nelbo_100 > target_nelbo_100


def _test_vi_diag(optimizer_cls, stl):
def _test_vi_diag(optimizer_cls, stl, n_vi_samps):
torch.manual_seed(42)
target_mean = {"a": torch.randn(2, 1), "b": torch.randn(1, 1)}
target_sds = tree_map(lambda x: torch.randn_like(x).abs(), target_mean)
Expand Down Expand Up @@ -78,7 +78,6 @@ def _test_vi_diag(optimizer_cls, stl):
assert nelbo_init > nelbo_target

n_steps = 500
n_vi_samps = 5

transform = vi.diag.build(
batch_normal_log_prob_spec,
Expand Down Expand Up @@ -141,16 +140,16 @@ def _test_vi_diag(optimizer_cls, stl):


def test_vi_diag_sgd():
_test_vi_diag(torchopt.sgd, False)
_test_vi_diag(torchopt.sgd, False, 5)


def test_vi_diag_adamw():
_test_vi_diag(torchopt.adamw, False)
_test_vi_diag(torchopt.adamw, False, 1)


def test_vi_diag_sgd_stl():
_test_vi_diag(torchopt.sgd, True)
_test_vi_diag(torchopt.sgd, True, 1)


def test_vi_diag_adamw_stl():
_test_vi_diag(torchopt.adamw, True)
_test_vi_diag(torchopt.adamw, True, 5)

0 comments on commit 38cedbc

Please sign in to comment.