Skip to content

Commit

Permalink
Removed covariance tracking, removed shape specifiction from new util…
Browse files Browse the repository at this point in the history
… functions
  • Loading branch information
jcqcai committed Nov 5, 2024
1 parent f232188 commit ae1dd14
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 30 deletions.
6 changes: 4 additions & 2 deletions posteriors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -892,16 +892,18 @@ def is_scalar(x: Any) -> bool:
return isinstance(x, (int, float)) or (torch.is_tensor(x) and x.numel() == 1)


def L_from_flat(L_flat: torch.Tensor, num_params: int) -> torch.Tensor:
def L_from_flat(L_flat: torch.Tensor) -> torch.Tensor:
"""Returns lower triangular matrix from a flat representation of its nonzero elements.
Args:
L_flat: Flat representation of nonzero lower triangular matrix elements.
num_params: Width of the desired lower triangular matrix.
Returns:
Lower triangular matrix.
"""
k = torch.tensor(L_flat.shape[0], dtype=L_flat.dtype, device=L_flat.device)
n = (-1 + (1 + 8 * k).sqrt()) / 2
num_params = round(n.item())

tril_indices = torch.tril_indices(num_params, num_params)
L = torch.zeros((num_params, num_params), device=L_flat.device)
Expand Down
40 changes: 18 additions & 22 deletions posteriors/vi/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def build(
temperature: float = 1.0,
n_samples: int = 1,
stl: bool = True,
init_cov: torch.Tensor | float = 1.0,
init_L: torch.Tensor | float = 1.0,
) -> Transform:
"""Builds a transform for variational inference with a Normal
distribution over parameters.
Expand All @@ -44,12 +44,12 @@ def build(
n_samples: Number of samples to use for Monte Carlo estimate.
stl: Whether to use the stick-the-landing estimator
from (Roeder et al](https://arxiv.org/abs/1703.09194).
init_cov: Initial covariance matrix of the variational distribution.
init_L: Initial lower triangular matrix $L$ satisfying $LL^T$ = $\\Sigma$.
Returns:
Dense VI transform instance.
"""
init_fn = partial(init, optimizer=optimizer, init_cov=init_cov)
init_fn = partial(init, optimizer=optimizer, init_L=init_L)
update_fn = partial(
update,
log_posterior=log_posterior,
Expand All @@ -66,17 +66,16 @@ class VIDenseState(NamedTuple):
Attributes:
params: Mean of the variational distribution.
cov: Covariance matrix of the variational distribution.
L_factor: Flat representation of the nonzero values of the lower
triangular matrix $L$ satisfying $LL^T$ = cov.
triangular matrix $L$ satisfying $LL^T$ = $\\Sigma$, where $\\Sigma$
is the covariance matrix of the variational distribution.
opt_state: TorchOpt state storing optimizer data for updating the
variational parameters.
nelbo: Negative evidence lower bound (lower is better).
aux: Auxiliary information from the log_posterior call.
"""

params: TensorTree
cov: torch.Tensor
L_factor: torch.Tensor
opt_state: torchopt.typing.OptState
nelbo: torch.tensor = torch.tensor([])
Expand All @@ -86,7 +85,7 @@ class VIDenseState(NamedTuple):
def init(
params: TensorTree,
optimizer: torchopt.base.GradientTransformation,
init_cov: torch.Tensor | float = 1.0,
init_L: torch.Tensor | float = 1.0,
) -> VIDenseState:
"""Initialise diagonal Normal variational distribution over parameters.
Expand All @@ -106,21 +105,20 @@ def init(
params: Initial mean of the variational distribution.
optimizer: TorchOpt functional optimizer for updating the variational
parameters. Make sure to use lower case like torchopt.adam()
init_cov: Initial covariance matrix of the variational distribution.
init_L: Initial lower triangular matrix $L$ satisfying $LL^T$ = $\\Sigma$,
where $\\Sigma$ is the covariance matrix of the variational distribution.
Returns:
Initial DenseVIState.
"""

num_params = tree_size(params)
if is_scalar(init_cov):
init_cov = init_cov * torch.eye(num_params, requires_grad=True)
if is_scalar(init_L):
init_L = init_L * torch.eye(num_params, requires_grad=True)

init_L = torch.linalg.cholesky(init_cov)
init_L = L_to_flat(init_L)

opt_state = optimizer.init([params, init_L])
return VIDenseState(params, init_cov, init_L, opt_state)
return VIDenseState(params, init_L, opt_state)


def update(
Expand Down Expand Up @@ -169,15 +167,12 @@ def nelbo_L_factor(m, L_flat):
mean, L_factor = torchopt.apply_updates(
(state.params, state.L_factor), updates, inplace=inplace
)
L = L_from_flat(L_factor, state.cov.shape[0])
cov = L @ L.T

if inplace:
tree_insert_(state.nelbo, nelbo_val.detach())
tree_insert_(state.cov, cov)
return state._replace(aux=aux)

return VIDenseState(mean, cov, L_factor, opt_state, nelbo_val.detach(), aux)
return VIDenseState(mean, L_factor, opt_state, nelbo_val.detach(), aux)


def nelbo(
Expand Down Expand Up @@ -211,7 +206,7 @@ def nelbo(
Args:
mean: Mean of the variational distribution.
L_factor: Flat representation of the nonzero values of the lower
triangular matrix $L$ satisfying $LL^T$ = cov, where cov
triangular matrix $L$ satisfying $LL^T$ = $\\Sigma$, where $\\Sigma$
is the covariance matrix of the variational distribution.
batch: Input data to log_posterior.
log_posterior: Function that takes parameters and input batch and
Expand All @@ -226,8 +221,7 @@ def nelbo(
"""

mean_flat, unravel_func = tree_ravel(mean)
num_params = mean_flat.shape[0]
L = L_from_flat(L_factor, num_params)
L = L_from_flat(L_factor)
cov = L @ L.T
dist = torch.distributions.MultivariateNormal(
loc=mean_flat,
Expand All @@ -240,7 +234,7 @@ def nelbo(

if stl:
mean_flat.detach()
L = L_from_flat(L_factor.detach(), num_params)
L = L_from_flat(L_factor.detach())
cov = L @ L.T
# Redefine distribution to sample from after stl
dist = torch.distributions.MultivariateNormal(
Expand Down Expand Up @@ -276,10 +270,12 @@ def sample(
"""

mean_flat, unravel_func = tree_ravel(state.params)
L = L_from_flat(state.L_factor)
cov = L @ L.T

samples = torch.distributions.MultivariateNormal(
loc=mean_flat,
covariance_matrix=state.cov,
covariance_matrix=cov,
validate_args=False,
).rsample(sample_shape)

Expand Down
3 changes: 1 addition & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -816,8 +816,7 @@ def test_L_from_flat():
]
)
L_flat = torch.tensor([1.0, -4.1, 2.2, -1.7, 4.4, -5.5])
num_params = expected_L.shape[0]
L = L_from_flat(L_flat, num_params)
L = L_from_flat(L_flat)
assert torch.allclose(expected_L, L)


Expand Down
14 changes: 10 additions & 4 deletions tests/vi/test_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from posteriors import vi
from posteriors.tree_utils import tree_size
from posteriors.utils import L_to_flat
from posteriors.utils import L_from_flat, L_to_flat


def test_nelbo():
Expand Down Expand Up @@ -118,7 +118,9 @@ def log_prob(p, b):
assert torch.allclose(
init_mean[key], init_mean_copy[key]
) # check init_mean was left untouched
assert torch.allclose(state.cov, target_cov, atol=0.5)
state_L = L_from_flat(state.L_factor)
state_cov = state_L @ state_L.T
assert torch.allclose(state_cov, target_cov, atol=0.5)

# Test inplace = True
state = transform.init(init_mean)
Expand All @@ -137,7 +139,9 @@ def log_prob(p, b):
assert torch.allclose(
state.params[key], init_mean[key]
) # check init_mean was updated in place
assert torch.allclose(state.cov, target_cov, atol=0.5)
state_L = L_from_flat(state.L_factor)
state_cov = state_L @ state_L.T
assert torch.allclose(state_cov, target_cov, atol=0.5)

# Test sample
mean_copy = tree_map(lambda x: x.clone(), state.params)
Expand All @@ -148,7 +152,9 @@ def log_prob(p, b):
for key in samples_mean:
assert torch.allclose(samples_mean[key], state.params[key], atol=1e-1)
assert not torch.allclose(samples_mean[key], mean_copy[key])
assert torch.allclose(state.cov, samples_cov, atol=2e-1)
state_L = L_from_flat(state.L_factor)
state_cov = state_L @ state_L.T
assert torch.allclose(state_cov, samples_cov, atol=2e-1)


def test_vi_dense_sgd():
Expand Down

0 comments on commit ae1dd14

Please sign in to comment.