diff --git a/botorch/acquisition/penalized.py b/botorch/acquisition/penalized.py new file mode 100644 index 0000000000..dcc9cddb7b --- /dev/null +++ b/botorch/acquisition/penalized.py @@ -0,0 +1,172 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +r""" +Modules to add regularization to acquisition functions. +""" + +from __future__ import annotations + +import math +from typing import List, Optional + +import torch +from botorch.acquisition.acquisition import AcquisitionFunction +from botorch.acquisition.analytic import AnalyticAcquisitionFunction +from botorch.exceptions import UnsupportedError +from torch import Tensor + + +class L2Penalty(torch.nn.Module): + r"""L2 penalty class to be added to any arbitrary acquisition function.""" + + def __init__(self, init_point: Tensor): + r"""Initializing L2 regularization. + + Args: + init_point: The "1 x dim" reference point against which + we want to regularize. + """ + super().__init__() + self.init_point = init_point + + def forward(self, X: Tensor) -> Tensor: + r""" + Args: + X: A "batch_shape x q x dim" representing the points to be evaluated. + + Returns: + A tensor of size "batch_shape" representing the acqfn for each q-batch. + """ + regularization_term = ( + torch.norm((X - self.init_point), p=2, dim=-1).max(dim=-1).values ** 2 + ) + return regularization_term + + +class GaussianPenalty(torch.nn.Module): + r"""Gaussian penalty class to be added to any arbitrary acquisition function.""" + + def __init__(self, init_point: Tensor, sigma: float): + r"""Initializing Gaussian regularization. + + Args: + init_point: The "1 x dim" reference point against which + we want to regularize. + sigma: The parameter used in gaussian function. + """ + super().__init__() + self.init_point = init_point + self.sigma = sigma + + def forward(self, X: Tensor) -> Tensor: + r""" + Args: + X: A "batch_shape x q x dim" representing the points to be evaluated. + + Returns: + A tensor of size "batch_shape" representing the acqfn for each q-batch. + """ + sq_diff = torch.norm((X - self.init_point), p=2, dim=-1) ** 2 + pdf = torch.exp(sq_diff / 2 / self.sigma ** 2) + regularization_term = pdf.max(dim=-1).values + return regularization_term + + +class GroupLassoPenalty(torch.nn.Module): + r"""Group lasso penalty class to be added to any arbitrary acquisition function.""" + + def __init__(self, init_point: Tensor, groups: List[List[int]]): + r"""Initializing Group-Lasso regularization. + + Args: + init_point: The "1 x dim" reference point against which we want + to regularize. + groups: Groups of indices used in group lasso. + """ + super().__init__() + self.init_point = init_point + self.groups = groups + + def forward(self, X: Tensor) -> Tensor: + r""" + X should be batch_shape x 1 x dim tensor. Evaluation for q-batch is not + implemented yet. + """ + if X.shape[-2] != 1: + raise NotImplementedError( + "group-lasso has not been implemented for q>1 yet." + ) + + regularization_term = group_lasso_regularizer( + X=X.squeeze(-2) - self.init_point, groups=self.groups + ) + return regularization_term + + +class PenalizedAcquisitionFunction(AcquisitionFunction): + r"""Single-outcome acquisition function regularized by the given penalty. + + The usage is similar to: + raw_acqf = NoisyExpectedImprovement(...) + penalty = GroupLassoPenalty(...) + acqf = PenalizedAcquisitionFunction(raw_acqf, penalty) + """ + + def __init__( + self, + raw_acqf: AcquisitionFunction, + penalty_func: torch.nn.Module, + regularization_parameter: float, + ) -> None: + r"""Initializing Group-Lasso regularization. + + Args: + raw_acqf: The raw acquisition function that is going to be regularized. + penalty_func: The regularization function. + regularization_parameter: Regularization parameter used in optimization. + """ + super().__init__(model=raw_acqf.model) + self.raw_acqf = raw_acqf + self.penalty_func = penalty_func + self.regularization_parameter = regularization_parameter + + def forward(self, X: Tensor) -> Tensor: + raw_value = self.raw_acqf(X=X) + penalty_term = self.penalty_func(X) + return raw_value - self.regularization_parameter * penalty_term + + @property + def X_pending(self) -> Optional[Tensor]: + return self.raw_acqf.X_pending + + def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None: + if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction): + self.raw_acqf.set_X_pending(X_pending=X_pending) + else: + raise UnsupportedError( + "The raw acquisition function is Analytic and does not account " + "for X_pending yet." + ) + + +def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor: + r"""Computes the group lasso regularization function for the given point. + + Args: + X: A bxd tensor representing the points to evaluate the regularization at. + groups: List of indices of different groups. + + Returns: + Computed group lasso norm of at the given points. + """ + return torch.sum( + torch.stack( + [math.sqrt(len(g)) * torch.norm(X[..., g], p=2, dim=-1) for g in groups], + dim=-1, + ), + dim=-1, + ) diff --git a/sphinx/source/acquisition.rst b/sphinx/source/acquisition.rst index a3d220351d..a92cd702ed 100644 --- a/sphinx/source/acquisition.rst +++ b/sphinx/source/acquisition.rst @@ -110,6 +110,11 @@ Fixed Feature Acquisition Function .. automodule:: botorch.acquisition.fixed_feature :members: +Penalized Acquisition Function Wrapper +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.acquisition.penalized + :members: + General Utilities for Acquisition Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ .. automodule:: botorch.acquisition.utils diff --git a/test/acquisition/test_penalized.py b/test/acquisition/test_penalized.py new file mode 100644 index 0000000000..3f97784d92 --- /dev/null +++ b/test/acquisition/test_penalized.py @@ -0,0 +1,131 @@ +#!/usr/bin/env python3 +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from botorch.acquisition.analytic import ExpectedImprovement +from botorch.acquisition.monte_carlo import qExpectedImprovement +from botorch.acquisition.penalized import ( + GaussianPenalty, + GroupLassoPenalty, + L2Penalty, + PenalizedAcquisitionFunction, + group_lasso_regularizer, +) +from botorch.exceptions import UnsupportedError +from botorch.sampling.samplers import IIDNormalSampler +from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior + + +class TestL2Penalty(BotorchTestCase): + def test_gaussian_penalty(self): + for dtype in (torch.float, torch.double): + init_point = torch.tensor([1.0, 1.0, 1.0], device=self.device, dtype=dtype) + l2_module = L2Penalty(init_point=init_point) + + # testing a batch of two points + sample_point = torch.tensor( + [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], device=self.device, dtype=dtype + ) + + diff_norm_squared = ( + torch.norm((sample_point - init_point), p=2, dim=-1) ** 2 + ) + real_value = diff_norm_squared.max(dim=-1).values + computed_value = l2_module(sample_point) + self.assertEqual(computed_value.item(), real_value.item()) + + +class TestGaussianPenalty(BotorchTestCase): + def test_gaussian_penalty(self): + for dtype in (torch.float, torch.double): + init_point = torch.tensor([1.0, 1.0, 1.0], device=self.device, dtype=dtype) + sigma = 0.1 + gaussian_module = GaussianPenalty(init_point=init_point, sigma=sigma) + + # testing a batch of two points + sample_point = torch.tensor( + [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], device=self.device, dtype=dtype + ) + + diff_norm_squared = ( + torch.norm((sample_point - init_point), p=2, dim=-1) ** 2 + ) + max_l2_distance = diff_norm_squared.max(dim=-1).values + real_value = torch.exp(max_l2_distance / 2 / sigma ** 2) + computed_value = gaussian_module(sample_point) + self.assertEqual(computed_value.item(), real_value.item()) + + +class TestGroupLassoPenalty(BotorchTestCase): + def test_group_lasso_penalty(self): + for dtype in (torch.float, torch.double): + init_point = torch.tensor([0.5, 0.5, 0.5], device=self.device, dtype=dtype) + groups = [[0, 2], [1]] + group_lasso_module = GroupLassoPenalty(init_point=init_point, groups=groups) + + # testing a single point + sample_point = torch.tensor( + [[1.0, 2.0, 3.0]], device=self.device, dtype=dtype + ) + real_value = group_lasso_regularizer( + sample_point - init_point, groups + ) # torch.tensor([5.105551242828369], device=self.device, dtype=dtype) + computed_value = group_lasso_module(sample_point) + self.assertEqual(computed_value.item(), real_value.item()) + + # testing unsupported input dim: X.shape[-2] > 1 + sample_point_2 = torch.tensor( + [[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], device=self.device, dtype=dtype + ) + with self.assertRaises(NotImplementedError): + group_lasso_module(sample_point_2) + + +class TestPenalizedAcquisitionFunction(BotorchTestCase): + def test_penalized_acquisition_function(self): + for dtype in (torch.float, torch.double): + mock_model = MockModel( + MockPosterior(mean=torch.tensor([1.0]), variance=torch.tensor([1.0])) + ) + init_point = torch.tensor([0.5, 0.5, 0.5], device=self.device, dtype=dtype) + groups = [[0, 2], [1]] + raw_acqf = ExpectedImprovement(model=mock_model, best_f=1.0) + penalty = GroupLassoPenalty(init_point=init_point, groups=groups) + lmbda = 0.1 + acqf = PenalizedAcquisitionFunction( + raw_acqf=raw_acqf, penalty_func=penalty, regularization_parameter=lmbda + ) + + sample_point = torch.tensor( + [[1.0, 2.0, 3.0]], device=self.device, dtype=dtype + ) + raw_value = raw_acqf(sample_point) + penalty_value = penalty(sample_point) + real_value = raw_value - lmbda * penalty_value + computed_value = acqf(sample_point) + self.assertTrue(torch.equal(real_value, computed_value)) + + # testing X_pending for analytic raw_acqfn (EI) + X_pending = torch.tensor([0.1, 0.2, 0.3], device=self.device, dtype=dtype) + with self.assertRaises(UnsupportedError): + acqf.set_X_pending(X_pending) + + # testing X_pending for non-analytic raw_acqfn (EI) + sampler = IIDNormalSampler(num_samples=2) + raw_acqf_2 = qExpectedImprovement( + model=mock_model, best_f=0, sampler=sampler + ) + init_point = torch.tensor([1.0, 1.0, 1.0], device=self.device, dtype=dtype) + l2_module = L2Penalty(init_point=init_point) + acqf_2 = PenalizedAcquisitionFunction( + raw_acqf=raw_acqf_2, + penalty_func=l2_module, + regularization_parameter=lmbda, + ) + + X_pending = torch.tensor([0.1, 0.2, 0.3], device=self.device, dtype=dtype) + acqf_2.set_X_pending(X_pending) + self.assertTrue(torch.equal(acqf_2.X_pending, X_pending))