Skip to content

Commit c296802

Browse files
sdaultonfacebook-github-bot
authored andcommitted
pass gen_candidates callable in optimize_acqf
Summary: see title. This will support using stochastic optimization Differential Revision: D41629164 fbshipit-source-id: 8c6e6f46f7f605ba5bb18291880915fe23dc2f1a
1 parent 44e51c6 commit c296802

File tree

2 files changed

+20
-13
lines changed

2 files changed

+20
-13
lines changed

botorch/generation/gen.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636

3737
logger = _get_logger()
3838

39+
TGenCandidates = Callable[[Tensor, AcquisitionFunction, Any], Tuple[Tensor, Tensor]]
40+
3941

4042
def gen_candidates_scipy(
4143
initial_conditions: Tensor,
@@ -47,6 +49,7 @@ def gen_candidates_scipy(
4749
nonlinear_inequality_constraints: Optional[List[Callable]] = None,
4850
options: Optional[Dict[str, Any]] = None,
4951
fixed_features: Optional[Dict[int, Optional[float]]] = None,
52+
**kwargs,
5053
) -> Tuple[Tensor, Tensor]:
5154
r"""Generate a set of candidates using `scipy.optimize.minimize`.
5255
@@ -253,6 +256,7 @@ def gen_candidates_torch(
253256
options: Optional[Dict[str, Union[float, str]]] = None,
254257
callback: Optional[Callable[[int, Tensor, Tensor], NoReturn]] = None,
255258
fixed_features: Optional[Dict[int, Optional[float]]] = None,
259+
**kwargs,
256260
) -> Tuple[Tensor, Tensor]:
257261
r"""Generate a set of candidates using a `torch.optim` optimizer.
258262

botorch/optim/optimize.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
2323
from botorch.exceptions import InputDataError, UnsupportedError
2424
from botorch.exceptions.warnings import OptimizationWarning
25-
from botorch.generation.gen import gen_candidates_scipy
25+
from botorch.generation.gen import gen_candidates_scipy, TGenCandidates
2626
from botorch.logging import logger
2727
from botorch.optim.initializers import (
2828
gen_batch_initial_conditions,
@@ -64,6 +64,7 @@ def optimize_acqf(
6464
post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
6565
batch_initial_conditions: Optional[Tensor] = None,
6666
return_best_only: bool = True,
67+
gen_candidates: TGenCandidates = gen_candidates_scipy,
6768
sequential: bool = False,
6869
**kwargs: Any,
6970
) -> Tuple[Tensor, Tensor]:
@@ -103,6 +104,8 @@ def optimize_acqf(
103104
this if you do not want to use default initialization strategy.
104105
return_best_only: If False, outputs the solutions corresponding to all
105106
random restart initializations of the optimization.
107+
gen_candidates: A callable for generating candidates given initial
108+
conditions. Default: `gen_candidates_scipy`
106109
sequential: If False, uses joint optimization, otherwise uses sequential
107110
optimization.
108111
kwargs: Additonal keyword arguments.
@@ -258,23 +261,23 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
258261
batched_ics = batch_initial_conditions.split(batch_limit)
259262
opt_warnings = []
260263

261-
scipy_kws = dict(
262-
acquisition_function=acq_function,
263-
lower_bounds=None if bounds[0].isinf().all() else bounds[0],
264-
upper_bounds=None if bounds[1].isinf().all() else bounds[1],
265-
options={k: v for k, v in options.items() if k not in INIT_OPTION_KEYS},
266-
inequality_constraints=inequality_constraints,
267-
equality_constraints=equality_constraints,
268-
nonlinear_inequality_constraints=nonlinear_inequality_constraints,
269-
fixed_features=fixed_features,
270-
)
264+
gen_kwargs = {
265+
"acquisition_function": acq_function,
266+
"lower_bounds": None if bounds[0].isinf().all() else bounds[0],
267+
"upper_bounds": None if bounds[1].isinf().all() else bounds[1],
268+
"options": {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS},
269+
"fixed_features": fixed_features,
270+
"inequality_constraints": inequality_constraints,
271+
"equality_constraints": equality_constraints,
272+
"nonlinear_inequality_constraints": nonlinear_inequality_constraints,
273+
}
271274

272275
for i, batched_ics_ in enumerate(batched_ics):
273276
# optimize using random restart optimization
274277
with warnings.catch_warnings(record=True) as ws:
275278
warnings.simplefilter("always", category=OptimizationWarning)
276-
batch_candidates_curr, batch_acq_values_curr = gen_candidates_scipy(
277-
initial_conditions=batched_ics_, **scipy_kws
279+
batch_candidates_curr, batch_acq_values_curr = gen_candidates(
280+
initial_conditions=batched_ics_, **gen_kwargs
278281
)
279282
opt_warnings += ws
280283
batch_candidates_list.append(batch_candidates_curr)

0 commit comments

Comments
 (0)