Skip to content

Add Inequality Constraints to SEBO #2938

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 46 additions & 39 deletions ax/models/torch/botorch_modular/sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,15 @@
from botorch.models.deterministic import GenericDeterministicModel
from botorch.models.model import ModelList
from botorch.optim import (
gen_batch_initial_conditions,
Homotopy,
HomotopyParameter,
LogLinearHomotopySchedule,
optimize_acqf_homotopy,
)
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.transforms import unnormalize
from pyre_extensions import none_throws
from pyre_extensions import assert_is_instance, none_throws
from torch import Tensor
from torch.quasirandom import SobolEngine

CLAMP_TOL = 1e-2
logger: Logger = get_logger(__name__)
Expand Down Expand Up @@ -234,15 +233,11 @@ def optimize(
with the weight for each candidate.
"""
if self.penalty_name == "L0_norm":
if inequality_constraints is not None:
raise NotImplementedError(
"Homotopy does not support optimization with inequality "
+ "constraints. Use L1 penalty norm instead."
)
candidates, expected_acquisition_value, weights = (
self._optimize_with_homotopy(
n=n,
search_space_digest=search_space_digest,
inequality_constraints=inequality_constraints,
fixed_features=fixed_features,
rounding_func=rounding_func,
optimizer_options=optimizer_options,
Expand All @@ -269,6 +264,7 @@ def _optimize_with_homotopy(
self,
n: int,
search_space_digest: SearchSpaceDigest,
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
fixed_features: dict[int, float] | None = None,
rounding_func: Callable[[Tensor], Tensor] | None = None,
optimizer_options: dict[str, Any] | None = None,
Expand All @@ -277,8 +273,7 @@ def _optimize_with_homotopy(
optimizer_options = optimizer_options or {}
# extend to fixed a no homotopy_schedule schedule
_tensorize = partial(torch.tensor, dtype=self.dtype, device=self.device)
ssd = search_space_digest
bounds = _tensorize(ssd.bounds).t()
bounds = _tensorize(search_space_digest.bounds).t()
homotopy_schedule = LogLinearHomotopySchedule(
start=0.2,
end=1e-3,
Expand All @@ -300,26 +295,34 @@ def _optimize_with_homotopy(
)
],
)
batch_initial_conditions = get_batch_initial_conditions(
acq_function=self.acqf,
raw_samples=optimizer_options_with_defaults["raw_samples"],
# pyre-fixme[6]: For 3rd argument expected `Tensor` but got
# `Union[Tensor, Module]`.
X_pareto=self.acqf.X_baseline,
target_point=self.target_point,
bounds=bounds,
num_restarts=optimizer_options_with_defaults["num_restarts"],
)

if "batch_initial_conditions" not in optimizer_options_with_defaults:
optimizer_options_with_defaults["batch_initial_conditions"] = (
get_batch_initial_conditions(
acq_function=self.acqf,
raw_samples=optimizer_options_with_defaults["raw_samples"],
inequality_constraints=inequality_constraints,
fixed_features=fixed_features,
X_pareto=assert_is_instance(self.acqf.X_baseline, Tensor),
target_point=self.target_point,
bounds=bounds,
num_restarts=optimizer_options_with_defaults["num_restarts"],
)
)

candidates, expected_acquisition_value = optimize_acqf_homotopy(
q=n,
acq_function=self.acqf,
bounds=bounds,
homotopy=homotopy,
num_restarts=optimizer_options_with_defaults["num_restarts"],
raw_samples=optimizer_options_with_defaults["raw_samples"],
inequality_constraints=inequality_constraints,
post_processing_func=rounding_func,
fixed_features=fixed_features,
batch_initial_conditions=batch_initial_conditions,
batch_initial_conditions=optimizer_options_with_defaults[
"batch_initial_conditions"
],
)
return (
candidates,
Expand Down Expand Up @@ -357,23 +360,29 @@ def get_batch_initial_conditions(
target_point: Tensor,
bounds: Tensor,
num_restarts: int = 20,
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
fixed_features: dict[int, float] | None = None,
) -> Tensor:
"""Generate starting points for the SEBO acquisition function optimization."""
tkwargs: dict[str, Any] = {"device": X_pareto.device, "dtype": X_pareto.dtype}
dim = X_pareto.shape[-1] # dimension
num_sobol, num_local = num_restarts // 2, num_restarts - num_restarts // 2
# (1) Global sparse Sobol points
X_cand_sobol = (
SobolEngine(dimension=dim, scramble=True)
.draw(raw_samples, dtype=tkwargs["dtype"])
.to(**tkwargs)
)
X_cand_sobol = unnormalize(X_cand_sobol, bounds=bounds)
acq_vals = acq_function(X_cand_sobol.unsqueeze(1))
if len(X_pareto) == 0:
return X_cand_sobol[acq_vals.topk(num_restarts).indices]
num_rand = num_restarts if len(X_pareto) == 0 else num_restarts // 2
num_local = num_restarts - num_rand

# (1) Random points (Sobol if no constraints, otherwise uses hit-and-run)
X_cand_rand = gen_batch_initial_conditions(
acq_function=acq_function,
bounds=bounds,
q=1,
raw_samples=raw_samples,
num_restarts=num_rand,
options={"topn": True},
fixed_features=fixed_features,
inequality_constraints=inequality_constraints,
).to(**tkwargs)

if num_local == 0:
return X_cand_rand

X_cand_sobol = X_cand_sobol[acq_vals.topk(num_sobol).indices]
# (2) Perturbations of points on the Pareto frontier (done by TuRBO/Spearmint)
X_cand_local = X_pareto.clone()[
torch.randint(high=len(X_pareto), size=(raw_samples,))
Expand All @@ -382,8 +391,6 @@ def get_batch_initial_conditions(
X_cand_local[mask] += (
0.2 * ((bounds[1] - bounds[0]) * torch.randn_like(X_cand_local))[mask]
)
X_cand_local = torch.clamp(X_cand_local, min=bounds[0], max=bounds[1])
X_cand_local = X_cand_local[
acq_function(X_cand_local.unsqueeze(1)).topk(num_local).indices
]
return torch.cat((X_cand_sobol, X_cand_local), dim=0).unsqueeze(1)
X_cand_local = torch.clamp(X_cand_local.unsqueeze(1), min=bounds[0], max=bounds[1])
X_cand_local = X_cand_local[acq_function(X_cand_local).topk(num_local).indices]
return torch.cat((X_cand_rand, X_cand_local), dim=0)
65 changes: 43 additions & 22 deletions ax/models/torch/tests/test_sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,38 +294,19 @@ def test_optimize_l0(self, mock_optimize_acqf_homotopy: Mock) -> None:
acquisition2.optimize(
n=2,
search_space_digest=self.search_space_digest,
# does not support in homotopy now
# inequality_constraints=self.inequality_constraints,
inequality_constraints=self.inequality_constraints,
fixed_features=self.fixed_features,
rounding_func=self.rounding_func,
optimizer_options=self.optimizer_options,
)
args, kwargs = mock_optimize_acqf_homotopy.call_args
_args, kwargs = mock_optimize_acqf_homotopy.call_args
self.assertEqual(kwargs["acq_function"], acquisition2.acqf)
self.assertEqual(kwargs["q"], 2)
self.assertEqual(kwargs["inequality_constraints"], self.inequality_constraints)
self.assertEqual(kwargs["post_processing_func"], self.rounding_func)
self.assertEqual(kwargs["num_restarts"], self.optimizer_options["num_restarts"])
self.assertEqual(kwargs["raw_samples"], self.optimizer_options["raw_samples"])

# assert error raise with inequality_constraints input
acquisition = self.get_acquisition_function(
fixed_features=self.fixed_features,
options={"penalty": "L0_norm", "target_point": self.target_point},
)
with self.assertRaisesRegex(
NotImplementedError,
"Homotopy does not support optimization with inequality "
"constraints. Use L1 penalty norm instead.",
):
acquisition.optimize(
n=2,
search_space_digest=self.search_space_digest,
inequality_constraints=self.inequality_constraints,
fixed_features=self.fixed_features,
rounding_func=self.rounding_func,
optimizer_options=self.optimizer_options,
)

def test_optimization_result(self) -> None:
acquisition = self.get_acquisition_function(
options={"penalty": "L0_norm", "target_point": self.target_point},
Expand Down Expand Up @@ -399,3 +380,43 @@ def test_get_batch_initial_conditions(
self.assertEqual(batch_initial_conditions.shape, torch.Size([3, 1, 3]))
self.assertTrue(torch.all(batch_initial_conditions[:1] != 0.5))
self.assertTrue(torch.all(batch_initial_conditions[1:, :, 1] == 0.5))

@mock.patch(f"{SEBOACQUISITION_PATH}.optimize_acqf_homotopy")
@mock.patch(
f"{SEBOACQUISITION_PATH}.get_batch_initial_conditions",
wraps=get_batch_initial_conditions,
)
def test_optimize_with_provided_batch_initial_conditions(
self, mock_get_batch_initial_conditions: Mock, mock_optimize_acqf_homotopy: Mock
) -> None:
mock_optimize_acqf_homotopy.return_value = (
torch.tensor([[0.1, 0.1, 0.1]], dtype=torch.double),
torch.tensor([1.0], dtype=torch.double),
)

# Create batch initial conditions
batch_ics = torch.rand(3, 1, 3, dtype=torch.double)

acquisition = self.get_acquisition_function(
options={
"target_point": self.target_point,
"penalty": "L0_norm",
},
)

acquisition.optimize(
n=1,
search_space_digest=self.search_space_digest,
optimizer_options={
"batch_initial_conditions": batch_ics,
Keys.NUM_RESTARTS: 3,
Keys.RAW_SAMPLES: 32,
},
)

# Verify get_batch_initial_conditions was not called
mock_get_batch_initial_conditions.assert_not_called()

# Verify the batch_initial_conditions were passed to optimize_acqf_homotopy
call_kwargs = mock_optimize_acqf_homotopy.call_args[1]
self.assertTrue(torch.equal(call_kwargs["batch_initial_conditions"], batch_ics))
Loading