Skip to content

Commit

Permalink
pass gen_candidates callable in optimize_acqf (pytorch#1655)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1655

see title. This will support using stochastic optimization

Differential Revision: https://internalfb.com/D41629164

fbshipit-source-id: 6db9499cff12e54393968246c7344af0820e5a40
  • Loading branch information
sdaulton authored and facebook-github-bot committed Feb 8, 2023
1 parent 58090d3 commit 4d4b47d
Show file tree
Hide file tree
Showing 4 changed files with 358 additions and 231 deletions.
4 changes: 2 additions & 2 deletions botorch/generation/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@

logger = _get_logger()

TGenCandidates = Callable[[Tensor, AcquisitionFunction, Any], Tuple[Tensor, Tensor]]


def gen_candidates_scipy(
initial_conditions: Tensor,
Expand Down Expand Up @@ -152,7 +154,6 @@ def gen_candidates_scipy(
clamped_candidates
)
return clamped_candidates, batch_acquisition

clamped_candidates = columnwise_clamp(
X=initial_conditions, lower=lower_bounds, upper=upper_bounds
)
Expand Down Expand Up @@ -360,7 +361,6 @@ def gen_candidates_torch(
clamped_candidates
)
return clamped_candidates, batch_acquisition

_clamp = partial(columnwise_clamp, lower=lower_bounds, upper=upper_bounds)
clamped_candidates = _clamp(initial_conditions).requires_grad_(True)
_optimizer = optimizer(params=[clamped_candidates], lr=options.get("lr", 0.025))
Expand Down
43 changes: 36 additions & 7 deletions botorch/optim/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
from botorch.exceptions import InputDataError, UnsupportedError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.generation.gen import gen_candidates_scipy
from botorch.generation.gen import gen_candidates_scipy, TGenCandidates
from botorch.logging import logger
from botorch.optim.initializers import (
gen_batch_initial_conditions,
gen_one_shot_kg_initial_conditions,
)
from botorch.optim.stopping import ExpMAStoppingCriterion
from botorch.optim.utils import _filter_kwargs
from torch import Tensor

INIT_OPTION_KEYS = {
Expand Down Expand Up @@ -64,6 +65,7 @@ def optimize_acqf(
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
batch_initial_conditions: Optional[Tensor] = None,
return_best_only: bool = True,
gen_candidates: Optional[TGenCandidates] = None,
sequential: bool = False,
**kwargs: Any,
) -> Tuple[Tensor, Tensor]:
Expand Down Expand Up @@ -103,6 +105,12 @@ def optimize_acqf(
this if you do not want to use default initialization strategy.
return_best_only: If False, outputs the solutions corresponding to all
random restart initializations of the optimization.
gen_candidates: A callable for generating candidates (and their associated
acquisition values) given a tensor of initial conditions and an
acquisition function. Other common inputs include lower and upper bounds
and a dictionary of options, but refer to the documentation of specific
generation functions (e.g gen_candidates_scipy and gen_candidates_torch)
for method-specific inputs. Default: `gen_candidates_scipy`
sequential: If False, uses joint optimization, otherwise uses sequential
optimization.
kwargs: Additonal keyword arguments.
Expand Down Expand Up @@ -134,6 +142,9 @@ def optimize_acqf(
"""
start_time: float = time.monotonic()
timeout_sec = kwargs.pop("timeout_sec", None)
# using a default of None simplifies unit testing
if gen_candidates is None:
gen_candidates = gen_candidates_scipy

if inequality_constraints is None:
if not (bounds.ndim == 2 and bounds.shape[0] == 2):
Expand Down Expand Up @@ -229,6 +240,7 @@ def optimize_acqf(
sequential=False,
ic_generator=ic_gen,
timeout_sec=timeout_sec,
gen_candidates=gen_candidates,
)

candidate_list.append(candidate)
Expand Down Expand Up @@ -277,6 +289,11 @@ def optimize_acqf(
batch_limit: int = options.get(
"batch_limit", num_restarts if not nonlinear_inequality_constraints else 1
)
has_parameter_constraints = (
inequality_constraints is not None
or equality_constraints is not None
or nonlinear_inequality_constraints is not None
)

def _optimize_batch_candidates(
timeout_sec: Optional[float],
Expand All @@ -288,24 +305,36 @@ def _optimize_batch_candidates(
if timeout_sec is not None:
timeout_sec = (timeout_sec - start_time) / len(batched_ics)

scipy_kws = {
gen_kwargs = {
"acquisition_function": acq_function,
"lower_bounds": None if bounds[0].isinf().all() else bounds[0],
"upper_bounds": None if bounds[1].isinf().all() else bounds[1],
"options": {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS},
"inequality_constraints": inequality_constraints,
"equality_constraints": equality_constraints,
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
"fixed_features": fixed_features,
"timeout_sec": timeout_sec,
}

if has_parameter_constraints:
# only add parameter constraints to gen_kwargs if they are specified
# to avoid unnecessary warnings in _filter_kwargs
gen_kwargs.update(
{
"inequality_constraints": inequality_constraints,
"equality_constraints": equality_constraints,
# the line is too long
"nonlinear_inequality_constraints": (
nonlinear_inequality_constraints
),
}
)
filtered_gen_kwargs = _filter_kwargs(gen_candidates, **gen_kwargs)

for i, batched_ics_ in enumerate(batched_ics):
# optimize using random restart optimization
with warnings.catch_warnings(record=True) as ws:
warnings.simplefilter("always", category=OptimizationWarning)
batch_candidates_curr, batch_acq_values_curr = gen_candidates_scipy(
initial_conditions=batched_ics_, **scipy_kws
batch_candidates_curr, batch_acq_values_curr = gen_candidates(
initial_conditions=batched_ics_, **filtered_gen_kwargs
)
opt_warnings += ws
batch_candidates_list.append(batch_candidates_curr)
Expand Down
8 changes: 8 additions & 0 deletions test/acquisition/test_knowledge_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
)
from botorch.acquisition.utils import project_to_sample_points
from botorch.exceptions.errors import UnsupportedError
from botorch.generation.gen import gen_candidates_scipy
from botorch.models import SingleTaskGP
from botorch.optim.optimize import optimize_acqf
from botorch.optim.utils import _filter_kwargs
from botorch.posteriors.gpytorch import GPyTorchPosterior
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior
Expand Down Expand Up @@ -593,7 +595,13 @@ def test_optimize_w_posterior_transform(self):
torch.zeros(2, n_f + 1, 2, **tkwargs),
torch.zeros(2, **tkwargs),
),
), mock.patch(
f"{optimize_acqf.__module__}._filter_kwargs",
wraps=lambda f, **kwargs: _filter_kwargs(
function=gen_candidates_scipy, **kwargs
),
):

candidate, value = optimize_acqf(
acq_function=kg,
bounds=bounds,
Expand Down
Loading

0 comments on commit 4d4b47d

Please sign in to comment.