From 422cebe015c386d122c0145b159dcfac8ee4e953 Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Fri, 13 Sep 2024 12:45:29 -0700 Subject: [PATCH] Support for priors in OAK Kernel (#2535) Summary: Pull Request resolved: https://github.com/pytorch/botorch/pull/2535 Added support for registering priors to the coefficients of the OrthogonalAdditiveKernel. Useful for incentivizing sparsity in the additive components, and improve identifiability between first- and second-order components. Differential Revision: D61730632 --- .../kernels/orthogonal_additive_kernel.py | 84 +++++++++++++ .../test_orthogonal_additive_kernel.py | 113 +++++++++++++++++- 2 files changed, 196 insertions(+), 1 deletion(-) diff --git a/botorch/models/kernels/orthogonal_additive_kernel.py b/botorch/models/kernels/orthogonal_additive_kernel.py index 37293c540e..73459f9f10 100644 --- a/botorch/models/kernels/orthogonal_additive_kernel.py +++ b/botorch/models/kernels/orthogonal_additive_kernel.py @@ -11,10 +11,15 @@ from botorch.exceptions.errors import UnsupportedError from gpytorch.constraints import Interval, Positive from gpytorch.kernels import Kernel +from gpytorch.priors import Prior from torch import nn, Tensor _positivity_constraint = Positive() +SECOND_ORDER_PRIOR_ERROR_MSG = ( + "Second order is disabled, but there is a prior on the second order coefficients. " + "Please remove the second order prior or enable second order terms." +) class OrthogonalAdditiveKernel(Kernel): @@ -40,6 +45,9 @@ def __init__( dtype: Optional[torch.dtype] = None, device: Optional[torch.device] = None, coeff_constraint: Interval = _positivity_constraint, + offset_prior: Optional[Prior] = None, + coeffs_1_prior: Optional[Prior] = None, + coeffs_2_prior: Optional[Prior] = None, ): """ Args: @@ -55,6 +63,9 @@ def __init__( """ super().__init__(batch_shape=batch_shape) self.base_kernel = base_kernel + if not second_order and coeffs_2_prior is not None: + raise AttributeError(SECOND_ORDER_PRIOR_ERROR_MSG) + # integration nodes, weights for [0, 1] tkwargs = {"dtype": dtype, "device": device} z, w = leggauss(deg=quad_deg, a=0, b=1, **tkwargs) @@ -82,6 +93,29 @@ def __init__( else None ), ) + if offset_prior is not None: + self.register_prior( + "offset_prior", + offset_prior, + self._offset_param, + self._offset_closure, + ) + if coeffs_1_prior is not None: + self.register_prior( + "coeffs_1_prior", + coeffs_1_prior, + self._coeffs_1_param, + self._coeffs_1_closure, + ) + if coeffs_2_prior is not None: + self.register_prior( + "coeffs_2_prior", + coeffs_2_prior, + self._coeffs_2_param, + self._coeffs_2_closure, + ) + + # for second order interactions, we only if second_order: self._rev_triu_indices = torch.tensor( _reverse_triu_indices(dim), @@ -140,6 +174,56 @@ def coeffs_2(self) -> Optional[Tensor]: else: return None + def _coeffs_1_param(self, m): + return m.coeffs_1 + + def _coeffs_2_param(self, m): + return m.coeffs_2 + + def _offset_param(self, m): + return m.offset + + def _coeffs_1_closure(self, m, v): + return m._set_coeffs_1(v) + + def _coeffs_2_closure(self, m, v): + return m._set_coeffs_2(v) + + def _offset_closure(self, m, v): + return m._set_offset(v) + + def _set_coeffs_1(self, value): + if not torch.is_tensor(value): + value = torch.as_tensor(value).to(self.raw_coeffs_1) + self.initialize(raw_coeffs_1=self.coeff_constraint.inverse_transform(value)) + + def _set_coeffs_2(self, value): + if not isinstance(value, Tensor) or value.tril().abs().sum() == 0: + raise ValueError( + "Second order coefficients must be provided as an " + "upper-triangular matrix." + ) + triu_indices = torch.triu_indices(self.dim, self.dim, offset=1) + value = value[..., triu_indices[0], triu_indices[1]].to(self.raw_coeffs_2) + self.initialize(raw_coeffs_2=self.coeff_constraint.inverse_transform(value)) + + def _set_offset(self, value): + if not torch.is_tensor(value): + value = torch.as_tensor(value).to(self.raw_offset) + self.initialize(raw_offset=self.coeff_constraint.inverse_transform(value)) + + @coeffs_1.setter + def coeffs_1(self, value): + self._set_coeffs_1(value) + + @coeffs_2.setter + def coeffs_2(self, value): + self._set_coeffs_2(value) + + @offset.setter # offsetter, lol + def offset(self, value): + self._set_offset(value) + def forward( self, x1: Tensor, diff --git a/test/models/kernels/test_orthogonal_additive_kernel.py b/test/models/kernels/test_orthogonal_additive_kernel.py index 7f378b0034..76a21edc6b 100644 --- a/test/models/kernels/test_orthogonal_additive_kernel.py +++ b/test/models/kernels/test_orthogonal_additive_kernel.py @@ -6,10 +6,15 @@ import torch from botorch.exceptions.errors import UnsupportedError -from botorch.models.kernels.orthogonal_additive_kernel import OrthogonalAdditiveKernel +from botorch.models.kernels.orthogonal_additive_kernel import ( + OrthogonalAdditiveKernel, + SECOND_ORDER_PRIOR_ERROR_MSG, +) from botorch.utils.testing import BotorchTestCase from gpytorch.kernels import MaternKernel, RBFKernel from gpytorch.lazy import LazyEvaluatedKernelTensor +from gpytorch.priors import LogNormalPrior +from gpytorch.priors.torch_priors import GammaPrior, HalfCauchyPrior from torch import nn, Tensor @@ -118,6 +123,112 @@ def test_kernel(self): tol = 1e-5 self.assertTrue(((K_ortho @ oak.w).squeeze(-1) < tol).all()) + def test_priors(self): + d = 5 + dtypes = [torch.float, torch.double] + batch_shapes = [(), (2,), (7, 2)] + for dtype in dtypes: + for batch_shape in batch_shapes: + # test with default args and batch_shape = None in second_order + offset_prior = HalfCauchyPrior(0.1) + coeffs_1_prior = LogNormalPrior(0, 1) + coeffs_2_prior = GammaPrior(3, 6) + oak = OrthogonalAdditiveKernel( + RBFKernel(), + dim=d, + batch_shape=None, + second_order=True, + offset_prior=offset_prior, + coeffs_1_prior=coeffs_1_prior, + coeffs_2_prior=coeffs_2_prior, + ) + + self.assertIsInstance(oak.offset_prior, HalfCauchyPrior) + self.assertIsInstance(oak.coeffs_1_prior, LogNormalPrior) + self.assertEqual(oak.coeffs_1_prior.scale, 1) + self.assertEqual(oak.coeffs_2_prior.concentration, 3) + + oak = OrthogonalAdditiveKernel( + RBFKernel(), + dim=d, + batch_shape=None, + second_order=True, + coeffs_1_prior=None, + coeffs_2_prior=coeffs_2_prior, + ) + self.assertEqual(oak.coeffs_2_prior.concentration, 3) + with self.assertRaisesRegex( + AttributeError, + "'OrthogonalAdditiveKernel' object has no attribute " + "'coeffs_1_prior", + ): + _ = oak.coeffs_1_prior + # test with batch_shape = None in second_order + oak = OrthogonalAdditiveKernel( + RBFKernel(), + dim=d, + batch_shape=batch_shape, + second_order=True, + coeffs_1_prior=coeffs_1_prior, + ) + with self.assertRaisesRegex(AttributeError, SECOND_ORDER_PRIOR_ERROR_MSG): + OrthogonalAdditiveKernel( + RBFKernel(), + dim=d, + batch_shape=None, + second_order=False, + coeffs_2_prior=GammaPrior(1, 1), + ) + # test no prior + oak = OrthogonalAdditiveKernel( + RBFKernel(), dim=d, batch_shape=None, second_order=True + ) + + def test_set_coeffs(self): + d = 5 + dtype = torch.double + oak = OrthogonalAdditiveKernel( + RBFKernel(), + dim=d, + batch_shape=None, + second_order=True, + dtype=dtype, + ) + constraint = oak.coeff_constraint + coeffs_1 = torch.arange(d).to(dtype) + coeffs_2 = torch.ones((d * d)).reshape(d, d).triu().to(dtype) + oak.coeffs_1 = coeffs_1 + oak.coeffs_2 = coeffs_2 + + self.assertAllClose( + oak.raw_coeffs_1, + constraint.inverse_transform(coeffs_1), + ) + # raw_coeffs_2 has length d * (d-1) / 2 + self.assertAllClose( + oak.raw_coeffs_2, constraint.inverse_transform(torch.ones(10).to(dtype)) + ) + + batch_shapes = torch.Size([2]), torch.Size([5, 2]) + for batch_shape in batch_shapes: + dtype = torch.double + oak = OrthogonalAdditiveKernel( + RBFKernel(), + dim=d, + batch_shape=batch_shape, + second_order=True, + dtype=dtype, + ) + constraint = oak.coeff_constraint + coeffs_1 = torch.arange(d).to(dtype) + coeffs_2 = torch.ones((d * d)).reshape(d, d).triu().to(dtype) + oak.coeffs_1 = coeffs_1 + oak.coeffs_2 = coeffs_2 + + self.assertEqual(oak.raw_coeffs_1.shape, batch_shape + torch.Size([5])) + # raw_coeffs_2 has length d * (d-1) / 2 + self.assertEqual(oak.raw_coeffs_2.shape, batch_shape + torch.Size([10])) + def isposdef(A: Tensor) -> bool: """Determines whether A is positive definite or not, by attempting a Cholesky