forked from pytorch/botorch
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Introduce
fast_optimize
context manager to speed up testing (pytorc…
…h#2563) Summary: Context: Many BoTorch tests take a while because they run optimization code. But it's good that this code runs rather than being avoided or mocked out, becuase the tests are ensuring that things work end-to-end. Borrowing a page from Ax's `fast_botorch_optimize`, this commit introduces the same thing to BoTorch, with the exception of `fit_fully_bayesian_model_nuts`. A future commit to Ax can remove that functionality from Ax in favor of importing it from BoTorch, but we might not want to do it right way because then Ax won't work with older versions of BoTorch. This PR: * Introduces `fast_optimize`, which is the same as Ax's `fast_botorch_optimize`, but with different import paths. * Applies it to a slow test, reducing runtime to 2s from 6s-10s. Differential Revision: D63838626
- Loading branch information
1 parent
a0a2c05
commit 91307aa
Showing
4 changed files
with
206 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
test_utils has its own top-level directory to avoid circular dependencies: | ||
Anything in 'tests/' can depend on anything in 'test_utils/', and anything in | ||
'test_utils/' can depend on anything in 'botorch/'. | ||
""" | ||
|
||
from botorch.test_utils.mock import fast_optimize | ||
|
||
__all__ = ["fast_optimize"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
""" | ||
Utilities for speeding up optimization in tests. | ||
""" | ||
|
||
from collections.abc import Generator | ||
from contextlib import contextmanager, ExitStack | ||
from functools import wraps | ||
from typing import Any, Callable | ||
from unittest import mock | ||
|
||
from botorch.optim.initializers import ( | ||
gen_batch_initial_conditions, | ||
gen_one_shot_kg_initial_conditions, | ||
) | ||
|
||
from botorch.optim.utils.timeout import minimize_with_timeout | ||
from scipy.optimize import OptimizeResult | ||
from torch import Tensor | ||
|
||
|
||
@contextmanager | ||
def fast_optimize_context_manager( | ||
force: bool = False, | ||
) -> Generator[None, None, None]: | ||
"""A context manager to force botorch to speed up optimization. Currently, the | ||
primary tactic is to force the underlying scipy methods to stop after just one | ||
iteration. | ||
force: If True will not raise an AssertionError if no mocks are called. | ||
USE RESPONSIBLY. | ||
""" | ||
|
||
def one_iteration_minimize(*args: Any, **kwargs: Any) -> OptimizeResult: | ||
if kwargs["options"] is None: | ||
kwargs["options"] = {} | ||
|
||
kwargs["options"]["maxiter"] = 1 | ||
return minimize_with_timeout(*args, **kwargs) | ||
|
||
def minimal_gen_ics(*args: Any, **kwargs: Any) -> Tensor: | ||
kwargs["num_restarts"] = 2 | ||
kwargs["raw_samples"] = 4 | ||
|
||
return gen_batch_initial_conditions(*args, **kwargs) | ||
|
||
def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Tensor | None: | ||
kwargs["num_restarts"] = 2 | ||
kwargs["raw_samples"] = 4 | ||
|
||
return gen_one_shot_kg_initial_conditions(*args, **kwargs) | ||
|
||
with ExitStack() as es: | ||
# Note this `minimize_with_timeout` is defined in optim.utils.timeout; | ||
# this mock only has an effect when calling a function used in | ||
# `botorch.generation.gen`, such as `gen_candidates_scipy`. | ||
mock_generation = es.enter_context( | ||
mock.patch( | ||
"botorch.generation.gen.minimize_with_timeout", | ||
wraps=one_iteration_minimize, | ||
) | ||
) | ||
|
||
# Similarly, works when using calling a function defined in | ||
# `optim.core`, such as `scipy_minimize` and `torch_minimize`. | ||
mock_fit = es.enter_context( | ||
mock.patch( | ||
"botorch.optim.core.minimize_with_timeout", | ||
wraps=one_iteration_minimize, | ||
) | ||
) | ||
|
||
# Works when calling a function in `optim.optimize` such as | ||
# `optimize_acqf` | ||
mock_gen_ics = es.enter_context( | ||
mock.patch( | ||
"botorch.optim.optimize.gen_batch_initial_conditions", | ||
wraps=minimal_gen_ics, | ||
) | ||
) | ||
|
||
# Works when calling a function in `optim.optimize` such as | ||
# `optimize_acqf` | ||
mock_gen_os_ics = es.enter_context( | ||
mock.patch( | ||
"botorch.optim.optimize.gen_one_shot_kg_initial_conditions", | ||
wraps=minimal_gen_os_ics, | ||
) | ||
) | ||
|
||
yield | ||
|
||
if (not force) and all( | ||
mock_.call_count < 1 | ||
for mock_ in [ | ||
mock_generation, | ||
mock_fit, | ||
mock_gen_ics, | ||
mock_gen_os_ics, | ||
] | ||
): | ||
raise AssertionError( | ||
"No mocks were called in the context manager. Please remove unused " | ||
"fast_botorch_optimize_context_manager()." | ||
) | ||
|
||
|
||
def fast_optimize(f: Callable) -> Callable: | ||
"""Wraps f in the fast_botorch_optimize_context_manager for use as a decorator.""" | ||
|
||
@wraps(f) | ||
# pyre-fixme[3]: Return type must be annotated. | ||
def inner(*args: Any, **kwargs: Any): | ||
with fast_optimize_context_manager(): | ||
return f(*args, **kwargs) | ||
|
||
return inner |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
.. role:: hidden | ||
:class: hidden-section | ||
|
||
|
||
botorch.test_utils | ||
======================================================== | ||
.. automodule:: botorch.test_utils | ||
|
||
Mock | ||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ | ||
.. automodule:: botorch.test_utils.mock | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
#!/usr/bin/env python3 | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
import torch | ||
from botorch.generation.gen import gen_candidates_scipy | ||
from botorch.optim.core import scipy_minimize | ||
from botorch.optim.optimize import optimize_acqf | ||
|
||
from botorch.test_utils.mock import fast_optimize | ||
from botorch.utils.testing import BotorchTestCase, MockAcquisitionFunction | ||
|
||
|
||
class SinAcqusitionFunction(MockAcquisitionFunction): | ||
def __init__(self, *args, **kwargs): | ||
return | ||
|
||
def __call__(self, X): | ||
return torch.sin(X[..., 0].max(dim=-1).values) | ||
|
||
|
||
class TestMock(BotorchTestCase): | ||
@fast_optimize | ||
def test_fast_optimize(self): | ||
with self.subTest("gen_candidates_scipy"): | ||
cand, value = gen_candidates_scipy( | ||
initial_conditions=torch.tensor([[0.0]]), | ||
acquisition_function=SinAcqusitionFunction(), | ||
) | ||
# When not using `fast_optimize`, the value is 1.0. With it, the value is | ||
# around 0.84 | ||
self.assertLess(value.item(), 0.99) | ||
|
||
with self.subTest("scipy_minimize"): | ||
x = torch.tensor([0.0]) | ||
|
||
def closure(): | ||
return torch.sin(x), [torch.cos(x)] | ||
|
||
result = scipy_minimize(closure=closure, parameters={"x": x}) | ||
self.assertEqual( | ||
result.message, "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT" | ||
) | ||
|
||
with self.subTest("optimize_acqf"): | ||
cand, value = optimize_acqf( | ||
acq_function=SinAcqusitionFunction(), | ||
bounds=torch.tensor([[-2.0], [2.0]]), | ||
q=1, | ||
num_restarts=32, | ||
batch_initial_conditions=torch.tensor([[0.0]]), | ||
) | ||
self.assertLess(value.item(), 0.99) |