Skip to content

Commit f3baebc

Browse files
committed
Add unit test, and correct bugs found
1 parent 68ba15d commit f3baebc

File tree

2 files changed

+31
-10
lines changed

2 files changed

+31
-10
lines changed

botorch/posteriors/multitask.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,9 @@ def rsample_from_base_samples(
228228
train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples
229229
)
230230
train_covar_plus_noise = self.train_train_covar + self.train_noise
231-
obs_solve = _permute_solve(train_covar_plus_noise, obs_minus_samples)
231+
obs_solve = _permute_solve(
232+
train_covar_plus_noise, obs_minus_samples.unsqueeze(-1)
233+
)
232234

233235
# and multiply the test-observed matrix against the result of the solve
234236
updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1)
@@ -290,7 +292,7 @@ def _draw_from_base_covar(
290292
return res.squeeze(-1)
291293

292294

293-
def _permute_solve(A: LinearOperator, b: LinearOperator) -> LinearOperator:
295+
def _permute_solve(A: LinearOperator, b: Tensor) -> LinearOperator:
294296
r"""Solve the batched linear system AX = b, where b is a batched column
295297
vector. The solve is carried out after permuting the largest batch
296298
dimension of b to the final position, which results in a more efficient
@@ -302,24 +304,23 @@ def _permute_solve(A: LinearOperator, b: LinearOperator) -> LinearOperator:
302304
303305
Args:
304306
A: LinearOperator of shape (n, n)
305-
b: LinearOperator of shape (..., n, 1)
307+
b: Tensor of shape (..., n, 1)
306308
307309
Returns:
308310
LinearOperator of shape (..., n, 1)
309311
"""
310312
# permute dimensions to move largest batch dimension to the end (more efficient
311313
# than unsqueezing)
312-
largest_batch_dim, _ = max(enumerate(b.shape[:-1]), key=lambda t: t[0])
313314
perm = list(range(b.ndim))
314-
perm.remove(largest_batch_dim)
315-
perm.append(largest_batch_dim)
315+
if b.ndim > 2:
316+
largest_batch_dim, _ = max(enumerate(b.shape[:-2]), key=lambda t: t[1])
317+
perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1]
316318
b_p = b.permute(*perm)
317319

318-
# solve
319320
x_p = A.solve(b_p)
320321

321322
# Undo permutation
322323
inverse_perm = torch.argsort(torch.tensor(perm))
323-
x = x_p.permute(*inverse_perm).unsqueeze(-1)
324+
x = x_p.permute(*inverse_perm)
324325

325326
return x

test/posteriors/test_multitask.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,10 @@
88
import torch
99
from botorch.exceptions.errors import BotorchTensorDimensionError
1010
from botorch.models.multitask import KroneckerMultiTaskGP
11-
from botorch.posteriors.multitask import MultitaskGPPosterior
11+
from botorch.posteriors.multitask import MultitaskGPPosterior, _permute_solve
1212
from botorch.sampling.normal import IIDNormalSampler
1313
from botorch.utils.testing import BotorchTestCase
14+
from linear_operator.operators import to_linear_operator
1415

1516

1617
def get_posterior_test_cases(
@@ -41,7 +42,6 @@ def get_posterior_test_cases(
4142

4243

4344
class TestMultitaskGPPosterior(BotorchTestCase):
44-
4545
def _test_MultitaskGPPosterior(self, dtype: torch.dtype) -> None:
4646
post_list = get_posterior_test_cases(device=self.device, dtype=dtype)
4747
sample_shaping = torch.Size([5, 3])
@@ -189,3 +189,23 @@ def test_draw_from_base_covar(self):
189189
base_samples = torch.randn(4, 10, 1, device=self.device)
190190
with self.assertRaises(RuntimeError):
191191
res = posterior._draw_from_base_covar(sym_mat, base_samples)
192+
193+
194+
class TestPermuteSolve(BotorchTestCase):
195+
def test_permute_solve_tensor(self):
196+
# Random PSD matrix
197+
a = torch.randn(32, 32, device=self.device, dtype=torch.float64)
198+
A = torch.mm(a, a.t())
199+
200+
# Random batched column vector
201+
b = torch.randn(4, 1, 32, 1, device=self.device, dtype=torch.float64)
202+
203+
# Compare results of permuted and standard solve
204+
x_1 = _permute_solve(to_linear_operator(A), b)
205+
x_2 = torch.linalg.solve(A, b)
206+
self.assertAllClose(x_1, x_2)
207+
208+
# Ensure also works if b is not batched
209+
x_1 = _permute_solve(to_linear_operator(A), b[0, 0, :, :])
210+
x_2 = torch.linalg.solve(A, b[0, 0, :, :])
211+
self.assertAllClose(x_1, x_2)

0 commit comments

Comments
 (0)