diff --git a/botorch/posteriors/multitask.py b/botorch/posteriors/multitask.py index 03a6267dbc..76a2df43d4 100644 --- a/botorch/posteriors/multitask.py +++ b/botorch/posteriors/multitask.py @@ -36,9 +36,11 @@ def __init__( distribution: Posterior multivariate normal distribution. joint_covariance_matrix: Joint test train covariance matrix over the entire tensor. - train_train_covar: Covariance matrix of train points in the data space. - test_obs_covar: Covariance matrix of test x train points in the data space. + test_train_covar: Covariance matrix of test x train points in the data + space. train_diff: Difference between train mean and train responses. + test_mean: Test mean response. + train_train_covar: Covariance matrix of train points in the data space. train_noise: Training noise covariance. test_noise: Only used if posterior should contain observation noise. Testing noise covariance. @@ -226,7 +228,9 @@ def rsample_from_base_samples( train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples ) train_covar_plus_noise = self.train_train_covar + self.train_noise - obs_solve = train_covar_plus_noise.solve(obs_minus_samples.unsqueeze(-1)) + obs_solve = _permute_solve( + train_covar_plus_noise, obs_minus_samples.unsqueeze(-1) + ) # and multiply the test-observed matrix against the result of the solve updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1) @@ -286,3 +290,37 @@ def _draw_from_base_covar( res = covar_root.matmul(base_samples) return res.squeeze(-1) + + +def _permute_solve(A: LinearOperator, b: Tensor) -> LinearOperator: + r"""Solve the batched linear system AX = b, where b is a batched column + vector. The solve is carried out after permuting the largest batch + dimension of b to the final position, which results in a more efficient + matrix-matrix solve. + + This ideally should be handled upstream (in GPyTorch, linear_operator or + PyTorch), after which any uses of this method can be replaced with + `A.solve(b)`. + + Args: + A: LinearOperator of shape (n, n) + b: Tensor of shape (..., n, 1) + + Returns: + LinearOperator of shape (..., n, 1) + """ + # permute dimensions to move largest batch dimension to the end (more efficient + # than unsqueezing) + perm = list(range(b.ndim)) + if b.ndim > 2: + largest_batch_dim, _ = max(enumerate(b.shape[:-2]), key=lambda t: t[1]) + perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1] + b_p = b.permute(*perm) + + x_p = A.solve(b_p) + + # Undo permutation + inverse_perm = torch.argsort(torch.tensor(perm)) + x = x_p.permute(*inverse_perm) + + return x diff --git a/test/posteriors/test_multitask.py b/test/posteriors/test_multitask.py index 42913d3ef7..1b6c6b7dbc 100644 --- a/test/posteriors/test_multitask.py +++ b/test/posteriors/test_multitask.py @@ -8,9 +8,10 @@ import torch from botorch.exceptions.errors import BotorchTensorDimensionError from botorch.models.multitask import KroneckerMultiTaskGP -from botorch.posteriors.multitask import MultitaskGPPosterior +from botorch.posteriors.multitask import _permute_solve, MultitaskGPPosterior from botorch.sampling.normal import IIDNormalSampler from botorch.utils.testing import BotorchTestCase +from linear_operator.operators import to_linear_operator def get_posterior_test_cases( @@ -41,7 +42,6 @@ def get_posterior_test_cases( class TestMultitaskGPPosterior(BotorchTestCase): - def _test_MultitaskGPPosterior(self, dtype: torch.dtype) -> None: post_list = get_posterior_test_cases(device=self.device, dtype=dtype) sample_shaping = torch.Size([5, 3]) @@ -189,3 +189,23 @@ def test_draw_from_base_covar(self): base_samples = torch.randn(4, 10, 1, device=self.device) with self.assertRaises(RuntimeError): res = posterior._draw_from_base_covar(sym_mat, base_samples) + + +class TestPermuteSolve(BotorchTestCase): + def test_permute_solve_tensor(self): + # Random PSD matrix + a = torch.randn(32, 32, device=self.device, dtype=torch.float64) + A = torch.mm(a, a.t()) + + # Random batched column vector + b = torch.randn(4, 1, 32, 1, device=self.device, dtype=torch.float64) + + # Compare results of permuted and standard solve + x_1 = _permute_solve(to_linear_operator(A), b) + x_2 = torch.linalg.solve(A, b) + self.assertAllClose(x_1, x_2) + + # Ensure also works if b is not batched + x_1 = _permute_solve(to_linear_operator(A), b[0, 0, :, :]) + x_2 = torch.linalg.solve(A, b[0, 0, :, :]) + self.assertAllClose(x_1, x_2)