Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow for dimension-specific variance in linear kernel #2593

Merged
merged 2 commits into from
Sep 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions gpytorch/kernels/linear_kernel.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
#!/usr/bin/env python3

import warnings
from typing import Optional, Union

import torch
Expand Down Expand Up @@ -40,31 +39,28 @@ class LinearKernel(Kernel):
\top} \mathbf v)`, where the base multiply :math:`\mathbf X \mathbf v`
takes only :math:`\mathcal O(ND)` time and space.

:param ard_num_dims: Set this if you want a separate variance priors for each weight. (Default: `None`)
:param variance_prior: Prior over the variance parameter. (Default `None`.)
:param variance_constraint: Constraint to place on variance parameter. (Default: `Positive`.)
:param active_dims: List of data dimensions to operate on. `len(active_dims)` should equal `num_dimensions`.
:param active_dims: List of data dimensions to operate on.
"""

def __init__(
self,
num_dimensions: Optional[int] = None,
offset_prior: Optional[Prior] = None,
ard_num_dims: Optional[int] = None,
variance_prior: Optional[Prior] = None,
variance_constraint: Optional[Interval] = None,
**kwargs,
):
super(LinearKernel, self).__init__(**kwargs)
if variance_constraint is None:
variance_constraint = Positive()

if num_dimensions is not None:
# Remove after 1.0
warnings.warn("The `num_dimensions` argument is deprecated and no longer used.", DeprecationWarning)
self.register_parameter(name="offset", parameter=torch.nn.Parameter(torch.zeros(1, 1, num_dimensions)))
if offset_prior is not None:
# Remove after 1.0
warnings.warn("The `offset_prior` argument is deprecated and no longer used.", DeprecationWarning)
self.register_parameter(name="raw_variance", parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1, 1)))
self.register_parameter(
name="raw_variance",
parameter=torch.nn.Parameter(
torch.zeros(*self.batch_shape, 1, 1 if ard_num_dims is None else ard_num_dims)
),
)
if variance_prior is not None:
if not isinstance(variance_prior, Prior):
raise TypeError("Expected gpytorch.priors.Prior but got " + type(variance_prior).__name__)
Expand Down
25 changes: 21 additions & 4 deletions test/kernels/test_linear_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@


class TestLinearKernel(unittest.TestCase, BaseKernelTestCase):
kernel_kwargs = {}

def create_kernel_no_ard(self, **kwargs):
return LinearKernel(**kwargs)
return LinearKernel(**kwargs, **self.kernel_kwargs)

def test_computes_linear_function_rectangular(self):
a = torch.tensor([4, 2, 8], dtype=torch.float).view(3, 1)
b = torch.tensor([0, 2, 1], dtype=torch.float).view(3, 1)

kernel = LinearKernel().initialize(variance=1.0)
kernel = self.create_kernel_no_ard().initialize(variance=1.0)
kernel.eval()
actual = torch.matmul(a, b.t())
res = kernel(a, b).to_dense()
Expand All @@ -31,7 +33,7 @@ def test_computes_linear_function_rectangular(self):
def test_computes_linear_function_square(self):
a = torch.tensor([[4, 1], [2, 0], [8, 3]], dtype=torch.float)

kernel = LinearKernel().initialize(variance=3.14)
kernel = self.create_kernel_no_ard().initialize(variance=3.14)
kernel.eval()
actual = torch.matmul(a, a.t()) * 3.14
res = kernel(a, a).to_dense()
Expand All @@ -57,7 +59,7 @@ def test_computes_linear_function_square(self):
def test_computes_linear_function_square_batch(self):
a = torch.tensor([[[4, 1], [2, 0], [8, 3]], [[1, 1], [2, 1], [1, 3]]], dtype=torch.float)

kernel = LinearKernel().initialize(variance=1.0)
kernel = self.create_kernel_no_ard().initialize(variance=1.0)
kernel.eval()
actual = torch.matmul(a, a.transpose(-1, -2))
res = kernel(a, a).to_dense()
Expand Down Expand Up @@ -92,5 +94,20 @@ def test_prior_type(self):
self.assertRaises(TypeError, self.create_kernel_with_prior, 1)


class TestLinearKernelARD(TestLinearKernel):
def test_kernel_ard(self) -> None:
self.kernel_kwargs = {"ard_num_dims": 2}
kernel = self.create_kernel_no_ard()
self.assertEqual(kernel.variance.shape, torch.Size([1, 2]))

def test_computes_linear_function_rectangular(self):
self.kernel_kwargs = {"ard_num_dims": 1}
super().test_computes_linear_function_rectangular()

def test_computes_linear_function_square_batch(self):
self.kernel_kwargs = {"ard_num_dims": 2}
super().test_computes_linear_function_square_batch()


if __name__ == "__main__":
unittest.main()
Loading