-
Notifications
You must be signed in to change notification settings - Fork 400
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #996 see title Reviewed By: Balandat Differential Revision: D32668194 fbshipit-source-id: c4eba63632c5b4e63fd0ae7db211efb413f49cd1
- Loading branch information
1 parent
0a77787
commit 22de2f2
Showing
7 changed files
with
342 additions
and
313 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from __future__ import annotations | ||
|
||
import torch | ||
from botorch.exceptions.errors import BotorchError | ||
from botorch.posteriors.gpytorch import GPyTorchPosterior | ||
from gpytorch.distributions.multitask_multivariate_normal import ( | ||
MultitaskMultivariateNormal, | ||
) | ||
from gpytorch.lazy import BlockDiagLazyTensor | ||
from gpytorch.lazy.lazy_tensor import LazyTensor | ||
from gpytorch.utils.cholesky import psd_safe_cholesky | ||
from gpytorch.utils.errors import NanError | ||
from torch import Tensor | ||
|
||
|
||
def extract_batch_covar(mt_mvn: MultitaskMultivariateNormal) -> LazyTensor: | ||
r"""Extract a batched independent covariance matrix from an MTMVN. | ||
Args: | ||
mt_mvn: A multi-task multivariate normal with a block diagonal | ||
covariance matrix. | ||
Returns: | ||
A lazy covariance matrix consisting of a batch of the blocks of | ||
the diagonal of the MultitaskMultivariateNormal. | ||
""" | ||
lazy_covar = mt_mvn.lazy_covariance_matrix | ||
if not isinstance(lazy_covar, BlockDiagLazyTensor): | ||
raise BotorchError(f"Expected BlockDiagLazyTensor, but got {type(lazy_covar)}.") | ||
return lazy_covar.base_lazy_tensor | ||
|
||
|
||
def _reshape_base_samples( | ||
base_samples: Tensor, sample_shape: torch.Size, posterior: GPyTorchPosterior | ||
) -> Tensor: | ||
r"""Manipulate shape of base_samples to match `MultivariateNormal.rsample`. | ||
This ensure that base_samples are used in the same way as in | ||
gpytorch.distributions.MultivariateNormal. For CBD, it is important to ensure | ||
that the same base samples are used for the in-sample points here and in the | ||
cached box decompositions. | ||
Args: | ||
base_samples: The base samples. | ||
sample_shape: The sample shape. | ||
posterior: The joint posterior is over (X_baseline, X). | ||
Returns: | ||
Reshaped and expanded base samples. | ||
""" | ||
loc = posterior.mvn.loc | ||
peshape = posterior.event_shape | ||
base_samples = base_samples.view( | ||
sample_shape + torch.Size([1 for _ in range(loc.ndim - 1)]) + peshape[-2:] | ||
).expand(sample_shape + loc.shape[:-1] + peshape[-2:]) | ||
base_samples = base_samples.reshape( | ||
-1, *loc.shape[:-1], posterior.mvn.lazy_covariance_matrix.shape[-1] | ||
) | ||
base_samples = base_samples.permute(*range(1, loc.dim() + 1), 0) | ||
return base_samples.reshape( | ||
*peshape[:-2], | ||
peshape[-1], | ||
peshape[-2], | ||
*sample_shape, | ||
) | ||
|
||
|
||
def sample_cached_cholesky( | ||
posterior: GPyTorchPosterior, | ||
baseline_L: Tensor, | ||
q: int, | ||
base_samples: Tensor, | ||
sample_shape: torch.Size, | ||
max_tries: int = 6, | ||
) -> Tensor: | ||
r"""Get posterior samples at the `q` new points from the joint multi-output posterior. | ||
Args: | ||
posterior: The joint posterior is over (X_baseline, X). | ||
baseline_L: The baseline lower triangular cholesky factor. | ||
q: The number of new points in X. | ||
base_samples: The base samples. | ||
sample_shape: The sample shape. | ||
max_tries: The number of tries for computing the Cholesky | ||
decomposition with increasing jitter. | ||
Returns: | ||
A `sample_shape x batch_shape x q x m`-dim tensor of posterior | ||
samples at the new points. | ||
""" | ||
# compute bottom left covariance block | ||
if isinstance(posterior.mvn, MultitaskMultivariateNormal): | ||
lazy_covar = extract_batch_covar(mt_mvn=posterior.mvn) | ||
else: | ||
lazy_covar = posterior.mvn.lazy_covariance_matrix | ||
# Get the `q` new rows of the batched covariance matrix | ||
bottom_rows = lazy_covar[..., -q:, :].evaluate() | ||
# The covariance in block form is: | ||
# [K(X_baseline, X_baseline), K(X_baseline, X)] | ||
# [K(X, X_baseline), K(X, X)] | ||
# bl := K(X, X_baseline) | ||
# br := K(X, X) | ||
# Get bottom right block of new covariance | ||
bl, br = torch.split(bottom_rows, bottom_rows.shape[-1] - q, dim=-1) | ||
# Solve Ax = b | ||
# where A = K(X_baseline, X_baseline) and b = K(X, X_baseline)^T | ||
# and bl_chol := x^T | ||
# bl_chol is the new `(batch_shape) x q x n`-dim bottom left block | ||
# of the cholesky decomposition | ||
bl_chol = torch.triangular_solve( | ||
bl.transpose(-2, -1), baseline_L, upper=False | ||
).solution.transpose(-2, -1) | ||
# Compute the new bottom right block of the Cholesky | ||
# decomposition via: | ||
# Cholesky(K(X, X) - bl_chol @ bl_chol^T) | ||
br_to_chol = br - bl_chol @ bl_chol.transpose(-2, -1) | ||
# TODO: technically we should make sure that we add a | ||
# consistent nugget to the cached covariance and the new block | ||
br_chol = psd_safe_cholesky(br_to_chol, max_tries=max_tries) | ||
# Create a `(batch_shape) x q x (n+q)`-dim tensor containing the | ||
# `q` new bottom rows of the Cholesky decomposition | ||
new_Lq = torch.cat([bl_chol, br_chol], dim=-1) | ||
mean = posterior.mvn.mean | ||
base_samples = _reshape_base_samples( | ||
base_samples=base_samples, sample_shape=sample_shape, posterior=posterior | ||
) | ||
if not isinstance(posterior.mvn, MultitaskMultivariateNormal): | ||
# add output dim | ||
mean = mean.unsqueeze(-1) | ||
# add batch dim corresponding to output dim | ||
new_Lq = new_Lq.unsqueeze(-3) | ||
new_mean = mean[..., -q:, :] | ||
res = ( | ||
new_Lq.matmul(base_samples) | ||
.permute(-1, *range(mean.dim() - 2), -2, -3) | ||
.contiguous() | ||
.add(new_mean) | ||
) | ||
contains_nans = torch.isnan(res).any() | ||
contains_infs = torch.isinf(res).any() | ||
if contains_nans or contains_infs: | ||
suffix_args = [] | ||
if contains_nans: | ||
suffix_args.append("nans") | ||
if contains_infs: | ||
suffix_args.append("infs") | ||
suffix = " and ".join(suffix_args) | ||
raise NanError(f"Samples contain {suffix}.") | ||
return res |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.