Skip to content

Commit

Permalink
Use optimize_acqf_mixed_alternating in Acquisition.optimize (#2972)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2972

This diff updates `Acquisition.optimize` to utilize `optimize_acqf_mixed_alternating` when
- there are no categorical features (except for those handled by transforms),
- all ordinal features are integer valued,
- and there are more than `ALTERNATING_OPTIMIZER_THRESHOLD` combinations of discrete choices.
`optimize_acqf_mixed_alternating` will be more efficient than `optimize_acqf_mixed` when there are many combinations of discrete choices. The other conditions are current limitations of the optimizer.

The current choice of `ALTERNATING_OPTIMIZER_THRESHOLD = 10` is somewhat arbitrary.

Reviewed By: susanxia1006

Differential Revision: D65066344

fbshipit-source-id: b2c24edd00755b210defe06ae9a065d617d6bdc0
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Oct 28, 2024
1 parent 8562ea9 commit d1e65c5
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 45 deletions.
5 changes: 3 additions & 2 deletions ax/core/search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from functools import reduce
from logging import Logger
from random import choice, uniform
from typing import Sequence

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -714,7 +715,7 @@ def _find_applicable_parameters(root: Parameter) -> set[str]:
):
raise RuntimeError(
error_msg_prefix
+ f"Parameters {applicable_paramers- set(parameters.keys())} are"
+ f"Parameters {applicable_paramers - set(parameters.keys())} are"
" missing."
)

Expand Down Expand Up @@ -1074,7 +1075,7 @@ class SearchSpaceDigest:
bounds: list[tuple[int | float, int | float]]
ordinal_features: list[int] = field(default_factory=list)
categorical_features: list[int] = field(default_factory=list)
discrete_choices: Mapping[int, list[int | float]] = field(default_factory=dict)
discrete_choices: Mapping[int, Sequence[int | float]] = field(default_factory=dict)
task_features: list[int] = field(default_factory=list)
fidelity_features: list[int] = field(default_factory=list)
target_values: dict[int, int | float] = field(default_factory=dict)
Expand Down
31 changes: 26 additions & 5 deletions ax/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import itertools
import warnings
from collections.abc import Callable, Mapping
from typing import Protocol, Union
from typing import Protocol, Sequence, Union

import numpy as np
import torch
Expand All @@ -20,6 +20,7 @@
from ax.models.types import TConfig
from ax.utils.common.typeutils import checked_cast
from botorch.acquisition.risk_measures import RiskMeasureMCObjective
from botorch.exceptions.warnings import OptimizationWarning
from torch import Tensor


Expand Down Expand Up @@ -605,8 +606,8 @@ def filter_constraints_and_fixed_features(

def mk_discrete_choices(
ssd: SearchSpaceDigest,
fixed_features: dict[int, float] | None = None,
) -> Mapping[int, list[int | float]]:
fixed_features: Mapping[int, float] | None = None,
) -> Mapping[int, Sequence[float]]:
discrete_choices = ssd.discrete_choices
# Add in fixed features.
if fixed_features is not None:
Expand All @@ -619,17 +620,37 @@ def mk_discrete_choices(


def enumerate_discrete_combinations(
discrete_choices: Mapping[int, list[int | float]],
) -> list[dict[int, float | int]]:
discrete_choices: Mapping[int, Sequence[float]],
) -> list[dict[int, float]]:
n_combos = np.prod([len(v) for v in discrete_choices.values()])
if n_combos > 50:
warnings.warn(
f"Enumerating {n_combos} combinations of discrete parameter values "
"while optimizing over a mixed search space. This can be very slow.",
OptimizationWarning,
stacklevel=2,
)
fixed_features_list = [
dict(zip(discrete_choices.keys(), c))
for c in itertools.product(*discrete_choices.values())
]
return fixed_features_list


def all_ordinal_features_are_integer_valued(
ssd: SearchSpaceDigest,
) -> bool:
"""Check if all ordinal features are integer-valued.
Args:
ssd: A SearchSpaceDigest.
Returns:
True if all ordinal features are integer-valued, False otherwise.
"""
for feature_idx in ssd.ordinal_features:
choices = ssd.discrete_choices[feature_idx]
int_choices = [int(c) for c in choices]
if choices != int_choices:
return False
return True
33 changes: 29 additions & 4 deletions ax/models/tests/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy as np
from ax.core.search_space import SearchSpaceDigest
from ax.models.model_utils import (
all_ordinal_features_are_integer_valued,
best_observed_point,
check_duplicate,
enumerate_discrete_combinations,
Expand Down Expand Up @@ -172,13 +173,9 @@ def test_MkDiscreteChoices(self) -> None:

def test_EnumerateDiscreteCombinations(self) -> None:
dc1 = {1: [0, 1, 2]}
# pyre-fixme[6]: For 1st param expected `Dict[int, List[Union[float, int]]]`
# but got `Dict[int, List[int]]`.
dc1_enum = enumerate_discrete_combinations(dc1)
self.assertEqual(dc1_enum, [{1: 0}, {1: 1}, {1: 2}])
dc2 = {1: [0, 1, 2], 2: [3, 4]}
# pyre-fixme[6]: For 1st param expected `Dict[int, List[Union[float, int]]]`
# but got `Dict[int, List[int]]`.
dc2_enum = enumerate_discrete_combinations(dc2)
self.assertEqual(
dc2_enum,
Expand All @@ -191,3 +188,31 @@ def test_EnumerateDiscreteCombinations(self) -> None:
{1: 2, 2: 4},
],
)

def test_all_ordinal_features_are_integer_valued(self) -> None:
# No ordinal features.
ssd = SearchSpaceDigest(
feature_names=["a", "b"],
bounds=[(0, 1), (0, 2)],
categorical_features=[0],
discrete_choices={0: [0, 1]},
)
self.assertTrue(all_ordinal_features_are_integer_valued(ssd=ssd))
# Non-integer ordinal features.
ssd = SearchSpaceDigest(
feature_names=["a", "b"],
bounds=[(0, 1), (0, 2)],
categorical_features=[0],
ordinal_features=[1],
discrete_choices={0: [0, 1], 1: [0.5, 1.5]},
)
self.assertFalse(all_ordinal_features_are_integer_valued(ssd=ssd))
# Integer ordinal features.
ssd = SearchSpaceDigest(
feature_names=["a", "b"],
bounds=[(0, 1), (0, 2)],
categorical_features=[0],
ordinal_features=[1],
discrete_choices={0: [0.5, 1.0], 1: [0.0, 1.0]},
)
self.assertTrue(all_ordinal_features_are_integer_valued(ssd=ssd))
77 changes: 58 additions & 19 deletions ax/models/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from __future__ import annotations

import math
import operator
from collections.abc import Callable
from functools import partial, reduce
Expand All @@ -17,8 +18,12 @@

import torch
from ax.core.search_space import SearchSpaceDigest
from ax.exceptions.core import SearchSpaceExhausted
from ax.models.model_utils import enumerate_discrete_combinations, mk_discrete_choices
from ax.exceptions.core import AxError, SearchSpaceExhausted
from ax.models.model_utils import (
all_ordinal_features_are_integer_valued,
enumerate_discrete_combinations,
mk_discrete_choices,
)
from ax.models.torch.botorch_modular.optimizer_argparse import optimizer_argparse
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.botorch_moo_defaults import infer_objective_thresholds
Expand All @@ -45,12 +50,14 @@
optimize_acqf_discrete_local_search,
optimize_acqf_mixed,
)
from botorch.optim.optimize_acqf_mixed import optimize_acqf_mixed_alternating
from botorch.utils.constraints import get_outcome_constraint_transforms
from pyre_extensions import none_throws
from torch import Tensor


MAX_CHOICES_ENUMERATE = 100_000
MAX_CHOICES_ENUMERATE = 100_000 # For fully discrete search spaces.
ALTERNATING_OPTIMIZER_THRESHOLD = 10 # For mixed search spaces.

logger: Logger = get_logger(__name__)

Expand Down Expand Up @@ -293,6 +300,9 @@ def optimize(
else:
fully_discrete = len(discrete_choices) == len(ssd.feature_names)
if fully_discrete:
# If there are less than `MAX_CHOICES_ENUMERATE` choices, we will
# evaluate all of them and pick the best. Otherwise, we will use
# local search.
total_discrete_choices = reduce(
operator.mul, [float(len(c)) for c in discrete_choices.values()]
)
Expand All @@ -306,7 +316,24 @@ def optimize(
if optimizer_options is not None:
optimizer_options.pop("raw_samples", None)
else:
optimizer = "optimize_acqf_mixed"
n_combos = math.prod([len(v) for v in discrete_choices.values()])
# If there are
# - any categorical features (except for those handled by transforms),
# - any ordinal features with non-integer choices,
# - or less than `ALTERNATING_OPTIMIZER_THRESHOLD` combinations
# of discrete choices, we will use `optimize_acqf_mixed`, which
# enumerates all discrete combinations and optimizes the continuous
# features with discrete features being fixed. Otherwise, we will
# use `optimize_acqf_mixed_alternating`, which alternates between
# continuous and discrete optimization steps.
if (
n_combos <= ALTERNATING_OPTIMIZER_THRESHOLD
or len(ssd.categorical_features) > 0
or not all_ordinal_features_are_integer_valued(ssd=ssd)
):
optimizer = "optimize_acqf_mixed"
else:
optimizer = "optimize_acqf_mixed_alternating"

# Prepare arguments for optimizer
optimizer_options_with_defaults = optimizer_argparse(
Expand Down Expand Up @@ -384,21 +411,33 @@ def optimize(
return candidates, acqf_values, arm_weights

# 3. Handle mixed search spaces that have discrete and continuous features.
# Only sequential optimization is supported for `optimize_acqf_mixed`.
candidates, acqf_values = optimize_acqf_mixed(
acq_function=self.acqf,
bounds=bounds,
q=n,
# For now we just enumerate all possible discrete combinations. This is not
# scalable and and only works for a reasonably small number of choices. A
# slowdown warning is logged in `enumerate_discrete_combinations` if needed.
fixed_features_list=enumerate_discrete_combinations(
discrete_choices=discrete_choices
),
inequality_constraints=inequality_constraints,
post_processing_func=rounding_func,
**optimizer_options_with_defaults,
)
if optimizer == "optimize_acqf_mixed":
candidates, acqf_values = optimize_acqf_mixed(
acq_function=self.acqf,
bounds=bounds,
q=n,
fixed_features_list=enumerate_discrete_combinations(
discrete_choices=discrete_choices
),
inequality_constraints=inequality_constraints,
post_processing_func=rounding_func,
**optimizer_options_with_defaults,
)
elif optimizer == "optimize_acqf_mixed_alternating":
candidates, acqf_values = optimize_acqf_mixed_alternating(
acq_function=self.acqf,
bounds=bounds,
discrete_dims=search_space_digest.ordinal_features,
q=n,
post_processing_func=rounding_func,
fixed_features=fixed_features,
inequality_constraints=inequality_constraints,
**optimizer_options_with_defaults,
)
else:
raise AxError( # pragma: no cover
f"Unknown optimizer: {optimizer}. This code should be unreachable."
)
return candidates, acqf_values, arm_weights

def evaluate(self, X: Tensor) -> Tensor:
Expand Down
Loading

0 comments on commit d1e65c5

Please sign in to comment.