diff --git a/botorch/test_utils/__init__.py b/botorch/test_utils/__init__.py new file mode 100644 index 0000000000..3a89f80f66 --- /dev/null +++ b/botorch/test_utils/__init__.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +test_utils has its own directory with 'botorch/' to avoid circular dependencies: +Anything in 'tests/' can depend on anything in 'botorch/test_utils/', and +anything in 'botorch/test_utils/' can depend on anything in the rest of +'botorch/'. +""" + +from botorch.test_utils.mock import fast_optimize + +__all__ = ["fast_optimize"] diff --git a/botorch/test_utils/mock.py b/botorch/test_utils/mock.py new file mode 100644 index 0000000000..b4b1147700 --- /dev/null +++ b/botorch/test_utils/mock.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Utilities for speeding up optimization in tests. + +""" + +from collections.abc import Generator +from contextlib import contextmanager, ExitStack +from functools import wraps +from typing import Any, Callable +from unittest import mock + +from botorch.optim.initializers import ( + gen_batch_initial_conditions, + gen_one_shot_kg_initial_conditions, +) + +from botorch.optim.utils.timeout import minimize_with_timeout +from scipy.optimize import OptimizeResult +from torch import Tensor + + +@contextmanager +def fast_optimize_context_manager( + force: bool = False, +) -> Generator[None, None, None]: + """A context manager to force botorch to speed up optimization. Currently, the + primary tactic is to force the underlying scipy methods to stop after just one + iteration. + + force: If True will not raise an AssertionError if no mocks are called. + USE RESPONSIBLY. + """ + + def one_iteration_minimize(*args: Any, **kwargs: Any) -> OptimizeResult: + if kwargs["options"] is None: + kwargs["options"] = {} + + kwargs["options"]["maxiter"] = 1 + return minimize_with_timeout(*args, **kwargs) + + def minimal_gen_ics(*args: Any, **kwargs: Any) -> Tensor: + kwargs["num_restarts"] = 2 + kwargs["raw_samples"] = 4 + + return gen_batch_initial_conditions(*args, **kwargs) + + def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Tensor | None: + kwargs["num_restarts"] = 2 + kwargs["raw_samples"] = 4 + + return gen_one_shot_kg_initial_conditions(*args, **kwargs) + + with ExitStack() as es: + # Note this `minimize_with_timeout` is defined in optim.utils.timeout; + # this mock only has an effect when calling a function used in + # `botorch.generation.gen`, such as `gen_candidates_scipy`. + mock_generation = es.enter_context( + mock.patch( + "botorch.generation.gen.minimize_with_timeout", + wraps=one_iteration_minimize, + ) + ) + + # Similarly, works when using calling a function defined in + # `optim.core`, such as `scipy_minimize` and `torch_minimize`. + mock_fit = es.enter_context( + mock.patch( + "botorch.optim.core.minimize_with_timeout", + wraps=one_iteration_minimize, + ) + ) + + # Works when calling a function in `optim.optimize` such as + # `optimize_acqf` + mock_gen_ics = es.enter_context( + mock.patch( + "botorch.optim.optimize.gen_batch_initial_conditions", + wraps=minimal_gen_ics, + ) + ) + + # Works when calling a function in `optim.optimize` such as + # `optimize_acqf` + mock_gen_os_ics = es.enter_context( + mock.patch( + "botorch.optim.optimize.gen_one_shot_kg_initial_conditions", + wraps=minimal_gen_os_ics, + ) + ) + + yield + + if (not force) and all( + mock_.call_count < 1 + for mock_ in [ + mock_generation, + mock_fit, + mock_gen_ics, + mock_gen_os_ics, + ] + ): + raise AssertionError( + "No mocks were called in the context manager. Please remove unused " + "fast_botorch_optimize_context_manager()." + ) + + +def fast_optimize(f: Callable) -> Callable: + """Wraps f in the fast_botorch_optimize_context_manager for use as a decorator.""" + + @wraps(f) + # pyre-fixme[3]: Return type must be annotated. + def inner(*args: Any, **kwargs: Any): + with fast_optimize_context_manager(): + return f(*args, **kwargs) + + return inner diff --git a/sphinx/source/index.rst b/sphinx/source/index.rst index dd819d2acb..f715932b9a 100644 --- a/sphinx/source/index.rst +++ b/sphinx/source/index.rst @@ -22,6 +22,7 @@ BoTorch API Reference settings logging test_functions + test_utils exceptions utils diff --git a/sphinx/source/test_utils.rst b/sphinx/source/test_utils.rst new file mode 100644 index 0000000000..71aa2d9214 --- /dev/null +++ b/sphinx/source/test_utils.rst @@ -0,0 +1,12 @@ +.. role:: hidden + :class: hidden-section + + +botorch.test_utils +======================================================== +.. automodule:: botorch.test_utils + +Mock +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: botorch.test_utils.mock + :members: diff --git a/test/test_utils/test_mock.py b/test/test_utils/test_mock.py new file mode 100644 index 0000000000..243e970a38 --- /dev/null +++ b/test/test_utils/test_mock.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from botorch.generation.gen import gen_candidates_scipy +from botorch.optim.core import scipy_minimize +from botorch.optim.optimize import optimize_acqf + +from botorch.test_utils.mock import fast_optimize +from botorch.utils.testing import BotorchTestCase, MockAcquisitionFunction + + +class SinAcqusitionFunction(MockAcquisitionFunction): + """Simple acquisition function with known numerical properties.""" + + def __init__(self, *args, **kwargs): + return + + def __call__(self, X): + return torch.sin(X[..., 0].max(dim=-1).values) + + +class TestMock(BotorchTestCase): + @fast_optimize + def test_fast_optimize(self): + with self.subTest("gen_candidates_scipy"): + cand, value = gen_candidates_scipy( + initial_conditions=torch.tensor([[0.0]]), + acquisition_function=SinAcqusitionFunction(), + ) + # When not using `fast_optimize`, the value is 1.0. With it, the value is + # around 0.84 + self.assertLess(value.item(), 0.99) + + with self.subTest("scipy_minimize"): + x = torch.tensor([0.0]) + + def closure(): + return torch.sin(x), [torch.cos(x)] + + result = scipy_minimize(closure=closure, parameters={"x": x}) + self.assertEqual( + result.message, "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT" + ) + + with self.subTest("optimize_acqf"): + cand, value = optimize_acqf( + acq_function=SinAcqusitionFunction(), + bounds=torch.tensor([[-2.0], [2.0]]), + q=1, + num_restarts=32, + batch_initial_conditions=torch.tensor([[0.0]]), + ) + self.assertLess(value.item(), 0.99)