diff --git a/botorch/posteriors/multitask.py b/botorch/posteriors/multitask.py index 0f5d9adce1..76a2df43d4 100644 --- a/botorch/posteriors/multitask.py +++ b/botorch/posteriors/multitask.py @@ -228,7 +228,9 @@ def rsample_from_base_samples( train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples ) train_covar_plus_noise = self.train_train_covar + self.train_noise - obs_solve = _permute_solve(train_covar_plus_noise, obs_minus_samples) + obs_solve = _permute_solve( + train_covar_plus_noise, obs_minus_samples.unsqueeze(-1) + ) # and multiply the test-observed matrix against the result of the solve updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1) @@ -290,7 +292,7 @@ def _draw_from_base_covar( return res.squeeze(-1) -def _permute_solve(A: LinearOperator, b: LinearOperator) -> LinearOperator: +def _permute_solve(A: LinearOperator, b: Tensor) -> LinearOperator: r"""Solve the batched linear system AX = b, where b is a batched column vector. The solve is carried out after permuting the largest batch dimension of b to the final position, which results in a more efficient @@ -302,24 +304,23 @@ def _permute_solve(A: LinearOperator, b: LinearOperator) -> LinearOperator: Args: A: LinearOperator of shape (n, n) - b: LinearOperator of shape (..., n, 1) + b: Tensor of shape (..., n, 1) Returns: LinearOperator of shape (..., n, 1) """ # permute dimensions to move largest batch dimension to the end (more efficient # than unsqueezing) - largest_batch_dim, _ = max(enumerate(b.shape[:-1]), key=lambda t: t[0]) perm = list(range(b.ndim)) - perm.remove(largest_batch_dim) - perm.append(largest_batch_dim) + if b.ndim > 2: + largest_batch_dim, _ = max(enumerate(b.shape[:-2]), key=lambda t: t[1]) + perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1] b_p = b.permute(*perm) - # solve x_p = A.solve(b_p) # Undo permutation inverse_perm = torch.argsort(torch.tensor(perm)) - x = x_p.permute(*inverse_perm).unsqueeze(-1) + x = x_p.permute(*inverse_perm) return x diff --git a/test/posteriors/test_multitask.py b/test/posteriors/test_multitask.py index 42913d3ef7..d5add4ca12 100644 --- a/test/posteriors/test_multitask.py +++ b/test/posteriors/test_multitask.py @@ -8,9 +8,10 @@ import torch from botorch.exceptions.errors import BotorchTensorDimensionError from botorch.models.multitask import KroneckerMultiTaskGP -from botorch.posteriors.multitask import MultitaskGPPosterior +from botorch.posteriors.multitask import MultitaskGPPosterior, _permute_solve from botorch.sampling.normal import IIDNormalSampler from botorch.utils.testing import BotorchTestCase +from linear_operator.operators import to_linear_operator def get_posterior_test_cases( @@ -41,7 +42,6 @@ def get_posterior_test_cases( class TestMultitaskGPPosterior(BotorchTestCase): - def _test_MultitaskGPPosterior(self, dtype: torch.dtype) -> None: post_list = get_posterior_test_cases(device=self.device, dtype=dtype) sample_shaping = torch.Size([5, 3]) @@ -189,3 +189,23 @@ def test_draw_from_base_covar(self): base_samples = torch.randn(4, 10, 1, device=self.device) with self.assertRaises(RuntimeError): res = posterior._draw_from_base_covar(sym_mat, base_samples) + + +class TestPermuteSolve(BotorchTestCase): + def test_permute_solve_tensor(self): + # Random PSD matrix + a = torch.randn(32, 32, device=self.device, dtype=torch.float64) + A = torch.mm(a, a.t()) + + # Random batched column vector + b = torch.randn(4, 1, 32, 1, device=self.device, dtype=torch.float64) + + # Compare results of permuted and standard solve + x_1 = _permute_solve(to_linear_operator(A), b) + x_2 = torch.linalg.solve(A, b) + self.assertAllClose(x_1, x_2) + + # Ensure also works if b is not batched + x_1 = _permute_solve(to_linear_operator(A), b[0, 0, :, :]) + x_2 = torch.linalg.solve(A, b[0, 0, :, :]) + self.assertAllClose(x_1, x_2)