-
Notifications
You must be signed in to change notification settings - Fork 415
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Moving penalized acqfn from botorch_fb to botorch (#585)
Summary: Pull Request resolved: #585 Just moved the development for penalized acqfn from botorch_fb to botorch to push it to the OSS. Reviewed By: Balandat Differential Revision: D24508442 fbshipit-source-id: 54f0884e8e5a86296c6d0e58cc913bbf46323dbb
- Loading branch information
1 parent
4033c16
commit c5ec613
Showing
3 changed files
with
308 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
r""" | ||
Modules to add regularization to acquisition functions. | ||
""" | ||
|
||
from __future__ import annotations | ||
|
||
import math | ||
from typing import List, Optional | ||
|
||
import torch | ||
from botorch.acquisition.acquisition import AcquisitionFunction | ||
from botorch.acquisition.analytic import AnalyticAcquisitionFunction | ||
from botorch.exceptions import UnsupportedError | ||
from torch import Tensor | ||
|
||
|
||
class L2Penalty(torch.nn.Module): | ||
r"""L2 penalty class to be added to any arbitrary acquisition function.""" | ||
|
||
def __init__(self, init_point: Tensor): | ||
r"""Initializing L2 regularization. | ||
Args: | ||
init_point: The "1 x dim" reference point against which | ||
we want to regularize. | ||
""" | ||
super().__init__() | ||
self.init_point = init_point | ||
|
||
def forward(self, X: Tensor) -> Tensor: | ||
r""" | ||
Args: | ||
X: A "batch_shape x q x dim" representing the points to be evaluated. | ||
Returns: | ||
A tensor of size "batch_shape" representing the acqfn for each q-batch. | ||
""" | ||
regularization_term = ( | ||
torch.norm((X - self.init_point), p=2, dim=-1).max(dim=-1).values ** 2 | ||
) | ||
return regularization_term | ||
|
||
|
||
class GaussianPenalty(torch.nn.Module): | ||
r"""Gaussian penalty class to be added to any arbitrary acquisition function.""" | ||
|
||
def __init__(self, init_point: Tensor, sigma: float): | ||
r"""Initializing Gaussian regularization. | ||
Args: | ||
init_point: The "1 x dim" reference point against which | ||
we want to regularize. | ||
sigma: The parameter used in gaussian function. | ||
""" | ||
super().__init__() | ||
self.init_point = init_point | ||
self.sigma = sigma | ||
|
||
def forward(self, X: Tensor) -> Tensor: | ||
r""" | ||
Args: | ||
X: A "batch_shape x q x dim" representing the points to be evaluated. | ||
Returns: | ||
A tensor of size "batch_shape" representing the acqfn for each q-batch. | ||
""" | ||
sq_diff = torch.norm((X - self.init_point), p=2, dim=-1) ** 2 | ||
pdf = torch.exp(sq_diff / 2 / self.sigma ** 2) | ||
regularization_term = pdf.max(dim=-1).values | ||
return regularization_term | ||
|
||
|
||
class GroupLassoPenalty(torch.nn.Module): | ||
r"""Group lasso penalty class to be added to any arbitrary acquisition function.""" | ||
|
||
def __init__(self, init_point: Tensor, groups: List[List[int]]): | ||
r"""Initializing Group-Lasso regularization. | ||
Args: | ||
init_point: The "1 x dim" reference point against which we want | ||
to regularize. | ||
groups: Groups of indices used in group lasso. | ||
""" | ||
super().__init__() | ||
self.init_point = init_point | ||
self.groups = groups | ||
|
||
def forward(self, X: Tensor) -> Tensor: | ||
r""" | ||
X should be batch_shape x 1 x dim tensor. Evaluation for q-batch is not | ||
implemented yet. | ||
""" | ||
if X.shape[-2] != 1: | ||
raise NotImplementedError( | ||
"group-lasso has not been implemented for q>1 yet." | ||
) | ||
|
||
regularization_term = group_lasso_regularizer( | ||
X=X.squeeze(-2) - self.init_point, groups=self.groups | ||
) | ||
return regularization_term | ||
|
||
|
||
class PenalizedAcquisitionFunction(AcquisitionFunction): | ||
r"""Single-outcome acquisition function regularized by the given penalty. | ||
The usage is similar to: | ||
raw_acqf = NoisyExpectedImprovement(...) | ||
penalty = GroupLassoPenalty(...) | ||
acqf = PenalizedAcquisitionFunction(raw_acqf, penalty) | ||
""" | ||
|
||
def __init__( | ||
self, | ||
raw_acqf: AcquisitionFunction, | ||
penalty_func: torch.nn.Module, | ||
regularization_parameter: float, | ||
) -> None: | ||
r"""Initializing Group-Lasso regularization. | ||
Args: | ||
raw_acqf: The raw acquisition function that is going to be regularized. | ||
penalty_func: The regularization function. | ||
regularization_parameter: Regularization parameter used in optimization. | ||
""" | ||
super().__init__(model=raw_acqf.model) | ||
self.raw_acqf = raw_acqf | ||
self.penalty_func = penalty_func | ||
self.regularization_parameter = regularization_parameter | ||
|
||
def forward(self, X: Tensor) -> Tensor: | ||
raw_value = self.raw_acqf(X=X) | ||
penalty_term = self.penalty_func(X) | ||
return raw_value - self.regularization_parameter * penalty_term | ||
|
||
@property | ||
def X_pending(self) -> Optional[Tensor]: | ||
return self.raw_acqf.X_pending | ||
|
||
def set_X_pending(self, X_pending: Optional[Tensor] = None) -> None: | ||
if not isinstance(self.raw_acqf, AnalyticAcquisitionFunction): | ||
self.raw_acqf.set_X_pending(X_pending=X_pending) | ||
else: | ||
raise UnsupportedError( | ||
"The raw acquisition function is Analytic and does not account " | ||
"for X_pending yet." | ||
) | ||
|
||
|
||
def group_lasso_regularizer(X: Tensor, groups: List[List[int]]) -> Tensor: | ||
r"""Computes the group lasso regularization function for the given point. | ||
Args: | ||
X: A bxd tensor representing the points to evaluate the regularization at. | ||
groups: List of indices of different groups. | ||
Returns: | ||
Computed group lasso norm of at the given points. | ||
""" | ||
return torch.sum( | ||
torch.stack( | ||
[math.sqrt(len(g)) * torch.norm(X[..., g], p=2, dim=-1) for g in groups], | ||
dim=-1, | ||
), | ||
dim=-1, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from botorch.acquisition.analytic import ExpectedImprovement | ||
from botorch.acquisition.monte_carlo import qExpectedImprovement | ||
from botorch.acquisition.penalized import ( | ||
GaussianPenalty, | ||
GroupLassoPenalty, | ||
L2Penalty, | ||
PenalizedAcquisitionFunction, | ||
group_lasso_regularizer, | ||
) | ||
from botorch.exceptions import UnsupportedError | ||
from botorch.sampling.samplers import IIDNormalSampler | ||
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior | ||
|
||
|
||
class TestL2Penalty(BotorchTestCase): | ||
def test_gaussian_penalty(self): | ||
for dtype in (torch.float, torch.double): | ||
init_point = torch.tensor([1.0, 1.0, 1.0], device=self.device, dtype=dtype) | ||
l2_module = L2Penalty(init_point=init_point) | ||
|
||
# testing a batch of two points | ||
sample_point = torch.tensor( | ||
[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], device=self.device, dtype=dtype | ||
) | ||
|
||
diff_norm_squared = ( | ||
torch.norm((sample_point - init_point), p=2, dim=-1) ** 2 | ||
) | ||
real_value = diff_norm_squared.max(dim=-1).values | ||
computed_value = l2_module(sample_point) | ||
self.assertEqual(computed_value.item(), real_value.item()) | ||
|
||
|
||
class TestGaussianPenalty(BotorchTestCase): | ||
def test_gaussian_penalty(self): | ||
for dtype in (torch.float, torch.double): | ||
init_point = torch.tensor([1.0, 1.0, 1.0], device=self.device, dtype=dtype) | ||
sigma = 0.1 | ||
gaussian_module = GaussianPenalty(init_point=init_point, sigma=sigma) | ||
|
||
# testing a batch of two points | ||
sample_point = torch.tensor( | ||
[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], device=self.device, dtype=dtype | ||
) | ||
|
||
diff_norm_squared = ( | ||
torch.norm((sample_point - init_point), p=2, dim=-1) ** 2 | ||
) | ||
max_l2_distance = diff_norm_squared.max(dim=-1).values | ||
real_value = torch.exp(max_l2_distance / 2 / sigma ** 2) | ||
computed_value = gaussian_module(sample_point) | ||
self.assertEqual(computed_value.item(), real_value.item()) | ||
|
||
|
||
class TestGroupLassoPenalty(BotorchTestCase): | ||
def test_group_lasso_penalty(self): | ||
for dtype in (torch.float, torch.double): | ||
init_point = torch.tensor([0.5, 0.5, 0.5], device=self.device, dtype=dtype) | ||
groups = [[0, 2], [1]] | ||
group_lasso_module = GroupLassoPenalty(init_point=init_point, groups=groups) | ||
|
||
# testing a single point | ||
sample_point = torch.tensor( | ||
[[1.0, 2.0, 3.0]], device=self.device, dtype=dtype | ||
) | ||
real_value = group_lasso_regularizer( | ||
sample_point - init_point, groups | ||
) # torch.tensor([5.105551242828369], device=self.device, dtype=dtype) | ||
computed_value = group_lasso_module(sample_point) | ||
self.assertEqual(computed_value.item(), real_value.item()) | ||
|
||
# testing unsupported input dim: X.shape[-2] > 1 | ||
sample_point_2 = torch.tensor( | ||
[[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], device=self.device, dtype=dtype | ||
) | ||
with self.assertRaises(NotImplementedError): | ||
group_lasso_module(sample_point_2) | ||
|
||
|
||
class TestPenalizedAcquisitionFunction(BotorchTestCase): | ||
def test_penalized_acquisition_function(self): | ||
for dtype in (torch.float, torch.double): | ||
mock_model = MockModel( | ||
MockPosterior(mean=torch.tensor([1.0]), variance=torch.tensor([1.0])) | ||
) | ||
init_point = torch.tensor([0.5, 0.5, 0.5], device=self.device, dtype=dtype) | ||
groups = [[0, 2], [1]] | ||
raw_acqf = ExpectedImprovement(model=mock_model, best_f=1.0) | ||
penalty = GroupLassoPenalty(init_point=init_point, groups=groups) | ||
lmbda = 0.1 | ||
acqf = PenalizedAcquisitionFunction( | ||
raw_acqf=raw_acqf, penalty_func=penalty, regularization_parameter=lmbda | ||
) | ||
|
||
sample_point = torch.tensor( | ||
[[1.0, 2.0, 3.0]], device=self.device, dtype=dtype | ||
) | ||
raw_value = raw_acqf(sample_point) | ||
penalty_value = penalty(sample_point) | ||
real_value = raw_value - lmbda * penalty_value | ||
computed_value = acqf(sample_point) | ||
self.assertTrue(torch.equal(real_value, computed_value)) | ||
|
||
# testing X_pending for analytic raw_acqfn (EI) | ||
X_pending = torch.tensor([0.1, 0.2, 0.3], device=self.device, dtype=dtype) | ||
with self.assertRaises(UnsupportedError): | ||
acqf.set_X_pending(X_pending) | ||
|
||
# testing X_pending for non-analytic raw_acqfn (EI) | ||
sampler = IIDNormalSampler(num_samples=2) | ||
raw_acqf_2 = qExpectedImprovement( | ||
model=mock_model, best_f=0, sampler=sampler | ||
) | ||
init_point = torch.tensor([1.0, 1.0, 1.0], device=self.device, dtype=dtype) | ||
l2_module = L2Penalty(init_point=init_point) | ||
acqf_2 = PenalizedAcquisitionFunction( | ||
raw_acqf=raw_acqf_2, | ||
penalty_func=l2_module, | ||
regularization_parameter=lmbda, | ||
) | ||
|
||
X_pending = torch.tensor([0.1, 0.2, 0.3], device=self.device, dtype=dtype) | ||
acqf_2.set_X_pending(X_pending) | ||
self.assertTrue(torch.equal(acqf_2.X_pending, X_pending)) |