Skip to content

Commit

Permalink
Homotopy explicit all kwargs (pytorch#2588)
Browse files Browse the repository at this point in the history
Summary:
This PR seeks to address: pytorch#2579

`optimize_acqf_homotopy` is a fairly minimal wrapper around `optimize_acqf` but didn't have all the constraint functionality.

This PR copies over all of the arguments that we could in principal want to use up into `optimize_acqf_homotopy`. For the time being `final_options` has been kept. The apparent bug with fixed features not being passed to the final optimization has been fixed.

a simple dict rather than `OptimizeAcqfInputs` dataclass is used to store the shared parameters.

## Related PRs

The original approach in pytorch#2580 made use of kwargs which was opposed due to being less explicit.

Pull Request resolved: pytorch#2588

Reviewed By: Balandat

Differential Revision: D64694021

Pulled By: saitcakmak

fbshipit-source-id: a10f00f0d069e411e6f12e9eafaec0eba454493d
  • Loading branch information
CompRhys authored and facebook-github-bot committed Oct 22, 2024
1 parent b9d863d commit 24f659c
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 21 deletions.
6 changes: 4 additions & 2 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,8 @@ def optimize_acqf(
functions and `gen_batch_initial_conditions` otherwise. Must be specified
for nonlinear inequality constraints.
timeout_sec: Max amount of time optimization can run for.
return_full_tree:
return_full_tree: Return the full tree of optimizers of the previous
iteration.
retry_on_optimization_warning: Whether to retry candidate generation with a new
set of initial conditions when it fails with an `OptimizationWarning`.
ic_gen_kwargs: Additional keyword arguments passed to function specified by
Expand Down Expand Up @@ -623,7 +624,8 @@ def optimize_acqf_cyclic(
functions and `gen_batch_initial_conditions` otherwise. Must be specified
for nonlinear inequality constraints.
timeout_sec: Max amount of time optimization can run for.
return_full_tree:
return_full_tree: Return the full tree of optimizers of the previous
iteration.
retry_on_optimization_warning: Whether to retry candidate generation with a new
set of initial conditions when it fails with an `OptimizationWarning`.
ic_gen_kwargs: Additional keyword arguments passed to function specified by
Expand Down
128 changes: 109 additions & 19 deletions botorch/optim/optimize_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from collections.abc import Callable

from typing import Any

import torch
from botorch.acquisition import AcquisitionFunction

from botorch.generation.gen import TGenCandidates
from botorch.optim.homotopy import Homotopy
from botorch.optim.initializers import TGenInitialConditions
from botorch.optim.optimize import optimize_acqf
from torch import Tensor

Expand Down Expand Up @@ -50,37 +57,121 @@ def optimize_acqf_homotopy(
acq_function: AcquisitionFunction,
bounds: Tensor,
q: int,
homotopy: Homotopy,
num_restarts: int,
homotopy: Homotopy,
prune_tolerance: float = 1e-4,
raw_samples: int | None = None,
fixed_features: dict[int, float] | None = None,
options: dict[str, bool | float | int | str] | None = None,
final_options: dict[str, bool | float | int | str] | None = None,
batch_initial_conditions: Tensor | None = None,
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
nonlinear_inequality_constraints: list[tuple[Callable, bool]] | None = None,
fixed_features: dict[int, float] | None = None,
post_processing_func: Callable[[Tensor], Tensor] | None = None,
prune_tolerance: float = 1e-4,
batch_initial_conditions: Tensor | None = None,
gen_candidates: TGenCandidates | None = None,
sequential: bool = False,
*,
ic_generator: TGenInitialConditions | None = None,
timeout_sec: float | None = None,
return_full_tree: bool = False,
retry_on_optimization_warning: bool = True,
**ic_gen_kwargs: Any,
) -> tuple[Tensor, Tensor]:
r"""Generate a set of candidates via multi-start optimization.
Args:
acq_function: An AcquisitionFunction.
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`.
bounds: A `2 x d` tensor of lower and upper bounds for each column of `X`
(if inequality_constraints is provided, these bounds can be -inf and
+inf, respectively).
q: The number of candidates.
homotopy: Homotopy object that will make the necessary modifications to the
problem when calling `step()`.
prune_tolerance: The minimum distance to prune candidates.
num_restarts: The number of starting points for multistart acquisition
function optimization.
raw_samples: The number of samples for initialization. This is required
if `batch_initial_conditions` is not specified.
options: Options for candidate generation in the initial step of the homotopy.
final_options: Options for candidate generation in the final step of
the homotopy.
inequality_constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an inequality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. `indices` and
`coefficients` should be torch tensors. See the docstring of
`make_scipy_linear_constraints` for an example. When q=1, or when
applying the same constraint to each candidate in the batch
(intra-point constraint), `indices` should be a 1-d tensor.
For inter-point constraints, in which the constraint is applied to the
whole batch of candidates, `indices` must be a 2-d tensor, where
in each row `indices[i] =(k_i, l_i)` the first index `k_i` corresponds
to the `k_i`-th element of the `q`-batch and the second index `l_i`
corresponds to the `l_i`-th feature of that element.
equality_constraints: A list of tuples (indices, coefficients, rhs),
with each tuple encoding an equality constraint of the form
`\sum_i (X[indices[i]] * coefficients[i]) = rhs`. See the docstring of
`make_scipy_linear_constraints` for an example.
nonlinear_inequality_constraints: A list of tuples representing the nonlinear
inequality constraints. The first element in the tuple is a callable
representing a constraint of the form `callable(x) >= 0`. In case of an
intra-point constraint, `callable()`takes in an one-dimensional tensor of
shape `d` and returns a scalar. In case of an inter-point constraint,
`callable()` takes a two dimensional tensor of shape `q x d` and again
returns a scalar. The second element is a boolean, indicating if it is an
intra-point or inter-point constraint (`True` for intra-point. `False` for
inter-point). For more information on intra-point vs inter-point
constraints, see the docstring of the `inequality_constraints` argument to
`optimize_acqf()`. The constraints will later be passed to the scipy
solver. You need to pass in `batch_initial_conditions` in this case.
Using non-linear inequality constraints also requires that `batch_limit`
is set to 1, which will be done automatically if not specified in
`options`.
fixed_features: A map `{feature_index: value}` for features that
should be fixed to a particular value during generation.
options: Options for candidate generation.
final_options: Options for candidate generation in the last homotopy step.
post_processing_func: A function that post-processes an optimization
result appropriately (i.e., according to `round-trip`
transformations).
batch_initial_conditions: A tensor to specify the initial conditions. Set
this if you do not want to use default initialization strategy.
post_processing_func: Post processing function (such as rounding or clamping)
that is applied before choosing the final candidate.
gen_candidates: A callable for generating candidates (and their associated
acquisition values) given a tensor of initial conditions and an
acquisition function. Other common inputs include lower and upper bounds
and a dictionary of options, but refer to the documentation of specific
generation functions (e.g gen_candidates_scipy and gen_candidates_torch)
for method-specific inputs. Default: `gen_candidates_scipy`
sequential: If False, uses joint optimization, otherwise uses sequential
optimization.
ic_generator: Function for generating initial conditions. Not needed when
`batch_initial_conditions` are provided. Defaults to
`gen_one_shot_kg_initial_conditions` for `qKnowledgeGradient` acquisition
functions and `gen_batch_initial_conditions` otherwise. Must be specified
for nonlinear inequality constraints.
timeout_sec: Max amount of time optimization can run for.
return_full_tree: Return the full tree of optimizers of the previous
iteration.
retry_on_optimization_warning: Whether to retry candidate generation with a new
set of initial conditions when it fails with an `OptimizationWarning`.
ic_gen_kwargs: Additional keyword arguments passed to function specified by
`ic_generator`
"""
shared_optimize_acqf_kwargs = {
"num_restarts": num_restarts,
"raw_samples": raw_samples,
"inequality_constraints": inequality_constraints,
"equality_constraints": equality_constraints,
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
"fixed_features": fixed_features,
"return_best_only": False, # False to make n_restarts persist through homotopy.
"gen_candidates": gen_candidates,
"sequential": sequential,
"ic_generator": ic_generator,
"timeout_sec": timeout_sec,
"return_full_tree": return_full_tree,
"retry_on_optimization_warning": retry_on_optimization_warning,
**ic_gen_kwargs,
}

candidate_list, acq_value_list = [], []
if q > 1:
base_X_pending = acq_function.X_pending
Expand All @@ -91,15 +182,12 @@ def optimize_acqf_homotopy(

while not homotopy.should_stop:
candidates, acq_values = optimize_acqf(
q=1,
acq_function=acq_function,
bounds=bounds,
num_restarts=num_restarts,
batch_initial_conditions=candidates,
raw_samples=raw_samples,
fixed_features=fixed_features,
return_best_only=False,
q=1,
options=options,
batch_initial_conditions=candidates,
**shared_optimize_acqf_kwargs,
)
homotopy.step()

Expand All @@ -112,26 +200,27 @@ def optimize_acqf_homotopy(

# Optimize one more time with the final options
candidates, acq_values = optimize_acqf(
q=1,
acq_function=acq_function,
bounds=bounds,
num_restarts=num_restarts,
batch_initial_conditions=candidates,
return_best_only=False,
q=1,
options=final_options,
batch_initial_conditions=candidates,
**shared_optimize_acqf_kwargs,
)

# Post-process the candidates and grab the best candidate
if post_processing_func is not None:
candidates = post_processing_func(candidates)
acq_values = acq_function(candidates)

best = torch.argmax(acq_values.view(-1), dim=0)
candidate, acq_value = candidates[best], acq_values[best]

# Keep the new candidate and update the pending points
candidate_list.append(candidate)
acq_value_list.append(acq_value)
selected_candidates = torch.cat(candidate_list, dim=-2)

if q > 1:
acq_function.set_X_pending(
torch.cat([base_X_pending, selected_candidates], dim=-2)
Expand All @@ -141,6 +230,7 @@ def optimize_acqf_homotopy(

if q > 1: # Reset acq_function to previous X_pending state
acq_function.set_X_pending(base_X_pending)

homotopy.reset() # Reset the homotopy parameters

return selected_candidates, torch.stack(acq_value_list)
23 changes: 23 additions & 0 deletions test/optim/test_homotopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,28 @@ def test_optimize_acqf_homotopy(self):
self.assertEqual(candidate.shape, torch.Size([3, 2]))
self.assertEqual(acqf_val.shape, torch.Size([3]))

# with linear constraints
constraints = [
( # X[..., 0] + X[..., 1] >= 2.
torch.tensor([0, 1], device=self.device),
torch.ones(2, device=self.device, dtype=torch.double),
2.0,
)
]

acqf = PosteriorMean(model=model)
candidate, acqf_val = optimize_acqf_homotopy(
q=1,
acq_function=acqf,
bounds=torch.tensor([[-10, -10], [5, 5]]).to(**tkwargs),
homotopy=Homotopy(homotopy_parameters=[hp]),
num_restarts=2,
raw_samples=16,
inequality_constraints=constraints,
)
self.assertEqual(candidate.shape, torch.Size([1, 2]))
self.assertGreaterEqual(candidate.sum().item(), 2.0 - 1e-6)

def test_prune_candidates(self):
tkwargs = {"device": self.device, "dtype": torch.double}
# no pruning
Expand Down Expand Up @@ -210,6 +232,7 @@ def test_optimize_acqf_homotopy_pruning(self, prune_candidates_mock):
num_restarts=4,
raw_samples=16,
post_processing_func=lambda x: x.round(),
return_full_tree=True,
)
# First time we expect to call `prune_candidates` with 4 candidates
self.assertEqual(
Expand Down

0 comments on commit 24f659c

Please sign in to comment.