Skip to content

Commit

Permalink
Add unit test, and correct bugs found
Browse files Browse the repository at this point in the history
  • Loading branch information
slishak-PX committed Oct 1, 2024
1 parent 68ba15d commit f3baebc
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 10 deletions.
17 changes: 9 additions & 8 deletions botorch/posteriors/multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,9 @@ def rsample_from_base_samples(
train_diff.reshape(*train_diff.shape[:-2], -1) - updated_obs_samples
)
train_covar_plus_noise = self.train_train_covar + self.train_noise
obs_solve = _permute_solve(train_covar_plus_noise, obs_minus_samples)
obs_solve = _permute_solve(
train_covar_plus_noise, obs_minus_samples.unsqueeze(-1)
)

# and multiply the test-observed matrix against the result of the solve
updated_samples = self.test_train_covar.matmul(obs_solve).squeeze(-1)
Expand Down Expand Up @@ -290,7 +292,7 @@ def _draw_from_base_covar(
return res.squeeze(-1)


def _permute_solve(A: LinearOperator, b: LinearOperator) -> LinearOperator:
def _permute_solve(A: LinearOperator, b: Tensor) -> LinearOperator:
r"""Solve the batched linear system AX = b, where b is a batched column
vector. The solve is carried out after permuting the largest batch
dimension of b to the final position, which results in a more efficient
Expand All @@ -302,24 +304,23 @@ def _permute_solve(A: LinearOperator, b: LinearOperator) -> LinearOperator:
Args:
A: LinearOperator of shape (n, n)
b: LinearOperator of shape (..., n, 1)
b: Tensor of shape (..., n, 1)
Returns:
LinearOperator of shape (..., n, 1)
"""
# permute dimensions to move largest batch dimension to the end (more efficient
# than unsqueezing)
largest_batch_dim, _ = max(enumerate(b.shape[:-1]), key=lambda t: t[0])
perm = list(range(b.ndim))
perm.remove(largest_batch_dim)
perm.append(largest_batch_dim)
if b.ndim > 2:
largest_batch_dim, _ = max(enumerate(b.shape[:-2]), key=lambda t: t[1])
perm[-1], perm[largest_batch_dim] = perm[largest_batch_dim], perm[-1]
b_p = b.permute(*perm)

# solve
x_p = A.solve(b_p)

# Undo permutation
inverse_perm = torch.argsort(torch.tensor(perm))
x = x_p.permute(*inverse_perm).unsqueeze(-1)
x = x_p.permute(*inverse_perm)

return x
24 changes: 22 additions & 2 deletions test/posteriors/test_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@
import torch
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.models.multitask import KroneckerMultiTaskGP
from botorch.posteriors.multitask import MultitaskGPPosterior
from botorch.posteriors.multitask import MultitaskGPPosterior, _permute_solve
from botorch.sampling.normal import IIDNormalSampler
from botorch.utils.testing import BotorchTestCase
from linear_operator.operators import to_linear_operator


def get_posterior_test_cases(
Expand Down Expand Up @@ -41,7 +42,6 @@ def get_posterior_test_cases(


class TestMultitaskGPPosterior(BotorchTestCase):

def _test_MultitaskGPPosterior(self, dtype: torch.dtype) -> None:
post_list = get_posterior_test_cases(device=self.device, dtype=dtype)
sample_shaping = torch.Size([5, 3])
Expand Down Expand Up @@ -189,3 +189,23 @@ def test_draw_from_base_covar(self):
base_samples = torch.randn(4, 10, 1, device=self.device)
with self.assertRaises(RuntimeError):
res = posterior._draw_from_base_covar(sym_mat, base_samples)


class TestPermuteSolve(BotorchTestCase):
def test_permute_solve_tensor(self):
# Random PSD matrix
a = torch.randn(32, 32, device=self.device, dtype=torch.float64)
A = torch.mm(a, a.t())

# Random batched column vector
b = torch.randn(4, 1, 32, 1, device=self.device, dtype=torch.float64)

# Compare results of permuted and standard solve
x_1 = _permute_solve(to_linear_operator(A), b)
x_2 = torch.linalg.solve(A, b)
self.assertAllClose(x_1, x_2)

# Ensure also works if b is not batched
x_1 = _permute_solve(to_linear_operator(A), b[0, 0, :, :])
x_2 = torch.linalg.solve(A, b[0, 0, :, :])
self.assertAllClose(x_1, x_2)

0 comments on commit f3baebc

Please sign in to comment.