|
22 | 22 | from botorch.acquisition.knowledge_gradient import qKnowledgeGradient
|
23 | 23 | from botorch.exceptions import InputDataError, UnsupportedError
|
24 | 24 | from botorch.exceptions.warnings import OptimizationWarning
|
25 |
| -from botorch.generation.gen import gen_candidates_scipy |
| 25 | +from botorch.generation.gen import gen_candidates_scipy, TGenCandidates |
26 | 26 | from botorch.logging import logger
|
27 | 27 | from botorch.optim.initializers import (
|
28 | 28 | gen_batch_initial_conditions,
|
@@ -64,6 +64,7 @@ def optimize_acqf(
|
64 | 64 | post_processing_func: Optional[Callable[[Tensor], Tensor]] = None,
|
65 | 65 | batch_initial_conditions: Optional[Tensor] = None,
|
66 | 66 | return_best_only: bool = True,
|
| 67 | + gen_candidates: TGenCandidates = gen_candidates_scipy, |
67 | 68 | sequential: bool = False,
|
68 | 69 | **kwargs: Any,
|
69 | 70 | ) -> Tuple[Tensor, Tensor]:
|
@@ -103,6 +104,8 @@ def optimize_acqf(
|
103 | 104 | this if you do not want to use default initialization strategy.
|
104 | 105 | return_best_only: If False, outputs the solutions corresponding to all
|
105 | 106 | random restart initializations of the optimization.
|
| 107 | + gen_candidates: A callable for generating candidates given initial |
| 108 | + conditions. Default: `gen_candidates_scipy` |
106 | 109 | sequential: If False, uses joint optimization, otherwise uses sequential
|
107 | 110 | optimization.
|
108 | 111 | kwargs: Additonal keyword arguments.
|
@@ -258,23 +261,23 @@ def _optimize_batch_candidates() -> Tuple[Tensor, Tensor, List[Warning]]:
|
258 | 261 | batched_ics = batch_initial_conditions.split(batch_limit)
|
259 | 262 | opt_warnings = []
|
260 | 263 |
|
261 |
| - scipy_kws = dict( |
262 |
| - acquisition_function=acq_function, |
263 |
| - lower_bounds=None if bounds[0].isinf().all() else bounds[0], |
264 |
| - upper_bounds=None if bounds[1].isinf().all() else bounds[1], |
265 |
| - options={k: v for k, v in options.items() if k not in INIT_OPTION_KEYS}, |
266 |
| - inequality_constraints=inequality_constraints, |
267 |
| - equality_constraints=equality_constraints, |
268 |
| - nonlinear_inequality_constraints=nonlinear_inequality_constraints, |
269 |
| - fixed_features=fixed_features, |
270 |
| - ) |
| 264 | + gen_kwargs = { |
| 265 | + "acquisition_function": acq_function, |
| 266 | + "lower_bounds": None if bounds[0].isinf().all() else bounds[0], |
| 267 | + "upper_bounds": None if bounds[1].isinf().all() else bounds[1], |
| 268 | + "options": {k: v for k, v in options.items() if k not in INIT_OPTION_KEYS}, |
| 269 | + "fixed_features": fixed_features, |
| 270 | + "inequality_constraints": inequality_constraints, |
| 271 | + "equality_constraints": equality_constraints, |
| 272 | + "nonlinear_inequality_constraints": nonlinear_inequality_constraints, |
| 273 | + } |
271 | 274 |
|
272 | 275 | for i, batched_ics_ in enumerate(batched_ics):
|
273 | 276 | # optimize using random restart optimization
|
274 | 277 | with warnings.catch_warnings(record=True) as ws:
|
275 | 278 | warnings.simplefilter("always", category=OptimizationWarning)
|
276 |
| - batch_candidates_curr, batch_acq_values_curr = gen_candidates_scipy( |
277 |
| - initial_conditions=batched_ics_, **scipy_kws |
| 279 | + batch_candidates_curr, batch_acq_values_curr = gen_candidates( |
| 280 | + initial_conditions=batched_ics_, **gen_kwargs |
278 | 281 | )
|
279 | 282 | opt_warnings += ws
|
280 | 283 | batch_candidates_list.append(batch_candidates_curr)
|
|
0 commit comments