Skip to content

Commit

Permalink
Return weights from Acquisition.optimize (#2314)
Browse files Browse the repository at this point in the history
Summary:

see title. This allows for controlling the relative allocation per arm within the Acquisition.

Differential Revision: D54912788
  • Loading branch information
sdaulton authored and facebook-github-bot committed Apr 4, 2024
1 parent 032a4f0 commit 60be1fe
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 36 deletions.
23 changes: 15 additions & 8 deletions ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def optimize(
fixed_features: Optional[Dict[int, float]] = None,
rounding_func: Optional[Callable[[Tensor], Tensor]] = None,
optimizer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[Tensor, Tensor]:
) -> Tuple[Tensor, Tensor, Tensor]:
"""Generate a set of candidates via multi-start optimization. Obtains
candidates and their associated acquisition function values.
Expand All @@ -409,8 +409,9 @@ def optimize(
down these options while constructing a generation strategy.
Returns:
A two-element tuple containing an `n x d`-dim tensor of generated candidates
and a tensor with the associated acquisition value.
A three-element tuple containing an `n x d`-dim tensor of generated
candidates, a tensor with the associated acquisition values, and a tensor
with the weight for each candidate.
"""
# NOTE: Could make use of `optimizer_class` when it's added to BoTorch
# instead of calling `optimizer_acqf` or `optimize_acqf_discrete` etc.
Expand All @@ -434,13 +435,15 @@ def optimize(
for i in fixed_features:
if not 0 <= i < len(ssd.feature_names):
raise ValueError(f"Invalid fixed_feature index: {i}")

# Return a weight of 1 for each arm by default. This can be
# customized in subclasses if necessary.
arm_weights = torch.ones(n, dtype=self.dtype)
# 1. Handle the fully continuous search space.
if (
optimizer_options_with_defaults.pop("force_use_optimize_acqf", False)
or not discrete_features
):
return optimize_acqf(
candidates, acqf_values = optimize_acqf(
acq_function=self.acqf,
bounds=bounds,
q=n,
Expand All @@ -449,6 +452,7 @@ def optimize(
post_processing_func=post_processing_func,
**optimizer_options_with_defaults,
)
return candidates, acqf_values, arm_weights

# 2. Handle search spaces with discrete features.
discrete_choices = mk_discrete_choices(ssd=ssd, fixed_features=fixed_features)
Expand All @@ -468,14 +472,15 @@ def optimize(
torch.tensor(c, device=self.device, dtype=self.dtype)
for c in discrete_choices.values()
]
return optimize_acqf_discrete_local_search(
candidates, acqf_values = optimize_acqf_discrete_local_search(
acq_function=self.acqf,
q=n,
discrete_choices=discrete_choices,
inequality_constraints=inequality_constraints,
X_avoid=X_observed,
**optimizer_options_with_defaults,
)
return candidates, acqf_values, arm_weights

# Enumerate all possible choices
all_choices = (discrete_choices[i] for i in range(len(discrete_choices)))
Expand Down Expand Up @@ -520,12 +525,13 @@ def optimize(
optimizer_options=optimizer_options,
optimizer_is_discrete=True,
)
return optimize_acqf_discrete(
candidates, acqf_values = optimize_acqf_discrete(
acq_function=self.acqf, q=n, choices=all_choices, **discrete_opt_options
)
return candidates, acqf_values, arm_weights

# 2b. Handle mixed search spaces that have discrete and continuous features.
return optimize_acqf_mixed(
candidates, acqf_values = optimize_acqf_mixed(
acq_function=self.acqf,
bounds=bounds,
q=n,
Expand All @@ -539,6 +545,7 @@ def optimize(
post_processing_func=post_processing_func,
**optimizer_options_with_defaults,
)
return candidates, acqf_values, arm_weights

def evaluate(self, X: Tensor) -> Tensor:
"""Evaluate the acquisition function on the candidate set `X`.
Expand Down
4 changes: 2 additions & 2 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def gen(
acq_options=acq_options,
)
botorch_rounding_func = get_rounding_func(torch_opt_config.rounding_func)
candidates, expected_acquisition_value = acqf.optimize(
candidates, expected_acquisition_value, weights = acqf.optimize(
n=n,
search_space_digest=search_space_digest,
inequality_constraints=_to_inequality_constraints(
Expand All @@ -444,7 +444,7 @@ def gen(
)
return TorchGenResults(
points=candidates.detach().cpu(),
weights=torch.ones(n, dtype=self.dtype),
weights=weights,
gen_metadata=gen_metadata,
)

Expand Down
33 changes: 22 additions & 11 deletions ax/models/torch/botorch_modular/sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def optimize(
fixed_features: Optional[Dict[int, float]] = None,
rounding_func: Optional[Callable[[Tensor], Tensor]] = None,
optimizer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[Tensor, Tensor]:
) -> Tuple[Tensor, Tensor, Tensor]:
"""Generate a set of candidates via multi-start optimization. Obtains
candidates and their associated acquisition function values.
Expand All @@ -232,23 +232,30 @@ def optimize(
transformations).
optimizer_options: Options for the optimizer function, e.g. ``sequential``
or ``raw_samples``.
Returns:
A three-element tuple containing an `n x d`-dim tensor of generated
candidates, a tensor with the associated acquisition values, and a tensor
with the weight for each candidate.
"""
if self.penalty_name == "L0_norm":
if inequality_constraints is not None:
raise NotImplementedError(
"Homotopy does not support optimization with inequality "
+ "constraints. Use L1 penalty norm instead."
)
candidates, expected_acquisition_value = self._optimize_with_homotopy(
n=n,
search_space_digest=search_space_digest,
fixed_features=fixed_features,
rounding_func=rounding_func,
optimizer_options=optimizer_options,
candidates, expected_acquisition_value, weights = (
self._optimize_with_homotopy(
n=n,
search_space_digest=search_space_digest,
fixed_features=fixed_features,
rounding_func=rounding_func,
optimizer_options=optimizer_options,
)
)
else:
# if L1 norm use standard moo-opt
candidates, expected_acquisition_value = super().optimize(
candidates, expected_acquisition_value, weights = super().optimize(
n=n,
search_space_digest=search_space_digest,
inequality_constraints=inequality_constraints,
Expand All @@ -265,7 +272,7 @@ def optimize(
device=self.device,
dtype=self.dtype,
)
return candidates, expected_acquisition_value
return candidates, expected_acquisition_value, weights

def _optimize_with_homotopy(
self,
Expand All @@ -274,7 +281,7 @@ def _optimize_with_homotopy(
fixed_features: Optional[Dict[int, float]] = None,
rounding_func: Optional[Callable[[Tensor], Tensor]] = None,
optimizer_options: Optional[Dict[str, Any]] = None,
) -> Tuple[Tensor, Tensor]:
) -> Tuple[Tensor, Tensor, Tensor]:
"""Optimize SEBO ACQF with L0 norm using homotopy."""
# extend to fixed a no homotopy_schedule schedule
_tensorize = partial(torch.tensor, dtype=self.dtype, device=self.device)
Expand Down Expand Up @@ -346,7 +353,11 @@ def callback(): # pyre-ignore
batch_initial_conditions=batch_initial_conditions,
)

return candidates, expected_acquisition_value
return (
candidates,
expected_acquisition_value,
torch.ones(n, dtype=candidates.dtype),
)


def L1_norm_func(X: Tensor, init_point: Tensor) -> Tensor:
Expand Down
2 changes: 1 addition & 1 deletion ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,7 +670,7 @@ def best_out_of_sample_point(
torch_opt_config=torch_opt_config,
options=acqf_options,
)
candidates, acqf_values = acqf.optimize(
candidates, acqf_values, _ = acqf.optimize(
n=1,
search_space_digest=search_space_digest,
inequality_constraints=_to_inequality_constraints(
Expand Down
27 changes: 19 additions & 8 deletions ax/models/torch/tests/test_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ def test_init_with_subset_model_false(
outcome_constraints=self.outcome_constraints
)

@mock.patch(f"{ACQUISITION_PATH}.optimize_acqf")
@mock.patch(f"{ACQUISITION_PATH}.optimize_acqf", return_value=(Mock(), Mock()))
def test_optimize(self, mock_optimize_acqf: Mock) -> None:
acquisition = self.get_acquisition_function(fixed_features=self.fixed_features)
acquisition.optimize(
Expand Down Expand Up @@ -421,7 +421,7 @@ def test_optimize_discrete(self) -> None:
# 2 candidates have acqf value 8, but [1, 3, 4] is pending and thus should
# not be selected. [2, 3, 4] is the best point, but has already been picked
acquisition = self.get_acquisition_function()
X_selected, _ = acquisition.optimize(
X_selected, _, weights = acquisition.optimize(
n=2,
search_space_digest=ssd1,
rounding_func=self.rounding_func,
Expand All @@ -431,6 +431,7 @@ def test_optimize_discrete(self) -> None:
self.assertTrue(
all((x.unsqueeze(0) == expected).all(dim=-1).any() for x in X_selected)
)
self.assertTrue(torch.equal(weights, torch.ones(2)))
# check with fixed feature
# Since parameter 1 is fixed to 2, the best 3 candidates are
# [4, 2, 4], [3, 2, 4], [4, 2, 3]
Expand All @@ -444,7 +445,7 @@ def test_optimize_discrete(self) -> None:
# int]]]` but got `Dict[int, List[int]]`.
discrete_choices={k: [0, 1, 2, 3, 4] for k in range(3)},
)
X_selected, _ = acquisition.optimize(
X_selected, _, weights = acquisition.optimize(
n=3,
search_space_digest=ssd2,
fixed_features=self.fixed_features,
Expand All @@ -455,9 +456,10 @@ def test_optimize_discrete(self) -> None:
self.assertTrue(
all((x.unsqueeze(0) == expected).all(dim=-1).any() for x in X_selected)
)
self.assertTrue(torch.equal(weights, torch.ones(3)))
# check with a constraint that -1 * x[0] -1 * x[1] >= 0 which should make
# [0, 0, 4] the best candidate.
X_selected, _ = acquisition.optimize(
X_selected, _, weights = acquisition.optimize(
n=1,
search_space_digest=ssd2,
rounding_func=self.rounding_func,
Expand All @@ -467,8 +469,9 @@ def test_optimize_discrete(self) -> None:
)
expected = torch.tensor([[0, 0, 4]]).to(self.X)
self.assertTrue(torch.equal(expected, X_selected))
self.assertTrue(torch.equal(weights, torch.tensor([1.0], dtype=self.X.dtype)))
# Same thing but use two constraints instead
X_selected, _ = acquisition.optimize(
X_selected, _, weights = acquisition.optimize(
n=1,
search_space_digest=ssd2,
rounding_func=self.rounding_func,
Expand All @@ -479,8 +482,12 @@ def test_optimize_discrete(self) -> None:
)
expected = torch.tensor([[0, 0, 4]]).to(self.X)
self.assertTrue(torch.equal(expected, X_selected))
self.assertTrue(torch.equal(weights, torch.tensor([1.0])))

@mock.patch(f"{ACQUISITION_PATH}.optimize_acqf_discrete_local_search")
@mock.patch(
f"{ACQUISITION_PATH}.optimize_acqf_discrete_local_search",
return_value=(Mock(), Mock()),
)
def test_optimize_acqf_discrete_local_search(
self,
mock_optimize_acqf_discrete_local_search: Mock,
Expand Down Expand Up @@ -524,7 +531,9 @@ def test_optimize_acqf_discrete_local_search(
all((X_avoid_true == x).all(dim=-1).any().item() for x in kwargs["X_avoid"])
)

@mock.patch(f"{ACQUISITION_PATH}.optimize_acqf_mixed")
@mock.patch(
f"{ACQUISITION_PATH}.optimize_acqf_mixed", return_value=(Mock(), Mock())
)
def test_optimize_mixed(self, mock_optimize_acqf_mixed: Mock) -> None:
tkwargs = {"dtype": self.X.dtype, "device": self.X.device}
ssd = SearchSpaceDigest(
Expand Down Expand Up @@ -564,7 +573,9 @@ def test_optimize_mixed(self, mock_optimize_acqf_mixed: Mock) -> None:
mock_optimize_acqf_mixed.reset_mock()
optimizer_options = self.optimizer_options.copy()
optimizer_options["force_use_optimize_acqf"] = True
with mock.patch(f"{ACQUISITION_PATH}.optimize_acqf") as mock_optimize_acqf:
with mock.patch(
f"{ACQUISITION_PATH}.optimize_acqf", return_value=(Mock(), Mock())
) as mock_optimize_acqf:
acquisition.optimize(
n=3,
search_space_digest=ssd,
Expand Down
7 changes: 4 additions & 3 deletions ax/models/torch/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,8 +482,9 @@ def _test_gen(
input_constructor=mock_input_constructor,
)
mock_optimize.return_value = (
torch.tensor([1.0]),
torch.tensor([[1.0]]),
torch.tensor([2.0]),
torch.tensor([1.0]),
)
surrogate = Surrogate(botorch_model_class=botorch_model_class)
model = BoTorchModel(
Expand Down Expand Up @@ -824,8 +825,8 @@ def test_model_list_choice(self, _) -> None: # , mock_extract_training_data):

@mock.patch(
f"{ACQUISITION_PATH}.Acquisition.optimize",
# Dummy candidates and acquisition function value.
return_value=(torch.tensor([[2.0]]), torch.tensor([1.0])),
# Dummy candidates, acquisition value, and weights
return_value=(torch.tensor([[2.0]]), torch.tensor([1.0]), torch.tensor([1.0])),
)
def test_MOO(self, _) -> None:
# Add mock for qLogNEHVI input constructor to catch arguments passed to it.
Expand Down
3 changes: 2 additions & 1 deletion ax/models/torch/tests/test_sebo.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def test_optimize_l0_homotopy(
feature_names=["a"],
bounds=[(-10.0, 5.0)],
)
candidate, acqf_val = acquisition._optimize_with_homotopy(
candidate, acqf_val, weights = acquisition._optimize_with_homotopy(
n=1,
search_space_digest=search_space_digest,
optimizer_options={
Expand All @@ -255,6 +255,7 @@ def test_optimize_l0_homotopy(
)
self.assertEqual(candidate, torch.zeros(1, **tkwargs))
self.assertEqual(acqf_val, 5 * torch.ones(1, **tkwargs))
self.assertEqual(weights, torch.ones(1, **tkwargs))

@mock.patch(f"{SEBOACQUISITION_PATH}.optimize_acqf_homotopy")
def test_optimize_l0(self, mock_optimize_acqf_homotopy: Mock) -> None:
Expand Down
8 changes: 6 additions & 2 deletions ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,11 @@ def test_best_in_sample_point(self) -> None:
@patch(f"{ACQUISITION_PATH}.Acquisition.__init__", return_value=None)
@patch(
f"{ACQUISITION_PATH}.Acquisition.optimize",
return_value=([torch.tensor([0.0])], [torch.tensor([1.0])]),
return_value=(
torch.tensor([[0.0]]),
torch.tensor([1.0]),
torch.tensor([1.0]),
),
)
@patch(
f"{SURROGATE_PATH}.pick_best_out_of_sample_point_acqf_class",
Expand Down Expand Up @@ -615,7 +619,7 @@ def test_best_out_of_sample_point(
options={Keys.SAMPLER: SobolQMCNormalSampler},
)
self.assertTrue(torch.equal(candidate, torch.tensor([0.0])))
self.assertTrue(torch.equal(acqf_value, torch.tensor([1.0])))
self.assertTrue(torch.equal(acqf_value, torch.tensor(1.0)))

def test_serialize_attributes_as_kwargs(self) -> None:
for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]:
Expand Down

0 comments on commit 60be1fe

Please sign in to comment.