Skip to content

Commit

Permalink
Add optimize_acqf_mixed_alternating to `mock_botorch_optimize_conte…
Browse files Browse the repository at this point in the history
…xt_manager` & reduce duplication with `mock_optimize_context_manager` (#2973)

Summary:

A previous diff added mixed optimizer to MBM. This diff adds it to optimizer mocks.

`mock_botorch_optimize_context_manager` had a good bit of overlap with BoTorch's `mock_optimize_context_manager`, which is also cleaned up in this diff. It now uses `mock_optimize_context_manager` and adds additional mocks on top of that.

Differential Revision: D65067691
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 30, 2024
1 parent 049de5f commit d224b01
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 60 deletions.
2 changes: 1 addition & 1 deletion ax/utils/testing/core_stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def get_branin_experiment(
status_quo=Arm(parameters={"x1": 0.0, "x2": 0.0}) if with_status_quo else None,
)

if with_batch:
if with_batch or with_completed_batch:
for _ in range(num_batch_trial):
sobol_generator = get_sobol(search_space=exp.search_space)
sobol_run = sobol_generator.gen(n=15)
Expand Down
111 changes: 52 additions & 59 deletions ax/utils/testing/mock.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,8 @@
from unittest import mock

from botorch.fit import fit_fully_bayesian_model_nuts
from botorch.generation.gen import minimize_with_timeout
from botorch.optim.initializers import (
gen_batch_initial_conditions,
gen_one_shot_kg_initial_conditions,
)
from scipy.optimize import OptimizeResult
from botorch.optim.optimize_mixed import optimize_acqf_mixed_alternating
from botorch.test_utils.mock import mock_optimize_context_manager
from torch import Tensor


Expand All @@ -29,80 +25,77 @@ def mock_botorch_optimize_context_manager(
Currently, the primary tactic is to force the underlying scipy methods to
stop after just one iteration.
This context manager uses BoTorch's `mock_optimize_context_manager`, and
adds some additional mocks that are not possible to cover in BoTorch due to
the need to mock the functions where they are used.
Args:
force: If True will not raise an AssertionError if no mocks are called.
USE RESPONSIBLY.
"""

def one_iteration_minimize(*args: Any, **kwargs: Any) -> OptimizeResult:
if kwargs["options"] is None:
kwargs["options"] = {}

kwargs["options"]["maxiter"] = 1
return minimize_with_timeout(*args, **kwargs)

def minimal_gen_ics(*args: Any, **kwargs: Any) -> Tensor:
kwargs["num_restarts"] = 2
kwargs["raw_samples"] = 4

return gen_batch_initial_conditions(*args, **kwargs)

def minimal_gen_os_ics(*args: Any, **kwargs: Any) -> Tensor | None:
kwargs["num_restarts"] = 2
kwargs["raw_samples"] = 4

return gen_one_shot_kg_initial_conditions(*args, **kwargs)

def minimal_fit_fully_bayesian(*args: Any, **kwargs: Any) -> None:
fit_fully_bayesian_model_nuts(*args, **_get_minimal_mcmc_kwargs(**kwargs))

with ExitStack() as es:
mock_generation = es.enter_context(
mock.patch(
"botorch.generation.gen.minimize_with_timeout",
wraps=one_iteration_minimize,
)
def minimal_mixed_optimizer(*args: Any, **kwargs: Any) -> tuple[Tensor, Tensor]:
# BoTorch's `mock_optimize_context_manager` also has some mocks for this,
# but the full set of mocks applied here cannot be covered by that.
kwargs["raw_samples"] = 2
kwargs["num_restarts"] = 1
kwargs["options"].update(
{
"maxiter_alternating": 1,
"maxiter_continuous": 1,
"maxiter_init": 1,
"maxiter_discrete": 1,
}
)
return optimize_acqf_mixed_alternating(*args, **kwargs)

mock_fit = es.enter_context(
mock.patch(
"botorch.optim.core.minimize_with_timeout",
wraps=one_iteration_minimize,
)
)

mock_gen_ics = es.enter_context(
with ExitStack() as es:
mock_mcmc_mbm = es.enter_context(
mock.patch(
"botorch.optim.optimize.gen_batch_initial_conditions",
wraps=minimal_gen_ics,
"ax.models.torch.botorch_modular.utils.fit_fully_bayesian_model_nuts",
wraps=minimal_fit_fully_bayesian,
)
)

mock_gen_os_ics = es.enter_context(
mock_mixed_optimizer = es.enter_context(
mock.patch(
"botorch.optim.optimize.gen_one_shot_kg_initial_conditions",
wraps=minimal_gen_os_ics,
"ax.models.torch.botorch_modular.acquisition."
"optimize_acqf_mixed_alternating",
wraps=minimal_mixed_optimizer,
)
)

mock_mcmc_mbm = es.enter_context(
mock.patch(
"ax.models.torch.botorch_modular.utils.fit_fully_bayesian_model_nuts",
wraps=minimal_fit_fully_bayesian,
)
)
es.enter_context(mock_optimize_context_manager())

yield

if (not force) and all(
mock_.call_count < 1
for mock_ in [
mock_generation,
mock_fit,
mock_gen_ics,
mock_gen_os_ics,
mock_mcmc_mbm,
]
# Only raise if none of the BoTorch or Ax side mocks were called.
# We do this by catching the error that could be raised by the BoTorch
# context manager, and combining it with the signals from Ax side mocks.
try:
es.close()
except AssertionError as e:
# Check if the error is due to no BoTorch mocks being called.
if "No mocks were called" in str(e):
botorch_mocks_called = False
else:
raise
else:
botorch_mocks_called = True

if (
not force
and all(
mock_.call_count < 1
for mock_ in [
mock_mcmc_mbm,
mock_mixed_optimizer,
]
)
and botorch_mocks_called is False
):
raise AssertionError(
"No mocks were called in the context manager. Please remove unused "
Expand Down
79 changes: 79 additions & 0 deletions ax/utils/testing/tests/test_mock.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from unittest.mock import patch

import torch
from ax.modelbridge.registry import Models
from ax.modelbridge.transforms.choice_encode import OrderedChoiceToIntegerRange
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_branin_experiment
from ax.utils.testing.mock import mock_botorch_optimize_context_manager
from botorch.generation.gen import gen_candidates_scipy
from botorch.optim.optimize_mixed import generate_starting_points
from botorch.utils.testing import MockAcquisitionFunction
from pyro.infer import MCMC


class TestMock(TestCase):
def test_no_mocks_called(self) -> None:
# Should raise by default if no mocks are called.
with self.assertRaisesRegex(AssertionError, "No mocks were called"):
with mock_botorch_optimize_context_manager():
pass
# Doesn't raise when force=True.
with mock_botorch_optimize_context_manager(force=True):
pass

def test_botorch_mocks(self) -> None:
# Should not raise when BoTorch mocks are called.
with mock_botorch_optimize_context_manager():
gen_candidates_scipy(
initial_conditions=torch.tensor([[0.0]]),
acquisition_function=MockAcquisitionFunction(), # pyre-ignore [6]
)

def test_fully_bayesian_mocks(self) -> None:
experiment = get_branin_experiment(with_completed_batch=True)
with patch("botorch.fit.MCMC", wraps=MCMC) as mock_mcmc:
with mock_botorch_optimize_context_manager():
Models.SAASBO(experiment=experiment, data=experiment.lookup_data())
mock_mcmc.assert_called_once()
kwargs = mock_mcmc.call_args.kwargs
self.assertEqual(kwargs["num_samples"], 16)
self.assertEqual(kwargs["warmup_steps"], 0)

def test_mixed_optimizer_mocks(self) -> None:
experiment = get_branin_experiment(
with_completed_batch=True, with_choice_parameter=True
)
with patch(
"botorch.optim.optimize_mixed.generate_starting_points",
wraps=generate_starting_points,
) as mock_gen:
with mock_botorch_optimize_context_manager():
Models.BOTORCH_MODULAR(
experiment=experiment,
data=experiment.lookup_data(),
transforms=[OrderedChoiceToIntegerRange],
).gen(n=1)
mock_gen.assert_called_once()
opt_inputs = mock_gen.call_args.kwargs["opt_inputs"]
self.assertEqual(opt_inputs.raw_samples, 2)
self.assertEqual(opt_inputs.num_restarts, 1)
self.assertEqual(
opt_inputs.options,
{
"init_batch_limit": 32,
"batch_limit": 5,
"maxiter_alternating": 1,
"maxiter_continuous": 1,
"maxiter_init": 1,
"maxiter_discrete": 1,
},
)

0 comments on commit d224b01

Please sign in to comment.