Skip to content

Commit

Permalink
implement pairwisekernel (#385)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #385

This implements Houlsby et al. (2011)'s Pairwise kernel, which can turn any other model into a pairwise one.

Reviewed By: JasonKChow

Differential Revision: D63402856

fbshipit-source-id: 1dd8f80e425abd3cc4b8b4614ed6f52ace2ba1b9
  • Loading branch information
crasanders authored and facebook-github-bot committed Sep 25, 2024
1 parent be2db33 commit 017e168
Show file tree
Hide file tree
Showing 4 changed files with 242 additions and 1 deletion.
5 changes: 5 additions & 0 deletions aepsych/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from .pairwisekernel import PairwiseKernel
from .rbf_partial_grad import RBFKernelPartialObsGrad

__all__ = ["PairwiseKernel", "RBFKernelPartialObsGrad"]
85 changes: 85 additions & 0 deletions aepsych/kernels/pairwisekernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import torch
from gpytorch.kernels import Kernel
from linear_operator import to_linear_operator


class PairwiseKernel(Kernel):
"""
Wrapper to convert a kernel K on R^k to a kernel K' on R^{2k}, modeling
functions of the form g(a, b) = f(a) - f(b), where f ~ GP(mu, K).
Since g is a linear combination of Gaussians, it follows that g ~ GP(0, K')
where K'((a,b), (c,d)) = K(a,c) - K(a, d) - K(b, c) + K(b, d).
"""

def __init__(self, latent_kernel, is_partial_obs=False, **kwargs):
super(PairwiseKernel, self).__init__(**kwargs)

self.latent_kernel = latent_kernel
self.is_partial_obs = is_partial_obs

def forward(self, x1, x2, diag=False, **params):
r"""
TODO: make last_batch_dim work properly
d must be 2*k for integer k, k is the dimension of the latent space
Args:
:attr:`x1` (Tensor `n x d` or `b x n x d`):
First set of data
:attr:`x2` (Tensor `m x d` or `b x m x d`):
Second set of data
:attr:`diag` (bool):
Should the Kernel compute the whole kernel, or just the diag?
Returns:
:class:`Tensor` or :class:`gpytorch.lazy.LazyTensor`.
The exact size depends on the kernel's evaluation mode:
* `full_covar`: `n x m` or `b x n x m`
* `diag`: `n` or `b x n`
"""
if self.is_partial_obs:
d = x1.shape[-1] - 1
assert d == x2.shape[-1] - 1, "tensors not the same dimension"
assert d % 2 == 0, "dimension must be even"

k = int(d / 2)

# special handling for kernels that (also) do funky
# things with the input dimension
deriv_idx_1 = x1[..., -1][:, None]
deriv_idx_2 = x2[..., -1][:, None]

a = torch.cat((x1[..., :k], deriv_idx_1), dim=1)
b = torch.cat((x1[..., k:-1], deriv_idx_1), dim=1)
c = torch.cat((x2[..., :k], deriv_idx_2), dim=1)
d = torch.cat((x2[..., k:-1], deriv_idx_2), dim=1)

else:
d = x1.shape[-1]

assert d == x2.shape[-1], "tensors not the same dimension"
assert d % 2 == 0, "dimension must be even"

k = int(d / 2)

a = x1[..., :k]
b = x1[..., k:]
c = x2[..., :k]
d = x2[..., k:]

if not diag:
return (
to_linear_operator(self.latent_kernel(a, c, diag=diag, **params))
+ to_linear_operator(self.latent_kernel(b, d, diag=diag, **params))
- to_linear_operator(self.latent_kernel(b, c, diag=diag, **params))
- to_linear_operator(self.latent_kernel(a, d, diag=diag, **params))
)
else:
return (
self.latent_kernel(a, c, diag=diag, **params)
+ self.latent_kernel(b, d, diag=diag, **params)
- self.latent_kernel(b, c, diag=diag, **params)
- self.latent_kernel(a, d, diag=diag, **params)
)
2 changes: 1 addition & 1 deletion aepsych/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
"GPClassificationModel",
"MonotonicRejectionGP",
"GPRegressionModel",
"PairwiseProbitModel",
"OrdinalGPModel",
"MonotonicProjectionGP",
"MultitaskGPRModel",
Expand All @@ -35,6 +34,7 @@
"SemiParametricGPModel",
"semi_p_posterior_transform",
"GPBetaRegressionModel",
"PairwiseProbitModel",
]

Config.register_module(sys.modules[__name__])
151 changes: 151 additions & 0 deletions tests/test_pairwise_kernel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
#!/usr/bin/env python3
import unittest

import numpy as np
import numpy.testing as npt
import torch
from aepsych.kernels.pairwisekernel import PairwiseKernel
from aepsych.kernels.rbf_partial_grad import RBFKernelPartialObsGrad
from gpytorch.kernels import RBFKernel


class PairwiseKernelTest(unittest.TestCase):
"""
Basic tests that PairwiseKernel is working
"""

def setUp(self):
self.latent_kernel = RBFKernel()
self.kernel = PairwiseKernel(self.latent_kernel)

def test_kernelgrad_pairwise(self):
kernel = PairwiseKernel(RBFKernelPartialObsGrad(), is_partial_obs=True)
x1 = torch.rand(torch.Size([2, 4]))
x2 = torch.rand(torch.Size([2, 4]))

x1 = torch.cat((x1, torch.zeros(2, 1)), dim=1)
x2 = torch.cat((x2, torch.zeros(2, 1)), dim=1)

deriv_idx_1 = x1[..., -1][:, None]
deriv_idx_2 = x2[..., -1][:, None]

a = torch.cat((x1[..., :2], deriv_idx_1), dim=1)
b = torch.cat((x1[..., 2:-1], deriv_idx_1), dim=1)
c = torch.cat((x2[..., :2], deriv_idx_2), dim=1)
d = torch.cat((x2[..., 2:-1], deriv_idx_2), dim=1)

c12 = kernel.forward(x1, x2).evaluate().detach().numpy()
pwc = (
(
kernel.latent_kernel.forward(a, c)
- kernel.latent_kernel.forward(a, d)
- kernel.latent_kernel.forward(b, c)
+ kernel.latent_kernel.forward(b, d)
)
.detach()
.numpy()
)
npt.assert_allclose(c12, pwc, atol=1e-6)

def test_dim_check(self):
"""
Test that we get expected errors.
"""
x1 = torch.zeros(torch.Size([3]))
x2 = torch.zeros(torch.Size([3]))
x3 = torch.zeros(torch.Size([2]))
x4 = torch.zeros(torch.Size([4]))

self.assertRaises(AssertionError, self.kernel.forward, x1=x1, x2=x2)

self.assertRaises(AssertionError, self.kernel.forward, x1=x3, x2=x4)

def test_covar(self):
"""
Test that we get expected covariances
"""
np.random.seed(1)
torch.manual_seed(1)

x1 = torch.rand(torch.Size([2, 4]))
x2 = torch.rand(torch.Size([2, 4]))
a = x1[..., :2]
b = x1[..., 2:]
c = x2[..., :2]
d = x2[..., 2:]
c12 = self.kernel.forward(x1, x2).evaluate().detach().numpy()
pwc = (
(
self.latent_kernel.forward(a, c)
- self.latent_kernel.forward(a, d)
- self.latent_kernel.forward(b, c)
+ self.latent_kernel.forward(b, d)
)
.detach()
.numpy()
)
npt.assert_allclose(c12, pwc, atol=1e-6)

shape = np.array(c12.shape)
npt.assert_equal(shape, np.array([2, 2]))

x3 = torch.rand(torch.Size([3, 4]))
x4 = torch.rand(torch.Size([6, 4]))
a = x3[..., :2]
b = x3[..., 2:]
c = x4[..., :2]
d = x4[..., 2:]
c34 = self.kernel.forward(x3, x4).evaluate().detach().numpy()
pwc = (
(
self.latent_kernel.forward(a, c)
- self.latent_kernel.forward(a, d)
- self.latent_kernel.forward(b, c)
+ self.latent_kernel.forward(b, d)
)
.detach()
.numpy()
)
npt.assert_allclose(c34, pwc, atol=1e-6)

shape = np.array(c34.shape)
npt.assert_equal(shape, np.array([3, 6]))

def test_latent_diag(self):
"""
g(a, a) = 0 for all a, so K((a, a), (a, a)) = 0
"""

np.random.seed(1)
torch.manual_seed(1)
a = torch.rand(torch.Size([2, 2]))

# should get 0 variance on pairs (a,a)
diag = torch.cat((a, a), dim=1)
diagv = self.kernel.forward(diag, diag).evaluate().detach().numpy()
npt.assert_allclose(diagv, 0.0)

def test_diag(self):
"""
make sure the diagonal is the right shape
"""
np.random.seed(1)
torch.manual_seed(1)

x1 = torch.rand(torch.Size([2, 2, 4]))
x2 = torch.rand(torch.Size([2, 2, 4]))

diag = self.kernel(x1, x2, diag=True)
shape = np.array(diag.shape)
npt.assert_equal(shape, np.array([2, 2]))

x1 = torch.rand(torch.Size([2, 4]))
x2 = torch.rand(torch.Size([2, 4]))

diag = self.kernel(x1, x2, diag=True)
shape = np.array(diag.shape)
npt.assert_equal(shape, np.array([2]))


if __name__ == "__main__":
unittest.main()

0 comments on commit 017e168

Please sign in to comment.