Skip to content

Commit

Permalink
Introduce fast_optimize context manager to speed up testing
Browse files Browse the repository at this point in the history
Summary:
Context: Many BoTorch tests take a while because they run optimization code. But it's good that this code runs rather than being avoided or mocked out, becuase the tests are ensuring that things work end-to-end. Borrowing a page from Ax's `fast_botorch_optimize`, this commit introduces the same thing to BoTorch, with the exception of `fit_fully_bayesian_model_nuts`.

A future commit to Ax can remove that functionality from Ax in favor of importing it from BoTorch, but we might not want to do it right way because then Ax won't work with older versions of BoTorch.


This PR:
* Introduces `fast_optimize`, which is the same as Ax's `fast_botorch_optimize`, but with different import paths.
* Applies it to a slow test, reducing runtime to 2s from 6s-10s.

Differential Revision: D63838626
  • Loading branch information
esantorella authored and facebook-github-bot committed Oct 3, 2024
1 parent a0a2c05 commit b8d497c
Show file tree
Hide file tree
Showing 6 changed files with 347 additions and 1 deletion.
15 changes: 15 additions & 0 deletions botorch/test_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
test_utils has its own top-level directory to avoid circular dependencies:
Anything in 'tests/' can depend on anything in 'test_utils/', and anything in
'test_utils/' can depend on anything in 'botorch/'.
"""

from botorch.test_utils.mock import fast_optimize

__all__ = ["fast_optimize"]
123 changes: 123 additions & 0 deletions botorch/test_utils/mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Utilities for speeding up optimization in tests.
"""

from collections.abc import Generator
from contextlib import contextmanager, ExitStack
from functools import wraps
from typing import Any, Callable
from unittest import mock

from botorch.optim.initializers import (
gen_batch_initial_conditions,
gen_one_shot_kg_initial_conditions,
)

from botorch.optim.utils.timeout import minimize_with_timeout
from scipy.optimize import OptimizeResult
from torch import Tensor


@contextmanager
def fast_optimize_context_manager(
force: bool = False,
) -> Generator[None, None, None]:
"""A context manager to force botorch to speed up optimization. Currently, the
primary tactic is to force the underlying scipy methods to stop after just one
iteration.
force: If True will not raise an AssertionError if no mocks are called.
USE RESPONSIBLY.
"""

def one_iteration_minimize(*args: Any, **kwargs: Any) -> OptimizeResult:
if kwargs["options"] is None:
kwargs["options"] = {}

kwargs["options"]["maxiter"] = 1
return minimize_with_timeout(*args, **kwargs)

def minimal_gen_ics(*args: Any, **kwargs: Any) -> Tensor:
kwargs["num_restarts"] = 2
kwargs["raw_samples"] = 4

return gen_batch_initial_conditions(*args, **kwargs)

def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Tensor | None:
kwargs["num_restarts"] = 2
kwargs["raw_samples"] = 4

return gen_one_shot_kg_initial_conditions(*args, **kwargs)

with ExitStack() as es:
# Note this `minimize_with_timeout` is defined in optim.utils.timeout;
# this mock only has an effect when calling a function used in
# `botorch.generation.gen`, such as `gen_candidates_scipy`.
mock_generation = es.enter_context(
mock.patch(
"botorch.generation.gen.minimize_with_timeout",
wraps=one_iteration_minimize,
)
)

# Similarly, works when using calling a function defined in
# `optim.core`, such as `scipy_minimize` and `torch_minimize`.
mock_fit = es.enter_context(
mock.patch(
"botorch.optim.core.minimize_with_timeout",
wraps=one_iteration_minimize,
)
)

# Works when calling a function in `optim.optimize` such as
# `optimize_acqf`
mock_gen_ics = es.enter_context(
mock.patch(
"botorch.optim.optimize.gen_batch_initial_conditions",
wraps=minimal_gen_ics,
)
)

# Works when calling a function in `optim.optimize` such as
# `optimize_acqf`
mock_gen_os_ics = es.enter_context(
mock.patch(
"botorch.optim.optimize.gen_one_shot_kg_initial_conditions",
wraps=minimal_gen_os_ics,
)
)

yield

if (not force) and all(
mock_.call_count < 1
for mock_ in [
mock_generation,
mock_fit,
mock_gen_ics,
mock_gen_os_ics,
]
):
raise AssertionError(
"No mocks were called in the context manager. Please remove unused "
"fast_botorch_optimize_context_manager()."
)


def fast_optimize(f: Callable) -> Callable:
"""Wraps f in the fast_botorch_optimize_context_manager for use as a decorator."""

@wraps(f)
# pyre-fixme[3]: Return type must be annotated.
def inner(*args: Any, **kwargs: Any):
with fast_optimize_context_manager():
return f(*args, **kwargs)

return inner
135 changes: 135 additions & 0 deletions botorch/utils/mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

"""
Utilities for speeding up optimization in tests.
This lives here rather than in botorch/utils to avoid circular dependencies.
"""

from collections.abc import Generator
from contextlib import contextmanager, ExitStack
from functools import wraps
from typing import Any, Callable
from unittest import mock

from botorch.fit import fit_fully_bayesian_model_nuts

from botorch.generation.gen import minimize_with_timeout
from botorch.optim.initializers import (
gen_batch_initial_conditions,
gen_one_shot_kg_initial_conditions,
)
from scipy.optimize import OptimizeResult
from torch import Tensor


def _get_minimal_mcmc_kwargs(**kwargs: Any) -> dict[str, Any]:
kwargs["warmup_steps"] = 0
# Just get as many samples as otherwise expected.
kwargs["num_samples"] = kwargs.get("num_samples", 256) // kwargs.get("thinning", 16)
kwargs["thinning"] = 1
return kwargs


@contextmanager
def fast_optimize_context_manager(
force: bool = False,
) -> Generator[None, None, None]:
"""A context manager to force botorch to speed up optimization. Currently, the
primary tactic is to force the underlying scipy methods to stop after just one
iteration.
force: If True will not raise an AssertionError if no mocks are called.
USE RESPONSIBLY.
"""

def one_iteration_minimize(*args: Any, **kwargs: Any) -> OptimizeResult:
if kwargs["options"] is None:
kwargs["options"] = {}

kwargs["options"]["maxiter"] = 1
return minimize_with_timeout(*args, **kwargs)

def minimal_gen_ics(*args: Any, **kwargs: Any) -> Tensor:
kwargs["num_restarts"] = 2
kwargs["raw_samples"] = 4

return gen_batch_initial_conditions(*args, **kwargs)

def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Tensor | None:
kwargs["num_restarts"] = 2
kwargs["raw_samples"] = 4

return gen_one_shot_kg_initial_conditions(*args, **kwargs)

def minimal_fit_fully_bayesian(*args: Any, **kwargs: Any) -> None:
fit_fully_bayesian_model_nuts(*args, **_get_minimal_mcmc_kwargs(**kwargs))

with ExitStack() as es:
mock_generation = es.enter_context(
mock.patch(
"botorch.generation.gen.minimize_with_timeout",
wraps=one_iteration_minimize,
)
)

mock_fit = es.enter_context(
mock.patch(
"botorch.optim.core.minimize_with_timeout",
wraps=one_iteration_minimize,
)
)

mock_gen_ics = es.enter_context(
mock.patch(
"botorch.optim.optimize.gen_batch_initial_conditions",
wraps=minimal_gen_ics,
)
)

mock_gen_os_ics = es.enter_context(
mock.patch(
"botorch.optim.optimize.gen_one_shot_kg_initial_conditions",
wraps=minimal_gen_os_ics,
)
)

mock_mcmc_mbm = es.enter_context(
mock.patch(
"botorch.fit.fit_fully_bayesian_model_nuts",
wraps=minimal_fit_fully_bayesian,
)
)

yield

if (not force) and all(
mock_.call_count < 1
for mock_ in [
mock_generation,
mock_fit,
mock_gen_ics,
mock_gen_os_ics,
mock_mcmc_mbm,
]
):
raise AssertionError(
"No mocks were called in the context manager. Please remove unused "
"fast_botorch_optimize_context_manager()."
)


def fast_optimize(f: Callable) -> Callable:
"""Wraps f in the fast_botorch_optimize_context_manager for use as a decorator."""

@wraps(f)
# pyre-fixme[3]: Return type must be annotated.
def inner(*args: Any, **kwargs: Any):
with fast_optimize_context_manager():
return f(*args, **kwargs)

return inner
12 changes: 12 additions & 0 deletions sphinx/source/test_utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.. role:: hidden
:class: hidden-section


botorch.test_utils
========================================================
.. automodule:: botorch.test_utils

Mock
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. automodule:: botorch.test_utils.mock
:members:
7 changes: 6 additions & 1 deletion test/acquisition/test_input_constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import math
from collections.abc import Callable
from functools import reduce
from time import monotonic
from unittest import mock
from unittest.mock import MagicMock

Expand Down Expand Up @@ -111,6 +112,7 @@
from botorch.sampling.normal import IIDNormalSampler, SobolQMCNormalSampler
from botorch.utils.constraints import get_outcome_constraint_transforms
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.mock import fast_optimize
from botorch.utils.multi_objective.box_decompositions.non_dominated import (
FastNondominatedPartitioning,
NondominatedPartitioning,
Expand Down Expand Up @@ -1845,7 +1847,9 @@ def setUp(self, suppress_input_warnings: bool = True) -> None:
},
)

def test_constructors_can_instantiate(self) -> None:
# @fast_botorch_optimize
def ftest_constructors_can_instantiate(self) -> None:
start = monotonic()
for key, (classes, input_constructor_kwargs) in self.cases.items():
with self.subTest(
key, classes=classes, input_constructor_kwargs=input_constructor_kwargs
Expand All @@ -1856,6 +1860,7 @@ def test_constructors_can_instantiate(self) -> None:
)
# no assertions; we are just testing that this doesn't error
cls_(**acqf_kwargs)
print(monotonic() - start)

def test_all_cases_covered(self) -> None:
all_classes_tested = reduce(
Expand Down
56 changes: 56 additions & 0 deletions test/test_utils/test_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


import torch
from botorch.generation.gen import gen_candidates_scipy
from botorch.optim.core import scipy_minimize
from botorch.optim.optimize import optimize_acqf

from botorch.test_utils.mock import fast_optimize
from botorch.utils.testing import BotorchTestCase, MockAcquisitionFunction


class SinAcqusitionFunction(MockAcquisitionFunction):
def __init__(self, *args, **kwargs):
return

def __call__(self, X):
return torch.sin(X[..., 0].max(dim=-1).values)


class TestMock(BotorchTestCase):
@fast_optimize
def test_fast_optimize(self):
with self.subTest("gen_candidates_scipy"):
cand, value = gen_candidates_scipy(
initial_conditions=torch.tensor([[0.0]]),
acquisition_function=SinAcqusitionFunction(),
)
# When not using `fast_optimize`, the value is 1.0. With it, the value is
# around 0.84
self.assertLess(value.item(), 0.99)

with self.subTest("scipy_minimize"):
x = torch.tensor([0.0])

def closure():
return torch.sin(x), [torch.cos(x)]

result = scipy_minimize(closure=closure, parameters={"x": x})
self.assertEqual(
result.message, "STOP: TOTAL NO. of ITERATIONS REACHED LIMIT"
)

with self.subTest("optimize_acqf"):
cand, value = optimize_acqf(
acq_function=SinAcqusitionFunction(),
bounds=torch.tensor([[-2.0], [2.0]]),
q=1,
num_restarts=32,
batch_initial_conditions=torch.tensor([[0.0]]),
)
self.assertLess(value.item(), 0.99)

0 comments on commit b8d497c

Please sign in to comment.