Skip to content

Commit 08ead8f

Browse files
sdaultonfacebook-github-bot
authored andcommitted
qLowerConfidenceBound (#2517)
Summary: Pull Request resolved: #2517 Implement a qLowerConfidence acquisition function for more confident/risk-averse candidate selection. Reviewed By: SebastianAment Differential Revision: D60624931
1 parent 33e11f4 commit 08ead8f

File tree

4 files changed

+138
-20
lines changed

4 files changed

+138
-20
lines changed

botorch/acquisition/input_constructors.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import inspect
1515
from collections.abc import Hashable, Iterable, Sequence
16-
from typing import Any, Callable, Optional, TypeVar, Union
16+
from typing import Any, Callable, List, Optional, TypeVar, Union
1717

1818
import torch
1919
from botorch.acquisition.acquisition import AcquisitionFunction
@@ -50,6 +50,7 @@
5050
)
5151
from botorch.acquisition.monte_carlo import (
5252
qExpectedImprovement,
53+
qLowerConfidenceBound,
5354
qNoisyExpectedImprovement,
5455
qProbabilityOfImprovement,
5556
qSimpleRegret,
@@ -767,13 +768,15 @@ def construct_inputs_qPI(
767768
}
768769

769770

770-
@acqf_input_constructor(qUpperConfidenceBound)
771+
@acqf_input_constructor(qLowerConfidenceBound, qUpperConfidenceBound)
771772
def construct_inputs_qUCB(
772773
model: Model,
773774
objective: Optional[MCAcquisitionObjective] = None,
774775
posterior_transform: Optional[PosteriorTransform] = None,
775776
X_pending: Optional[Tensor] = None,
776777
sampler: Optional[MCSampler] = None,
778+
X_baseline: Optional[Tensor] = None,
779+
constraints: Optional[List[Callable[[Tensor], Tensor]]] = None,
777780
beta: float = 0.2,
778781
) -> dict[str, Any]:
779782
r"""Construct kwargs for the `qUpperConfidenceBound` constructor.
@@ -788,11 +791,30 @@ def construct_inputs_qUCB(
788791
Concatenated into X upon forward call.
789792
sampler: The sampler used to draw base samples. If omitted, uses
790793
the acquisition functions's default sampler.
794+
X_baseline: A `batch_shape x r x d`-dim Tensor of `r` design points
795+
that have already been observed. These points are used to
796+
compute with infeasible cost when there are constraints.
797+
constraints: A list of constraint callables which map a Tensor of posterior
798+
samples of dimension `sample_shape x batch-shape x q x m`-dim to a
799+
`sample_shape x batch-shape x q`-dim Tensor. The associated constraints
800+
are considered satisfied if the output is less than zero.
791801
beta: Controls tradeoff between mean and standard deviation in UCB.
792802
793803
Returns:
794804
A dict mapping kwarg names of the constructor to values.
795805
"""
806+
if constraints is not None:
807+
if X_baseline is None:
808+
raise ValueError("Constraints require an X_baseline.")
809+
if objective is None:
810+
objective = IdentityMCObjective()
811+
objective = ConstrainedMCObjective(
812+
objective=objective,
813+
constraints=constraints,
814+
infeasible_cost=get_infeasible_cost(
815+
X=X_baseline, model=model, objective=objective
816+
),
817+
)
796818
return {
797819
"model": model,
798820
"objective": objective,

botorch/acquisition/monte_carlo.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -856,7 +856,10 @@ def __init__(
856856
posterior_transform=posterior_transform,
857857
X_pending=X_pending,
858858
)
859-
self.beta_prime = math.sqrt(beta * math.pi / 2)
859+
self.beta_prime = self._get_beta_prime(beta=beta)
860+
861+
def _get_beta_prime(self, beta: float) -> float:
862+
return math.sqrt(beta * math.pi / 2)
860863

861864
def _sample_forward(self, obj: Tensor) -> Tensor:
862865
r"""Evaluate qUpperConfidenceBound per sample on the candidate set `X`.
@@ -869,3 +872,17 @@ def _sample_forward(self, obj: Tensor) -> Tensor:
869872
"""
870873
mean = obj.mean(dim=0)
871874
return mean + self.beta_prime * (obj - mean).abs()
875+
876+
877+
class qLowerConfidenceBound(qUpperConfidenceBound):
878+
r"""MC-based batched lower confidence bound.
879+
880+
This acquisition function is useful for confident/risk-averse decision making.
881+
This acquisition function is intended to be maximized as with qUpperConfidenceBound,
882+
but the qLowerConfidenceBound will be pessimistic in the face of uncertainty and
883+
lead to conservative candidates.
884+
"""
885+
886+
def _get_beta_prime(self, beta: float) -> float:
887+
"""Multiply beta prime by -1 to get the lower confidence bound."""
888+
return -super()._get_beta_prime(beta=beta)

test/acquisition/test_input_constructors.py

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
)
6262
from botorch.acquisition.monte_carlo import (
6363
qExpectedImprovement,
64+
qLowerConfidenceBound,
6465
qNoisyExpectedImprovement,
6566
qProbabilityOfImprovement,
6667
qSimpleRegret,
@@ -86,6 +87,7 @@
8687
from botorch.acquisition.multi_objective.utils import get_default_partitioning_alpha
8788
from botorch.acquisition.objective import (
8889
ConstrainedMCObjective,
90+
IdentityMCObjective,
8991
LinearMCObjective,
9092
ScalarizedPosteriorTransform,
9193
)
@@ -754,34 +756,86 @@ def test_construct_inputs_qPI(self) -> None:
754756
self.assertIs(acqf.model, mock_model)
755757
self.assertIs(acqf.objective, objective)
756758

757-
def test_construct_inputs_qUCB(self) -> None:
758-
c = get_acqf_input_constructor(qUpperConfidenceBound)
759+
760+
class TestQUpperConfidenceBoundInputConstructor(InputConstructorBaseTestCase):
761+
acqf_class = qUpperConfidenceBound
762+
763+
def setUp(self, suppress_input_warnings: bool = True) -> None:
764+
super().setUp(suppress_input_warnings=suppress_input_warnings)
765+
self.c = get_acqf_input_constructor(self.acqf_class)
766+
767+
def test_confidence_bound(self) -> None:
759768
mock_model = self.mock_model
760-
kwargs = c(model=mock_model, training_data=self.blockX_blockY)
769+
kwargs = self.c(model=mock_model, training_data=self.blockX_blockY)
761770
self.assertEqual(kwargs["model"], mock_model)
762-
self.assertIsNone(kwargs["objective"])
763771
self.assertIsNone(kwargs["X_pending"])
764772
self.assertIsNone(kwargs["sampler"])
765773
self.assertEqual(kwargs["beta"], 0.2)
766-
acqf = qUpperConfidenceBound(**kwargs)
774+
acqf = self.acqf_class(**kwargs)
767775
self.assertIs(acqf.model, mock_model)
768776

777+
def test_confidence_bound_with_objective(self) -> None:
769778
X_pending = torch.rand(2, 2)
770779
objective = LinearMCObjective(torch.rand(2))
771-
kwargs = c(
772-
model=mock_model,
780+
kwargs = self.c(
781+
model=self.mock_model,
773782
training_data=self.blockX_blockY,
774783
objective=objective,
775784
X_pending=X_pending,
776785
beta=0.1,
777786
)
778-
self.assertEqual(kwargs["model"], mock_model)
787+
self.assertEqual(kwargs["model"], self.mock_model)
779788
self.assertTrue(torch.equal(kwargs["objective"].weights, objective.weights))
780789
self.assertTrue(torch.equal(kwargs["X_pending"], X_pending))
781790
self.assertIsNone(kwargs["sampler"])
782791
self.assertEqual(kwargs["beta"], 0.1)
783-
acqf = qUpperConfidenceBound(**kwargs)
784-
self.assertIs(acqf.model, mock_model)
792+
acqf = self.acqf_class(**kwargs)
793+
self.assertIs(acqf.model, self.mock_model)
794+
795+
def test_confidence_bound_with_constraints_error(self) -> None:
796+
with self.assertRaisesRegex(ValueError, "Constraints require an X_baseline."):
797+
self.c(
798+
model=self.mock_model,
799+
training_data=self.blockX_blockY,
800+
constraints=torch.rand(2, 2),
801+
)
802+
803+
def test_confidence_bound_with_constraints(self) -> None:
804+
# these are needed for computing the infeasible cost
805+
self.mock_model._posterior._mean = torch.zeros(2, 2)
806+
self.mock_model._posterior._variance = torch.ones(2, 2)
807+
808+
X_baseline = torch.rand(2, 2)
809+
outcome_constraints = (torch.tensor([[0.0, 1.0]]), torch.tensor([[0.5]]))
810+
constraints = get_outcome_constraint_transforms(
811+
outcome_constraints=outcome_constraints
812+
)
813+
for objective in (LinearMCObjective(torch.rand(2)), None):
814+
with self.subTest(objective=objective):
815+
kwargs = self.c(
816+
model=self.mock_model,
817+
training_data=self.blockX_blockY,
818+
objective=objective,
819+
constraints=constraints,
820+
X_baseline=X_baseline,
821+
)
822+
final_objective = kwargs["objective"]
823+
self.assertIsInstance(final_objective, ConstrainedMCObjective)
824+
if objective is None:
825+
self.assertIsInstance(
826+
final_objective.objective, IdentityMCObjective
827+
)
828+
else:
829+
self.assertIs(final_objective.objective, objective)
830+
self.assertIs(final_objective.constraints, constraints)
831+
# test that we can construct the acquisition function
832+
self.acqf_class(**kwargs)
833+
834+
835+
class TestQLowerConfidenceBoundInputConstructor(
836+
TestQUpperConfidenceBoundInputConstructor
837+
):
838+
acqf_class = qLowerConfidenceBound
785839

786840

787841
class TestMultiObjectiveAcquisitionFunctionInputConstructors(

test/acquisition/test_monte_carlo.py

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import math
78
import warnings
89
from copy import deepcopy
910
from functools import partial
@@ -17,6 +18,7 @@
1718
from botorch.acquisition.monte_carlo import (
1819
MCAcquisitionFunction,
1920
qExpectedImprovement,
21+
qLowerConfidenceBound,
2022
qNoisyExpectedImprovement,
2123
qProbabilityOfImprovement,
2224
qSimpleRegret,
@@ -871,7 +873,9 @@ def test_q_simple_regret_constraints(self):
871873

872874

873875
class TestQUpperConfidenceBound(BotorchTestCase):
874-
def test_q_upper_confidence_bound(self):
876+
acqf_class = qUpperConfidenceBound
877+
878+
def test_q_confidence_bound(self):
875879
for dtype in (torch.float, torch.double):
876880
# the event shape is `b x q x t` = 1 x 1 x 1
877881
samples = torch.zeros(1, 1, 1, device=self.device, dtype=dtype)
@@ -881,13 +885,13 @@ def test_q_upper_confidence_bound(self):
881885

882886
# basic test
883887
sampler = IIDNormalSampler(sample_shape=torch.Size([2]))
884-
acqf = qUpperConfidenceBound(model=mm, beta=0.5, sampler=sampler)
888+
acqf = self.acqf_class(model=mm, beta=0.5, sampler=sampler)
885889
res = acqf(X)
886890
self.assertEqual(res.item(), 0.0)
887891

888892
# basic test
889893
sampler = IIDNormalSampler(sample_shape=torch.Size([2]), seed=12345)
890-
acqf = qUpperConfidenceBound(model=mm, beta=0.5, sampler=sampler)
894+
acqf = self.acqf_class(model=mm, beta=0.5, sampler=sampler)
891895
res = acqf(X)
892896
self.assertEqual(res.item(), 0.0)
893897
self.assertEqual(acqf.sampler.base_samples.shape, torch.Size([2, 1, 1, 1]))
@@ -924,7 +928,7 @@ def test_q_upper_confidence_bound(self):
924928
sum(issubclass(w.category, BotorchWarning) for w in ws), 1
925929
)
926930

927-
def test_q_upper_confidence_bound_batch(self):
931+
def test_q_confidence_bound_batch(self):
928932
# TODO: T41739913 Implement tests for all MCAcquisitionFunctions
929933
for dtype in (torch.float, torch.double):
930934
samples = torch.zeros(2, 2, 1, device=self.device, dtype=dtype)
@@ -935,14 +939,14 @@ def test_q_upper_confidence_bound_batch(self):
935939

936940
# test batch mode
937941
sampler = IIDNormalSampler(sample_shape=torch.Size([2]))
938-
acqf = qUpperConfidenceBound(model=mm, beta=0.5, sampler=sampler)
942+
acqf = self.acqf_class(model=mm, beta=0.5, sampler=sampler)
939943
res = acqf(X)
940944
self.assertEqual(res[0].item(), 1.0)
941945
self.assertEqual(res[1].item(), 0.0)
942946

943947
# test batch mode
944948
sampler = IIDNormalSampler(sample_shape=torch.Size([2]), seed=12345)
945-
acqf = qUpperConfidenceBound(model=mm, beta=0.5, sampler=sampler)
949+
acqf = self.acqf_class(model=mm, beta=0.5, sampler=sampler)
946950
res = acqf(X) # 1-dim batch
947951
self.assertEqual(res[0].item(), 1.0)
948952
self.assertEqual(res[1].item(), 0.0)
@@ -961,7 +965,7 @@ def test_q_upper_confidence_bound_batch(self):
961965

962966
# test batch mode, qmc
963967
sampler = SobolQMCNormalSampler(sample_shape=torch.Size([2]))
964-
acqf = qUpperConfidenceBound(model=mm, beta=0.5, sampler=sampler)
968+
acqf = self.acqf_class(model=mm, beta=0.5, sampler=sampler)
965969
res = acqf(X)
966970
self.assertEqual(res[0].item(), 1.0)
967971
self.assertEqual(res[1].item(), 0.0)
@@ -991,9 +995,30 @@ def test_q_upper_confidence_bound_batch(self):
991995
sum(issubclass(w.category, BotorchWarning) for w in ws), 1
992996
)
993997

998+
def test_beta_prime(self, negate: bool = False) -> None:
999+
acqf = self.acqf_class(
1000+
model=MockModel(
1001+
posterior=MockPosterior(
1002+
samples=torch.zeros(2, 2, 1, device=self.device, dtype=torch.double)
1003+
)
1004+
),
1005+
beta=1.96,
1006+
)
1007+
expected_value = math.sqrt(1.96 * math.pi / 2)
1008+
if negate:
1009+
expected_value *= -1
1010+
self.assertEqual(acqf.beta_prime, expected_value)
1011+
9941012
# TODO: Test different objectives (incl. constraints)
9951013

9961014

1015+
class TestQLowerConfidenceBound(TestQUpperConfidenceBound):
1016+
acqf_class = qLowerConfidenceBound
1017+
1018+
def test_beta_prime(self):
1019+
super().test_beta_prime(negate=True)
1020+
1021+
9971022
class TestMCAcquisitionFunctionWithConstraints(BotorchTestCase):
9981023
def test_mc_acquisition_function_with_constraints(self):
9991024
for dtype in (torch.float, torch.double):

0 commit comments

Comments
 (0)