Skip to content

Commit

Permalink
SEBO bug fixes (#2935)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2935

This fixes a bunch of bugs in SEBO

Reviewed By: bletham

Differential Revision: D60089765

fbshipit-source-id: 231db9b67a0bce7ab9a99a2c03e8c617b488a408
  • Loading branch information
David Eriksson authored and facebook-github-bot committed Oct 23, 2024
1 parent 63dedd6 commit 02c0ce5
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 199 deletions.
226 changes: 96 additions & 130 deletions ax/models/torch/botorch_modular/sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,26 @@
# pyre-strict

import functools
import warnings
from collections.abc import Callable
from copy import deepcopy
from functools import partial
from logging import Logger
from typing import Any

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import AxWarning
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch_base import TorchOptConfig
from ax.utils.common.constants import Keys
from ax.utils.common.logger import get_logger
from ax.utils.common.typeutils import not_none
from botorch.acquisition.acquisition import AcquisitionFunction
from botorch.acquisition.multi_objective.monte_carlo import (
qExpectedHypervolumeImprovement,
from botorch.acquisition.multi_objective.logei import (
qLogNoisyExpectedHypervolumeImprovement,
)
from botorch.acquisition.penalized import L0Approximation
from botorch.models.deterministic import GenericDeterministicModel
Expand All @@ -34,11 +38,12 @@
optimize_acqf_homotopy,
)
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.multi_objective.pareto import is_non_dominated
from botorch.utils.transforms import unnormalize
from torch import Tensor
from torch.quasirandom import SobolEngine

CLAMP_TOL = 0.01
CLAMP_TOL = 1e-2
logger: Logger = get_logger(__name__)


class SEBOAcquisition(Acquisition):
Expand Down Expand Up @@ -76,53 +81,58 @@ def __init__(
# construct determinsitic model for penalty term
# pyre-fixme[4]: Attribute must be annotated.
self.deterministic_model = self._construct_penalty()

surrogate_f = deepcopy(surrogate)
# we need to clamp the training data to the target point here as it may
# be slightly off due to numerical issues.
X_sparse = clamp_to_target(
X=surrogate_f.Xs[0].clone(),
target_point=self.target_point,
clamp_tol=CLAMP_TOL,
)
# update the training data in new surrogate
not_none(surrogate_f._training_data).append(
SupervisedDataset(
surrogate_f.Xs[0],
self.deterministic_model(surrogate_f.Xs[0]),
# append Yvar as zero for penalty term
Yvar=torch.zeros(surrogate_f.Xs[0].shape[0], 1, **tkwargs),
X=X_sparse,
Y=self.deterministic_model(X_sparse),
Yvar=torch.zeros(X_sparse.shape[0], 1, **tkwargs), # noiseless
feature_names=surrogate_f.training_data[0].feature_names,
outcome_names=[self.penalty_name],
)
)
# update the model in new surrogate
surrogate_f._model = ModelList(surrogate.model, self.deterministic_model)
self.det_metric_indx = -1

# update objective weights and thresholds in the torch config
# update objective weights and thresholds in the torch config
torch_opt_config_sebo = self._transform_torch_config(
torch_opt_config, **tkwargs
torch_opt_config=torch_opt_config, **tkwargs
)

# instantiate botorch_acqf_class
if not issubclass(botorch_acqf_class, qExpectedHypervolumeImprovement):
raise ValueError("botorch_acqf_class must be qEHVI to use SEBO")
# Change some options (note: we do not want to do this in-place)
if options.get("cache_root", False):
warnings.warn(
"SEBO doesn't support `cache_root=True`. Changing it to `False`.",
AxWarning,
stacklevel=3,
)
options = {**options, "cache_root": False}

# Instantiate the `botorch_acqf_class`. We need to modify `a` before doing this
# (as it controls the L0 norm approximation) since the baseline will be pruned
# when the acquisition function is created. With a=1e-6 the deterministic model
# will be numerically close to the true L0 norm and we will select the
# baseline according to the last homotopy step.
if self.penalty_name == "L0_norm":
self.deterministic_model._f.a.fill_(1e-6)
super().__init__(
surrogates={"sebo": surrogate_f},
search_space_digest=search_space_digest,
torch_opt_config=torch_opt_config_sebo,
botorch_acqf_class=botorch_acqf_class,
botorch_acqf_class=qLogNoisyExpectedHypervolumeImprovement,
options=options,
)
if not isinstance(self.acqf, qExpectedHypervolumeImprovement):
raise ValueError("botorch_acqf_class must be qEHVI to use SEBO")

# update objective threshold for deterministic model (penalty term)
self.acqf.ref_point[-1] = self.sparsity_threshold * -1
# pyre-ignore
self._objective_thresholds[-1] = self.sparsity_threshold

Y_pareto = torch.cat(
[d.Y for d in self.surrogates["sebo"].training_data],
dim=-1,
)
ind_pareto = is_non_dominated(Y_pareto * self._full_objective_weights)
# pyre-ignore
self.X_pareto = self.surrogates["sebo"].Xs[0][ind_pareto].clone()
self._objective_thresholds[-1] = self.sparsity_threshold # pyre-ignore

def _construct_penalty(self) -> GenericDeterministicModel:
"""Construct a penalty term as deterministic model to be included in
Expand Down Expand Up @@ -150,41 +160,29 @@ def _transform_torch_config(
"""Transform torch config to include penalty term (deterministic model) as
an additional outcomes in BoTorch model.
"""
# update objective weights by appending the weight (-1) of penalty term
# at the end
ow_sebo = torch.cat(
[torch_opt_config.objective_weights, torch.tensor([-1], **tkwargs)]
# update objective weights by appending the weight -1 for sparsity objective.
objective_weights_sebo = torch.cat(
[torch_opt_config.objective_weights, -torch.ones(1, **tkwargs)]
)
if torch_opt_config.outcome_constraints is not None:
# update the shape of A matrix in outcome_constraints
oc_sebo = (
torch.cat(
[
torch_opt_config.outcome_constraints[0],
torch.zeros(
# pyre-ignore
torch_opt_config.outcome_constraints[0].shape[0],
1,
**tkwargs,
),
],
dim=1,
),
torch_opt_config.outcome_constraints[1],
A, b = not_none(torch_opt_config.outcome_constraints)
outcome_constraints_sebo = (
torch.cat([A, torch.zeros(A.shape[0], 1, **tkwargs)], dim=1),
b,
)
else:
oc_sebo = None
outcome_constraints_sebo = None
if torch_opt_config.objective_thresholds is not None:
# append the sparsity threshold at the end if objective_thresholds
# is not None
ot_sebo = torch.cat(
objective_thresholds_sebo = torch.cat(
[
torch_opt_config.objective_thresholds,
torch.tensor([self.sparsity_threshold], **tkwargs),
]
)
else:
ot_sebo = None
# NOTE: The reference point will be inferred in the base class.
objective_thresholds_sebo = None

# update pending observations (if not none) by appending an obs for
# the new penalty outcome
Expand All @@ -195,9 +193,9 @@ def _transform_torch_config(
]

return TorchOptConfig(
objective_weights=ow_sebo,
outcome_constraints=oc_sebo,
objective_thresholds=ot_sebo,
objective_weights=objective_weights_sebo,
outcome_constraints=outcome_constraints_sebo,
objective_thresholds=objective_thresholds_sebo,
linear_constraints=torch_opt_config.linear_constraints,
fixed_features=torch_opt_config.fixed_features,
pending_observations=pending_observations,
Expand Down Expand Up @@ -266,12 +264,8 @@ def optimize(
)

# similar, make sure if applies to sparse dimensions only
candidates = clamp_candidates(
X=candidates,
target_point=self.target_point,
clamp_tol=CLAMP_TOL,
device=self.device,
dtype=self.dtype,
candidates = clamp_to_target(
X=candidates, target_point=self.target_point, clamp_tol=CLAMP_TOL
)
return candidates, expected_acquisition_value, weights

Expand All @@ -288,8 +282,7 @@ def _optimize_with_homotopy(
_tensorize = partial(torch.tensor, dtype=self.dtype, device=self.device)
ssd = search_space_digest
bounds = _tensorize(ssd.bounds).t()

homotopy_schedule = LogLinearHomotopySchedule(start=0.1, end=1e-3, num_steps=30)
homotopy_schedule = LogLinearHomotopySchedule(start=0.2, end=1e-3, num_steps=30)

# Prepare arguments for optimizer
optimizer_options_with_defaults = optimizer_argparse(
Expand All @@ -300,48 +293,21 @@ def _optimize_with_homotopy(
optimizer="optimize_acqf_homotopy",
)

def callback() -> None:
if (
self.acqf.cache_pending
): # If true, pending points are concatenated with X_baseline
if self.acqf._max_iep != 0:
raise ValueError(
"The maximum number of pending points (max_iep) must be 0"
)
X_baseline = self.acqf._X_baseline_and_pending.clone()
self.acqf.__init__( # pyre-ignore
X_baseline=X_baseline,
model=self.surrogates["sebo"].model,
ref_point=self.acqf.ref_point,
objective=self.acqf.objective,
)
else: # We can directly get the pending points here
X_pending = self.acqf.X_pending
self.acqf.__init__( # pyre-ignore
X_baseline=self.X_observed,
model=self.surrogates["sebo"].model,
ref_point=self.acqf.ref_point,
objective=self.acqf.objective,
)
self.acqf.set_X_pending(X_pending)

homotopy = Homotopy(
homotopy_parameters=[
HomotopyParameter(
parameter=self.deterministic_model._f.a,
schedule=homotopy_schedule,
)
],
callbacks=[callback],
)
# need to know sparse dimensions
batch_initial_conditions = get_batch_initial_conditions(
acq_function=self.acqf,
raw_samples=optimizer_options_with_defaults["raw_samples"],
X_pareto=self.X_pareto,
X_pareto=self.acqf.X_baseline,
target_point=self.target_point,
bounds=bounds,
num_restarts=optimizer_options_with_defaults["num_restarts"],
**{"device": self.device, "dtype": self.dtype},
)
candidates, expected_acquisition_value = optimize_acqf_homotopy(
q=n,
Expand All @@ -354,11 +320,10 @@ def callback() -> None:
fixed_features=fixed_features,
batch_initial_conditions=batch_initial_conditions,
)

return (
candidates,
expected_acquisition_value,
torch.ones(n, dtype=candidates.dtype),
torch.ones(n, device=candidates.device, dtype=candidates.dtype),
)


Expand All @@ -370,15 +335,17 @@ def L1_norm_func(X: Tensor, init_point: Tensor) -> Tensor:
return torch.linalg.norm((X - init_point), ord=1, dim=-1, keepdim=True)


def clamp_candidates(
X: Tensor, target_point: Tensor, clamp_tol: float, **tkwargs: Any
) -> Tensor:
"""Clamp generated candidates within the given ranges to the target point."""
clamp_mask = (X - target_point).abs() < clamp_tol
clamp_mask = clamp_mask
X[clamp_mask] = (
target_point.clone().repeat(*X.shape[:-1], 1).to(**tkwargs)[clamp_mask]
)
def clamp_to_target(X: Tensor, target_point: Tensor, clamp_tol: float) -> Tensor:
"""Clamp generated candidates within the given ranges to the target point.
Args:
X: A `batch_shape x n x d`-dim input tensor `X`.
target_point: A tensor of size `d` corresponding to the target point.
clamp_tol: The clamping tolerance. Any value within `clamp_tol` of the
`target_point` will be clamped to the `target_point`.
"""
clamp_mask = (X - target_point).abs() <= clamp_tol
X[clamp_mask] = target_point.clone().repeat(*X.shape[:-1], 1)[clamp_mask]
return X


Expand All @@ -387,36 +354,35 @@ def get_batch_initial_conditions(
raw_samples: int,
X_pareto: Tensor,
target_point: Tensor,
bounds: Tensor,
num_restarts: int = 20,
**tkwargs: Any,
) -> Tensor:
"""Generate starting points for the SEBO acquisition function optimization."""
tkwargs: dict[str, Any] = {"device": X_pareto.device, "dtype": X_pareto.dtype}
dim = X_pareto.shape[-1] # dimension
# (1) Global Sobol points
X_cand1 = SobolEngine(dimension=dim, scramble=True).draw(raw_samples).to(**tkwargs)
X_cand1 = X_cand1[
acq_function(X_cand1.unsqueeze(1)).topk(num_restarts // 5).indices
]
# (2) Global Sobol points with a Bernoulli mask
X_cand2 = SobolEngine(dimension=dim, scramble=True).draw(raw_samples).to(**tkwargs)
mask = torch.rand(X_cand2.shape, **tkwargs) < 0.5
X_cand2[mask] = target_point.repeat(len(X_cand2), 1).to(**tkwargs)[mask]
X_cand2 = X_cand2[
acq_function(X_cand2.unsqueeze(1)).topk(num_restarts // 5).indices
]
# (3) Perturbations of points on the Pareto frontier (done by TuRBO and Spearmint)
X_cand3 = X_pareto.clone()[torch.randint(high=len(X_pareto), size=(raw_samples,))]
mask = X_cand3 != target_point
X_cand3[mask] += 0.2 * torch.randn(*X_cand3.shape, **tkwargs)[mask]
X_cand3 = torch.clamp(X_cand3, min=0.0, max=1.0)
X_cand3 = X_cand3[
acq_function(X_cand3.unsqueeze(1)).topk(num_restarts // 5).indices
num_sobol, num_local = num_restarts // 2, num_restarts - num_restarts // 2
# (1) Global sparse Sobol points
X_cand_sobol = (
SobolEngine(dimension=dim, scramble=True)
.draw(raw_samples, dtype=tkwargs["dtype"])
.to(**tkwargs)
)
X_cand_sobol = unnormalize(X_cand_sobol, bounds=bounds)
acq_vals = acq_function(X_cand_sobol.unsqueeze(1))
if len(X_pareto) == 0:
return X_cand_sobol[acq_vals.topk(num_restarts).indices]

X_cand_sobol = X_cand_sobol[acq_vals.topk(num_sobol).indices]
# (2) Perturbations of points on the Pareto frontier (done by TuRBO/Spearmint)
X_cand_local = X_pareto.clone()[
torch.randint(high=len(X_pareto), size=(raw_samples,))
]
# (4) Apply a Bernoulli mask to points on the Pareto frontier
X_cand4 = X_pareto.clone()[torch.randint(high=len(X_pareto), size=(raw_samples,))]
mask = torch.rand(X_cand4.shape, **tkwargs) < 0.5
X_cand4[mask] = target_point.repeat(len(X_cand4), 1).to(**tkwargs)[mask].clone()
X_cand4 = X_cand4[
acq_function(X_cand4.unsqueeze(1)).topk(num_restarts // 5).indices
mask = X_cand_local != target_point
X_cand_local[mask] += (
0.2 * ((bounds[1] - bounds[0]) * torch.randn_like(X_cand_local))[mask]
)
X_cand_local = torch.clamp(X_cand_local, min=bounds[0], max=bounds[1])
X_cand_local = X_cand_local[
acq_function(X_cand_local.unsqueeze(1)).topk(num_local).indices
]
return torch.cat((X_cand1, X_cand2, X_cand3, X_cand4), dim=0).unsqueeze(1)
return torch.cat((X_cand_sobol, X_cand_local), dim=0).unsqueeze(1)
Loading

0 comments on commit 02c0ce5

Please sign in to comment.