Skip to content

Commit

Permalink
Support for priors in OAK Kernel (pytorch#2535)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#2535

Added support for registering priors to the coefficients of the OrthogonalAdditiveKernel. Useful for incentivizing sparsity in the additive components, and improve identifiability between first- and second-order components.

Differential Revision: D61730632
  • Loading branch information
Carl Hvarfner authored and facebook-github-bot committed Sep 17, 2024
1 parent e9ce11f commit b1ef113
Show file tree
Hide file tree
Showing 2 changed files with 271 additions and 2 deletions.
87 changes: 86 additions & 1 deletion botorch/models/kernels/orthogonal_additive_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,16 @@
from botorch.exceptions.errors import UnsupportedError
from gpytorch.constraints import Interval, Positive
from gpytorch.kernels import Kernel
from gpytorch.module import Module
from gpytorch.priors import Prior

from torch import nn, Tensor

_positivity_constraint = Positive()
SECOND_ORDER_PRIOR_ERROR_MSG = (
"Second order is disabled, but there is a prior on the second order coefficients. "
"Please remove the second order prior or enable second order terms."
)


class OrthogonalAdditiveKernel(Kernel):
Expand All @@ -40,6 +46,9 @@ def __init__(
dtype: Optional[torch.dtype] = None,
device: Optional[torch.device] = None,
coeff_constraint: Interval = _positivity_constraint,
offset_prior: Optional[Prior] = None,
coeffs_1_prior: Optional[Prior] = None,
coeffs_2_prior: Optional[Prior] = None,
):
"""
Args:
Expand All @@ -55,6 +64,9 @@ def __init__(
"""
super().__init__(batch_shape=batch_shape)
self.base_kernel = base_kernel
if not second_order and coeffs_2_prior is not None:
raise AttributeError(SECOND_ORDER_PRIOR_ERROR_MSG)

# integration nodes, weights for [0, 1]
tkwargs = {"dtype": dtype, "device": device}
z, w = leggauss(deg=quad_deg, a=0, b=1, **tkwargs)
Expand Down Expand Up @@ -82,6 +94,29 @@ def __init__(
else None
),
)
if offset_prior is not None:
self.register_prior(
name="offset_prior",
prior=offset_prior,
param_or_closure=self._offset_param,
setting_closure=self._offset_closure,
)
if coeffs_1_prior is not None:
self.register_prior(
name="coeffs_1_prior",
prior=coeffs_1_prior,
param_or_closure=self._coeffs_1_param,
setting_closure=self._coeffs_1_closure,
)
if coeffs_2_prior is not None:
self.register_prior(
name="coeffs_2_prior",
prior=coeffs_2_prior,
param_or_closure=self._coeffs_2_param,
setting_closure=self._coeffs_2_closure,
)

# for second order interactions, we only
if second_order:
self._rev_triu_indices = torch.tensor(
_reverse_triu_indices(dim),
Expand All @@ -95,7 +130,7 @@ def __init__(
self.coeff_constraint = coeff_constraint
self.dim = dim

def k(self, x1, x2) -> Tensor:
def k(self, x1: Tensor, x2: Tensor) -> Tensor:
"""Evaluates the kernel matrix base_kernel(x1, x2) on each input dimension
independently.
Expand Down Expand Up @@ -140,6 +175,56 @@ def coeffs_2(self) -> Optional[Tensor]:
else:
return None

def _coeffs_1_param(self, m: Module):
return m.coeffs_1

def _coeffs_2_param(self, m: Module):
return m.coeffs_2

def _offset_param(self, m: Module):
return m.offset

def _coeffs_1_closure(self, m: Module, v: Tensor):
return m._set_coeffs_1(v)

def _coeffs_2_closure(self, m: Module, v: Tensor):
return m._set_coeffs_2(v)

def _offset_closure(self, m: Module, v: Tensor):
return m._set_offset(v)

def _set_coeffs_1(self, value: Tensor):
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_coeffs_1)
value = value.expand(*self.batch_shape, self.dim)
self.initialize(raw_coeffs_1=self.coeff_constraint.inverse_transform(value))

def _set_coeffs_2(self, value: Tensor):
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_coeffs_1)

value = value.expand(*self.batch_shape, self.dim, self.dim)
triu_indices = torch.triu_indices(self.dim, self.dim, offset=1)
value = value[..., triu_indices[0], triu_indices[1]].to(self.raw_coeffs_2)
self.initialize(raw_coeffs_2=self.coeff_constraint.inverse_transform(value))

def _set_offset(self, value: Tensor):
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_offset)
self.initialize(raw_offset=self.coeff_constraint.inverse_transform(value))

@coeffs_1.setter
def coeffs_1(self, value):
self._set_coeffs_1(value)

@coeffs_2.setter
def coeffs_2(self, value):
self._set_coeffs_2(value)

@offset.setter # offsetter, lol
def offset(self, value):
self._set_offset(value)

def forward(
self,
x1: Tensor,
Expand Down
186 changes: 185 additions & 1 deletion test/models/kernels/test_orthogonal_additive_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,19 @@

import torch
from botorch.exceptions.errors import UnsupportedError
from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel
from botorch.fit import fit_gpytorch_mll
from botorch.models import SingleTaskGP
from botorch.models.kernels.orthogonal_additive_kernel import (
OrthogonalAdditiveKernel,
SECOND_ORDER_PRIOR_ERROR_MSG,
)
from botorch.utils.testing import BotorchTestCase
from gpytorch.constraints import Positive
from gpytorch.kernels import MaternKernel, RBFKernel
from gpytorch.lazy import LazyEvaluatedKernelTensor
from gpytorch.mlls.exact_marginal_log_likelihood import ExactMarginalLogLikelihood
from gpytorch.priors import LogNormalPrior
from gpytorch.priors.torch_priors import GammaPrior, HalfCauchyPrior, UniformPrior
from torch import nn, Tensor


Expand Down Expand Up @@ -118,6 +127,181 @@ def test_kernel(self):
tol = 1e-5
self.assertTrue(((K_ortho @ oak.w).squeeze(-1) < tol).all())

def test_priors(self):
d = 5
dtypes = [torch.float, torch.double]
batch_shapes = [(), (2,), (7, 2)]

# test no prior
oak = OrthogonalAdditiveKernel(
RBFKernel(), dim=d, batch_shape=None, second_order=True
)
for dtype in dtypes:
for batch_shape in batch_shapes:
# test with default args and batch_shape = None in second_order
tkwargs = {"dtype": dtype, "device": self.device}
offset_prior = HalfCauchyPrior(0.1).to(**tkwargs)
coeffs_1_prior = LogNormalPrior(0, 1).to(**tkwargs)
coeffs_2_prior = GammaPrior(3, 6).to(**tkwargs)
oak = OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
second_order=True,
offset_prior=offset_prior,
coeffs_1_prior=coeffs_1_prior,
coeffs_2_prior=coeffs_2_prior,
batch_shape=batch_shape,
**tkwargs,
)

self.assertIsInstance(oak.offset_prior, HalfCauchyPrior)
self.assertIsInstance(oak.coeffs_1_prior, LogNormalPrior)
self.assertEqual(oak.coeffs_1_prior.scale, 1)
self.assertEqual(oak.coeffs_2_prior.concentration, 3)

oak = OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
second_order=True,
coeffs_1_prior=None,
coeffs_2_prior=coeffs_2_prior,
batch_shape=batch_shape,
**tkwargs,
)
self.assertEqual(oak.coeffs_2_prior.concentration, 3)
with self.assertRaisesRegex(
AttributeError,
"'OrthogonalAdditiveKernel' object has no attribute "
"'coeffs_1_prior",
):
_ = oak.coeffs_1_prior
# test with batch_shape = None in second_order
oak = OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
second_order=True,
coeffs_1_prior=coeffs_1_prior,
batch_shape=batch_shape,
**tkwargs,
)
with self.assertRaisesRegex(AttributeError, SECOND_ORDER_PRIOR_ERROR_MSG):
OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
batch_shape=None,
second_order=False,
coeffs_2_prior=GammaPrior(1, 1),
)

# train the model to ensure that param setters are called
train_X = torch.rand(5, d, dtype=dtype, device=self.device)
train_Y = torch.randn(5, 1, dtype=dtype, device=self.device)

oak = OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
batch_shape=None,
second_order=True,
offset_prior=offset_prior,
coeffs_1_prior=coeffs_1_prior,
coeffs_2_prior=coeffs_2_prior,
**tkwargs,
)
model = SingleTaskGP(train_X=train_X, train_Y=train_Y, covar_module=oak)
mll = ExactMarginalLogLikelihood(model.likelihood, model)
fit_gpytorch_mll(mll, optimizer_kwargs={"options": {"maxiter": 2}})

unif_prior = UniformPrior(10, 11)
# coeff_constraint is not enforced so that we can check the raw parameter
# values and not the reshaped (triu transformed) ones
oak_for_sample = OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
batch_shape=None,
second_order=True,
offset_prior=unif_prior,
coeffs_1_prior=unif_prior,
coeffs_2_prior=unif_prior,
coeff_constraint=Positive(transform=None, inv_transform=None),
**tkwargs,
)
oak_for_sample.sample_from_prior("offset_prior")
oak_for_sample.sample_from_prior("coeffs_1_prior")
oak_for_sample.sample_from_prior("coeffs_2_prior")

# check that all sampled values are within the bounds set by the priors
self.assertTrue(torch.all(10 <= oak_for_sample.raw_offset <= 11))
self.assertTrue(
torch.all(
(10 <= oak_for_sample.raw_coeffs_1)
* (oak_for_sample.raw_coeffs_1 <= 11)
)
)
self.assertTrue(
torch.all(
(10 <= oak_for_sample.raw_coeffs_2)
* (oak_for_sample.raw_coeffs_2 <= 11)
)
)

def test_set_coeffs(self):
d = 5
dtype = torch.double
oak = OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
batch_shape=None,
second_order=True,
dtype=dtype,
)
constraint = oak.coeff_constraint
coeffs_1 = torch.arange(d).to(dtype)
coeffs_2 = torch.ones((d * d)).reshape(d, d).triu().to(dtype)
oak.coeffs_1 = coeffs_1
oak.coeffs_2 = coeffs_2

self.assertAllClose(
oak.raw_coeffs_1,
constraint.inverse_transform(coeffs_1),
)
# raw_coeffs_2 has length d * (d-1) / 2
self.assertAllClose(
oak.raw_coeffs_2, constraint.inverse_transform(torch.ones(10).to(dtype))
)

batch_shapes = torch.Size([2]), torch.Size([5, 2])
for batch_shape in batch_shapes:
dtype = torch.double
oak = OrthogonalAdditiveKernel(
RBFKernel(),
dim=d,
batch_shape=batch_shape,
second_order=True,
dtype=dtype,
)
constraint = oak.coeff_constraint
coeffs_1 = torch.arange(d).to(dtype)
coeffs_2 = torch.ones((d * d)).reshape(d, d).triu().to(dtype)
oak.coeffs_1 = coeffs_1
oak.coeffs_2 = coeffs_2

self.assertEqual(oak.raw_coeffs_1.shape, batch_shape + torch.Size([5]))
# raw_coeffs_2 has length d * (d-1) / 2
self.assertEqual(oak.raw_coeffs_2.shape, batch_shape + torch.Size([10]))

# test setting value as float
oak.offset = 0.5
self.assertAllClose(oak.offset, 0.5 * torch.ones_like(oak.offset))
# raw_coeffs_2 has length d * (d-1) / 2
oak.coeffs_1 = 0.2
self.assertAllClose(oak.coeffs_1, 0.2 * torch.ones_like(oak.raw_coeffs_1))
with self.assertRaisesRegex(
ValueError,
"Second order coefficients must be provided as an "
"upper-triangular matrix.",
):
oak.coeffs_2 = 0.3


def isposdef(A: Tensor) -> bool:
"""Determines whether A is positive definite or not, by attempting a Cholesky
Expand Down

0 comments on commit b1ef113

Please sign in to comment.