Skip to content

Commit

Permalink
relocate low-rank utilities (#996)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #996

see title

Reviewed By: Balandat

Differential Revision: D32668194

fbshipit-source-id: c4eba63632c5b4e63fd0ae7db211efb413f49cd1
  • Loading branch information
sdaulton authored and facebook-github-bot committed Dec 8, 2021
1 parent 0a77787 commit 22de2f2
Show file tree
Hide file tree
Showing 7 changed files with 342 additions and 313 deletions.
8 changes: 4 additions & 4 deletions botorch/acquisition/multi_objective/monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@
IdentityMCMultiOutputObjective,
MCMultiOutputObjective,
)
from botorch.acquisition.multi_objective.utils import (
extract_batch_covar,
sample_cached_cholesky,
)
from botorch.acquisition.multi_objective.utils import (
prune_inferior_points_multi_objective,
)
Expand All @@ -50,6 +46,10 @@
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.posteriors.posterior import Posterior
from botorch.sampling.samplers import MCSampler, SobolQMCNormalSampler
from botorch.utils.low_rank import (
extract_batch_covar,
sample_cached_cholesky,
)
from botorch.utils.multi_objective.box_decompositions.box_decomposition_list import (
BoxDecompositionList,
)
Expand Down
146 changes: 1 addition & 145 deletions botorch/acquisition/multi_objective/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,13 @@
IdentityMCMultiOutputObjective,
MCMultiOutputObjective,
)
from botorch.exceptions.errors import BotorchError, UnsupportedError
from botorch.exceptions.errors import UnsupportedError
from botorch.exceptions.warnings import BotorchWarning
from botorch.exceptions.warnings import SamplingWarning
from botorch.models.model import Model
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.multi_objective.pareto import is_non_dominated
from botorch.utils.transforms import normalize_indices
from gpytorch.distributions.multitask_multivariate_normal import (
MultitaskMultivariateNormal,
)
from gpytorch.lazy.block_diag_lazy_tensor import BlockDiagLazyTensor
from gpytorch.lazy.lazy_tensor import LazyTensor
from gpytorch.utils.cholesky import psd_safe_cholesky
from gpytorch.utils.errors import NanError
from torch import Tensor
from torch.quasirandom import SobolEngine

Expand Down Expand Up @@ -164,139 +156,3 @@ def prune_inferior_points_multi_objective(
idcs = order_idcs[:max_points]

return X[idcs]


def extract_batch_covar(mt_mvn: MultitaskMultivariateNormal) -> LazyTensor:
r"""Extract a batched independent covariance matrix from a MTMVN.
Args:
mt_mvn: A multi-task multivariate normal with a block diagonal
covariance matrix.
Returns:
A lazy covariance matrix consisting of a batch of the blocks of the
diagonal of the MultitaskMultivariateNormal.
"""
lazy_covar = mt_mvn.lazy_covariance_matrix
if not isinstance(lazy_covar, BlockDiagLazyTensor):
raise BotorchError(f"Expected BlockDiagLazyTensor, but got {type(lazy_covar)}.")
return lazy_covar.base_lazy_tensor


def _reshape_base_samples(
base_samples: Tensor, sample_shape: torch.Size, posterior: GPyTorchPosterior
) -> Tensor:
r"""Manipulate shape of base_samples to match `MultivariateNormal.rsample`.
This ensure that base_samples are used in the same way as in
gpytorch.distributions.MultivariateNormal. For CBD, it is important to ensure
that the same base samples are used for the in-sample points here and in the
cached box decompositions.
Args:
base_samples: The base samples.
sample_shape: The sample shape.
posterior: The joint posterior is over (X_baseline, X).
Returns:
Reshaped and expanded base samples.
"""
loc = posterior.mvn.loc
peshape = posterior.event_shape
base_samples = base_samples.view(
sample_shape + torch.Size([1 for _ in range(loc.ndim - 1)]) + peshape[-2:]
).expand(sample_shape + loc.shape[:-1] + peshape[-2:])
base_samples = base_samples.reshape(
-1, *loc.shape[:-1], posterior.mvn.lazy_covariance_matrix.shape[-1]
)
base_samples = base_samples.permute(*range(1, loc.dim() + 1), 0)
return base_samples.reshape(
*peshape[:-2],
peshape[-1],
peshape[-2],
*sample_shape,
)


def sample_cached_cholesky(
posterior: GPyTorchPosterior,
baseline_L: Tensor,
q: int,
base_samples: Tensor,
sample_shape: torch.Size,
max_tries: int = 6,
) -> Tensor:
r"""Get posterior samples at the `q` new points from the joint multi-output posterior.
Args:
posterior: The joint posterior is over (X_baseline, X).
baseline_L: The baseline lower triangular cholesky factor.
q: The number of new points in X.
base_samples: The base samples.
sample_shape: The sample shape.
max_tries: The number of tries for computing the Cholesky
decomposition with increasing jitter.
Returns:
A `sample_shape x batch_shape x q x m`-dim tensor of posterior
samples at the new points.
"""
# compute bottom left covariance block
if isinstance(posterior.mvn, MultitaskMultivariateNormal):
lazy_covar = extract_batch_covar(mt_mvn=posterior.mvn)
else:
lazy_covar = posterior.mvn.lazy_covariance_matrix
# Get the `q` new rows of the batched covariance matrix
bottom_rows = lazy_covar[..., -q:, :].evaluate()
# The covariance in block form is:
# [K(X_baseline, X_baseline), K(X_baseline, X)]
# [K(X, X_baseline), K(X, X)]
# bl := K(X, X_baseline)
# br := K(X, X)
# Get bottom right block of new covariance
bl, br = torch.split(bottom_rows, bottom_rows.shape[-1] - q, dim=-1)
# Solve Ax = b
# where A = K(X_baseline, X_baseline) and b = K(X, X_baseline)^T
# and bl_chol := x^T
# bl_chol is the new `(batch_shape) x q x n`-dim bottom left block
# of the cholesky decomposition
bl_chol = torch.triangular_solve(
bl.transpose(-2, -1), baseline_L, upper=False
).solution.transpose(-2, -1)
# Compute the new bottom right block of the Cholesky decomposition via:
# Cholesky(K(X, X) - bl_chol @ bl_chol^T)
br_to_chol = br - bl_chol @ bl_chol.transpose(-2, -1)
# TODO: technically we should make sure that we add a consistent
# nugget to the cached covariance and the new block
br_chol = psd_safe_cholesky(br_to_chol, max_tries=max_tries)
# Create a `(batch_shape) x q x (n+q)`-dim tensor containing the
# `q` new bottom rows of the Cholesky decomposition
new_Lq = torch.cat([bl_chol, br_chol], dim=-1)
mean = posterior.mvn.mean
base_samples = _reshape_base_samples(
base_samples=base_samples, sample_shape=sample_shape, posterior=posterior
)
if not isinstance(posterior.mvn, MultitaskMultivariateNormal):
# add output dim
mean = mean.unsqueeze(-1)
# add batch dim corresponding to output dim
new_Lq = new_Lq.unsqueeze(-3)
new_mean = mean[..., -q:, :]
res = (
new_Lq.matmul(base_samples)
.permute(-1, *range(mean.dim() - 2), -2, -3)
.contiguous()
.add(new_mean)
)
contains_nans = torch.isnan(res).any()
contains_infs = torch.isinf(res).any()
if contains_nans or contains_infs:
suffix_args = []
if contains_nans:
suffix_args.append("nans")
if contains_infs:
suffix_args.append("infs")
suffix = " and ".join(suffix_args)
raise NanError(f"Samples contain {suffix}.")
return res
157 changes: 157 additions & 0 deletions botorch/utils/low_rank.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import torch
from botorch.exceptions.errors import BotorchError
from botorch.posteriors.gpytorch import GPyTorchPosterior
from gpytorch.distributions.multitask_multivariate_normal import (
MultitaskMultivariateNormal,
)
from gpytorch.lazy import BlockDiagLazyTensor
from gpytorch.lazy.lazy_tensor import LazyTensor
from gpytorch.utils.cholesky import psd_safe_cholesky
from gpytorch.utils.errors import NanError
from torch import Tensor


def extract_batch_covar(mt_mvn: MultitaskMultivariateNormal) -> LazyTensor:
r"""Extract a batched independent covariance matrix from an MTMVN.
Args:
mt_mvn: A multi-task multivariate normal with a block diagonal
covariance matrix.
Returns:
A lazy covariance matrix consisting of a batch of the blocks of
the diagonal of the MultitaskMultivariateNormal.
"""
lazy_covar = mt_mvn.lazy_covariance_matrix
if not isinstance(lazy_covar, BlockDiagLazyTensor):
raise BotorchError(f"Expected BlockDiagLazyTensor, but got {type(lazy_covar)}.")
return lazy_covar.base_lazy_tensor


def _reshape_base_samples(
base_samples: Tensor, sample_shape: torch.Size, posterior: GPyTorchPosterior
) -> Tensor:
r"""Manipulate shape of base_samples to match `MultivariateNormal.rsample`.
This ensure that base_samples are used in the same way as in
gpytorch.distributions.MultivariateNormal. For CBD, it is important to ensure
that the same base samples are used for the in-sample points here and in the
cached box decompositions.
Args:
base_samples: The base samples.
sample_shape: The sample shape.
posterior: The joint posterior is over (X_baseline, X).
Returns:
Reshaped and expanded base samples.
"""
loc = posterior.mvn.loc
peshape = posterior.event_shape
base_samples = base_samples.view(
sample_shape + torch.Size([1 for _ in range(loc.ndim - 1)]) + peshape[-2:]
).expand(sample_shape + loc.shape[:-1] + peshape[-2:])
base_samples = base_samples.reshape(
-1, *loc.shape[:-1], posterior.mvn.lazy_covariance_matrix.shape[-1]
)
base_samples = base_samples.permute(*range(1, loc.dim() + 1), 0)
return base_samples.reshape(
*peshape[:-2],
peshape[-1],
peshape[-2],
*sample_shape,
)


def sample_cached_cholesky(
posterior: GPyTorchPosterior,
baseline_L: Tensor,
q: int,
base_samples: Tensor,
sample_shape: torch.Size,
max_tries: int = 6,
) -> Tensor:
r"""Get posterior samples at the `q` new points from the joint multi-output posterior.
Args:
posterior: The joint posterior is over (X_baseline, X).
baseline_L: The baseline lower triangular cholesky factor.
q: The number of new points in X.
base_samples: The base samples.
sample_shape: The sample shape.
max_tries: The number of tries for computing the Cholesky
decomposition with increasing jitter.
Returns:
A `sample_shape x batch_shape x q x m`-dim tensor of posterior
samples at the new points.
"""
# compute bottom left covariance block
if isinstance(posterior.mvn, MultitaskMultivariateNormal):
lazy_covar = extract_batch_covar(mt_mvn=posterior.mvn)
else:
lazy_covar = posterior.mvn.lazy_covariance_matrix
# Get the `q` new rows of the batched covariance matrix
bottom_rows = lazy_covar[..., -q:, :].evaluate()
# The covariance in block form is:
# [K(X_baseline, X_baseline), K(X_baseline, X)]
# [K(X, X_baseline), K(X, X)]
# bl := K(X, X_baseline)
# br := K(X, X)
# Get bottom right block of new covariance
bl, br = torch.split(bottom_rows, bottom_rows.shape[-1] - q, dim=-1)
# Solve Ax = b
# where A = K(X_baseline, X_baseline) and b = K(X, X_baseline)^T
# and bl_chol := x^T
# bl_chol is the new `(batch_shape) x q x n`-dim bottom left block
# of the cholesky decomposition
bl_chol = torch.triangular_solve(
bl.transpose(-2, -1), baseline_L, upper=False
).solution.transpose(-2, -1)
# Compute the new bottom right block of the Cholesky
# decomposition via:
# Cholesky(K(X, X) - bl_chol @ bl_chol^T)
br_to_chol = br - bl_chol @ bl_chol.transpose(-2, -1)
# TODO: technically we should make sure that we add a
# consistent nugget to the cached covariance and the new block
br_chol = psd_safe_cholesky(br_to_chol, max_tries=max_tries)
# Create a `(batch_shape) x q x (n+q)`-dim tensor containing the
# `q` new bottom rows of the Cholesky decomposition
new_Lq = torch.cat([bl_chol, br_chol], dim=-1)
mean = posterior.mvn.mean
base_samples = _reshape_base_samples(
base_samples=base_samples, sample_shape=sample_shape, posterior=posterior
)
if not isinstance(posterior.mvn, MultitaskMultivariateNormal):
# add output dim
mean = mean.unsqueeze(-1)
# add batch dim corresponding to output dim
new_Lq = new_Lq.unsqueeze(-3)
new_mean = mean[..., -q:, :]
res = (
new_Lq.matmul(base_samples)
.permute(-1, *range(mean.dim() - 2), -2, -3)
.contiguous()
.add(new_mean)
)
contains_nans = torch.isnan(res).any()
contains_infs = torch.isinf(res).any()
if contains_nans or contains_infs:
suffix_args = []
if contains_nans:
suffix_args.append("nans")
if contains_infs:
suffix_args.append("infs")
suffix = " and ".join(suffix_args)
raise NanError(f"Samples contain {suffix}.")
return res
5 changes: 5 additions & 0 deletions sphinx/source/utils.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,11 @@ Containers
.. automodule:: botorch.utils.containers
:members:

Low-Rank Cholesky Update Utils
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.utils.low_rank
:members:

Objective
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.utils.objective
Expand Down
2 changes: 1 addition & 1 deletion test/acquisition/multi_objective/test_monte_carlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
IdentityMCMultiOutputObjective,
MCMultiOutputObjective,
)
from botorch.acquisition.multi_objective.utils import sample_cached_cholesky
from botorch.acquisition.objective import IdentityMCObjective
from botorch.exceptions.errors import BotorchError, UnsupportedError
from botorch.exceptions.warnings import BotorchWarning
from botorch.models.gp_regression import SingleTaskGP
from botorch.sampling.samplers import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.low_rank import sample_cached_cholesky
from botorch.utils.multi_objective.box_decompositions.dominated import (
DominatedPartitioning,
)
Expand Down
Loading

0 comments on commit 22de2f2

Please sign in to comment.