Skip to content

Commit

Permalink
Reap get_GPMES & Models.GPMES (#2317)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2317

Removes the unused factory function that builds a legacy model. Any interested user can still utilize MES through MBM.

Reviewed By: Balandat

Differential Revision: D55709258

fbshipit-source-id: 8eb6ba9f739276d66ab0aa109cd715016e6db2a1
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Apr 4, 2024
1 parent 5a4a0cd commit f14f873
Show file tree
Hide file tree
Showing 3 changed files with 1 addition and 92 deletions.
35 changes: 1 addition & 34 deletions ax/modelbridge/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

from logging import Logger
from typing import Any, Dict, List, Optional, Type
from typing import Dict, List, Optional, Type

import torch
from ax.core.data import Data
Expand Down Expand Up @@ -423,39 +423,6 @@ def get_thompson(
)


def get_GPMES(
experiment: Experiment,
data: Data,
search_space: Optional[SearchSpace] = None,
cost_intercept: float = 0.01,
dtype: torch.dtype = torch.double,
device: torch.device = DEFAULT_TORCH_DEVICE,
transforms: List[Type[Transform]] = Cont_X_trans + Y_trans,
transform_configs: Optional[Dict[str, TConfig]] = None,
**kwargs: Any,
) -> TorchModelBridge:
"""Instantiates a GP model that generates points with MES."""
if search_space is None:
search_space = experiment.search_space
if data.df.empty:
raise ValueError("GP + MES BotorchModel requires non-empty data.")

inputs = {
"search_space": search_space,
"experiment": experiment,
"data": data,
"cost_intercept": cost_intercept,
"torch_dtype": dtype,
"torch_device": device,
"transforms": transforms,
"transform_configs": transform_configs,
}

if any(p.is_fidelity for k, p in experiment.parameters.items()):
inputs["linear_truncated"] = kwargs.get("linear_truncated", True)
return checked_cast(TorchModelBridge, Models.GPMES(**inputs)) # pyre-ignore: [16]


def get_MOO_EHVI(
experiment: Experiment,
data: Data,
Expand Down
8 changes: 0 additions & 8 deletions ax/modelbridge/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
from ax.models.random.uniform import UniformGenerator
from ax.models.torch.alebo import ALEBO
from ax.models.torch.botorch import BotorchModel
from ax.models.torch.botorch_mes import MaxValueEntropySearch
from ax.models.torch.botorch_modular.model import (
BoTorchModel as ModularBoTorchModel,
SurrogateSpec,
Expand Down Expand Up @@ -177,12 +176,6 @@ class ModelSetup(NamedTuple):
transforms=Cont_X_trans + Y_trans,
standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS,
),
"GPMES": ModelSetup(
bridge_class=TorchModelBridge,
model_class=MaxValueEntropySearch,
transforms=Cont_X_trans + Y_trans,
standard_bridge_kwargs=STANDARD_TORCH_BRIDGE_KWARGS,
),
"EB": ModelSetup(
bridge_class=DiscreteModelBridge,
model_class=EmpiricalBayesThompsonSampler,
Expand Down Expand Up @@ -465,7 +458,6 @@ class Models(ModelRegistryBase):

SOBOL = "Sobol"
GPEI = "GPEI"
GPMES = "GPMES"
FACTORIAL = "Factorial"
SAASBO = "SAASBO"
FULLYBAYESIAN = "FullyBayesian"
Expand Down
50 changes: 0 additions & 50 deletions ax/modelbridge/tests/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,12 @@
from ax.core.observation import ObservationFeatures
from ax.core.optimization_config import MultiObjectiveOptimizationConfig
from ax.core.outcome_constraint import ComparisonOp, ObjectiveThreshold
from ax.core.parameter import RangeParameter
from ax.modelbridge.discrete import DiscreteModelBridge
from ax.modelbridge.factory import (
get_botorch,
get_empirical_bayes_thompson,
get_factorial,
get_GPEI,
get_GPMES,
get_MOO_EHVI,
get_MOO_NEHVI,
get_MOO_PAREGO,
Expand All @@ -39,7 +37,6 @@
from ax.modelbridge.torch import TorchModelBridge
from ax.models.discrete.eb_thompson import EmpiricalBayesThompsonSampler
from ax.models.discrete.thompson import ThompsonSampler
from ax.models.winsorization_config import WinsorizationConfig
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_branin_experiment,
Expand Down Expand Up @@ -136,53 +133,6 @@ def test_MTGP_LEGACY(self) -> None:
with self.assertRaises(ValueError):
get_MTGP_LEGACY(experiment=exp, data=exp.fetch_data(), trial_index=0)

@fast_botorch_optimize
def test_GPMES(self) -> None:
"""Tests GPMES instantiation."""
exp = get_branin_experiment(with_batch=True)
with self.assertRaises(ValueError):
get_GPMES(experiment=exp, data=exp.fetch_data())
exp.trials[0].run()
gpmes = get_GPMES(experiment=exp, data=exp.fetch_data())
self.assertIsInstance(gpmes, TorchModelBridge)

# Check that .gen returns without failure
gr = gpmes.gen(n=1)
self.assertEqual(len(gr.arms), 1)

# test transform_configs with winsorization
configs = {
"Winsorize": {
"winsorization_config": WinsorizationConfig(
lower_quantile_margin=0.1,
upper_quantile_margin=0.1,
)
}
}
gpmes_win = get_GPMES(
experiment=exp,
data=exp.fetch_data(),
# pyre-fixme[6]: For 3rd param expected `Optional[Dict[str, Dict[str,
# Union[None, Dict[str, typing.Any], OptimizationConfig,
# AcquisitionFunction, float, int, str]]]]` but got `Dict[str, Dict[str,
# WinsorizationConfig]]`.
transform_configs=configs,
)
self.assertIsInstance(gpmes_win, TorchModelBridge)
self.assertEqual(gpmes_win._transform_configs, configs)

# test multi-fidelity optimization
exp.parameters["x2"] = RangeParameter(
name="x2",
parameter_type=exp.parameters["x2"].parameter_type,
lower=-5.0,
upper=10.0,
is_fidelity=True,
target_value=10.0,
)
gpmes_mf = get_GPMES(experiment=exp, data=exp.fetch_data())
self.assertIsInstance(gpmes_mf, TorchModelBridge)

def test_model_kwargs(self) -> None:
"""Tests that model kwargs are passed correctly."""
exp = get_branin_experiment()
Expand Down

0 comments on commit f14f873

Please sign in to comment.