Skip to content

Commit

Permalink
Hotfix: fix .diagonal() calls for keops kernel matrices.
Browse files Browse the repository at this point in the history
[Fixes #2589]
  • Loading branch information
gpleiss committed Sep 20, 2024
1 parent 8825cdd commit 6421dce
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 13 deletions.
4 changes: 4 additions & 0 deletions gpytorch/kernels/keops/keops_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import pykeops # noqa F401
from pykeops.torch import LazyTensor

_Anysor = Union[Tensor, LazyTensor]

def _lazify_and_expand_inputs(
x1: Tensor, x2: Tensor
) -> Tuple[Union[Tensor, LazyTensor], Union[Tensor, LazyTensor]]:
Expand Down Expand Up @@ -49,6 +51,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Union[LinearOperator, Tensor, L

except ImportError:

_Anysor = Tensor

def _lazify_and_expand_inputs(x1: Tensor, x2: Tensor) -> Tuple[Tensor, Tensor]:
x1_ = x1[..., :, None, :]
x2_ = x2[..., None, :, :]
Expand Down
18 changes: 13 additions & 5 deletions gpytorch/kernels/keops/matern_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
import math

from linear_operator.operators import KernelLinearOperator
from torch import Tensor

from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel
from .keops_kernel import _Anysor, _lazify_and_expand_inputs, KeOpsKernel


def _covar_func(x1, x2, nu=2.5, **params):
def _covar_func(x1: _Anysor, x2: _Anysor, nu: float = 2.5, **params) -> _Anysor:
x1_, x2_ = _lazify_and_expand_inputs(x1, x2)

sq_distance = ((x1_ - x2_) ** 2).sum(-1)
Expand Down Expand Up @@ -57,15 +58,22 @@ class MaternKernel(KeOpsKernel):

has_lengthscale = True

def __init__(self, nu=2.5, **kwargs):
def __init__(self, nu: float = 2.5, **kwargs):
if nu not in {0.5, 1.5, 2.5}:
raise RuntimeError("nu expected to be 0.5, 1.5, or 2.5")
super().__init__(**kwargs)
self.nu = nu

def forward(self, x1, x2, **kwargs):
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs) -> KernelLinearOperator:
mean = x1.reshape(-1, x1.size(-1)).mean(0)[(None,) * (x1.dim() - 1)]
x1_ = (x1 - mean) / self.lengthscale
x2_ = (x2 - mean) / self.lengthscale
# return KernelLinearOperator inst only when calculating the whole covariance matrix
return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, nu=self.nu, **kwargs)
res = KernelLinearOperator(x1_, x2_, covar_func=_covar_func, nu=self.nu, **kwargs)

# TODO: diag=True mode will be removed with the GpyTorch 2.0 PR to remove LazyEvaluatedKernelTensor
# (it will be replaced by a `_symmetric_diag` method for quickly computing the diagonals of symmetric matrices)
if diag:
return res.diagonal(dim1=-1, dim2=-2)
else:
return res
16 changes: 12 additions & 4 deletions gpytorch/kernels/keops/periodic_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
import math

from linear_operator.operators import KernelLinearOperator
from torch import Tensor

from ..periodic_kernel import PeriodicKernel as GPeriodicKernel
from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel
from .keops_kernel import _Anysor, _lazify_and_expand_inputs, KeOpsKernel


def _covar_func(x1, x2, lengthscale, **kwargs):
def _covar_func(x1: _Anysor, x2: _Anysor, lengthscale: Tensor, **kwargs) -> _Anysor:
x1_, x2_ = _lazify_and_expand_inputs(x1, x2)
lengthscale = lengthscale[..., None, None, 0, :] # 1 x 1 x ndim
# do not use .power(2.0) as it gives NaN values on cuda
Expand Down Expand Up @@ -56,9 +57,16 @@ class PeriodicKernel(KeOpsKernel, GPeriodicKernel):

has_lengthscale = True

def forward(self, x1, x2, **kwargs):
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs) -> KernelLinearOperator:
x1_ = x1.div(self.period_length / math.pi)
x2_ = x2.div(self.period_length / math.pi)
# return KernelLinearOperator inst only when calculating the whole covariance matrix
# pass any parameters which are used inside _covar_func as *args to get gradients computed for them
return KernelLinearOperator(x1_, x2_, lengthscale=self.lengthscale, covar_func=_covar_func, **kwargs)
res = KernelLinearOperator(x1_, x2_, lengthscale=self.lengthscale, covar_func=_covar_func, **kwargs)

# TODO: diag=True mode will be removed with the GpyTorch 2.0 PR to remove LazyEvaluatedKernelTensor
# (it will be replaced by a `_symmetric_diag` method for quickly computing the diagonals of symmetric matrices)
if diag:
return res.diagonal(dim1=-1, dim2=-2)
else:
return res
16 changes: 12 additions & 4 deletions gpytorch/kernels/keops/rbf_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@

# from linear_operator.operators import KeOpsLinearOperator
from linear_operator.operators import KernelLinearOperator
from torch import Tensor

from .keops_kernel import _lazify_and_expand_inputs, KeOpsKernel
from .keops_kernel import _Anysor, _lazify_and_expand_inputs, KeOpsKernel


def _covar_func(x1, x2, **kwargs):
def _covar_func(x1: _Anysor, x2: _Anysor, **kwargs) -> _Anysor:
x1_, x2_ = _lazify_and_expand_inputs(x1, x2)
K = (-((x1_ - x2_) ** 2).sum(-1) / 2).exp()
return K
Expand Down Expand Up @@ -40,8 +41,15 @@ class RBFKernel(KeOpsKernel):

has_lengthscale = True

def forward(self, x1, x2, **kwargs):
def forward(self, x1: Tensor, x2: Tensor, diag: bool = False, **kwargs) -> KernelLinearOperator:
x1_ = x1 / self.lengthscale
x2_ = x2 / self.lengthscale
# return KernelLinearOperator inst only when calculating the whole covariance matrix
return KernelLinearOperator(x1_, x2_, covar_func=_covar_func, **kwargs)
res = KernelLinearOperator(x1_, x2_, covar_func=_covar_func, **kwargs)

# TODO: diag=True mode will be removed with the GpyTorch 2.0 PR to remove LazyEvaluatedKernelTensor
# (it will be replaced by a `_symmetric_diag` method for quickly computing the diagonals of symmetric matrices)
if diag:
return res.diagonal(dim1=-1, dim2=-2)
else:
return res
10 changes: 10 additions & 0 deletions gpytorch/test/base_keops_test_case.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ def test_forward_x1_eq_x2(self, ard=False, use_keops=True, **kwargs):
k2 = kern2(x1, x1).to_dense()
self.assertLess(torch.norm(k1 - k2), 1e-4)

# Test diagonal
d1 = kern1(x1, x1).diagonal(dim1=-1, dim2=-2)
d2 = kern2(x1, x1).diagonal(dim1=-1, dim2=-2)
self.assertLess(torch.norm(d1 - d2), 1e-4)

if use_keops:
self.assertTrue(keops_mock.called)

Expand All @@ -68,6 +73,11 @@ def test_forward_x1_neq_x2(self, use_keops=True, ard=False, **kwargs):
k2 = kern2(x1, x2).to_dense()
self.assertLess(torch.norm(k1 - k2), 1e-3)

# Test diagonal
d1 = kern1(x1, x1).diagonal(dim1=-1, dim2=-2)
d2 = kern2(x1, x1).diagonal(dim1=-1, dim2=-2)
self.assertLess(torch.norm(d1 - d2), 1e-4)

if use_keops:
self.assertTrue(keops_mock.called)

Expand Down

0 comments on commit 6421dce

Please sign in to comment.