diff --git a/ax/core/search_space.py b/ax/core/search_space.py index 7e14056be67..d31b1bda78c 100644 --- a/ax/core/search_space.py +++ b/ax/core/search_space.py @@ -15,6 +15,7 @@ from functools import reduce from logging import Logger from random import choice, uniform +from typing import Sequence import numpy as np import pandas as pd @@ -714,7 +715,7 @@ def _find_applicable_parameters(root: Parameter) -> set[str]: ): raise RuntimeError( error_msg_prefix - + f"Parameters {applicable_paramers- set(parameters.keys())} are" + + f"Parameters {applicable_paramers - set(parameters.keys())} are" " missing." ) @@ -1074,7 +1075,7 @@ class SearchSpaceDigest: bounds: list[tuple[int | float, int | float]] ordinal_features: list[int] = field(default_factory=list) categorical_features: list[int] = field(default_factory=list) - discrete_choices: Mapping[int, list[int | float]] = field(default_factory=dict) + discrete_choices: Mapping[int, Sequence[int | float]] = field(default_factory=dict) task_features: list[int] = field(default_factory=list) fidelity_features: list[int] = field(default_factory=list) target_values: dict[int, int | float] = field(default_factory=dict) diff --git a/ax/models/model_utils.py b/ax/models/model_utils.py index 61ca9f57932..e651d466792 100644 --- a/ax/models/model_utils.py +++ b/ax/models/model_utils.py @@ -11,7 +11,7 @@ import itertools import warnings from collections.abc import Callable, Mapping -from typing import Protocol, Union +from typing import Protocol, Sequence, Union import numpy as np import torch @@ -20,6 +20,7 @@ from ax.models.types import TConfig from ax.utils.common.typeutils import checked_cast from botorch.acquisition.risk_measures import RiskMeasureMCObjective +from botorch.exceptions.warnings import OptimizationWarning from torch import Tensor @@ -605,8 +606,8 @@ def filter_constraints_and_fixed_features( def mk_discrete_choices( ssd: SearchSpaceDigest, - fixed_features: dict[int, float] | None = None, -) -> Mapping[int, list[int | float]]: + fixed_features: Mapping[int, float] | None = None, +) -> Mapping[int, Sequence[float]]: discrete_choices = ssd.discrete_choices # Add in fixed features. if fixed_features is not None: @@ -619,13 +620,14 @@ def mk_discrete_choices( def enumerate_discrete_combinations( - discrete_choices: Mapping[int, list[int | float]], -) -> list[dict[int, float | int]]: + discrete_choices: Mapping[int, Sequence[float]], +) -> list[dict[int, float]]: n_combos = np.prod([len(v) for v in discrete_choices.values()]) if n_combos > 50: warnings.warn( f"Enumerating {n_combos} combinations of discrete parameter values " "while optimizing over a mixed search space. This can be very slow.", + OptimizationWarning, stacklevel=2, ) fixed_features_list = [ @@ -633,3 +635,22 @@ def enumerate_discrete_combinations( for c in itertools.product(*discrete_choices.values()) ] return fixed_features_list + + +def all_ordinal_features_are_integer_valued( + ssd: SearchSpaceDigest, +) -> bool: + """Check if all ordinal features are integer-valued. + + Args: + ssd: A SearchSpaceDigest. + + Returns: + True if all ordinal features are integer-valued, False otherwise. + """ + for feature_idx in ssd.ordinal_features: + choices = ssd.discrete_choices[feature_idx] + int_choices = [int(c) for c in choices] + if choices != int_choices: + return False + return True diff --git a/ax/models/tests/test_model_utils.py b/ax/models/tests/test_model_utils.py index 2a7ae4b7de8..63b15696b58 100644 --- a/ax/models/tests/test_model_utils.py +++ b/ax/models/tests/test_model_utils.py @@ -12,6 +12,7 @@ import numpy as np from ax.core.search_space import SearchSpaceDigest from ax.models.model_utils import ( + all_ordinal_features_are_integer_valued, best_observed_point, check_duplicate, enumerate_discrete_combinations, @@ -172,13 +173,9 @@ def test_MkDiscreteChoices(self) -> None: def test_EnumerateDiscreteCombinations(self) -> None: dc1 = {1: [0, 1, 2]} - # pyre-fixme[6]: For 1st param expected `Dict[int, List[Union[float, int]]]` - # but got `Dict[int, List[int]]`. dc1_enum = enumerate_discrete_combinations(dc1) self.assertEqual(dc1_enum, [{1: 0}, {1: 1}, {1: 2}]) dc2 = {1: [0, 1, 2], 2: [3, 4]} - # pyre-fixme[6]: For 1st param expected `Dict[int, List[Union[float, int]]]` - # but got `Dict[int, List[int]]`. dc2_enum = enumerate_discrete_combinations(dc2) self.assertEqual( dc2_enum, @@ -191,3 +188,31 @@ def test_EnumerateDiscreteCombinations(self) -> None: {1: 2, 2: 4}, ], ) + + def test_all_ordinal_features_are_integer_valued(self) -> None: + # No ordinal features. + ssd = SearchSpaceDigest( + feature_names=["a", "b"], + bounds=[(0, 1), (0, 2)], + categorical_features=[0], + discrete_choices={0: [0, 1]}, + ) + self.assertTrue(all_ordinal_features_are_integer_valued(ssd=ssd)) + # Non-integer ordinal features. + ssd = SearchSpaceDigest( + feature_names=["a", "b"], + bounds=[(0, 1), (0, 2)], + categorical_features=[0], + ordinal_features=[1], + discrete_choices={0: [0, 1], 1: [0.5, 1.5]}, + ) + self.assertFalse(all_ordinal_features_are_integer_valued(ssd=ssd)) + # Integer ordinal features. + ssd = SearchSpaceDigest( + feature_names=["a", "b"], + bounds=[(0, 1), (0, 2)], + categorical_features=[0], + ordinal_features=[1], + discrete_choices={0: [0.5, 1.0], 1: [0.0, 1.0]}, + ) + self.assertTrue(all_ordinal_features_are_integer_valued(ssd=ssd)) diff --git a/ax/models/torch/botorch_modular/acquisition.py b/ax/models/torch/botorch_modular/acquisition.py index 38e8a358fe7..aa576aa326f 100644 --- a/ax/models/torch/botorch_modular/acquisition.py +++ b/ax/models/torch/botorch_modular/acquisition.py @@ -8,6 +8,7 @@ from __future__ import annotations +import math import operator from collections.abc import Callable from functools import partial, reduce @@ -17,8 +18,12 @@ import torch from ax.core.search_space import SearchSpaceDigest -from ax.exceptions.core import SearchSpaceExhausted -from ax.models.model_utils import enumerate_discrete_combinations, mk_discrete_choices +from ax.exceptions.core import AxError, SearchSpaceExhausted +from ax.models.model_utils import ( + all_ordinal_features_are_integer_valued, + enumerate_discrete_combinations, + mk_discrete_choices, +) from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.models.torch.botorch_moo_defaults import infer_objective_thresholds @@ -45,12 +50,14 @@ optimize_acqf_discrete_local_search, optimize_acqf_mixed, ) +from botorch.optim.optimize_acqf_mixed import optimize_acqf_mixed_alternating from botorch.utils.constraints import get_outcome_constraint_transforms from pyre_extensions import none_throws from torch import Tensor -MAX_CHOICES_ENUMERATE = 100_000 +MAX_CHOICES_ENUMERATE = 100_000 # For fully discrete search spaces. +ALTERNATING_OPTIMIZER_THRESHOLD = 10 # For mixed search spaces. logger: Logger = get_logger(__name__) @@ -293,6 +300,9 @@ def optimize( else: fully_discrete = len(discrete_choices) == len(ssd.feature_names) if fully_discrete: + # If there are less than `MAX_CHOICES_ENUMERATE` choices, we will + # evaluate all of them and pick the best. Otherwise, we will use + # local search. total_discrete_choices = reduce( operator.mul, [float(len(c)) for c in discrete_choices.values()] ) @@ -306,7 +316,24 @@ def optimize( if optimizer_options is not None: optimizer_options.pop("raw_samples", None) else: - optimizer = "optimize_acqf_mixed" + n_combos = math.prod([len(v) for v in discrete_choices.values()]) + # If there are + # - any categorical features (except for those handled by transforms), + # - any ordinal features with non-integer choices, + # - or less than `ALTERNATING_OPTIMIZER_THRESHOLD` combinations + # of discrete choices, we will use `optimize_acqf_mixed`, which + # enumerates all discrete combinations and optimizes the continuous + # features with discrete features being fixed. Otherwise, we will + # use `optimize_acqf_mixed_alternating`, which alternates between + # continuous and discrete optimization steps. + if ( + n_combos <= ALTERNATING_OPTIMIZER_THRESHOLD + or len(ssd.categorical_features) > 0 + or not all_ordinal_features_are_integer_valued(ssd=ssd) + ): + optimizer = "optimize_acqf_mixed" + else: + optimizer = "optimize_acqf_mixed_alternating" # Prepare arguments for optimizer optimizer_options_with_defaults = optimizer_argparse( @@ -384,21 +411,33 @@ def optimize( return candidates, acqf_values, arm_weights # 3. Handle mixed search spaces that have discrete and continuous features. - # Only sequential optimization is supported for `optimize_acqf_mixed`. - candidates, acqf_values = optimize_acqf_mixed( - acq_function=self.acqf, - bounds=bounds, - q=n, - # For now we just enumerate all possible discrete combinations. This is not - # scalable and and only works for a reasonably small number of choices. A - # slowdown warning is logged in `enumerate_discrete_combinations` if needed. - fixed_features_list=enumerate_discrete_combinations( - discrete_choices=discrete_choices - ), - inequality_constraints=inequality_constraints, - post_processing_func=rounding_func, - **optimizer_options_with_defaults, - ) + if optimizer == "optimize_acqf_mixed": + candidates, acqf_values = optimize_acqf_mixed( + acq_function=self.acqf, + bounds=bounds, + q=n, + fixed_features_list=enumerate_discrete_combinations( + discrete_choices=discrete_choices + ), + inequality_constraints=inequality_constraints, + post_processing_func=rounding_func, + **optimizer_options_with_defaults, + ) + elif optimizer == "optimize_acqf_mixed_alternating": + candidates, acqf_values = optimize_acqf_mixed_alternating( + acq_function=self.acqf, + bounds=bounds, + discrete_dims=search_space_digest.ordinal_features, + q=n, + post_processing_func=rounding_func, + fixed_features=fixed_features, + inequality_constraints=inequality_constraints, + **optimizer_options_with_defaults, + ) + else: + raise AxError( # pragma: no cover + f"Unknown optimizer: {optimizer}. This code should be unreachable." + ) return candidates, acqf_values, arm_weights def evaluate(self, X: Tensor) -> Tensor: diff --git a/ax/models/torch/tests/test_acquisition.py b/ax/models/torch/tests/test_acquisition.py index 0c3ce9e70d9..df13d29a91b 100644 --- a/ax/models/torch/tests/test_acquisition.py +++ b/ax/models/torch/tests/test_acquisition.py @@ -54,6 +54,7 @@ optimize_acqf_discrete, optimize_acqf_mixed, ) +from botorch.optim.optimize_acqf_mixed import optimize_acqf_mixed_alternating from botorch.utils.constraints import get_outcome_constraint_transforms from botorch.utils.datasets import SupervisedDataset from botorch.utils.testing import MockPosterior @@ -80,9 +81,11 @@ def __init__(self, **kwargs: Any) -> None: def forward(self, X: torch.Tensor) -> None: # take the norm and sum over the q-batch dim if len(X.shape) > 2: - return torch.linalg.norm(X, dim=-1).sum(-1) + res = torch.linalg.norm(X, dim=-1).sum(-1) else: - return torch.linalg.norm(X, dim=-1).squeeze(-1) + res = torch.linalg.norm(X, dim=-1).squeeze(-1) + # At least 1d is required for sequential optimize_acqf. + return torch.atleast_1d(res) class DummyOneShotAcquisitionFunction(DummyAcquisitionFunction, qKnowledgeGradient): @@ -106,12 +109,12 @@ def setUp(self) -> None: acqf_cls=DummyOneShotAcquisitionFunction, input_constructor=self.mock_input_constructor, ) - tkwargs: dict[str, Any] = {"dtype": torch.double} + self.tkwargs: dict[str, Any] = {"dtype": torch.double} self.botorch_model_class = SingleTaskGP self.surrogate = Surrogate(botorch_model_class=self.botorch_model_class) - self.X = torch.tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], **tkwargs) - self.Y = torch.tensor([[3.0], [4.0]], **tkwargs) - self.Yvar = torch.tensor([[0.0], [2.0]], **tkwargs) + self.X = torch.tensor([[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]], **self.tkwargs) + self.Y = torch.tensor([[3.0], [4.0]], **self.tkwargs) + self.Yvar = torch.tensor([[0.0], [2.0]], **self.tkwargs) self.fidelity_features = [2] self.feature_names = ["a", "b", "c"] self.metric_names = ["metric"] @@ -141,10 +144,10 @@ def setUp(self) -> None: self.botorch_acqf_class = DummyAcquisitionFunction self.objective_weights = torch.tensor([1.0]) self.objective_thresholds = None - self.pending_observations = [torch.tensor([[1.0, 3.0, 4.0]], **tkwargs)] + self.pending_observations = [torch.tensor([[1.0, 3.0, 4.0]], **self.tkwargs)] self.outcome_constraints = ( - torch.tensor([[1.0]], **tkwargs), - torch.tensor([[0.5]], **tkwargs), + torch.tensor([[1.0]], **self.tkwargs), + torch.tensor([[0.5]], **self.tkwargs), ) self.constraints = get_outcome_constraint_transforms( outcome_constraints=self.outcome_constraints @@ -155,13 +158,12 @@ def setUp(self) -> None: self.inequality_constraints = [ ( torch.tensor([0, 1], dtype=torch.int), - torch.tensor([-1.0, 1.0], **tkwargs), + torch.tensor([-1.0, 1.0], **self.tkwargs), 1, ) ] self.rounding_func = lambda x: x self.optimizer_options = {Keys.NUM_RESTARTS: 20, Keys.RAW_SAMPLES: 1024} - self.tkwargs = tkwargs self.torch_opt_config = TorchOptConfig( objective_weights=self.objective_weights, objective_thresholds=self.objective_thresholds, @@ -529,7 +531,6 @@ def test_optimize_acqf_discrete_local_search( self, mock_optimize_acqf_discrete_local_search: Mock, ) -> None: - tkwargs = {"dtype": self.X.dtype, "device": self.X.device} ssd = SearchSpaceDigest( feature_names=["a", "b", "c"], bounds=[(0, 1) for _ in range(3)], @@ -579,7 +580,7 @@ def test_optimize_acqf_discrete_local_search( self.assertEqual(kwargs["raw_samples"], self.optimizer_options["raw_samples"]) self.assertTrue( all( - torch.allclose(torch.linspace(0, 1, 30 * (k + 1), **tkwargs), c) + torch.allclose(torch.linspace(0, 1, 30 * (k + 1), **self.tkwargs), c) for k, c in enumerate(kwargs["discrete_choices"]) ) ) @@ -591,7 +592,6 @@ def test_optimize_acqf_discrete_local_search( @fast_botorch_optimize def test_optimize_mixed(self) -> None: - tkwargs = {"dtype": self.X.dtype, "device": self.X.device} ssd = SearchSpaceDigest( feature_names=["a", "b"], bounds=[(0, 1), (0, 2)], @@ -621,13 +621,70 @@ def test_optimize_mixed(self) -> None: **self.optimizer_options, ) # can't use assert_called_with on bounds due to ambiguous bool comparison - expected_bounds = torch.tensor(ssd.bounds, **tkwargs).transpose(0, 1) + expected_bounds = torch.tensor(ssd.bounds, **self.tkwargs).transpose(0, 1) self.assertTrue( torch.equal( mock_optimize_acqf_mixed.call_args[1]["bounds"], expected_bounds ) ) + @fast_botorch_optimize + def test_optimize_acqf_mixed_alternating(self) -> None: + ssd = SearchSpaceDigest( + feature_names=["a", "b", "c"], + bounds=[(0, 1), (0, 25), (0, 5)], + ordinal_features=[1], + discrete_choices={1: list(range(26))}, + ) + acquisition = self.get_acquisition_function() + with mock.patch( + f"{ACQUISITION_PATH}.optimize_acqf_mixed_alternating", + wraps=optimize_acqf_mixed_alternating, + ) as mock_alternating: + acquisition.optimize( + n=3, + search_space_digest=ssd, + inequality_constraints=self.inequality_constraints, + fixed_features={0: 0.5}, + rounding_func=self.rounding_func, + optimizer_options={ + "options": {"maxiter_alternating": 2}, + "num_restarts": 2, + "raw_samples": 4, + }, + ) + mock_alternating.assert_called_with( + acq_function=acquisition.acqf, + bounds=mock.ANY, + discrete_dims=[1], + q=3, + options={ + "init_batch_limit": 32, + "batch_limit": 5, + "maxiter_alternating": 2, + }, + inequality_constraints=self.inequality_constraints, + fixed_features={0: 0.5}, + post_processing_func=self.rounding_func, + num_restarts=2, + raw_samples=4, + ) + # Check that it is not used if there are non-integer or categorical + # discrete dimensions. + ssd1 = dataclasses.replace(ssd, categorical_features=[0]) + ssd2 = dataclasses.replace( + ssd, + ordinal_features=[0, 1], + discrete_choices={0: [0.1, 0.6], 1: list(range(26))}, + ) + for ssd in [ssd1, ssd2]: + with mock.patch( + f"{ACQUISITION_PATH}.optimize_acqf_mixed_alternating", + wraps=optimize_acqf_mixed_alternating, + ) as mock_alternating: + acquisition.optimize(n=3, search_space_digest=ssd) + mock_alternating.assert_not_called() + @mock.patch( f"{DummyOneShotAcquisitionFunction.__module__}." "DummyOneShotAcquisitionFunction.evaluate",