diff --git a/ax/__init__.py b/ax/__init__.py index d1f2e9e6f65..5786f1b359f 100644 --- a/ax/__init__.py +++ b/ax/__init__.py @@ -32,7 +32,7 @@ SumConstraint, Trial, ) -from ax.modelbridge import Models +from ax.modelbridge import Generators from ax.service import OptimizationLoop, optimize from ax.storage import json_load, json_save @@ -52,7 +52,7 @@ "FixedParameter", "GeneratorRun", "Metric", - "Models", + "Generators", "MultiObjective", "MultiObjectiveOptimizationConfig", "Objective", diff --git a/ax/analysis/healthcheck/constraints_feasibility.py b/ax/analysis/healthcheck/constraints_feasibility.py index 29df12e0f13..61b050ecd1f 100644 --- a/ax/analysis/healthcheck/constraints_feasibility.py +++ b/ax/analysis/healthcheck/constraints_feasibility.py @@ -23,7 +23,7 @@ from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.optimization_config import OptimizationConfig from ax.exceptions.core import UserInputError -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.transforms.derelativize import Derelativize from pyre_extensions import assert_is_instance, none_throws @@ -155,7 +155,7 @@ def compute( def constraints_feasibility( optimization_config: OptimizationConfig, - model: ModelBridge, + model: Adapter, prob_threshold: float = 0.99, ) -> Tuple[bool, pd.DataFrame]: r""" diff --git a/ax/analysis/healthcheck/regression_detection_utils.py b/ax/analysis/healthcheck/regression_detection_utils.py index 2092ee5c955..4ec4b84061b 100644 --- a/ax/analysis/healthcheck/regression_detection_utils.py +++ b/ax/analysis/healthcheck/regression_detection_utils.py @@ -15,7 +15,7 @@ from ax.core.observation import observations_from_data from ax.exceptions.core import DataRequiredError, UserInputError -from ax.modelbridge.discrete import DiscreteModelBridge +from ax.modelbridge.discrete import DiscreteAdapter from ax.modelbridge.registry import rel_EB_ashr_trans from ax.models.discrete.eb_ashr import EBAshr from pyre_extensions import assert_is_instance @@ -101,7 +101,7 @@ def compute_regression_probabilities_single_trial( target_data = Data(df=data.df[data.df["metric_name"].isin(metric_names)]) - modelbridge = DiscreteModelBridge( + modelbridge = DiscreteAdapter( experiment=experiment, search_space=experiment.search_space, data=target_data, diff --git a/ax/analysis/healthcheck/tests/test_constraints_feasibility.py b/ax/analysis/healthcheck/tests/test_constraints_feasibility.py index 6a8bb0bc2b1..a6e95609685 100644 --- a/ax/analysis/healthcheck/tests/test_constraints_feasibility.py +++ b/ax/analysis/healthcheck/tests/test_constraints_feasibility.py @@ -26,8 +26,8 @@ from ax.modelbridge.factory import get_sobol from ax.modelbridge.generation_node import GenerationNode from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.model_spec import ModelSpec -from ax.modelbridge.registry import Models +from ax.modelbridge.model_spec import GeneratorSpec +from ax.modelbridge.registry import Generators from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_branin_experiment, @@ -96,8 +96,8 @@ def setUp(self) -> None: GenerationNode( node_name="gn", model_specs=[ - ModelSpec( - model_enum=Models.BOTORCH_MODULAR, + GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, ) ], ) diff --git a/ax/analysis/plotly/arm_effects/insample_effects.py b/ax/analysis/plotly/arm_effects/insample_effects.py index 4e1cca4b071..b556ec6a22d 100644 --- a/ax/analysis/plotly/arm_effects/insample_effects.py +++ b/ax/analysis/plotly/arm_effects/insample_effects.py @@ -22,9 +22,9 @@ from ax.core.generator_run import GeneratorRun from ax.core.outcome_constraint import OutcomeConstraint from ax.exceptions.core import DataRequiredError, UserInputError -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.modelbridge.transforms.derelativize import Derelativize from ax.utils.common.logger import get_logger from pyre_extensions import none_throws @@ -157,7 +157,7 @@ def _plot_type_string(self) -> str: return "Modeled" if self.use_modeled_effects else "Observed" -def _get_max_observed_trial_index(model: ModelBridge) -> int | None: +def _get_max_observed_trial_index(model: Adapter) -> int | None: """Returns the max observed trial index to appease multitask models for prediction by giving fixed features. This is not necessarily accurate and should eventually come from the generation strategy. @@ -178,7 +178,7 @@ def _get_model( use_modeled_effects: bool, trial_index: int, metric_name: str, -) -> ModelBridge: +) -> Adapter: """Get a model for predictions. Args: @@ -213,14 +213,14 @@ def _get_model( if model is None or not is_predictive(model=model): logger.info("Using empirical Bayes for predictions.") - return Models.EMPIRICAL_BAYES_THOMPSON( + return Generators.EMPIRICAL_BAYES_THOMPSON( experiment=experiment, data=trial_data ) return model else: # This model just predicts observed data - return Models.THOMPSON( + return Generators.THOMPSON( data=trial_data, search_space=experiment.search_space, experiment=experiment, @@ -229,7 +229,7 @@ def _get_model( def _prepare_data( experiment: Experiment, - model: ModelBridge, + model: Adapter, outcome_constraints: list[OutcomeConstraint], metric_name: str, trial_index: int, @@ -249,7 +249,7 @@ def _prepare_data( Args: experiment: Experiment to plot - model: ModelBridge being used for prediction + model: Adapter being used for prediction outcome_constraints: Derelatives outcome constraints used for assessing feasibility metric_name: Name of metric to plot diff --git a/ax/analysis/plotly/arm_effects/predicted_effects.py b/ax/analysis/plotly/arm_effects/predicted_effects.py index 741c91cac5b..75155257dc6 100644 --- a/ax/analysis/plotly/arm_effects/predicted_effects.py +++ b/ax/analysis/plotly/arm_effects/predicted_effects.py @@ -23,7 +23,7 @@ from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.exceptions.core import UserInputError -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.transforms.derelativize import Derelativize from pyre_extensions import assert_is_instance, none_throws @@ -149,7 +149,7 @@ def compute( def _prepare_data( - model: ModelBridge, + model: Adapter, metric_name: str, candidate_trial: BaseTrial, outcome_constraints: list[OutcomeConstraint], @@ -167,7 +167,7 @@ def _prepare_data( candidate trial. Args: - model: ModelBridge being used for prediction + model: Adapter being used for prediction metric_name: Name of metric to plot candidate_trial: Trial to plot candidates for by generator run """ diff --git a/ax/analysis/plotly/arm_effects/utils.py b/ax/analysis/plotly/arm_effects/utils.py index 5923c05c5a4..61841e388e7 100644 --- a/ax/analysis/plotly/arm_effects/utils.py +++ b/ax/analysis/plotly/arm_effects/utils.py @@ -19,7 +19,7 @@ from ax.core.outcome_constraint import OutcomeConstraint from ax.core.types import TParameterization from ax.exceptions.core import UserInputError -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.prediction_utils import predict_at_point from plotly import express as px, graph_objects as go from pyre_extensions import none_throws @@ -203,7 +203,7 @@ def _add_style_to_effects_by_arm_plot( ) -def _get_trial_index_for_predictions(model: ModelBridge) -> int | None: +def _get_trial_index_for_predictions(model: Adapter) -> int | None: """Returns status quo features index if defined on the model. Otherwise, returns the max observed trial index to appease multitask models for prediction by giving fixed features. The max index is not necessarily accurate and should @@ -224,7 +224,7 @@ def _get_trial_index_for_predictions(model: ModelBridge) -> int | None: def get_predictions_by_arm( - model: ModelBridge, + model: Adapter, metric_name: str, outcome_constraints: list[OutcomeConstraint], gr: GeneratorRun | None = None, diff --git a/ax/analysis/plotly/cross_validation.py b/ax/analysis/plotly/cross_validation.py index 4214a5156e0..75ba356a826 100644 --- a/ax/analysis/plotly/cross_validation.py +++ b/ax/analysis/plotly/cross_validation.py @@ -58,7 +58,7 @@ def __init__( folds: Number of subsamples to partition observations into. Use -1 for leave-one-out cross validation. untransform: Whether to untransform the model predictions before cross - validating. Models are trained on transformed data, and candidate + validating. Generators are trained on transformed data, and candidate generation is performed in the transformed space. Computing the model quality metric based on the cross-validation results in the untransformed space may not be representative of the model that diff --git a/ax/analysis/plotly/interaction.py b/ax/analysis/plotly/interaction.py index 91446720af6..164e5f91ee6 100644 --- a/ax/analysis/plotly/interaction.py +++ b/ax/analysis/plotly/interaction.py @@ -27,8 +27,8 @@ from ax.core.experiment import Experiment from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.exceptions.core import UserInputError -from ax.modelbridge.registry import Models -from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.registry import Generators +from ax.modelbridge.torch import TorchAdapter from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.utils.common.logger import get_logger from ax.utils.sensitivity.sobol_measures import ax_parameter_sens @@ -261,9 +261,7 @@ def compute( fig=fig, ) - def _get_oak_model( - self, experiment: Experiment, metric_name: str - ) -> TorchModelBridge: + def _get_oak_model(self, experiment: Experiment, metric_name: str) -> TorchAdapter: """ Retrieves the modelbridge used for the analysis. The model uses an OAK (Orthogonal Additive Kernel) with a sparsity-inducing prior, @@ -275,7 +273,7 @@ def _get_oak_model( lengthscales being fit. """ data = experiment.lookup_data().filter(metric_names=[metric_name]) - model_bridge = Models.BOTORCH_MODULAR( + model_bridge = Generators.BOTORCH_MODULAR( search_space=experiment.search_space, experiment=experiment, data=data, @@ -304,12 +302,12 @@ def _get_oak_model( ), ) - return assert_is_instance(model_bridge, TorchModelBridge) + return assert_is_instance(model_bridge, TorchAdapter) def _prepare_surface_plot( experiment: Experiment, - model: TorchModelBridge, + model: TorchAdapter, feature_name: str, metric_name: str, ) -> go.Figure: diff --git a/ax/analysis/plotly/surface/contour.py b/ax/analysis/plotly/surface/contour.py index a43d2b6e861..3ce17133302 100644 --- a/ax/analysis/plotly/surface/contour.py +++ b/ax/analysis/plotly/surface/contour.py @@ -22,7 +22,7 @@ from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.observation import ObservationFeatures from ax.exceptions.core import UserInputError -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.generation_strategy import GenerationStrategy from plotly import graph_objects as go from pyre_extensions import none_throws @@ -113,7 +113,7 @@ def compute( def _prepare_data( experiment: Experiment, - model: ModelBridge, + model: Adapter, x_parameter_name: str, y_parameter_name: str, metric_name: str, diff --git a/ax/analysis/plotly/surface/slice.py b/ax/analysis/plotly/surface/slice.py index 129f4275824..282c19c7c26 100644 --- a/ax/analysis/plotly/surface/slice.py +++ b/ax/analysis/plotly/surface/slice.py @@ -22,7 +22,7 @@ from ax.core.generation_strategy_interface import GenerationStrategyInterface from ax.core.observation import ObservationFeatures from ax.exceptions.core import UserInputError -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.generation_strategy import GenerationStrategy from plotly import express as px, graph_objects as go from pyre_extensions import none_throws @@ -100,7 +100,7 @@ def compute( def _prepare_data( experiment: Experiment, - model: ModelBridge, + model: Adapter, parameter_name: str, metric_name: str, ) -> pd.DataFrame: diff --git a/ax/analysis/plotly/tests/test_predicted_effects.py b/ax/analysis/plotly/tests/test_predicted_effects.py index 6cf77d1b926..862274a3003 100644 --- a/ax/analysis/plotly/tests/test_predicted_effects.py +++ b/ax/analysis/plotly/tests/test_predicted_effects.py @@ -17,7 +17,7 @@ from ax.exceptions.core import UserInputError from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.prediction_utils import predict_at_point -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_branin_experiment, @@ -315,7 +315,7 @@ def test_it_works_for_non_batch_experiments(self) -> None: experiment=experiment, ) # AND GIVEN we generate all Sobol trials and one GPEI trial - sobol_key = Models.SOBOL.value + sobol_key = Generators.SOBOL.value last_model_key = sobol_key while last_model_key == sobol_key: trial = experiment.new_trial( diff --git a/ax/analysis/plotly/tests/test_scatter.py b/ax/analysis/plotly/tests/test_scatter.py index 18538b0b3b5..e1f86659aa1 100644 --- a/ax/analysis/plotly/tests/test_scatter.py +++ b/ax/analysis/plotly/tests/test_scatter.py @@ -8,7 +8,7 @@ from ax.analysis.analysis import AnalysisCardLevel from ax.analysis.plotly.scatter import _prepare_data, ScatterPlot from ax.exceptions.core import DataRequiredError, UserInputError -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_branin_experiment_with_multi_objective, @@ -85,7 +85,7 @@ def test_prepare_data(self) -> None: def test_it_only_has_observations_with_data_for_both_metrics(self) -> None: # GIVEN an experiment with multiple trials and metrics experiment = get_branin_experiment_with_multi_objective() - sobol = Models.SOBOL(search_space=experiment.search_space) + sobol = Generators.SOBOL(search_space=experiment.search_space) t0 = experiment.new_batch_trial(generator_run=sobol.gen(3)).mark_completed( unsafe=True @@ -125,7 +125,7 @@ def test_it_only_has_observations_with_data_for_both_metrics(self) -> None: def test_it_must_have_some_observations_with_data_for_both_metrics(self) -> None: # GIVEN an experiment with multiple trials and metrics experiment = get_branin_experiment_with_multi_objective() - sobol = Models.SOBOL(search_space=experiment.search_space) + sobol = Generators.SOBOL(search_space=experiment.search_space) t0 = experiment.new_batch_trial(generator_run=sobol.gen(3)).mark_completed( unsafe=True diff --git a/ax/analysis/plotly/utils.py b/ax/analysis/plotly/utils.py index dc75196b04d..2198e02d041 100644 --- a/ax/analysis/plotly/utils.py +++ b/ax/analysis/plotly/utils.py @@ -11,7 +11,7 @@ from ax.core.objective import MultiObjective, ScalarizedObjective from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint from ax.exceptions.core import UnsupportedError, UserInputError -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds from numpy.typing import NDArray @@ -126,7 +126,7 @@ def format_constraint_violated_probabilities( return constraints_violated_str -def is_predictive(model: ModelBridge) -> bool: +def is_predictive(model: Adapter) -> bool: """Check if a model is predictive. Basically, we're checking if predict() is implemented. diff --git a/ax/benchmark/benchmark_test_functions/surrogate.py b/ax/benchmark/benchmark_test_functions/surrogate.py index 90832e033c2..d92e1528414 100644 --- a/ax/benchmark/benchmark_test_functions/surrogate.py +++ b/ax/benchmark/benchmark_test_functions/surrogate.py @@ -12,7 +12,7 @@ from ax.benchmark.benchmark_test_function import BenchmarkTestFunction from ax.core.observation import ObservationFeatures from ax.core.types import TParamValue -from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.torch import TorchAdapter from ax.utils.common.base import Base from ax.utils.common.equality import equality_typechecker from pyre_extensions import none_throws @@ -28,7 +28,7 @@ class SurrogateTestFunction(BenchmarkTestFunction): name: The name of the runner. outcome_names: Names of outcomes to return in `evaluate_true`, if the surrogate produces more outcomes than are needed. - _surrogate: Either `None`, or a `TorchModelBridge` surrogate to use + _surrogate: Either `None`, or a `TorchAdapter` surrogate to use for generating observations. If `None`, `get_surrogate` must not be None and will be used to generate the surrogate when it is needed. @@ -39,8 +39,8 @@ class SurrogateTestFunction(BenchmarkTestFunction): name: str outcome_names: Sequence[str] - _surrogate: TorchModelBridge | None = None - get_surrogate: None | Callable[[], TorchModelBridge] = None + _surrogate: TorchAdapter | None = None + get_surrogate: None | Callable[[], TorchAdapter] = None def __post_init__(self) -> None: if self.get_surrogate is None and self._surrogate is None: @@ -50,7 +50,7 @@ def __post_init__(self) -> None: ) @property - def surrogate(self) -> TorchModelBridge: + def surrogate(self) -> TorchAdapter: if self._surrogate is None: self._surrogate = none_throws(self.get_surrogate)() return none_throws(self._surrogate) diff --git a/ax/benchmark/methods/modular_botorch.py b/ax/benchmark/methods/modular_botorch.py index d1f27ef17ad..693927fcee5 100644 --- a/ax/benchmark/methods/modular_botorch.py +++ b/ax/benchmark/methods/modular_botorch.py @@ -10,7 +10,7 @@ from ax.benchmark.benchmark_method import BenchmarkMethod from ax.modelbridge.generation_node import GenerationStep from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.models.torch.botorch_modular.surrogate import SurrogateSpec from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.analytic import LogExpectedImprovement @@ -90,12 +90,12 @@ def get_sobol_mbm_generation_strategy( name=name, steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=num_sobol_trials, min_trials_observed=num_sobol_trials, ), GenerationStep( - model=Models.BOTORCH_MODULAR, + model=Generators.BOTORCH_MODULAR, num_trials=-1, model_kwargs=model_kwargs, model_gen_kwargs=model_gen_kwargs or {}, diff --git a/ax/benchmark/methods/sobol.py b/ax/benchmark/methods/sobol.py index 5c62ec6fc3c..666cccceecd 100644 --- a/ax/benchmark/methods/sobol.py +++ b/ax/benchmark/methods/sobol.py @@ -8,14 +8,14 @@ from ax.benchmark.benchmark_method import BenchmarkMethod from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators def get_sobol_generation_strategy() -> GenerationStrategy: return GenerationStrategy( name="Sobol", steps=[ - GenerationStep(model=Models.SOBOL, num_trials=-1), + GenerationStep(model=Generators.SOBOL, num_trials=-1), ], ) diff --git a/ax/benchmark/tests/benchmark_test_functions/test_surrogate_test_function.py b/ax/benchmark/tests/benchmark_test_functions/test_surrogate_test_function.py index 54e9ae96a8c..9dde2b593c2 100644 --- a/ax/benchmark/tests/benchmark_test_functions/test_surrogate_test_function.py +++ b/ax/benchmark/tests/benchmark_test_functions/test_surrogate_test_function.py @@ -9,7 +9,7 @@ import torch from ax.benchmark.benchmark_test_functions.surrogate import SurrogateTestFunction -from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.torch import TorchAdapter from ax.utils.common.testutils import TestCase from ax.utils.testing.benchmark_stubs import get_soo_surrogate_test_function @@ -38,8 +38,8 @@ def test_lazy_instantiation(self) -> None: self.assertIsNone(test_function._surrogate) # Accessing `surrogate` sets datasets and surrogate - self.assertIsInstance(test_function.surrogate, TorchModelBridge) - self.assertIsInstance(test_function._surrogate, TorchModelBridge) + self.assertIsInstance(test_function.surrogate, TorchAdapter) + self.assertIsInstance(test_function._surrogate, TorchAdapter) with patch.object( test_function, diff --git a/ax/benchmark/tests/methods/test_methods.py b/ax/benchmark/tests/methods/test_methods.py index 877f9370521..971fd54fd19 100644 --- a/ax/benchmark/tests/methods/test_methods.py +++ b/ax/benchmark/tests/methods/test_methods.py @@ -18,7 +18,7 @@ from ax.benchmark.methods.sobol import get_sobol_benchmark_method from ax.benchmark.problems.registry import get_problem from ax.core.experiment import Experiment -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.service.scheduler import Scheduler from ax.service.utils.best_point import ( get_best_by_raw_objective_with_trial_index, @@ -50,7 +50,7 @@ def _test_mbm_acquisition(self, batch_size: int) -> None: self.assertEqual(method.name, expected_name) gs = method.generation_strategy sobol, kg = gs._steps - self.assertEqual(kg.model, Models.BOTORCH_MODULAR) + self.assertEqual(kg.model, Generators.BOTORCH_MODULAR) model_kwargs = none_throws(kg.model_kwargs) self.assertEqual(model_kwargs["botorch_acqf_class"], qKnowledgeGradient) surrogate_spec = model_kwargs["surrogate_spec"] @@ -114,7 +114,7 @@ def test_sobol(self) -> None: self.assertEqual(method.name, "Sobol") gs = method.generation_strategy self.assertEqual(len(gs._steps), 1) - self.assertEqual(gs._steps[0].model, Models.SOBOL) + self.assertEqual(gs._steps[0].model, Generators.SOBOL) problem = get_problem(problem_key="ackley4", num_trials=3) result = benchmark_replication(problem=problem, method=method, seed=0) self.assertTrue(np.isfinite(result.score_trace).all()) diff --git a/ax/benchmark/tests/test_benchmark.py b/ax/benchmark/tests/test_benchmark.py index 5fcfd770ca5..4f744976ea5 100644 --- a/ax/benchmark/tests/test_benchmark.py +++ b/ax/benchmark/tests/test_benchmark.py @@ -49,8 +49,8 @@ from ax.early_stopping.strategies.threshold import ThresholdEarlyStoppingStrategy from ax.modelbridge.external_generation_node import ExternalGenerationNode from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy -from ax.modelbridge.model_spec import ModelSpec -from ax.modelbridge.registry import Models +from ax.modelbridge.model_spec import GeneratorSpec +from ax.modelbridge.registry import Generators from ax.service.utils.scheduler_options import TrialType from ax.storage.json_store.load import load_experiment from ax.storage.json_store.save import save_experiment @@ -769,7 +769,9 @@ def test_replication_with_generation_node(self) -> None: GenerationNode( node_name="Sobol", model_specs=[ - ModelSpec(Models.SOBOL, model_kwargs={"deduplicate": True}) + GeneratorSpec( + Generators.SOBOL, model_kwargs={"deduplicate": True} + ) ], ) ] diff --git a/ax/core/generation_strategy_interface.py b/ax/core/generation_strategy_interface.py index 7c4d14a8bb7..2486cd4756e 100644 --- a/ax/core/generation_strategy_interface.py +++ b/ax/core/generation_strategy_interface.py @@ -129,7 +129,7 @@ def _gen_multiple( resuggesting points that are currently being evaluated. model_gen_kwargs: Keyword arguments that are passed through to ``GenerationNode.gen``, which will pass them through to - ``ModelSpec.gen``, which will pass them to ``ModelBridge.gen``. + ``GeneratorSpec.gen``, which will pass them to ``Adapter.gen``. """ ... diff --git a/ax/core/observation.py b/ax/core/observation.py index 635df7d62b4..45e6c186ea1 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -424,7 +424,7 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]: """ feature_cols = OBS_COLS.intersection(data.df.columns) # note we use this check, rather than isinstance, since - # only some Modelbridges (e.g. MapTorchModelBridge) + # only some Adapters (e.g. MapTorchAdapter) # use observations_from_map_data, which is required # to properly handle MapData features (e.g. fidelity). if is_map_data: @@ -441,7 +441,7 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]: feature_cols.discard(column) # NOTE: This ensures the order of feature_cols is deterministic so that the order # of lists of observations are deterministic, to avoid nondeterministic tests. - # Necessary for test_TorchModelBridge. + # Necessary for test_TorchAdapter. return sorted(feature_cols) diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index 1e23c79ccaa..3fb517b06aa 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -36,7 +36,7 @@ from ax.core.types import ComparisonOp from ax.exceptions.core import AxError, RunnerNotFoundError, UnsupportedError from ax.metrics.branin import BraninMetric -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.runners.synthetic import SyntheticRunner from ax.service.ax_client import AxClient from ax.service.utils.instantiation import ObjectiveProperties @@ -1537,7 +1537,7 @@ def test_WarmStartMapData(self) -> None: trial._properties["source"], "Warm start.*Experiment.*trial" ) self.assertEqual( - trial._properties["generation_model_key"], Models.SOBOL.value + trial._properties["generation_model_key"], Generators.SOBOL.value ) self.assertDictEqual(trial.run_metadata, DUMMY_RUN_METADATA) i_old_trial += 1 @@ -1565,15 +1565,17 @@ def test_batch_with_multiple_generator_runs(self) -> None: # set seed to avoid transient errors caused by duplicate arms, # which leads to fewer arms in the trial than expected. seed = 0 - sobol = Models.SOBOL(experiment=exp, search_space=exp.search_space, seed=seed) + sobol = Generators.SOBOL( + experiment=exp, search_space=exp.search_space, seed=seed + ) exp.new_batch_trial(generator_runs=[sobol.gen(n=7)]).run().complete() data = exp.fetch_data() set_rng_seed(seed) - gp = Models.BOTORCH_MODULAR( + gp = Generators.BOTORCH_MODULAR( experiment=exp, search_space=exp.search_space, data=data ) - ts = Models.EMPIRICAL_BAYES_THOMPSON( + ts = Generators.EMPIRICAL_BAYES_THOMPSON( experiment=exp, search_space=exp.search_space, data=data ) exp.new_batch_trial(generator_runs=[gp.gen(n=3), ts.gen(n=1)]).run().complete() @@ -1586,7 +1588,7 @@ def test_batch_with_multiple_generator_runs(self) -> None: def test_it_does_not_take_both_single_and_multiple_gr_ars(self) -> None: exp = get_branin_experiment() - sobol = Models.SOBOL(experiment=exp, search_space=exp.search_space) + sobol = Generators.SOBOL(experiment=exp, search_space=exp.search_space) gr1 = sobol.gen(n=7) gr2 = sobol.gen(n=7) with self.assertRaisesRegex( diff --git a/ax/early_stopping/strategies/base.py b/ax/early_stopping/strategies/base.py index 628a971b710..cc0222121f1 100644 --- a/ax/early_stopping/strategies/base.py +++ b/ax/early_stopping/strategies/base.py @@ -22,7 +22,7 @@ from ax.core.objective import MultiObjective from ax.early_stopping.utils import estimate_early_stopping_savings -from ax.modelbridge.map_torch import MapTorchModelBridge +from ax.modelbridge.map_torch import MapTorchAdapter from ax.modelbridge.modelbridge_utils import ( _unpack_observations, observation_data_to_array, @@ -32,7 +32,7 @@ from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.map_unit_x import MapUnitX -from ax.models.torch_base import TorchModel +from ax.models.torch_base import TorchGenerator from ax.utils.common.base import Base from ax.utils.common.logger import get_logger from pyre_extensions import assert_is_instance, none_throws @@ -568,9 +568,9 @@ def get_transform_helper_model( experiment: Experiment, data: Data, transforms: list[type[Transform]] | None = None, -) -> MapTorchModelBridge: +) -> MapTorchAdapter: """ - Constructs a TorchModelBridge, to be used as a helper for transforming parameters. + Constructs a TorchAdapter, to be used as a helper for transforming parameters. We perform the default `Cont_X_trans` for parameters but do not perform any transforms on the observations. @@ -582,11 +582,11 @@ def get_transform_helper_model( """ if transforms is None: transforms = Cont_X_trans + [MapUnitX] + Y_trans - return MapTorchModelBridge( + return MapTorchAdapter( experiment=experiment, search_space=experiment.search_space, data=data, - model=TorchModel(), + model=TorchGenerator(), transforms=transforms, fit_out_of_design=True, ) diff --git a/ax/exceptions/model.py b/ax/exceptions/model.py index 77a97944449..78194265743 100644 --- a/ax/exceptions/model.py +++ b/ax/exceptions/model.py @@ -23,10 +23,10 @@ class CVNotSupportedError(AxError): pass -class ModelBridgeMethodNotImplementedError(AxError, NotImplementedError): - """Raised when a ``ModelBridge`` method is not implemented by subclasses. +class AdapterMethodNotImplementedError(AxError, NotImplementedError): + """Raised when a ``Adapter`` method is not implemented by subclasses. - NOTE: ``ModelBridge`` may catch and silently discard this error. + NOTE: ``Adapter`` may catch and silently discard this error. """ pass diff --git a/ax/modelbridge/__init__.py b/ax/modelbridge/__init__.py index e3a9371ba03..ae77e6cbc16 100644 --- a/ax/modelbridge/__init__.py +++ b/ax/modelbridge/__init__.py @@ -8,22 +8,22 @@ # flake8: noqa F401 from ax.modelbridge import transforms -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.factory import ( + Generators, get_factorial, get_sobol, get_thompson, get_uniform, - Models, ) -from ax.modelbridge.map_torch import MapTorchModelBridge -from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.map_torch import MapTorchAdapter +from ax.modelbridge.torch import TorchAdapter __all__ = [ - "MapTorchModelBridge", - "ModelBridge", - "Models", - "TorchModelBridge", + "MapTorchAdapter", + "Adapter", + "Generators", + "TorchAdapter", "get_factorial", "get_sobol", "get_thompson", diff --git a/ax/modelbridge/base.py b/ax/modelbridge/base.py index c4b477e5483..7c6179b73ad 100644 --- a/ax/modelbridge/base.py +++ b/ax/modelbridge/base.py @@ -42,7 +42,7 @@ TParameterization, ) from ax.exceptions.core import UnsupportedError, UserInputError -from ax.exceptions.model import ModelBridgeMethodNotImplementedError +from ax.exceptions.model import AdapterMethodNotImplementedError from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.cast import Cast from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters @@ -70,10 +70,10 @@ class GenResults: gen_metadata: dict[str, Any] = field(default_factory=dict) -class ModelBridge(ABC): # noqa: B024 -- ModelBridge doesn't have any abstract methods. +class Adapter(ABC): # noqa: B024 -- Adapter doesn't have any abstract methods. """The main object for using models in Ax. - ModelBridge specifies 3 methods for using models: + Adapter specifies 3 methods for using models: - predict: Make model predictions. This method is not optimized for speed and so should be used primarily for plotting or similar tasks @@ -81,7 +81,7 @@ class ModelBridge(ABC): # noqa: B024 -- ModelBridge doesn't have any abstract m - gen: Use the model to generate new candidates. - cross_validate: Do cross validation to assess model predictions. - ModelBridge converts Ax types like Data and Arm to types that are + Adapter converts Ax types like Data and Arm to types that are meant to be consumed by the models. The data sent to the model will depend on the implementation of the subclass, which will specify the actual API for external model. @@ -262,7 +262,7 @@ def _fit_if_implemented( increment = time.monotonic() - t_fit_start + time_so_far self.fit_time += increment self.fit_time_since_gen += increment - except ModelBridgeMethodNotImplementedError: + except AdapterMethodNotImplementedError: pass def _process_and_transform_data( @@ -584,7 +584,7 @@ def _fit( observations: list[Observation], ) -> None: """Apply terminal transform and fit model.""" - raise ModelBridgeMethodNotImplementedError( + raise AdapterMethodNotImplementedError( f"{self.__class__.__name__} does not implement `_fit`." ) @@ -714,7 +714,7 @@ def _predict( """Apply terminal transform, predict, and reverse terminal transform on output. """ - raise ModelBridgeMethodNotImplementedError( + raise AdapterMethodNotImplementedError( f"{self.__class__.__name__} does not implement `_predict`." ) @@ -731,7 +731,7 @@ def update(self, new_data: Data, experiment: Experiment) -> None: `update`. experiment: Experiment, in which this data was obtained. """ - raise DeprecationWarning("ModelBridge.update is deprecated. Use `fit` instead.") + raise DeprecationWarning("Adapter.update is deprecated. Use `fit` instead.") def _get_transformed_gen_args( self, @@ -758,7 +758,7 @@ def _get_transformed_gen_args( raise UnsupportedError( "When fit_tracking_metrics is False, the optimization config " "can only include metrics that were included in the " - "optimization config used while initializing the ModelBridge. " + "optimization config used while initializing the Adapter. " f"Metrics {outcomes} is not a subset of {self.outcomes}." ) optimization_config = optimization_config.clone() @@ -800,7 +800,7 @@ def _validate_gen_inputs( fixed_features: ObservationFeatures | None = None, model_gen_options: TConfig | None = None, ) -> None: - """Validate inputs to `ModelBridge.gen`. + """Validate inputs to `Adapter.gen`. Currently, this is only used to ensure that `n` is a positive integer. """ @@ -956,7 +956,7 @@ def _gen( """Apply terminal transform, gen, and reverse terminal transform on output. """ - raise ModelBridgeMethodNotImplementedError( + raise AdapterMethodNotImplementedError( f"{self.__class__.__name__} does not implement `_gen`." ) @@ -1019,7 +1019,7 @@ def _cross_validate( """Apply the terminal transform, make predictions on the test points, and reverse terminal transform on the results. """ - raise ModelBridgeMethodNotImplementedError( + raise AdapterMethodNotImplementedError( f"{self.__class__.__name__} does not implement `_cross_validate`." ) @@ -1096,7 +1096,7 @@ def feature_importances(self, metric_name: str) -> dict[str, float]: importances. """ - raise ModelBridgeMethodNotImplementedError( + raise AdapterMethodNotImplementedError( f"{self.__class__.__name__} does not implement `feature_importances`." ) @@ -1110,7 +1110,7 @@ def transform_observations(self, observations: list[Observation]) -> Any: Returns: Transformed values. This could be e.g. a torch Tensor, depending - on the ModelBridge subclass. + on the Adapter subclass. """ observations = deepcopy(observations) for t in self.transforms.values(): @@ -1121,7 +1121,7 @@ def transform_observations(self, observations: list[Observation]) -> Any: # pyre-fixme[3]: Return annotation cannot be `Any`. def _transform_observations(self, observations: list[Observation]) -> Any: """Apply terminal transform to given observations and return result.""" - raise ModelBridgeMethodNotImplementedError( + raise AdapterMethodNotImplementedError( f"{self.__class__.__name__} does not implement `_transform_observations`." ) @@ -1137,7 +1137,7 @@ def transform_observation_features( Returns: Transformed values. This could be e.g. a torch Tensor, depending - on the ModelBridge subclass. + on the Adapter subclass. """ obsf = deepcopy(observation_features) for t in self.transforms.values(): @@ -1150,7 +1150,7 @@ def _transform_observation_features( self, observation_features: list[ObservationFeatures] ) -> Any: """Apply terminal transform to given observation features and return result.""" - raise ModelBridgeMethodNotImplementedError( + raise AdapterMethodNotImplementedError( f"{self.__class__.__name__} does not implement " "`_transform_observation_features`." ) diff --git a/ax/modelbridge/best_model_selector.py b/ax/modelbridge/best_model_selector.py index 7100a241e79..8b3c5ab4876 100644 --- a/ax/modelbridge/best_model_selector.py +++ b/ax/modelbridge/best_model_selector.py @@ -17,7 +17,7 @@ import numpy as np import numpy.typing as npt from ax.exceptions.core import UserInputError -from ax.modelbridge.model_spec import ModelSpec +from ax.modelbridge.model_spec import GeneratorSpec from ax.utils.common.base import Base from pyre_extensions import none_throws @@ -27,12 +27,12 @@ class BestModelSelector(ABC, Base): @abstractmethod - def best_model(self, model_specs: list[ModelSpec]) -> ModelSpec: - """Return the best ``ModelSpec`` based on some criteria. + def best_model(self, model_specs: list[GeneratorSpec]) -> GeneratorSpec: + """Return the best ``GeneratorSpec`` based on some criteria. - NOTE: The returned ``ModelSpec`` may be a different object than + NOTE: The returned ``GeneratorSpec`` may be a different object than what was provided in the original list. It may be possible to - clone and modify the original ``ModelSpec`` to produce one that + clone and modify the original ``GeneratorSpec`` to produce one that performs better. """ @@ -59,7 +59,7 @@ def __call__(self, array_like: ARRAYLIKE) -> npt.NDArray: class SingleDiagnosticBestModelSelector(BestModelSelector): """Choose the best model using a single cross-validation diagnostic. - The input is a list of ``ModelSpec``, each corresponding to one model. + The input is a list of ``GeneratorSpec``, each corresponding to one model. The specified diagnostic is extracted from each of the models, its values (each of which corresponds to a separate metric) are aggregated with the aggregation function, the best one is determined @@ -109,14 +109,14 @@ def __init__( self.criterion = criterion self.model_cv_kwargs = model_cv_kwargs - def best_model(self, model_specs: list[ModelSpec]) -> ModelSpec: - """Return the best ``ModelSpec`` based on the specified diagnostic. + def best_model(self, model_specs: list[GeneratorSpec]) -> GeneratorSpec: + """Return the best ``GeneratorSpec`` based on the specified diagnostic. Args: - model_specs: List of ``ModelSpec`` to choose from. + model_specs: List of ``GeneratorSpec`` to choose from. Returns: - The best ``ModelSpec`` based on the specified diagnostic. + The best ``GeneratorSpec`` based on the specified diagnostic. """ for model_spec in model_specs: model_spec.cross_validate(model_cv_kwargs=self.model_cv_kwargs) diff --git a/ax/modelbridge/cross_validation.py b/ax/modelbridge/cross_validation.py index 318b5f7e3ca..4ffaffcadbf 100644 --- a/ax/modelbridge/cross_validation.py +++ b/ax/modelbridge/cross_validation.py @@ -20,7 +20,7 @@ import numpy.typing as npt from ax.core.observation import Observation, ObservationData, recombine_observations from ax.core.optimization_config import OptimizationConfig -from ax.modelbridge.base import ModelBridge, unwrap_observation_data +from ax.modelbridge.base import Adapter, unwrap_observation_data from ax.utils.common.logger import get_logger from ax.utils.stats.model_fit_stats import ( coefficient_of_determination, @@ -53,7 +53,7 @@ class AssessModelFitResult(NamedTuple): def cross_validate( - model: ModelBridge, + model: Adapter, folds: int = -1, test_selector: Callable | None = None, untransform: bool = True, @@ -76,13 +76,13 @@ def cross_validate( If not provided, all observations will be available for the test set. Args: - model: Fitted model (ModelBridge) to cross validate. + model: Fitted model (Adapter) to cross validate. folds: Number of folds. Use -1 for leave-one-out, otherwise will be k-fold. test_selector: Function for selecting observations for the test set. untransform: Whether to untransform the model predictions before cross validating. - Models are trained on transformed data, and candidate generation + Generators are trained on transformed data, and candidate generation is performed in the transformed space. Computing the model quality metric based on the cross-validation results in the untransformed space may not be representative of the model that @@ -187,7 +187,7 @@ def cross_validate( def cross_validate_by_trial( - model: ModelBridge, trial: int = -1, use_posterior_predictive: bool = False + model: Adapter, trial: int = -1, use_posterior_predictive: bool = False ) -> list[CVResult]: """Cross validation for model predictions on a particular trial. @@ -195,7 +195,7 @@ def cross_validate_by_trial( arms that was launched in that trial. Defaults to the last trial. Args: - model: Fitted model (ModelBridge) to cross validate. + model: Fitted model (Adapter) to cross validate. trial: Trial for which predictions are evaluated. use_posterior_predictive: A boolean indicating if the predictions should be from the posterior predictive (i.e. including @@ -407,10 +407,10 @@ def _gen_train_test_split( def get_fit_and_std_quality_and_generalization_dict( - fitted_model_bridge: ModelBridge, + fitted_model_bridge: Adapter, ) -> dict[str, float | None]: """ - Get stats and gen from a fitted ModelBridge for analytics purposes. + Get stats and gen from a fitted Adapter for analytics purposes. """ try: model_fit_dict = compute_model_fit_metrics_from_modelbridge( @@ -446,17 +446,17 @@ def get_fit_and_std_quality_and_generalization_dict( def compute_model_fit_metrics_from_modelbridge( - model_bridge: ModelBridge, + model_bridge: Adapter, fit_metrics_dict: dict[str, ModelFitMetricProtocol] | None = None, generalization: bool = False, untransform: bool = False, ) -> dict[str, dict[str, float]]: - """Computes the model fit metrics given a ModelBridge and an Experiment. + """Computes the model fit metrics given a Adapter and an Experiment. Args: - model_bridge: The ModelBridge for which to compute the model fit metrics. + model_bridge: The Adapter for which to compute the model fit metrics. experiment: The experiment with whose data to compute the metrics if - generalization == False. Otherwise, the data is taken from the ModelBridge. + generalization == False. Otherwise, the data is taken from the Adapter. fit_metrics_dict: An optional dictionary with model fit metric functions, i.e. a ModelFitMetricProtocol, as values and their names as keys. generalization: Boolean indicating whether to compute the generalization @@ -536,21 +536,21 @@ def _model_std_quality(std: npt.NDArray) -> float: def _predict_on_training_data( - model_bridge: ModelBridge, + model_bridge: Adapter, untransform: bool = False, ) -> tuple[ dict[str, npt.NDArray], dict[str, npt.NDArray], dict[str, npt.NDArray], ]: - """Makes predictions on the training data of a given experiment using a ModelBridge + """Makes predictions on the training data of a given experiment using a Adapter and returning the observed values, and the corresponding predictive means and predictive standard deviations of the model, in transformed space. NOTE: This is a helper function for `compute_model_fit_metrics_from_modelbridge`. Args: - model_bridge: A ModelBridge object with which to make predictions. + model_bridge: A Adapter object with which to make predictions. untransform: Boolean indicating whether to untransform model predictions. Returns: @@ -600,7 +600,7 @@ def _predict_on_training_data( def _predict_on_cross_validation_data( - model_bridge: ModelBridge, + model_bridge: Adapter, untransform: bool = False, ) -> tuple[ dict[str, npt.NDArray], @@ -608,14 +608,14 @@ def _predict_on_cross_validation_data( dict[str, npt.NDArray], ]: """Makes leave-one-out cross-validation predictions on the training data of the - ModelBridge and returns the observed values, and the corresponding predictive means + Adapter and returns the observed values, and the corresponding predictive means and predictive standard deviations of the model as numpy arrays, in transformed space. NOTE: This is a helper function for `compute_model_fit_metrics_from_modelbridge`. Args: - model_bridge: A ModelBridge object with which to make predictions. + model_bridge: A Adapter object with which to make predictions. untransform: Boolean indicating whether to untransform model predictions before cross validating. False by default. diff --git a/ax/modelbridge/discrete.py b/ax/modelbridge/discrete.py index 761cd54da59..5573323974c 100644 --- a/ax/modelbridge/discrete.py +++ b/ax/modelbridge/discrete.py @@ -18,7 +18,7 @@ from ax.core.search_space import SearchSpace from ax.core.types import TParamValueList from ax.exceptions.core import UserInputError -from ax.modelbridge.base import GenResults, ModelBridge +from ax.modelbridge.base import Adapter, GenResults from ax.modelbridge.modelbridge_utils import ( array_to_observation_data, get_fixed_features, @@ -28,21 +28,21 @@ extract_outcome_constraints, validate_optimization_config, ) -from ax.models.discrete_base import DiscreteModel +from ax.models.discrete_base import DiscreteGenerator from ax.models.types import TConfig FIT_MODEL_ERROR = "Model must be fit before {action}." -class DiscreteModelBridge(ModelBridge): +class DiscreteAdapter(Adapter): """A model bridge for using models based on discrete parameters. Requires that all parameters have been transformed to ChoiceParameters. """ # pyre-fixme[13]: Attribute `model` is never initialized. - model: DiscreteModel + model: DiscreteGenerator # pyre-fixme[13]: Attribute `outcomes` is never initialized. outcomes: list[str] # pyre-fixme[13]: Attribute `parameters` is never initialized. @@ -52,7 +52,7 @@ class DiscreteModelBridge(ModelBridge): def _fit( self, - model: DiscreteModel, + model: DiscreteGenerator, search_space: SearchSpace, observations: list[Observation], ) -> None: @@ -102,7 +102,7 @@ def _validate_gen_inputs( fixed_features: ObservationFeatures | None = None, model_gen_options: TConfig | None = None, ) -> None: - """Validate inputs to `ModelBridge.gen`. + """Validate inputs to `Adapter.gen`. Currently, this is only used to ensure that `n` is a positive integer or -1. """ diff --git a/ax/modelbridge/dispatch_utils.py b/ax/modelbridge/dispatch_utils.py index 548369a866c..6c951201bda 100644 --- a/ax/modelbridge/dispatch_utils.py +++ b/ax/modelbridge/dispatch_utils.py @@ -17,10 +17,16 @@ from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.search_space import SearchSpace from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.registry import MODEL_KEY_TO_MODEL_SETUP, ModelRegistryBase, Models +from ax.modelbridge.registry import ( + Generators, + MODEL_KEY_TO_MODEL_SETUP, + ModelRegistryBase, +) from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.winsorize import Winsorize -from ax.models.torch.botorch_modular.model import BoTorchModel as ModularBoTorchModel +from ax.models.torch.botorch_modular.model import ( + BoTorchGenerator as ModularBoTorchGenerator, +) from ax.models.types import TConfig from ax.models.winsorization_config import WinsorizationConfig from ax.utils.common.deprecation import _validate_force_random_search @@ -55,7 +61,7 @@ def _make_sobol_step( ) -> GenerationStep: """Shortcut for creating a Sobol generation step.""" return GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=num_trials, # NOTE: ceil(-1 / 2) = 0, so this is safe to do when num trials is -1. min_trials_observed=min_trials_observed or ceil(num_trials / 2), @@ -71,7 +77,7 @@ def _make_botorch_step( min_trials_observed: int | None = None, enforce_num_trials: bool = True, max_parallelism: int | None = None, - model: ModelRegistryBase = Models.BOTORCH_MODULAR, + model: ModelRegistryBase = Generators.BOTORCH_MODULAR, model_kwargs: dict[str, Any] | None = None, winsorization_config: None | (WinsorizationConfig | dict[str, WinsorizationConfig]) = None, @@ -111,7 +117,7 @@ def _make_botorch_step( winsorization_transform_config ) - if MODEL_KEY_TO_MODEL_SETUP[model.value].model_class != ModularBoTorchModel: + if MODEL_KEY_TO_MODEL_SETUP[model.value].model_class != ModularBoTorchGenerator: if verbose is not None: model_kwargs.update({"verbose": verbose}) if disable_progbar is not None: @@ -122,7 +128,7 @@ def _make_botorch_step( # TODO[T164389105] Rewrite choose_generation_strategy to be MBM first logger.info( "`verbose`, `disable_progbar`, and `jit_compile` are not yet supported " - "when using `choose_generation_strategy` with ModularBoTorchModel, " + "when using `choose_generation_strategy` with ModularBoTorchGenerator, " "dropping these arguments." ) return GenerationStep( @@ -226,7 +232,7 @@ def _suggest_gp_model( ) if use_saasbo: logger.warning(SAASBO_INCOMPATIBLE_MESSAGE.format("`BO_MIXED`")) - return Models.BO_MIXED + return Generators.BO_MIXED if num_ordered_parameters >= num_unordered_choices or ( num_unordered_choices < MAX_ONE_HOT_ENCODINGS_CONTINUOUS_OPTIMIZATION @@ -235,7 +241,7 @@ def _suggest_gp_model( # These use one-hot encoding for unordered choice parameters, resulting in a # total of num_unordered_choices OHE parameters. # So, we do not want to use them when there are too many unordered choices. - method = Models.SAASBO if use_saasbo else Models.BOTORCH_MODULAR + method = Generators.SAASBO if use_saasbo else Generators.BOTORCH_MODULAR reason = ( ( "there are more ordered parameters than there are categories for the " @@ -493,7 +499,7 @@ def choose_generation_strategy( ) steps = [] # `verbose` and `disable_progbar` defaults and overrides - model_is_saasbo = suggested_model is Models.SAASBO + model_is_saasbo = suggested_model is Generators.SAASBO if verbose is None and model_is_saasbo: verbose = True elif verbose is not None and not model_is_saasbo: diff --git a/ax/modelbridge/external_generation_node.py b/ax/modelbridge/external_generation_node.py index 4f20ee7d2dd..3f05da84381 100644 --- a/ax/modelbridge/external_generation_node.py +++ b/ax/modelbridge/external_generation_node.py @@ -181,8 +181,9 @@ def _gen( pending_observations: A map from metric name to pending observations for that metric, used by some methods to avoid re-suggesting candidates that are currently being evaluated. - model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``; - these override any pre-specified in ``ModelSpec.model_gen_kwargs``. + model_gen_kwargs: Keyword arguments, passed through to + ``GeneratorSpec.gen``; these override any pre-specified in + ``GeneratorSpec.model_gen_kwargs``. Returns: A ``GeneratorRun`` containing the newly generated candidates. @@ -207,7 +208,7 @@ def _gen( gen_time=time.monotonic() - t_gen_start, model_key=self.node_name, ) - # TODO: This shares the same bug as ModelBridge.gen. In both cases, after + # TODO: This shares the same bug as Adapter.gen. In both cases, after # deduplication, the generator run will record fit_time as 0. self.fit_time_since_gen = 0 return generator_run diff --git a/ax/modelbridge/factory.py b/ax/modelbridge/factory.py index bbc568207ce..409a5508393 100644 --- a/ax/modelbridge/factory.py +++ b/ax/modelbridge/factory.py @@ -13,10 +13,10 @@ from ax.core.experiment import Experiment from ax.core.optimization_config import OptimizationConfig from ax.core.search_space import SearchSpace -from ax.modelbridge.discrete import DiscreteModelBridge -from ax.modelbridge.random import RandomModelBridge -from ax.modelbridge.registry import Cont_X_trans, Models, Y_trans -from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.discrete import DiscreteAdapter +from ax.modelbridge.random import RandomAdapter +from ax.modelbridge.registry import Cont_X_trans, Generators, Y_trans +from ax.modelbridge.torch import TorchAdapter from ax.modelbridge.transforms.base import Transform from ax.models.torch.botorch import ( TAcqfConstructor, @@ -61,7 +61,7 @@ def get_sobol( init_position: int = 0, scramble: bool = True, fallback_to_sample_polytope: bool = False, -) -> RandomModelBridge: +) -> RandomAdapter: """Instantiates a Sobol sequence quasi-random generator. Args: @@ -69,10 +69,10 @@ def get_sobol( kwargs: Custom args for sobol generator. Returns: - RandomModelBridge, with SobolGenerator as model. + RandomAdapter, with SobolGenerator as model. """ return assert_is_instance( - Models.SOBOL( + Generators.SOBOL( search_space=search_space, seed=seed, deduplicate=deduplicate, @@ -80,13 +80,13 @@ def get_sobol( scramble=scramble, fallback_to_sample_polytope=fallback_to_sample_polytope, ), - RandomModelBridge, + RandomAdapter, ) def get_uniform( search_space: SearchSpace, deduplicate: bool = False, seed: int | None = None -) -> RandomModelBridge: +) -> RandomAdapter: """Instantiate uniform generator. Args: @@ -94,11 +94,13 @@ def get_uniform( kwargs: Custom args for uniform generator. Returns: - RandomModelBridge, with UniformGenerator as model. + RandomAdapter, with UniformGenerator as model. """ return assert_is_instance( - Models.UNIFORM(search_space=search_space, seed=seed, deduplicate=deduplicate), - RandomModelBridge, + Generators.UNIFORM( + search_space=search_space, seed=seed, deduplicate=deduplicate + ), + RandomAdapter, ) @@ -116,12 +118,12 @@ def get_botorch( acqf_optimizer: TOptimizer = scipy_optimizer, # pyre-ignore[9] refit_on_cv: bool = False, optimization_config: OptimizationConfig | None = None, -) -> TorchModelBridge: - """Instantiates a BotorchModel.""" +) -> TorchAdapter: + """Instantiates a BotorchGenerator.""" if data.df.empty: - raise ValueError("`BotorchModel` requires non-empty data.") + raise ValueError("`BotorchGenerator` requires non-empty data.") return assert_is_instance( - Models.LEGACY_BOTORCH( + Generators.LEGACY_BOTORCH( experiment=experiment, data=data, search_space=search_space or experiment.search_space, @@ -136,15 +138,15 @@ def get_botorch( refit_on_cv=refit_on_cv, optimization_config=optimization_config, ), - TorchModelBridge, + TorchAdapter, ) -def get_factorial(search_space: SearchSpace) -> DiscreteModelBridge: +def get_factorial(search_space: SearchSpace) -> DiscreteAdapter: """Instantiates a factorial generator.""" return assert_is_instance( - Models.FACTORIAL(search_space=search_space, fit_out_of_design=True), - DiscreteModelBridge, + Generators.FACTORIAL(search_space=search_space, fit_out_of_design=True), + DiscreteAdapter, ) @@ -155,12 +157,12 @@ def get_empirical_bayes_thompson( num_samples: int = 10000, min_weight: float | None = None, uniform_weights: bool = False, -) -> DiscreteModelBridge: +) -> DiscreteAdapter: """Instantiates an empirical Bayes / Thompson sampling model.""" if data.df.empty: raise ValueError("Empirical Bayes Thompson sampler requires non-empty data.") return assert_is_instance( - Models.EMPIRICAL_BAYES_THOMPSON( + Generators.EMPIRICAL_BAYES_THOMPSON( experiment=experiment, data=data, search_space=search_space or experiment.search_space, @@ -169,7 +171,7 @@ def get_empirical_bayes_thompson( uniform_weights=uniform_weights, fit_out_of_design=True, ), - DiscreteModelBridge, + DiscreteAdapter, ) @@ -180,12 +182,12 @@ def get_thompson( num_samples: int = 10000, min_weight: float | None = None, uniform_weights: bool = False, -) -> DiscreteModelBridge: +) -> DiscreteAdapter: """Instantiates a Thompson sampling model.""" if data.df.empty: raise ValueError("Thompson sampler requires non-empty data.") return assert_is_instance( - Models.THOMPSON( + Generators.THOMPSON( experiment=experiment, data=data, search_space=search_space or experiment.search_space, @@ -194,5 +196,5 @@ def get_thompson( uniform_weights=uniform_weights, fit_out_of_design=True, ), - DiscreteModelBridge, + DiscreteAdapter, ) diff --git a/ax/modelbridge/generation_node.py b/ax/modelbridge/generation_node.py index 61bebbb8b75..4c3abc51226 100644 --- a/ax/modelbridge/generation_node.py +++ b/ax/modelbridge/generation_node.py @@ -26,10 +26,10 @@ from ax.core.search_space import SearchSpace from ax.exceptions.core import UserInputError from ax.exceptions.generation_strategy import GenerationStrategyRepeatedPoints -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.best_model_selector import BestModelSelector -from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec +from ax.modelbridge.model_spec import FactoryFunctionGeneratorSpec, GeneratorSpec from ax.modelbridge.registry import _extract_model_state_after_gen, ModelRegistryBase from ax.modelbridge.transition_criterion import ( AutoTransitionAfterGen, @@ -47,11 +47,11 @@ logger: Logger = get_logger(__name__) -TModelFactory = Callable[..., ModelBridge] +TModelFactory = Callable[..., Adapter] MISSING_MODEL_SELECTOR_MESSAGE = ( "A `BestModelSelector` must be provided when using multiple " - "`ModelSpec`s in a `GenerationNode`. After fitting all `ModelSpec`s, " - "the `BestModelSelector` will be used to select the `ModelSpec` to " + "`GeneratorSpec`s in a `GenerationNode`. After fitting all `GeneratorSpec`s, " + "the `BestModelSelector` will be used to select the `GeneratorSpec` to " "use for candidate generation." ) MAX_GEN_DRAWS = 5 @@ -68,10 +68,11 @@ class GenerationNode(SerializationMixin, SortableBase): Args: node_name: A unique name for the GenerationNode. Used for storage purposes. - model_specs: A list of ModelSpecs to be selected from for generation in this + model_specs: A list of GeneratorSpecs to be selected from for generation in this GenerationNode. - best_model_selector: A ``BestModelSelector`` used to select the ``ModelSpec`` - to generate from in ``GenerationNode`` with multiple ``ModelSpec``s. + best_model_selector: A ``BestModelSelector`` used to select the + ``GeneratorSpec`` to generate from in ``GenerationNode`` with + multiple ``GeneratorSpec``s. should_deduplicate: Whether to deduplicate the parameters of proposed arms against those of previous arms via rejection sampling. If this is True, the GenerationStrategy will discard generator runs produced from the @@ -99,19 +100,19 @@ class GenerationNode(SerializationMixin, SortableBase): should_skip: Whether to skip this node during generation time. Defaults to False, and can only currently be set to True via ``NodeInputConstructors`` - Note for developers: by "model" here we really mean an Ax ModelBridge object, which + Note for developers: by "model" here we really mean an Ax Adapter object, which contains an Ax Model under the hood. We call it "model" here to simplify and focus on explaining the logic of GenerationStep and GenerationStrategy. """ # Required options: - model_specs: list[ModelSpec] - # TODO: Move `should_deduplicate` to `ModelSpec` if possible, and make optional + model_specs: list[GeneratorSpec] + # TODO: Move `should_deduplicate` to `GeneratorSpec` if possible, and make optional should_deduplicate: bool _node_name: str # Optional specifications - _model_spec_to_gen_from: ModelSpec | None = None + _model_spec_to_gen_from: GeneratorSpec | None = None # TODO: @mgarrard should this be a dict criterion_class name -> criterion mapping? _transition_criteria: Sequence[TransitionCriterion] _input_constructors: dict[ @@ -131,7 +132,7 @@ class GenerationNode(SerializationMixin, SortableBase): def __init__( self, node_name: str, - model_specs: list[ModelSpec], + model_specs: list[GeneratorSpec], best_model_selector: BestModelSelector | None = None, should_deduplicate: bool = False, transition_criteria: Sequence[TransitionCriterion] | None = None, @@ -181,7 +182,7 @@ def node_name(self) -> str: return self._node_name @property - def model_spec_to_gen_from(self) -> ModelSpec: + def model_spec_to_gen_from(self) -> GeneratorSpec: """Returns the cached `_model_spec_to_gen_from` or gets it from `_pick_fitted_model_to_gen_from` and then caches and returns it """ @@ -268,7 +269,7 @@ def _unique_id(self) -> str: return self.node_name @property - def _fitted_model(self) -> ModelBridge | None: + def _fitted_model(self) -> Adapter | None: """Private property to return optional fitted_model from self.model_spec_to_gen_from for convenience. If no model is fit, will return None. If using the non-private `fitted_model` property, @@ -296,7 +297,7 @@ def fit( optimization config. kwargs: Additional keyword arguments to pass to the model's ``fit`` method. NOTE: Local kwargs take precedence over the ones - stored in ``ModelSpec.model_kwargs``. + stored in ``GeneratorSpec.model_kwargs``. """ if not data.df.empty: trial_indices_in_data = sorted(data.df["trial_index"].unique()) @@ -322,11 +323,11 @@ def fit( ) def _get_model_state_from_last_generator_run( - self, model_spec: ModelSpec + self, model_spec: GeneratorSpec ) -> dict[str, Any]: """Get the fit args from the last generator run for the model being fit. - NOTE: This only works for the base ModelSpec class. Factory functions + NOTE: This only works for the base GeneratorSpec class. Factory functions are not supported and will return an empty dict. Args: @@ -337,7 +338,7 @@ def _get_model_state_from_last_generator_run( that was generated by the model being fit. """ if ( - isinstance(model_spec, FactoryFunctionModelSpec) + isinstance(model_spec, FactoryFunctionGeneratorSpec) or self._generation_strategy is None ): # We cannot extract the args for factory functions (which are to be @@ -375,7 +376,7 @@ def gen( """This method generates candidates using `self._gen` and handles deduplication of generated candidates if `self.should_deduplicate=True`. - NOTE: Models must have been fit prior to calling ``gen``. + NOTE: Generators must have been fit prior to calling ``gen``. NOTE: Some underlying models may ignore the ``n`` argument and produce a model-determined number of arms. In that case this method will also output a generator run with number of arms that may differ from ``n``. @@ -383,7 +384,7 @@ def gen( Args: n: Optional integer representing how many arms should be in the generator run produced by this method. When this is ``None``, ``n`` will be - determined by the ``ModelSpec`` that we are generating from. + determined by the ``GeneratorSpec`` that we are generating from. pending_observations: A map from metric name to pending observations for that metric, used by some models to avoid resuggesting points that are currently being evaluated. @@ -393,8 +394,9 @@ def gen( exception will be raised. arms_by_signature_for_deduplication: A dictionary mapping arm signatures to the arms, to be used for deduplicating newly generated arms. - model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``; - these override any pre-specified in ``ModelSpec.model_gen_kwargs``. + model_gen_kwargs: Keyword arguments, passed through to + ``GeneratorSpec.gen``; these override any pre-specified in + ``GeneratorSpec.model_gen_kwargs``. Returns: A ``GeneratorRun`` containing the newly generated candidates. @@ -454,19 +456,20 @@ def _gen( ) -> GeneratorRun: """Picks a fitted model, from which to generate candidates (via ``self._pick_fitted_model_to_gen_from``) and generates candidates - from it. Uses the ``model_gen_kwargs`` set on the selected ``ModelSpec`` + from it. Uses the ``model_gen_kwargs`` set on the selected ``GeneratorSpec`` alongside any kwargs passed in to this function (with local kwargs) taking precedent. Args: n: Optional integer representing how many arms should be in the generator run produced by this method. When this is ``None``, ``n`` will be - determined by the ``ModelSpec`` that we are generating from. + determined by the ``GeneratorSpec`` that we are generating from. pending_observations: A map from metric name to pending observations for that metric, used by some models to avoid resuggesting points that are currently being evaluated. - model_gen_kwargs: Keyword arguments, passed through to ``ModelSpec.gen``; - these override any pre-specified in ``ModelSpec.model_gen_kwargs``. + model_gen_kwargs: Keyword arguments, passed through to + ``GeneratorSpec.gen``;these override any pre-specified in + ``GeneratorSpec.model_gen_kwargs``. Returns: A ``GeneratorRun`` containing the newly generated candidates. @@ -487,7 +490,7 @@ def _gen( # ------------------------- Model selection logic helpers. ------------------------- - def _pick_fitted_model_to_gen_from(self) -> ModelSpec: + def _pick_fitted_model_to_gen_from(self) -> GeneratorSpec: """Select one model to generate from among the fitted models on this generation node. @@ -496,7 +499,7 @@ def _pick_fitted_model_to_gen_from(self) -> ModelSpec: use it to select one model to generate from among the fitted models on this generation node. 2. otherwise, ensure that this ``GenerationNode`` only contains one - `ModelSpec` and select it. + `GeneratorSpec` and select it. """ if self.best_model_selector is None: if len(self.model_specs) != 1: # pragma: no cover -- raised in __init__. @@ -687,13 +690,13 @@ class GenerationStep(GenerationNode, SortableBase): minimum number of observations is required to proceed to the next model, etc. NOTE: Model can be specified either from the model registry - (`ax.modelbridge.registry.Models` or using a callable model constructor. Only + (`ax.modelbridge.registry.Generators` or using a callable model constructor. Only models from the registry can be saved, and thus optimization can only be resumed if interrupted when using models from the registry. Args: - model: A member of `Models` enum or a callable returning an instance of - `ModelBridge` with an instantiated underlying `Model`. Refer to + model: A member of `Generators` enum or a callable returning an instance of + `Adapter` with an instantiated underlying `Model`. Refer to `ax/modelbridge/factory.py` for examples of such callables. num_trials: How many trials to generate with the model from this step. If set to -1, trials will continue to be generated from this model @@ -718,12 +721,13 @@ class GenerationStep(GenerationNode, SortableBase): trials` for it. Allows to avoid `DataRequiredError`, but delays proceeding to next generation step. model_kwargs: Dictionary of kwargs to pass into the model constructor on - instantiation. E.g. if `model` is `Models.SOBOL`, kwargs will be applied - as `Models.SOBOL(**model_kwargs)`; if `model` is `get_sobol`, `get_sobol( - **model_kwargs)`. NOTE: if generation strategy is interrupted and - resumed from a stored snapshot and its last used model has state saved on - its generator runs, `model_kwargs` is updated with the state dict of the - model, retrieved from the last generator run of this generation strategy. + instantiation. E.g. if `model` is `Generators.SOBOL`, kwargs will be applied + as `Generators.SOBOL(**model_kwargs)`; if `model` is `get_sobol`, + `get_sobol(**model_kwargs)`. NOTE: if generation strategy is + interrupted and resumed from a stored snapshot and its last used + model has state saved on its generator runs, `model_kwargs` is + updated with the state dict of the model, retrieved from the last + generator run of this generation strategy. model_gen_kwargs: Each call to `generation_strategy.gen` performs a call to the step's model's `gen` under the hood; `model_gen_kwargs` will be passed to the model's `gen` like so: `model.gen(**model_gen_kwargs)`. @@ -744,14 +748,14 @@ class GenerationStep(GenerationNode, SortableBase): model_name: Optional name of the model. If not specified, defaults to the model key of the model spec. - Note for developers: by "model" here we really mean an Ax ModelBridge object, which + Note for developers: by "model" here we really mean an Ax Adapter object, which contains an Ax Model under the hood. We call it "model" here to simplify and focus on explaining the logic of GenerationStep and GenerationStrategy. """ def __init__( self, - model: ModelRegistryBase | Callable[..., ModelBridge], + model: ModelRegistryBase | Callable[..., Adapter], num_trials: int, model_kwargs: dict[str, Any] | None = None, model_gen_kwargs: dict[str, Any] | None = None, @@ -806,7 +810,7 @@ def __init__( "enum subclass entry or a callable factory function returning a " "model bridge instance." ) - model_spec = FactoryFunctionModelSpec( + model_spec = FactoryFunctionGeneratorSpec( factory_function=self.model, # Only pass down the model name if it is not empty. model_key_override=model_name if model_name else None, @@ -814,7 +818,7 @@ def __init__( model_gen_kwargs=model_gen_kwargs, ) else: - model_spec = ModelSpec( + model_spec = GeneratorSpec( model_enum=self.model, model_kwargs=model_kwargs, model_gen_kwargs=model_gen_kwargs, @@ -873,16 +877,16 @@ def __init__( @property def model_kwargs(self) -> dict[str, Any]: - """Returns the model kwargs of the underlying ``ModelSpec``.""" + """Returns the model kwargs of the underlying ``GeneratorSpec``.""" return self.model_spec.model_kwargs @property def model_gen_kwargs(self) -> dict[str, Any]: - """Returns the model gen kwargs of the underlying ``ModelSpec``.""" + """Returns the model gen kwargs of the underlying ``GeneratorSpec``.""" return self.model_spec.model_gen_kwargs @property - def model_spec(self) -> ModelSpec: + def model_spec(self) -> GeneratorSpec: """Returns the first model_spec from the model_specs attribute.""" return self.model_specs[0] diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index 22e4c717539..591a27bd4ad 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -28,10 +28,10 @@ GenerationStrategyCompleted, GenerationStrategyMisconfiguredException, ) -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.generation_node import GenerationNode, GenerationStep from ax.modelbridge.generation_node_input_constructors import InputConstructorPurpose -from ax.modelbridge.model_spec import FactoryFunctionModelSpec +from ax.modelbridge.model_spec import FactoryFunctionGeneratorSpec from ax.modelbridge.transition_criterion import TrialBasedCriterion from ax.utils.common.logger import get_logger from ax.utils.common.typeutils import assert_is_instance_list @@ -93,7 +93,7 @@ class GenerationStrategy(GenerationStrategyInterface): _nodes: list[GenerationNode] _curr: GenerationNode # Current node in the strategy. - # Whether all models in this GS are in Models registry enum. + # Whether all models in this GS are in Generators registry enum. _uses_registered_models: bool # All generator runs created through this generation strategy, in chronological # order. @@ -101,7 +101,7 @@ class GenerationStrategy(GenerationStrategyInterface): # Experiment, for which this generation strategy has generated trials, if # it exists. _experiment: Experiment | None = None - _model: ModelBridge | None = None # Current model. + _model: Adapter | None = None # Current model. def __init__( self, @@ -135,7 +135,7 @@ def __init__( # Log warning if the GS uses a non-registered (factory function) model. self._uses_registered_models = not any( - isinstance(ms, FactoryFunctionModelSpec) + isinstance(ms, FactoryFunctionGeneratorSpec) for node in self._nodes for ms in node.model_specs ) @@ -228,7 +228,7 @@ def current_step_index(self) -> int: return node_names_for_all_steps.index(self._curr.node_name) @property - def model(self) -> ModelBridge | None: + def model(self) -> Adapter | None: """Current model in this strategy. Returns None if no model has been set yet (i.e., if no generator runs have been produced from this GS). """ @@ -856,7 +856,7 @@ def _gen_multiple( resuggesting points that are currently being evaluated. model_gen_kwargs: Keyword arguments that are passed through to ``GenerationNode.gen``, which will pass them through to - ``ModelSpec.gen``, which will pass them to ``ModelBridge.gen``. + ``GeneratorSpec.gen``, which will pass them to ``Adapter.gen``. status_quo_features: An ``ObservationFeature`` of the status quo arm, needed by some models during fit to accomadate relative constraints. Includes the status quo parameterization and target trial index. diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index 5ee71a88aa5..dce8ac0a947 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -32,9 +32,9 @@ observation_features_to_array, parse_observation_features, ) -from ax.modelbridge.torch import FIT_MODEL_ERROR, TorchModelBridge +from ax.modelbridge.torch import FIT_MODEL_ERROR, TorchAdapter from ax.modelbridge.transforms.base import Transform -from ax.models.torch_base import TorchModel +from ax.models.torch_base import TorchGenerator from ax.models.types import TConfig from ax.utils.common.constants import Keys from pyre_extensions import none_throws @@ -46,9 +46,9 @@ DEFAULT_TARGET_MAP_VALUES = {"steps": 1.0} -class MapTorchModelBridge(TorchModelBridge): +class MapTorchAdapter(TorchAdapter): """A model bridge for using torch-based models that fit on MapData. Most - of the `TorchModelBridge` functionality is retained, except that this + of the `TorchAdapter` functionality is retained, except that this class should be used in the case where `model` makes use of map_key values. For example, the use case of fitting a joint surrogate model on `(parameters, map_key)`, while candidate generation is only for `parameters`. @@ -59,7 +59,7 @@ def __init__( experiment: Experiment, search_space: SearchSpace, data: Data, - model: TorchModel, + model: TorchGenerator, transforms: list[type[Transform]], transform_configs: dict[str, TConfig] | None = None, torch_dtype: torch.dtype | None = None, @@ -117,12 +117,10 @@ def __init__( """ if not isinstance(data, MapData): - raise ValueError( - "`MapTorchModelBridge expects `MapData` instead of `Data`." - ) + raise ValueError("`MapTorchAdapter expects `MapData` instead of `Data`.") if any(isinstance(t, BatchTrial) for t in experiment.trials.values()): - raise ValueError("MapTorchModelBridge does not support batch trials.") + raise ValueError("MapTorchAdapter does not support batch trials.") # pyre-fixme[4]: Attribute must be annotated. self._map_key_features = data.map_keys self._map_data_limit_rows_per_metric = map_data_limit_rows_per_metric @@ -163,7 +161,7 @@ def parameters_with_map_keys(self) -> list[str]: def _predict( self, observation_features: list[ObservationFeatures] ) -> list[ObservationData]: - """This method is updated from `TorchModelBridge._predict(...) in that it + """This method is updated from `TorchAdapter._predict(...) in that it will accept observation features with or without map_keys. If observation features do not contain map_keys, it will insert them based on `target_map_values`. @@ -192,13 +190,13 @@ def _predict( def _fit( self, - model: TorchModel, + model: TorchGenerator, search_space: SearchSpace, observations: list[Observation], parameters: list[str] | None = None, **kwargs: Any, ) -> None: - """The difference from `TorchModelBridge._fit(...)` is that we use + """The difference from `TorchAdapter._fit(...)` is that we use `self.parameters_with_map_keys` instead of `self.parameters`. """ self.parameters = list(search_space.parameters.keys()) @@ -221,7 +219,7 @@ def _gen( model_gen_options: TConfig | None = None, optimization_config: OptimizationConfig | None = None, ) -> GenResults: - """An updated version of `TorchModelBridge._gen(...) that first injects + """An updated version of `TorchAdapter._gen(...) that first injects `map_dim_to_target` (e.g., `{-1: 1.0}`) into `model_gen_options` so that the target values of the map_keys are known during candidate generation. """ @@ -241,7 +239,7 @@ def _array_to_observation_features( candidate_metadata: Sequence[TCandidateMetadata] | None, ) -> list[ObservationFeatures]: """The difference b/t this method and - TorchModelBridge._array_to_observation_features(...) is + TorchAdapter._array_to_observation_features(...) is that this one makes use of `self.parameters_with_map_keys`. """ return parse_observation_features( @@ -253,7 +251,7 @@ def _array_to_observation_features( def _prepare_observations( self, experiment: Experiment | None, data: Data | None ) -> list[Observation]: - """The difference b/t this method and ModelBridge._prepare_observations(...) + """The difference b/t this method and Adapter._prepare_observations(...) is that this one uses `observations_from_map_data`. """ if experiment is None or data is None: @@ -271,7 +269,7 @@ def _prepare_observations( def _compute_in_design( self, search_space: SearchSpace, observations: list[Observation] ) -> list[bool]: - """The difference b/t this method and ModelBridge._compute_in_design(...) + """The difference b/t this method and Adapter._compute_in_design(...) is that this one correctly excludes map_keys when checking membership in search space (as map_keys are not explicitly in the search space). """ @@ -297,7 +295,7 @@ def _cross_validate( **kwargs: Any, ) -> list[ObservationData]: """Make predictions at cv_test_points using only the data in obs_feats - and obs_data. The difference from `TorchModelBridge._cross_validate` + and obs_data. The difference from `TorchAdapter._cross_validate` is that here we do cross validation on the parameters + map_keys. There is some extra logic to filter out out-of-design points in the map_key dimension. diff --git a/ax/modelbridge/model_spec.py b/ax/modelbridge/model_spec.py index 8b80ded568f..92dc4c24bdf 100644 --- a/ax/modelbridge/model_spec.py +++ b/ax/modelbridge/model_spec.py @@ -22,7 +22,7 @@ from ax.core.optimization_config import OptimizationConfig from ax.core.search_space import SearchSpace from ax.exceptions.core import AxWarning, UserInputError -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.cross_validation import ( compute_diagnostics, cross_validate, @@ -41,11 +41,11 @@ from pyre_extensions import none_throws -TModelFactory = Callable[..., ModelBridge] +TModelFactory = Callable[..., Adapter] -class ModelSpecJSONEncoder(json.JSONEncoder): - """Generic encoder to avoid JSON errors in ModelSpec.__repr__""" +class GeneratorSpecJSONEncoder(json.JSONEncoder): + """Generic encoder to avoid JSON errors in GeneratorSpec.__repr__""" # pyre-fixme[2]: Parameter annotation cannot be `Any`. def default(self, o: Any) -> str: @@ -53,22 +53,22 @@ def default(self, o: Any) -> str: @dataclass -class ModelSpec(SortableBase, SerializationMixin): +class GeneratorSpec(SortableBase, SerializationMixin): model_enum: ModelRegistryBase - # Kwargs to pass into the `Model` + `ModelBridge` constructors in + # Kwargs to pass into the `Model` + `Adapter` constructors in # `ModelRegistryBase.__call__`. model_kwargs: dict[str, Any] = field(default_factory=dict) - # Kwargs to pass to `ModelBridge.gen`. + # Kwargs to pass to `Adapter.gen`. model_gen_kwargs: dict[str, Any] = field(default_factory=dict) # Kwargs to pass to `cross_validate`. model_cv_kwargs: dict[str, Any] = field(default_factory=dict) - # An optional override for the model key. Each `ModelSpec` in a + # An optional override for the model key. Each `GeneratorSpec` in a # `GenerationNode` must have a unique key to ensure identifiability. model_key_override: str | None = None # Fitted model, constructed using specified `model_kwargs` and `Data` - # on `ModelSpec.fit` - _fitted_model: ModelBridge | None = None + # on `GeneratorSpec.fit` + _fitted_model: Adapter | None = None # Stored cross validation results set in cross validate. _cv_results: list[CVResult] | None = None @@ -88,7 +88,7 @@ def __post_init__(self) -> None: self.model_cv_kwargs = self.model_cv_kwargs or {} @property - def fitted_model(self) -> ModelBridge: + def fitted_model(self) -> Adapter: """Returns the fitted Ax model, asserting fit() was called""" self._assert_fitted() return none_throws(self._fitted_model) @@ -109,7 +109,7 @@ def fixed_features(self, value: ObservationFeatures | None) -> None: @property def model_key(self) -> str: - """Key string to identify the model used by this ``ModelSpec``.""" + """Key string to identify the model used by this ``GeneratorSpec``.""" if self.model_key_override is not None: return self.model_key_override else: @@ -249,8 +249,8 @@ def gen(self, **model_gen_kwargs: Any) -> GeneratorRun: ) return generator_run - def copy(self) -> ModelSpec: - """`ModelSpec` is both a spec and an object that performs actions. + def copy(self) -> GeneratorSpec: + """`GeneratorSpec` is both a spec and an object that performs actions. Copying is useful to avoid changes to a singleton model spec. """ return self.__class__( @@ -270,7 +270,7 @@ def _safe_to_update( This is a cheap way of checking that we're attempting to re-fit the same model for the same experiment, which is a very reasonable expectation - since this all happens on the same `ModelSpec` instance. + since this all happens on the same `GeneratorSpec` instance. """ if self.model_key == "TRBO": # Temporary hack to unblock TRBO. @@ -298,16 +298,16 @@ def _assert_fitted(self) -> None: def __repr__(self) -> str: model_kwargs = json.dumps( - self.model_kwargs, sort_keys=True, cls=ModelSpecJSONEncoder + self.model_kwargs, sort_keys=True, cls=GeneratorSpecJSONEncoder ) model_gen_kwargs = json.dumps( - self.model_gen_kwargs, sort_keys=True, cls=ModelSpecJSONEncoder + self.model_gen_kwargs, sort_keys=True, cls=GeneratorSpecJSONEncoder ) model_cv_kwargs = json.dumps( - self.model_cv_kwargs, sort_keys=True, cls=ModelSpecJSONEncoder + self.model_cv_kwargs, sort_keys=True, cls=GeneratorSpecJSONEncoder ) return ( - "ModelSpec(" + "GeneratorSpec(" f"\tmodel_enum={self.model_enum.value}, " f"\tmodel_kwargs={model_kwargs}, " f"\tmodel_gen_kwargs={model_gen_kwargs}, " @@ -319,7 +319,7 @@ def __repr__(self) -> str: def __hash__(self) -> int: return hash(repr(self)) - def __eq__(self, other: ModelSpec) -> bool: + def __eq__(self, other: GeneratorSpec) -> bool: return repr(self) == repr(other) @property @@ -330,23 +330,23 @@ def _unique_id(self) -> str: @dataclass -class FactoryFunctionModelSpec(ModelSpec): +class FactoryFunctionGeneratorSpec(GeneratorSpec): factory_function: TModelFactory | None = None - # pyre-ignore[15]: `ModelSpec` has this as non-optional + # pyre-ignore[15]: `GeneratorSpec` has this as non-optional model_enum: ModelRegistryBase | None = None def __post_init__(self) -> None: super().__post_init__() if self.model_enum is not None: raise UserInputError( - "Use regular `ModelSpec` when it's possible to describe the " + "Use regular `GeneratorSpec` when it's possible to describe the " "model as `ModelRegistryBase` subclass enum member." ) if self.factory_function is None: raise UserInputError( - "Please specify a valid function returning a `ModelBridge` instance " + "Please specify a valid function returning a `Adapter` instance " "as the required `factory_function` argument to " - "`FactoryFunctionModelSpec`." + "`FactoryFunctionGeneratorSpec`." ) if self.model_key_override is None: try: diff --git a/ax/modelbridge/modelbridge_utils.py b/ax/modelbridge/modelbridge_utils.py index e99be10f81d..9a34bc084f4 100644 --- a/ax/modelbridge/modelbridge_utils.py +++ b/ax/modelbridge/modelbridge_utils.py @@ -235,7 +235,7 @@ def extract_search_space_digest( if p.log_scale or p.logit_scale: raise UserInputError( "Log and Logit scale parameters must be transformed using the " - "corresponding transform within the `ModelBridge`. After applying " + "corresponding transform within the `Adapter`. After applying " f"the transforms, we have {p.log_scale=} and {p.logit_scale=}." ) if p.parameter_type == ParameterType.INT: @@ -683,7 +683,7 @@ def _roundtrip_transform(x: npt.NDArray) -> npt.NDArray: def get_pareto_frontier_and_configs( - modelbridge: modelbridge_module.torch.TorchModelBridge, + modelbridge: modelbridge_module.torch.TorchAdapter, observation_features: list[ObservationFeatures], observation_data: list[ObservationData] | None = None, objective_thresholds: TRefPoint | None = None, @@ -697,7 +697,7 @@ def get_pareto_frontier_and_configs( observations. Args: - modelbridge: ``Modelbridge`` used to predict metrics outcomes. + modelbridge: ``Adapter`` used to predict metrics outcomes. observation_features: Observation features to consider for the Pareto frontier. observation_data: Data for computing the Pareto front, unless @@ -827,7 +827,7 @@ def get_pareto_frontier_and_configs( def pareto_frontier( - modelbridge: modelbridge_module.torch.TorchModelBridge, + modelbridge: modelbridge_module.torch.TorchAdapter, observation_features: list[ObservationFeatures], observation_data: list[ObservationData] | None = None, objective_thresholds: TRefPoint | None = None, @@ -839,7 +839,7 @@ def pareto_frontier( in the untransformed search space. Args: - modelbridge: ``Modelbridge`` used to predict metrics outcomes. + modelbridge: ``Adapter`` used to predict metrics outcomes. observation_features: Observation features to consider for the Pareto frontier. observation_data: Data for computing the Pareto front, unless @@ -898,7 +898,7 @@ def pareto_frontier( def predicted_pareto_frontier( - modelbridge: modelbridge_module.torch.TorchModelBridge, + modelbridge: modelbridge_module.torch.TorchAdapter, objective_thresholds: TRefPoint | None = None, observation_features: list[ObservationFeatures] | None = None, optimization_config: MultiObjectiveOptimizationConfig | None = None, @@ -909,7 +909,7 @@ def predicted_pareto_frontier( which points lie on the Pareto frontier. Args: - modelbridge: ``Modelbridge`` used to predict metrics outcomes. + modelbridge: ``Adapter`` used to predict metrics outcomes. observation_features: Observation features to predict, if provided and ``use_model_predictions is True``. objective_thresholds: Metric values bounding the region of interest in @@ -943,7 +943,7 @@ def predicted_pareto_frontier( def observed_pareto_frontier( - modelbridge: modelbridge_module.torch.TorchModelBridge, + modelbridge: modelbridge_module.torch.TorchAdapter, objective_thresholds: TRefPoint | None = None, optimization_config: MultiObjectiveOptimizationConfig | None = None, ) -> list[Observation]: @@ -952,7 +952,7 @@ def observed_pareto_frontier( as `Observation`-s. Args: - modelbridge: ``Modelbridge`` that holds previous training data. + modelbridge: ``Adapter`` that holds previous training data. objective_thresholds: Metric values bounding the region of interest in the objective outcome space; used to override objective thresholds in the optimization config, if needed. @@ -979,7 +979,7 @@ def observed_pareto_frontier( def hypervolume( - modelbridge: modelbridge_module.torch.TorchModelBridge, + modelbridge: modelbridge_module.torch.TorchAdapter, observation_features: list[ObservationFeatures], objective_thresholds: TRefPoint | None = None, observation_data: list[ObservationData] | None = None, @@ -1054,7 +1054,7 @@ def hypervolume( def _get_multiobjective_optimization_config( - modelbridge: modelbridge_module.torch.TorchModelBridge, + modelbridge: modelbridge_module.torch.TorchAdapter, optimization_config: OptimizationConfig | None = None, objective_thresholds: TRefPoint | None = None, ) -> MultiObjectiveOptimizationConfig: @@ -1079,7 +1079,7 @@ def _get_multiobjective_optimization_config( def predicted_hypervolume( - modelbridge: modelbridge_module.torch.TorchModelBridge, + modelbridge: modelbridge_module.torch.TorchAdapter, objective_thresholds: TRefPoint | None = None, observation_features: list[ObservationFeatures] | None = None, optimization_config: MultiObjectiveOptimizationConfig | None = None, @@ -1092,7 +1092,7 @@ def predicted_hypervolume( frontier formed from their predicted outcomes. Args: - modelbridge: Modelbridge used to predict metrics outcomes. + modelbridge: Adapter used to predict metrics outcomes. objective_thresholds: point defining the origin of hyperrectangles that can contribute to hypervolume. observation_features: observation features to predict. Model's training @@ -1126,7 +1126,7 @@ def predicted_hypervolume( def observed_hypervolume( - modelbridge: modelbridge_module.torch.TorchModelBridge, + modelbridge: modelbridge_module.torch.TorchAdapter, objective_thresholds: TRefPoint | None = None, optimization_config: MultiObjectiveOptimizationConfig | None = None, selected_metrics: list[str] | None = None, @@ -1137,7 +1137,7 @@ def observed_hypervolume( those outcomes. Args: - modelbridge: Modelbridge that holds previous training data. + modelbridge: Adapter that holds previous training data. objective_thresholds: Point defining the origin of hyperrectangles that can contribute to hypervolume. Note that if this is None, `objective_thresholds` must be present on the @@ -1290,7 +1290,7 @@ def feasible_hypervolume( def _array_to_tensor( array: npt.NDArray | list[float], - modelbridge: modelbridge_module.base.ModelBridge | None = None, + modelbridge: modelbridge_module.base.Adapter | None = None, ) -> Tensor: if modelbridge and hasattr(modelbridge, "_array_to_tensor"): # pyre-ignore[16]: modelbridge does not have attribute `_array_to_tensor` @@ -1300,7 +1300,7 @@ def _array_to_tensor( def _get_modelbridge_training_data( - modelbridge: modelbridge_module.torch.TorchModelBridge, + modelbridge: modelbridge_module.torch.TorchAdapter, ) -> tuple[list[ObservationFeatures], list[ObservationData], list[str | None]]: obs = modelbridge.get_training_data() return _unpack_observations(obs=obs) diff --git a/ax/modelbridge/pairwise.py b/ax/modelbridge/pairwise.py index f177bd8e018..74936f6d9a1 100644 --- a/ax/modelbridge/pairwise.py +++ b/ax/modelbridge/pairwise.py @@ -13,7 +13,7 @@ from ax.core.observation import ObservationData, ObservationFeatures from ax.core.search_space import SearchSpaceDigest from ax.core.types import TCandidateMetadata -from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.torch import TorchAdapter from ax.utils.common.constants import Keys from botorch.models.utils.assorted import consolidate_duplicates from botorch.utils.containers import DenseContainer, SliceContainer @@ -21,7 +21,7 @@ from torch import Tensor -class PairwiseModelBridge(TorchModelBridge): +class PairwiseAdapter(TorchAdapter): def _convert_observations( self, observation_data: list[ObservationData], diff --git a/ax/modelbridge/prediction_utils.py b/ax/modelbridge/prediction_utils.py index 832b65e9ec1..c06c457807f 100644 --- a/ax/modelbridge/prediction_utils.py +++ b/ax/modelbridge/prediction_utils.py @@ -12,11 +12,11 @@ import numpy as np from ax.core.observation import ObservationFeatures -from ax.modelbridge import ModelBridge +from ax.modelbridge import Adapter def predict_at_point( - model: ModelBridge, + model: Adapter, obsf: ObservationFeatures, metric_names: set[str], scalarized_metric_config: list[dict[str, Any]] | None = None, @@ -26,7 +26,7 @@ def predict_at_point( Returns mean and standard deviation in format expected by plotting. Args: - model: ModelBridge + model: Adapter obsf: ObservationFeatures for which to predict metric_names: Limit predictions to these metrics. scalarized_metric_config: An optional list of dicts specifying how to aggregate @@ -71,7 +71,7 @@ def predict_at_point( def predict_by_features( - model: ModelBridge, + model: Adapter, label_to_feature_dict: dict[int, ObservationFeatures], metric_names: set[str], ) -> dict[int, dict[str, tuple[float, float]]]: diff --git a/ax/modelbridge/random.py b/ax/modelbridge/random.py index 4e59dde0a79..637c8f5553d 100644 --- a/ax/modelbridge/random.py +++ b/ax/modelbridge/random.py @@ -15,7 +15,7 @@ from ax.core.observation import Observation, ObservationData, ObservationFeatures from ax.core.optimization_config import OptimizationConfig from ax.core.search_space import SearchSpace -from ax.modelbridge.base import GenResults, ModelBridge +from ax.modelbridge.base import Adapter, GenResults from ax.modelbridge.modelbridge_utils import ( extract_parameter_constraints, extract_search_space_digest, @@ -24,21 +24,21 @@ transform_callback, ) from ax.modelbridge.transforms.base import Transform -from ax.models.random.base import RandomModel +from ax.models.random.base import RandomGenerator from ax.models.types import TConfig FIT_MODEL_ERROR = "Model must be fit before {action}." -class RandomModelBridge(ModelBridge): +class RandomAdapter(Adapter): """A model bridge for using purely random 'models'. Data and optimization configs are not required. - This model bridge interfaces with RandomModel. + This model bridge interfaces with RandomGenerator. Attributes: - model: A RandomModel used to generate candidates + model: A RandomGenerator used to generate candidates (note: this an awkward use of the word 'model'). parameters: Params found in search space on modelbridge init. @@ -81,7 +81,7 @@ class RandomModelBridge(ModelBridge): """ # pyre-fixme[13]: Attribute `model` is never initialized. - model: RandomModel + model: RandomGenerator # pyre-fixme[13]: Attribute `parameters` is never initialized. parameters: list[str] @@ -121,7 +121,7 @@ def __init__( def _fit( self, - model: RandomModel, + model: RandomGenerator, search_space: SearchSpace, observations: list[Observation] | None = None, ) -> None: @@ -168,7 +168,7 @@ def _predict( """Apply terminal transform, predict, and reverse terminal transform on output. """ - raise NotImplementedError("RandomModelBridge does not support prediction.") + raise NotImplementedError("RandomAdapter does not support prediction.") def _cross_validate( self, diff --git a/ax/modelbridge/registry.py b/ax/modelbridge/registry.py index a07c905dc6d..7abbe8d1ef8 100644 --- a/ax/modelbridge/registry.py +++ b/ax/modelbridge/registry.py @@ -10,7 +10,7 @@ Module containing a registry of standard models (and generators, samplers etc.) such as Sobol generator, GP+EI, Thompson sampler, etc. -Use of `Models` enum allows for serialization and reinstantiation of models and +Use of `Generators` enum allows for serialization and reinstantiation of models and generation strategies from generator runs they produced. To reinstantiate a model from generator run, use `get_model_from_generator_run` utility from this module. """ @@ -26,10 +26,10 @@ from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.search_space import SearchSpace -from ax.modelbridge.base import ModelBridge -from ax.modelbridge.discrete import DiscreteModelBridge -from ax.modelbridge.random import RandomModelBridge -from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.base import Adapter +from ax.modelbridge.discrete import DiscreteAdapter +from ax.modelbridge.random import RandomAdapter +from ax.modelbridge.torch import TorchAdapter from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.choice_encode import ( ChoiceToNumericChoice, @@ -55,15 +55,17 @@ from ax.modelbridge.transforms.transform_to_new_sq import TransformToNewSQ from ax.modelbridge.transforms.trial_as_task import TrialAsTask from ax.modelbridge.transforms.unit_x import UnitX -from ax.models.base import Model +from ax.models.base import Generator from ax.models.discrete.eb_ashr import EBAshr from ax.models.discrete.eb_thompson import EmpiricalBayesThompsonSampler from ax.models.discrete.full_factorial import FullFactorialGenerator from ax.models.discrete.thompson import ThompsonSampler from ax.models.random.sobol import SobolGenerator from ax.models.random.uniform import UniformGenerator -from ax.models.torch.botorch import BotorchModel -from ax.models.torch.botorch_modular.model import BoTorchModel as ModularBoTorchModel +from ax.models.torch.botorch import BotorchGenerator +from ax.models.torch.botorch_modular.model import ( + BoTorchGenerator as ModularBoTorchGenerator, +) from ax.models.torch.botorch_modular.surrogate import SurrogateSpec from ax.models.torch.cbo_sac import SACBO from ax.utils.common.kwargs import ( @@ -169,8 +171,8 @@ class ModelSetup(NamedTuple): such as BoTorch GP+EI, a Thompson sampler, or a Sobol quasirandom generator. """ - bridge_class: type[ModelBridge] - model_class: type[Model] + bridge_class: type[Adapter] + model_class: type[Generator] transforms: list[type[Transform]] default_model_kwargs: dict[str, Any] | None = None standard_bridge_kwargs: dict[str, Any] | None = None @@ -183,58 +185,58 @@ class ModelSetup(NamedTuple): """ MODEL_KEY_TO_MODEL_SETUP: dict[str, ModelSetup] = { "BoTorch": ModelSetup( - bridge_class=TorchModelBridge, - model_class=ModularBoTorchModel, + bridge_class=TorchAdapter, + model_class=ModularBoTorchGenerator, transforms=MBM_X_trans + Y_trans, ), "Legacy_GPEI": ModelSetup( - bridge_class=TorchModelBridge, - model_class=BotorchModel, + bridge_class=TorchAdapter, + model_class=BotorchGenerator, transforms=Cont_X_trans + Y_trans, ), "EB": ModelSetup( - bridge_class=DiscreteModelBridge, + bridge_class=DiscreteAdapter, model_class=EmpiricalBayesThompsonSampler, transforms=TS_trans, ), "EB_Ashr": ModelSetup( - bridge_class=DiscreteModelBridge, + bridge_class=DiscreteAdapter, model_class=EBAshr, transforms=EB_ashr_trans, ), "Factorial": ModelSetup( - bridge_class=DiscreteModelBridge, + bridge_class=DiscreteAdapter, model_class=FullFactorialGenerator, transforms=Discrete_X_trans, ), "Thompson": ModelSetup( - bridge_class=DiscreteModelBridge, + bridge_class=DiscreteAdapter, model_class=ThompsonSampler, transforms=TS_trans, ), "Sobol": ModelSetup( - bridge_class=RandomModelBridge, + bridge_class=RandomAdapter, model_class=SobolGenerator, transforms=Cont_X_trans, ), "Uniform": ModelSetup( - bridge_class=RandomModelBridge, + bridge_class=RandomAdapter, model_class=UniformGenerator, transforms=Cont_X_trans, ), "ST_MTGP": ModelSetup( - bridge_class=TorchModelBridge, - model_class=ModularBoTorchModel, + bridge_class=TorchAdapter, + model_class=ModularBoTorchGenerator, transforms=MBM_MTGP_trans, ), "BO_MIXED": ModelSetup( - bridge_class=TorchModelBridge, - model_class=ModularBoTorchModel, + bridge_class=TorchAdapter, + model_class=ModularBoTorchGenerator, transforms=Mixed_transforms + Y_trans, ), "SAASBO": ModelSetup( - bridge_class=TorchModelBridge, - model_class=ModularBoTorchModel, + bridge_class=TorchAdapter, + model_class=ModularBoTorchGenerator, transforms=MBM_X_trans + Y_trans, default_model_kwargs={ "surrogate_spec": SurrogateSpec( @@ -243,8 +245,8 @@ class ModelSetup(NamedTuple): }, ), "SAAS_MTGP": ModelSetup( - bridge_class=TorchModelBridge, - model_class=ModularBoTorchModel, + bridge_class=TorchAdapter, + model_class=ModularBoTorchGenerator, transforms=MBM_MTGP_trans, default_model_kwargs={ "surrogate_spec": SurrogateSpec( @@ -253,7 +255,7 @@ class ModelSetup(NamedTuple): }, ), "Contextual_SACBO": ModelSetup( - bridge_class=TorchModelBridge, + bridge_class=TorchAdapter, model_class=SACBO, transforms=Cont_X_trans + Y_trans, ), @@ -262,17 +264,17 @@ class ModelSetup(NamedTuple): class ModelRegistryBase(Enum): """Base enum that provides instrumentation of `__call__` on enum values, - for enums that link their values to `ModelSetup`-s like `Models`. + for enums that link their values to `ModelSetup`-s like `Generators`. """ @property - def model_class(self) -> type[Model]: + def model_class(self) -> type[Generator]: """Type of `Model` used for the given model+bridge setup.""" return MODEL_KEY_TO_MODEL_SETUP[self.value].model_class @property - def model_bridge_class(self) -> type[ModelBridge]: - """Type of `ModelBridge` used for the given model+bridge setup.""" + def model_bridge_class(self) -> type[Adapter]: + """Type of `Adapter` used for the given model+bridge setup.""" return MODEL_KEY_TO_MODEL_SETUP[self.value].bridge_class def __call__( @@ -282,7 +284,7 @@ def __call__( data: Data | None = None, silently_filter_kwargs: bool = False, **kwargs: Any, - ) -> ModelBridge: + ) -> Adapter: assert self.value in MODEL_KEY_TO_MODEL_SETUP, f"Unknown model {self.value}" # All model bridges require either a search space or an experiment. assert search_space or experiment, "Search space or experiment required." @@ -330,7 +332,7 @@ def __call__( ) model = model_class(**model_kwargs) - # Create `ModelBridge`: defaults + standard kwargs + passed in kwargs. + # Create `Adapter`: defaults + standard kwargs + passed in kwargs. bridge_kwargs = consolidate_kwargs( kwargs_iterable=[ get_function_default_arguments(bridge_class), @@ -366,7 +368,7 @@ def __call__( def view_defaults(self) -> tuple[dict[str, Any], dict[str, Any]]: """Obtains the default keyword arguments for the model and the modelbridge - specified through the Models enum, for ease of use in notebook environment, + specified through the Generators enum, for ease of use in notebook environment, since models and bridges cannot be inspected directly through the enum. Returns: @@ -380,7 +382,7 @@ def view_defaults(self) -> tuple[dict[str, Any], dict[str, Any]]: def view_kwargs(self) -> tuple[dict[str, Any], dict[str, Any]]: """Obtains annotated keyword arguments that the model and the modelbridge - (corresponding to a given member of the Models enum) constructors expect. + (corresponding to a given member of the Generators enum) constructors expect. Returns: A tuple of annotated keyword arguments for the model and the model bridge. @@ -418,20 +420,20 @@ def _get_bridge_kwargs( ) -class Models(ModelRegistryBase): +class Generators(ModelRegistryBase): """Registry of available models. Uses MODEL_KEY_TO_MODEL_SETUP to retrieve settings for model and model bridge, by the key stored in the enum value. To instantiate a model in this enum, simply call an enum member like so: - `Models.SOBOL(search_space=search_space)` or - `Models.BOTORCH(experiment=experiment, data=data)`. Keyword arguments + `Generators.SOBOL(search_space=search_space)` or + `Generators.BOTORCH(experiment=experiment, data=data)`. Keyword arguments specified to the call will be passed into the model or the model bridge constructors according to their keyword. - For instance, `Models.SOBOL(search_space=search_space, scramble=False)` - will instantiate a `RandomModelBridge(search_space=search_space)` + For instance, `Generators.SOBOL(search_space=search_space, scramble=False)` + will instantiate a `RandomAdapter(search_space=search_space)` with a `SobolGenerator(scramble=False)` underlying model. NOTE: If you deprecate a model, please add its replacement to @@ -454,29 +456,44 @@ class Models(ModelRegistryBase): CONTEXT_SACBO = "Contextual_SACBO" +class ModelsMetaClass(type): + """Metaclass to override `__getattr__` for the Models class.""" + + def __getattr__(self, name: str) -> None: + raise DeprecationWarning( + "Models is deprecated, use `ax.modelbridge.registry.Generators` instead." + ) + + +class Models(metaclass=ModelsMetaClass): + """This is deprecated. Use Generators instead.""" + + pass + + def get_model_from_generator_run( generator_run: GeneratorRun, experiment: Experiment, data: Data, models_enum: type[ModelRegistryBase], after_gen: bool = True, -) -> ModelBridge: +) -> Adapter: """Reinstantiate a model from model key and kwargs stored on a given generator run, with the given experiment and the data to initialize the model with. Note: requires that the model that was used to get the generator run, is part - of the `Models` registry enum. + of the `Generators` registry enum. Args: generator_run: A `GeneratorRun` created by the model we are looking to reinstantiate. experiment: The experiment for which the model is reinstantiated. data: Data, with which to reinstantiate the model. - models_enum: Subclass of `Models` registry, from which to obtain + models_enum: Subclass of `Generators` registry, from which to obtain the settings of the model. Useful only if the generator run was created via a model that could not be included into the main registry, but can still be represented as a `ModelSetup` and was added to a - registry that extends `Models`. + registry that extends `Generators`. after_gen: Whether to reinstantiate the model in the state, in which it was after it created this generator run, as opposed to before. Defaults to True, useful when reinstantiating the model to resume @@ -512,7 +529,7 @@ def get_model_from_generator_run( def _combine_model_kwargs_and_state( generator_run: GeneratorRun, - model_class: type[Model], + model_class: type[Generator], model_kwargs: dict[str, Any] | None = None, ) -> dict[str, Any]: """Produces a combined dict of model kwargs and model state after gen, @@ -534,7 +551,7 @@ def _combine_model_kwargs_and_state( def _extract_model_state_after_gen( - generator_run: GeneratorRun, model_class: type[Model] + generator_run: GeneratorRun, model_class: type[Generator] ) -> dict[str, Any]: """Extracts serialized post-generation model state from a generator run and deserializes it. diff --git a/ax/modelbridge/tests/test_aepsych_criterion.py b/ax/modelbridge/tests/test_aepsych_criterion.py index eda6a4c6632..0d7e9f584c9 100644 --- a/ax/modelbridge/tests/test_aepsych_criterion.py +++ b/ax/modelbridge/tests/test_aepsych_criterion.py @@ -11,7 +11,7 @@ from ax.core.base_trial import TrialStatus from ax.core.data import Data from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.modelbridge.transition_criterion import MinimumPreferenceOccurances, MinTrials from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_experiment @@ -34,10 +34,12 @@ def test_single_criterion(self) -> None: name="SOBOL+MBM::default", steps=[ GenerationStep( - model=Models.SOBOL, num_trials=-1, completion_criteria=[criterion] + model=Generators.SOBOL, + num_trials=-1, + completion_criteria=[criterion], ), GenerationStep( - model=Models.BOTORCH_MODULAR, + model=Generators.BOTORCH_MODULAR, num_trials=-1, max_parallelism=1, ), @@ -78,7 +80,7 @@ def test_single_criterion(self) -> None: self.assertEqual( generation_strategy._curr.model_spec_to_gen_from.model_enum, - Models.BOTORCH_MODULAR, + Generators.BOTORCH_MODULAR, ) def test_many_criteria(self) -> None: @@ -93,10 +95,10 @@ def test_many_criteria(self) -> None: name="SOBOL+MBM::default", steps=[ GenerationStep( - model=Models.SOBOL, num_trials=-1, completion_criteria=criteria + model=Generators.SOBOL, num_trials=-1, completion_criteria=criteria ), GenerationStep( - model=Models.BOTORCH_MODULAR, + model=Generators.BOTORCH_MODULAR, num_trials=-1, max_parallelism=1, ), @@ -155,5 +157,5 @@ def test_many_criteria(self) -> None: self.assertEqual( generation_strategy._curr.model_spec_to_gen_from.model_enum, - Models.BOTORCH_MODULAR, + Generators.BOTORCH_MODULAR, ) diff --git a/ax/modelbridge/tests/test_base_modelbridge.py b/ax/modelbridge/tests/test_base_modelbridge.py index 1bddf21fa41..6cc13a12c4a 100644 --- a/ax/modelbridge/tests/test_base_modelbridge.py +++ b/ax/modelbridge/tests/test_base_modelbridge.py @@ -23,16 +23,16 @@ from ax.core.search_space import SearchSpace from ax.exceptions.core import UnsupportedError, UserInputError from ax.modelbridge.base import ( + Adapter, clamp_observation_features, gen_arms, GenResults, - ModelBridge, unwrap_observation_data, ) from ax.modelbridge.factory import get_sobol -from ax.modelbridge.registry import Models, Y_trans +from ax.modelbridge.registry import Generators, Y_trans from ax.modelbridge.transforms.fill_missing_parameters import FillMissingParameters -from ax.models.base import Model +from ax.models.base import Generator from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger from ax.utils.common.testutils import TestCase @@ -64,7 +64,7 @@ from pyre_extensions import none_throws -class BaseModelBridgeTest(TestCase): +class BaseAdapterTest(TestCase): @mock.patch( "ax.modelbridge.base.observations_from_data", autospec=True, @@ -75,17 +75,17 @@ class BaseModelBridgeTest(TestCase): autospec=True, return_value=([Arm(parameters={})], None), ) - @mock.patch("ax.modelbridge.base.ModelBridge._fit", autospec=True) - def test_ModelBridge( + @mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True) + def test_Adapter( self, mock_fit: Mock, mock_gen_arms: Mock, mock_observations_from_data: Mock ) -> None: # Test that on init transforms are stored and applied in the correct order transforms = [transform_1, transform_2] exp = get_experiment_for_value() ss = get_search_space_for_value() - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=ss, - model=Model(), + model=Generator(), # pyre-fixme[6]: For 3rd param expected # `Optional[List[Type[Transform]]]` but got `List[Type[Union[transform_1, # transform_2]]]`. @@ -106,7 +106,7 @@ def test_ModelBridge( self.assertTrue(mock_observations_from_data.called) # Test deprecation error on update. - with self.assertRaisesRegex(DeprecationWarning, "ModelBridge.update"): + with self.assertRaisesRegex(DeprecationWarning, "Adapter.update"): modelbridge.update(Mock(), Mock()) # Test prediction with arms. @@ -118,13 +118,13 @@ def test_ModelBridge( # Test prediction on out of design features. modelbridge._predict = mock.MagicMock( - "ax.modelbridge.base.ModelBridge._predict", + "ax.modelbridge.base.Adapter._predict", autospec=True, side_effect=ValueError("Out of Design"), ) # This point is in design, and thus failures in predict are legitimate. with mock.patch.object( - ModelBridge, "model_space", return_value=get_search_space_for_range_values + Adapter, "model_space", return_value=get_search_space_for_range_values ): with self.assertRaises(ValueError): modelbridge.predict([get_observation2().features]) @@ -135,7 +135,7 @@ def test_ModelBridge( # Now it's in the training data. with mock.patch.object( - ModelBridge, + Adapter, "get_training_data", return_value=[get_observation_status_quo0()], ): @@ -147,7 +147,7 @@ def test_ModelBridge( # Test that transforms are applied correctly on predict modelbridge._predict = mock.MagicMock( - "ax.modelbridge.base.ModelBridge._predict", + "ax.modelbridge.base.Adapter._predict", autospec=True, return_value=[get_observation2trans().data], ) @@ -163,7 +163,7 @@ def test_ModelBridge( # Test transforms applied on gen modelbridge._gen = mock.MagicMock( - "ax.modelbridge.base.ModelBridge._gen", + "ax.modelbridge.base.Adapter._gen", autospec=True, return_value=GenResults( observation_features=[get_observation1trans().features], weights=[2] @@ -244,7 +244,7 @@ def warn_and_return_mock_obs( return [get_observation1trans().data] modelbridge._cross_validate = mock.MagicMock( - "ax.modelbridge.base.ModelBridge._cross_validate", + "ax.modelbridge.base.Adapter._cross_validate", autospec=True, side_effect=warn_and_return_mock_obs, ) @@ -316,7 +316,7 @@ def warn_and_return_mock_obs( # Test transform observation features with mock.patch( - "ax.modelbridge.base.ModelBridge._transform_observation_features", + "ax.modelbridge.base.Adapter._transform_observation_features", autospec=True, ) as mock_tr: modelbridge.transform_observation_features([get_observation2().features]) @@ -324,9 +324,9 @@ def warn_and_return_mock_obs( # Test that fit is not called when fit_on_init = False. mock_fit.reset_mock() - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=ss, - model=Model(), + model=Generator(), fit_on_init=False, ) self.assertEqual(mock_fit.call_count, 0) @@ -334,17 +334,17 @@ def warn_and_return_mock_obs( # Test error when fit_tracking_metrics is False and optimization # config is not specified. with self.assertRaisesRegex(UserInputError, "fit_tracking_metrics"): - ModelBridge( + Adapter( search_space=ss, - model=Model(), + model=Generator(), fit_tracking_metrics=False, ) # Test error when fit_tracking_metrics is False and optimization # config is updated to include new metrics. - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=ss, - model=Model(), + model=Generator(), optimization_config=oc, fit_tracking_metrics=False, ) @@ -364,18 +364,18 @@ def warn_and_return_mock_obs( autospec=True, return_value=([Arm(parameters={})], None), ) - @mock.patch("ax.modelbridge.base.ModelBridge._fit", autospec=True) + @mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True) def test_repeat_candidates( self, mock_fit: Mock, mock_gen_arms: Mock, mock_observations_from_data: Mock ) -> None: - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=get_search_space_for_value(), - model=Model(), + model=Generator(), experiment=get_experiment_for_value(), ) # mock _gen to return 1 result modelbridge._gen = mock.MagicMock( - "ax.modelbridge.base.ModelBridge._gen", + "ax.modelbridge.base.Adapter._gen", autospec=True, return_value=GenResults( observation_features=[get_observation1trans().features], weights=[2] @@ -414,7 +414,7 @@ def test_repeat_candidates( autospec=True, return_value=([Arm(parameters={"x1": 0.0, "x2": 0.0})], None), ) - @mock.patch("ax.modelbridge.base.ModelBridge._fit", autospec=True) + @mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True) def test_with_status_quo(self, mock_fit: Mock, mock_gen_arms: Mock) -> None: # Test init with a status quo. exp = get_branin_experiment( @@ -422,9 +422,9 @@ def test_with_status_quo(self, mock_fit: Mock, mock_gen_arms: Mock) -> None: with_status_quo=True, with_completed_trial=True, ) - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=exp.search_space, - model=Model(), + model=Generator(), transforms=Y_trans, experiment=exp, data=exp.lookup_data(), @@ -434,12 +434,12 @@ def test_with_status_quo(self, mock_fit: Mock, mock_gen_arms: Mock) -> None: modelbridge.status_quo.features.parameters, {"x1": 0.0, "x2": 0.0} ) - @mock.patch("ax.modelbridge.base.ModelBridge._fit", autospec=True) - @mock.patch("ax.modelbridge.base.ModelBridge._gen", autospec=True) + @mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True) + @mock.patch("ax.modelbridge.base.Adapter._gen", autospec=True) def test_timing(self, mock_fit: Mock, mock_gen: Mock) -> None: search_space = get_search_space_for_value() - modelbridge = ModelBridge( - search_space=search_space, model=Model(), fit_on_init=False + modelbridge = Adapter( + search_space=search_space, model=Generator(), fit_on_init=False ) self.assertEqual(modelbridge.fit_time, 0.0) modelbridge._fit_if_implemented( @@ -466,9 +466,9 @@ def test_ood_gen(self, _) -> None: # Test fit_out_of_design by returning OOD candidats exp = get_experiment_for_value() ss = SearchSpace([RangeParameter("x", ParameterType.FLOAT, 0.0, 1.0)]) - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=ss, - model=Model(), + model=Generator(), transforms=[], experiment=exp, # pyre-fixme[6]: For 5th param expected `Optional[Data]` but got `int`. @@ -477,7 +477,7 @@ def test_ood_gen(self, _) -> None: ) obs = ObservationFeatures(parameters={"x": 3.0}) modelbridge._gen = mock.MagicMock( - "ax.modelbridge.base.ModelBridge._gen", + "ax.modelbridge.base.Adapter._gen", autospec=True, return_value=GenResults(observation_features=[obs], weights=[2]), ) @@ -485,9 +485,9 @@ def test_ood_gen(self, _) -> None: self.assertEqual(gr.arms[0].parameters, obs.parameters) # Test clamping arms by setting fit_out_of_design=False - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=ss, - model=Model(), + model=Generator(), transforms=[], experiment=exp, # pyre-fixme[6]: For 5th param expected `Optional[Data]` but got `int`. @@ -496,7 +496,7 @@ def test_ood_gen(self, _) -> None: ) obs = ObservationFeatures(parameters={"x": 3.0}) modelbridge._gen = mock.MagicMock( - "ax.modelbridge.base.ModelBridge._gen", + "ax.modelbridge.base.Adapter._gen", autospec=True, return_value=GenResults(observation_features=[obs], weights=[2]), ) @@ -508,13 +508,13 @@ def test_ood_gen(self, _) -> None: autospec=True, return_value=([get_observation1()]), ) - @mock.patch("ax.modelbridge.base.ModelBridge._fit", autospec=True) + @mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True) # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def test_SetStatusQuo(self, mock_fit, mock_observations_from_data): # NOTE: If empty data object is not passed, observations are not # extracted, even with mock. - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=get_search_space_for_value(), model=0, experiment=get_experiment_for_value(), @@ -525,7 +525,7 @@ def test_SetStatusQuo(self, mock_fit, mock_observations_from_data): self.assertEqual(modelbridge.status_quo_name, "1_1") # Alternatively, we can specify by features - modelbridge = ModelBridge( + modelbridge = Adapter( get_search_space_for_value(), 0, [], @@ -544,13 +544,13 @@ def test_SetStatusQuo(self, mock_fit, mock_observations_from_data): exp._status_quo = sq # Check that we set SQ to arm 1_1 # pyre-fixme[6]: For 5th param expected `Optional[Data]` but got `int`. - modelbridge = ModelBridge(get_search_space_for_value(), 0, [], exp, 0) + modelbridge = Adapter(get_search_space_for_value(), 0, [], exp, 0) self.assertEqual(modelbridge.status_quo, get_observation1()) self.assertEqual(modelbridge.status_quo_name, "1_1") # Errors if features and name both specified with self.assertRaises(ValueError): - modelbridge = ModelBridge( + modelbridge = Adapter( get_search_space_for_value(), 0, [], @@ -562,7 +562,7 @@ def test_SetStatusQuo(self, mock_fit, mock_observations_from_data): ) # Left as None if features or name don't exist - modelbridge = ModelBridge( + modelbridge = Adapter( get_search_space_for_value(), 0, [], @@ -573,7 +573,7 @@ def test_SetStatusQuo(self, mock_fit, mock_observations_from_data): ) self.assertIsNone(modelbridge.status_quo) self.assertIsNone(modelbridge.status_quo_name) - modelbridge = ModelBridge( + modelbridge = Adapter( get_search_space_for_value(), 0, [], @@ -585,7 +585,7 @@ def test_SetStatusQuo(self, mock_fit, mock_observations_from_data): self.assertIsNone(modelbridge.status_quo) @mock.patch( - "ax.modelbridge.base.ModelBridge._gen", + "ax.modelbridge.base.Adapter._gen", autospec=True, ) # pyre-fixme[3]: Return type must be annotated. @@ -601,7 +601,7 @@ def test_status_quo_for_non_monolithic_data(self, mock_gen): weights=[1] * 5, ) exp = get_branin_experiment_with_multi_objective(with_status_quo=True) - sobol = Models.SOBOL(search_space=exp.search_space) + sobol = Generators.SOBOL(search_space=exp.search_space) exp.new_batch_trial(sobol.gen(5)).set_status_quo_and_optimize_power( status_quo=exp.status_quo ).run() @@ -609,10 +609,10 @@ def test_status_quo_for_non_monolithic_data(self, mock_gen): # create data where metrics vary in start and end times data = get_non_monolithic_branin_moo_data() with warnings.catch_warnings(record=True) as ws: - bridge = ModelBridge( + bridge = Adapter( experiment=exp, data=data, - model=Model(), + model=Generator(), search_space=exp.search_space, ) # just testing it doesn't error @@ -634,7 +634,7 @@ def test_status_quo_for_non_monolithic_data(self, mock_gen): ] ), ) - @mock.patch("ax.modelbridge.base.ModelBridge._fit", autospec=True) + @mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True) # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated. def test_SetStatusQuoMultipleObs(self, mock_fit, mock_observations_from_data): @@ -646,7 +646,7 @@ def test_SetStatusQuoMultipleObs(self, mock_fit, mock_observations_from_data): parameters=exp.trials[trial_index].status_quo.parameters, trial_index=trial_index, ) - modelbridge = ModelBridge( + modelbridge = Adapter( get_search_space_for_value(), 0, [], @@ -665,7 +665,7 @@ def test_transform_observations(self) -> None: This functionality is unused, even in the subclass where it is implemented. """ ss = get_search_space_for_value() - modelbridge = ModelBridge(search_space=ss, model=Model()) + modelbridge = Adapter(search_space=ss, model=Generator()) with self.assertRaises(NotImplementedError): modelbridge.transform_observations([]) with self.assertRaises(NotImplementedError): @@ -676,13 +676,13 @@ def test_transform_observations(self) -> None: autospec=True, return_value=([get_observation1(), get_observation1()]), ) - @mock.patch("ax.modelbridge.base.ModelBridge._fit", autospec=True) + @mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True) def test_SetTrainingDataDupFeatures( self, mock_fit: Mock, mock_observations_from_data: Mock ) -> None: # Throws an error if repeated features in observations. with self.assertRaises(ValueError): - ModelBridge( + Adapter( get_search_space_for_value(), 0, [], @@ -741,21 +741,19 @@ def test_GenArms(self) -> None: ) @mock.patch( - "ax.modelbridge.base.ModelBridge._gen", + "ax.modelbridge.base.Adapter._gen", autospec=True, return_value=GenResults( observation_features=[get_observation1trans().features], weights=[2] ), ) - @mock.patch( - "ax.modelbridge.base.ModelBridge.predict", autospec=True, return_value=None - ) + @mock.patch("ax.modelbridge.base.Adapter.predict", autospec=True, return_value=None) def test_GenWithDefaults(self, _, mock_gen: Mock) -> None: exp = get_experiment_for_value() exp.optimization_config = get_optimization_config_no_constraints() ss = get_search_space_for_range_value() - modelbridge = ModelBridge( - search_space=ss, model=Model(), transforms=[], experiment=exp + modelbridge = Adapter( + search_space=ss, model=Generator(), transforms=[], experiment=exp ) modelbridge.gen(1) mock_gen.assert_called_with( @@ -772,23 +770,21 @@ def test_GenWithDefaults(self, _, mock_gen: Mock) -> None: ) @mock.patch( - "ax.modelbridge.base.ModelBridge._gen", + "ax.modelbridge.base.Adapter._gen", autospec=True, return_value=GenResults( observation_features=[get_observation1trans().features], weights=[2] ), ) - @mock.patch( - "ax.modelbridge.base.ModelBridge.predict", autospec=True, return_value=None - ) + @mock.patch("ax.modelbridge.base.Adapter.predict", autospec=True, return_value=None) # pyre-fixme[3]: Return type must be annotated. def test_gen_on_experiment_with_imm_ss_and_opt_conf(self, _, __): exp = get_experiment_for_value() exp._properties[Keys.IMMUTABLE_SEARCH_SPACE_AND_OPT_CONF] = True exp.optimization_config = get_optimization_config_no_constraints() ss = get_search_space_for_range_value() - modelbridge = ModelBridge( - search_space=ss, model=Model(), transforms=[], experiment=exp + modelbridge = Adapter( + search_space=ss, model=Generator(), transforms=[], experiment=exp ) self.assertTrue( modelbridge._experiment_has_immutable_search_space_and_opt_config @@ -805,10 +801,10 @@ def test_set_status_quo(self) -> None: num_batch_trial=1, with_completed_batch=True, ) - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=exp.search_space, experiment=exp, - model=Model, + model=Generator, data=exp.lookup_data(), ) @@ -827,10 +823,10 @@ def test_set_status_quo(self) -> None: num_batch_trial=2, with_completed_batch=True, ) - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=exp.search_space, experiment=exp, - model=Model, + model=Generator, data=exp.lookup_data(), ) # we are able to set status_quo_data_by_trial when multiple @@ -847,10 +843,10 @@ def test_set_status_quo(self) -> None: parameters=none_throws(exp.status_quo).parameters, trial_index=0, ) - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=exp.search_space, experiment=exp, - model=Model, + model=Generator, data=exp.lookup_data(), status_quo_features=status_quo_features, ) @@ -906,7 +902,7 @@ def test_ClampObservationFeaturesNearBounds(self) -> None: actual_obs_ft = clamp_observation_features([obs_ft], search_space) self.assertEqual(actual_obs_ft[0], expected_obs_ft) - @mock.patch("ax.modelbridge.base.ModelBridge._fit", autospec=True) + @mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True) def test_FillMissingParameters(self, mock_fit: Mock) -> None: # Create experiment with arms from two search spaces ss1 = SearchSpace( @@ -948,7 +944,7 @@ def test_FillMissingParameters(self, mock_fit: Mock) -> None: get_branin_data_batch(batch=trial, fill_vals=sq_vals) ) # Fit model without filling missing parameters - m = ModelBridge( + m = Adapter( search_space=ss1, model=None, experiment=experiment, @@ -965,7 +961,7 @@ def test_FillMissingParameters(self, mock_fit: Mock) -> None: set(ood_arms), {"status_quo", "1_0", "1_1", "1_2", "1_3", "1_4"} ) # Fit with filling missing parameters - m = ModelBridge( + m = Adapter( search_space=ss2, model=None, experiment=experiment, @@ -1007,7 +1003,7 @@ def test_SetModelSpace(self) -> None: data = experiment.lookup_data() # Check that SQ and custom are OOD - m = ModelBridge( + m = Adapter( search_space=experiment.search_space, model=None, experiment=experiment, @@ -1021,7 +1017,7 @@ def test_SetModelSpace(self) -> None: self.assertEqual(m.model_space.parameters["x2"].upper, 15.0) # pyre-ignore[16] # With expand model space, custom is not OOD, and model space is expanded - m = ModelBridge( + m = Adapter( search_space=experiment.search_space, model=None, experiment=experiment, @@ -1034,7 +1030,7 @@ def test_SetModelSpace(self) -> None: self.assertEqual(m.model_space.parameters["x2"].upper, 18.0) # With fill values, SQ is also in design, and x2 is further expanded - m = ModelBridge( + m = Adapter( search_space=experiment.search_space, model=None, experiment=experiment, diff --git a/ax/modelbridge/tests/test_best_model_selector.py b/ax/modelbridge/tests/test_best_model_selector.py index 6359996f18b..eac4ddf5ff1 100644 --- a/ax/modelbridge/tests/test_best_model_selector.py +++ b/ax/modelbridge/tests/test_best_model_selector.py @@ -13,8 +13,8 @@ ReductionCriterion, SingleDiagnosticBestModelSelector, ) -from ax.modelbridge.model_spec import ModelSpec -from ax.modelbridge.registry import Models +from ax.modelbridge.model_spec import GeneratorSpec +from ax.modelbridge.registry import Generators from ax.utils.common.testutils import TestCase @@ -30,7 +30,7 @@ def setUp(self) -> None: {"Fisher exact test p": {"y_a": 0.5, "y_b": 0.6}}, ] for diagnostics in self.diagnostics: - ms = ModelSpec(model_enum=Models.BOTORCH_MODULAR) + ms = GeneratorSpec(model_enum=Generators.BOTORCH_MODULAR) ms._cv_results = Mock() ms._diagnostics = diagnostics ms._last_cv_kwargs = {} diff --git a/ax/modelbridge/tests/test_cross_validation.py b/ax/modelbridge/tests/test_cross_validation.py index 4887ec82099..05bb2e58f47 100644 --- a/ax/modelbridge/tests/test_cross_validation.py +++ b/ax/modelbridge/tests/test_cross_validation.py @@ -30,7 +30,7 @@ CVResult, has_good_opt_config_model_fit, ) -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_branin_experiment from ax.utils.testing.mock import mock_botorch_optimize @@ -105,22 +105,22 @@ def test_CrossValidate(self) -> None: # Prepare input and output data ma = mock.MagicMock() ma.get_training_data = mock.MagicMock( - "ax.modelbridge.base.ModelBridge.get_training_data", + "ax.modelbridge.base.Adapter.get_training_data", autospec=True, return_value=self.training_data, ) ma.cross_validate = mock.MagicMock( - "ax.modelbridge.base.ModelBridge.cross_validate", + "ax.modelbridge.base.Adapter.cross_validate", autospec=True, return_value=self.observation_data, ) ma._transform_inputs_for_cv = mock.MagicMock( - "ax.modelbridge.base.ModelBridge._transform_inputs_for_cv", + "ax.modelbridge.base.Adapter._transform_inputs_for_cv", autospec=True, return_value=list(self.transformed_cv_input_dict.values()), ) ma._cross_validate = mock.MagicMock( - "ax.modelbridge.base.ModelBridge._cross_validate", + "ax.modelbridge.base.Adapter._cross_validate", autospec=True, return_value=self.observation_data_transformed_result, ) @@ -132,7 +132,7 @@ def test_CrossValidate(self) -> None: # First 2-fold result = cross_validate(model=ma, folds=2) self.assertEqual(len(result), 4) - # Check that ModelBridge.cross_validate was called correctly. + # Check that Adapter.cross_validate was called correctly. z = ma.cross_validate.mock_calls self.assertEqual(len(z), 2) train = [ @@ -175,7 +175,7 @@ def test_CrossValidate(self) -> None: self.assertEqual( result_predicted_obs_data, self.observation_data_transformed_result ) - # Check that ModelBridge._transform_inputs_for_cv was called correctly. + # Check that Adapter._transform_inputs_for_cv was called correctly. z = ma._transform_inputs_for_cv.mock_calls self.assertEqual(len(z), 3) train = [ @@ -192,7 +192,7 @@ def test_CrossValidate(self) -> None: self.assertTrue( np.array_equal(sorted(all_test), np.array([2.0, 2.0, 3.0, 4.0])) ) - # Test ModelBridge._cross_validate was called correctly. + # Test Adapter._cross_validate was called correctly. z = ma._cross_validate.mock_calls self.assertEqual(len(z), 3) ma._cross_validate.assert_called_with( @@ -232,7 +232,7 @@ def test_CrossValidateByTrial(self) -> None: # With only 1 trial ma = mock.MagicMock() ma.get_training_data = mock.MagicMock( - "ax.modelbridge.base.ModelBridge.get_training_data", + "ax.modelbridge.base.Adapter.get_training_data", autospec=True, return_value=self.training_data[1:3], ) @@ -241,12 +241,12 @@ def test_CrossValidateByTrial(self) -> None: # Prepare input and output data ma = mock.MagicMock() ma.get_training_data = mock.MagicMock( - "ax.modelbridge.base.ModelBridge.get_training_data", + "ax.modelbridge.base.Adapter.get_training_data", autospec=True, return_value=self.training_data, ) ma.cross_validate = mock.MagicMock( - "ax.modelbridge.base.ModelBridge.cross_validate", + "ax.modelbridge.base.Adapter.cross_validate", autospec=True, return_value=self.observation_data, ) @@ -258,7 +258,7 @@ def test_CrossValidateByTrial(self) -> None: result = cross_validate_by_trial(model=ma) self.assertEqual(len(result), 1) - # Check that ModelBridge.cross_validate was called correctly. + # Check that Adapter.cross_validate was called correctly. z = ma.cross_validate.mock_calls self.assertEqual(len(z), 1) train_trials = [obs.features.trial_index for obs in z[0][2]["cv_training_data"]] @@ -281,14 +281,14 @@ def test_CrossValidateByTrial(self) -> None: def test_cross_validate_gives_a_useful_error_for_model_with_no_data(self) -> None: exp = get_branin_experiment() - sobol = Models.SOBOL(experiment=exp, search_space=exp.search_space) + sobol = Generators.SOBOL(experiment=exp, search_space=exp.search_space) with self.assertRaisesRegex(ValueError, "no training data"): cross_validate(model=sobol) @mock_botorch_optimize def test_cross_validate_catches_warnings(self) -> None: exp = get_branin_experiment(with_batch=True, with_completed_batch=True) - model = Models.BOTORCH_MODULAR( + model = Generators.BOTORCH_MODULAR( experiment=exp, search_space=exp.search_space, data=exp.fetch_data() ) for untransform in [False, True]: @@ -301,7 +301,7 @@ def test_cross_validate_raises_not_implemented_error_for_non_cv_model_with_data( ) -> None: exp = get_branin_experiment(with_batch=True) exp.trials[0].run().complete() - sobol = Models.SOBOL( + sobol = Generators.SOBOL( experiment=exp, search_space=exp.search_space, data=exp.fetch_data() ) with self.assertRaises(NotImplementedError): diff --git a/ax/modelbridge/tests/test_discrete_modelbridge.py b/ax/modelbridge/tests/test_discrete_modelbridge.py index bb8b5250c5a..b8bb01ad2ad 100644 --- a/ax/modelbridge/tests/test_discrete_modelbridge.py +++ b/ax/modelbridge/tests/test_discrete_modelbridge.py @@ -23,12 +23,12 @@ ) from ax.core.search_space import SearchSpace from ax.exceptions.core import UserInputError -from ax.modelbridge.discrete import _get_parameter_values, DiscreteModelBridge -from ax.models.discrete_base import DiscreteModel +from ax.modelbridge.discrete import _get_parameter_values, DiscreteAdapter +from ax.models.discrete_base import DiscreteGenerator from ax.utils.common.testutils import TestCase -class DiscreteModelBridgeTest(TestCase): +class DiscreteAdapterTest(TestCase): def setUp(self) -> None: super().setUp() self.parameters = [ @@ -73,14 +73,12 @@ def setUp(self) -> None: } self.model_gen_options = {"option": "yes"} - @mock.patch( - "ax.modelbridge.discrete.DiscreteModelBridge.__init__", return_value=None - ) + @mock.patch("ax.modelbridge.discrete.DiscreteAdapter.__init__", return_value=None) def test_fit(self, mock_init: Mock) -> None: # pyre-fixme[20]: Argument `model` expected. - ma = DiscreteModelBridge() + ma = DiscreteAdapter() ma._training_data = self.observations - model = mock.create_autospec(DiscreteModel, instance=True) + model = mock.create_autospec(DiscreteGenerator, instance=True) ma._fit(model, self.search_space, self.observations) self.assertEqual(ma.parameters, ["x", "y", "z"]) self.assertEqual(sorted(ma.outcomes), ["a", "b"]) @@ -106,13 +104,11 @@ def test_fit(self, mock_init: Mock) -> None: with self.assertRaises(ValueError): ma._fit(model, self.search_space, self.observations + [sq_obs]) - @mock.patch( - "ax.modelbridge.discrete.DiscreteModelBridge.__init__", return_value=None - ) + @mock.patch("ax.modelbridge.discrete.DiscreteAdapter.__init__", return_value=None) def test_predict(self, mock_init: Mock) -> None: # pyre-fixme[20]: Argument `model` expected. - ma = DiscreteModelBridge() - model = mock.MagicMock(DiscreteModel, autospec=True, instance=True) + ma = DiscreteAdapter() + model = mock.MagicMock(DiscreteGenerator, autospec=True, instance=True) model.predict.return_value = ( np.array([[1.0, -1], [2.0, -2]]), np.stack( @@ -128,9 +124,7 @@ def test_predict(self, mock_init: Mock) -> None: for i, od in enumerate(observation_data): self.assertEqual(od, self.observation_data[i]) - @mock.patch( - "ax.modelbridge.discrete.DiscreteModelBridge.__init__", return_value=None - ) + @mock.patch("ax.modelbridge.discrete.DiscreteAdapter.__init__", return_value=None) def test_gen(self, mock_init: Mock) -> None: # Test with constraints optimization_config = OptimizationConfig( @@ -140,13 +134,13 @@ def test_gen(self, mock_init: Mock) -> None: ], ) # pyre-fixme[20]: Argument `model` expected. - ma = DiscreteModelBridge() + ma = DiscreteAdapter() # Test validation. with self.assertRaisesRegex(UserInputError, "positive integer or -1."): ma._validate_gen_inputs(n=0) ma._validate_gen_inputs(n=-1) # Test rest of gen. - model = mock.MagicMock(DiscreteModel, autospec=True, instance=True) + model = mock.MagicMock(DiscreteGenerator, autospec=True, instance=True) best_x = [0.0, 2.0, 1.0] model.gen.return_value = ( [[0.0, 2.0, 3.0], [1.0, 1.0, 3.0]], @@ -232,13 +226,11 @@ def test_gen(self, mock_init: Mock) -> None: model_gen_options={}, ) - @mock.patch( - "ax.modelbridge.discrete.DiscreteModelBridge.__init__", return_value=None - ) + @mock.patch("ax.modelbridge.discrete.DiscreteAdapter.__init__", return_value=None) def test_cross_validate(self, mock_init: Mock) -> None: # pyre-fixme[20]: Argument `model` expected. - ma = DiscreteModelBridge() - model = mock.MagicMock(DiscreteModel, autospec=True, instance=True) + ma = DiscreteAdapter() + model = mock.MagicMock(DiscreteGenerator, autospec=True, instance=True) model.cross_validate.return_value = ( np.array([[1.0, -1], [2.0, -2]]), np.stack( diff --git a/ax/modelbridge/tests/test_dispatch_utils.py b/ax/modelbridge/tests/test_dispatch_utils.py index 2e6f8226927..4792ea96004 100644 --- a/ax/modelbridge/tests/test_dispatch_utils.py +++ b/ax/modelbridge/tests/test_dispatch_utils.py @@ -19,7 +19,7 @@ choose_generation_strategy, DEFAULT_BAYESIAN_PARALLELISM, ) -from ax.modelbridge.registry import MBM_X_trans, Mixed_transforms, Models, Y_trans +from ax.modelbridge.registry import Generators, MBM_X_trans, Mixed_transforms, Y_trans from ax.modelbridge.transforms.log_y import LogY from ax.modelbridge.transforms.winsorize import Winsorize from ax.models.winsorization_config import WinsorizationConfig @@ -52,9 +52,9 @@ def test_choose_generation_strategy(self) -> None: sobol_gpei = choose_generation_strategy( search_space=get_branin_search_space() ) - self.assertEqual(sobol_gpei._steps[0].model, Models.SOBOL) + self.assertEqual(sobol_gpei._steps[0].model, Generators.SOBOL) self.assertEqual(sobol_gpei._steps[0].num_trials, 5) - self.assertEqual(sobol_gpei._steps[1].model, Models.BOTORCH_MODULAR) + self.assertEqual(sobol_gpei._steps[1].model, Generators.BOTORCH_MODULAR) expected_model_kwargs: dict[str, Any] = { "torch_device": None, "transforms": expected_transforms, @@ -78,35 +78,35 @@ def test_choose_generation_strategy(self) -> None: search_space=get_branin_search_space(), max_initialization_trials=2, ) - self.assertEqual(sobol_gpei._steps[0].model, Models.SOBOL) + self.assertEqual(sobol_gpei._steps[0].model, Generators.SOBOL) self.assertEqual(sobol_gpei._steps[0].num_trials, 2) - self.assertEqual(sobol_gpei._steps[1].model, Models.BOTORCH_MODULAR) + self.assertEqual(sobol_gpei._steps[1].model, Generators.BOTORCH_MODULAR) with self.subTest("min sobol trials"): sobol_gpei = choose_generation_strategy( search_space=get_branin_search_space(), min_sobol_trials_observed=1, ) - self.assertEqual(sobol_gpei._steps[0].model, Models.SOBOL) + self.assertEqual(sobol_gpei._steps[0].model, Generators.SOBOL) self.assertEqual(sobol_gpei._steps[0].min_trials_observed, 1) - self.assertEqual(sobol_gpei._steps[1].model, Models.BOTORCH_MODULAR) + self.assertEqual(sobol_gpei._steps[1].model, Generators.BOTORCH_MODULAR) with self.subTest("num_initialization_trials > max_initialization_trials"): sobol_gpei = choose_generation_strategy( search_space=get_branin_search_space(), max_initialization_trials=2, num_initialization_trials=3, ) - self.assertEqual(sobol_gpei._steps[0].model, Models.SOBOL) + self.assertEqual(sobol_gpei._steps[0].model, Generators.SOBOL) self.assertEqual(sobol_gpei._steps[0].num_trials, 3) - self.assertEqual(sobol_gpei._steps[1].model, Models.BOTORCH_MODULAR) + self.assertEqual(sobol_gpei._steps[1].model, Generators.BOTORCH_MODULAR) with self.subTest("num_initialization_trials > max_initialization_trials"): sobol_gpei = choose_generation_strategy( search_space=get_branin_search_space(), max_initialization_trials=2, num_initialization_trials=3, ) - self.assertEqual(sobol_gpei._steps[0].model, Models.SOBOL) + self.assertEqual(sobol_gpei._steps[0].model, Generators.SOBOL) self.assertEqual(sobol_gpei._steps[0].num_trials, 3) - self.assertEqual(sobol_gpei._steps[1].model, Models.BOTORCH_MODULAR) + self.assertEqual(sobol_gpei._steps[1].model, Generators.BOTORCH_MODULAR) with self.subTest("MOO"): optimization_config = MultiObjectiveOptimizationConfig( objective=MultiObjective(objectives=[]) @@ -115,9 +115,9 @@ def test_choose_generation_strategy(self) -> None: search_space=get_branin_search_space(), optimization_config=optimization_config, ) - self.assertEqual(sobol_gpei._steps[0].model, Models.SOBOL) + self.assertEqual(sobol_gpei._steps[0].model, Generators.SOBOL) self.assertEqual(sobol_gpei._steps[0].num_trials, 5) - self.assertEqual(sobol_gpei._steps[1].model, Models.BOTORCH_MODULAR) + self.assertEqual(sobol_gpei._steps[1].model, Generators.BOTORCH_MODULAR) model_kwargs = none_throws(sobol_gpei._steps[1].model_kwargs) self.assertEqual( set(model_kwargs.keys()), @@ -133,13 +133,13 @@ def test_choose_generation_strategy(self) -> None: sobol = choose_generation_strategy( search_space=get_factorial_search_space(), num_trials=1000 ) - self.assertEqual(sobol._steps[0].model, Models.SOBOL) + self.assertEqual(sobol._steps[0].model, Generators.SOBOL) self.assertEqual(len(sobol._steps), 1) with self.subTest("Sobol (because of too many categories)"): sobol_large = choose_generation_strategy( search_space=get_large_factorial_search_space(), verbose=True ) - self.assertEqual(sobol_large._steps[0].model, Models.SOBOL) + self.assertEqual(sobol_large._steps[0].model, Generators.SOBOL) self.assertEqual(len(sobol_large._steps), 1) with self.subTest("Sobol (because of too many categories) with saasbo"): with self.assertLogs( @@ -157,7 +157,7 @@ def test_choose_generation_strategy(self) -> None: ), logger.output, ) - self.assertEqual(sobol_large._steps[0].model, Models.SOBOL) + self.assertEqual(sobol_large._steps[0].model, Generators.SOBOL) self.assertEqual(len(sobol_large._steps), 1) with self.subTest("SOBOL due to too many unordered choices"): # Search space with more unordered choices than ordered parameters. @@ -167,7 +167,7 @@ def test_choose_generation_strategy(self) -> None: num_unordered_choices=100, ) ) - self.assertEqual(sobol._steps[0].model, Models.SOBOL) + self.assertEqual(sobol._steps[0].model, Generators.SOBOL) self.assertEqual(len(sobol._steps), 1) with self.subTest("GPEI with more unordered choices than ordered parameters"): # Search space with more unordered choices than ordered parameters. @@ -177,15 +177,15 @@ def test_choose_generation_strategy(self) -> None: num_unordered_choices=10, ) ) - self.assertEqual(sobol_gpei._steps[1].model, Models.BOTORCH_MODULAR) + self.assertEqual(sobol_gpei._steps[1].model, Generators.BOTORCH_MODULAR) with self.subTest("GPEI despite many unordered 2-value parameters"): gs = choose_generation_strategy( search_space=get_large_factorial_search_space( num_levels=2, num_parameters=10 ), ) - self.assertEqual(gs._steps[0].model, Models.SOBOL) - self.assertEqual(gs._steps[1].model, Models.BOTORCH_MODULAR) + self.assertEqual(gs._steps[0].model, Generators.SOBOL) + self.assertEqual(gs._steps[1].model, Generators.BOTORCH_MODULAR) with self.subTest("GPEI-Batched"): sobol_gpei_batched = choose_generation_strategy( search_space=get_branin_search_space(), @@ -196,9 +196,9 @@ def test_choose_generation_strategy(self) -> None: bo_mixed = choose_generation_strategy( search_space=get_factorial_search_space() ) - self.assertEqual(bo_mixed._steps[0].model, Models.SOBOL) + self.assertEqual(bo_mixed._steps[0].model, Generators.SOBOL) self.assertEqual(bo_mixed._steps[0].num_trials, 6) - self.assertEqual(bo_mixed._steps[1].model, Models.BO_MIXED) + self.assertEqual(bo_mixed._steps[1].model, Generators.BO_MIXED) expected_model_kwargs = { "torch_device": None, "transforms": [Winsorize] + Mixed_transforms + Y_trans, @@ -211,9 +211,9 @@ def test_choose_generation_strategy(self) -> None: # pyre-fixme[16]: `Parameter` has no attribute `_is_ordered`. ss.parameters["x2"]._is_ordered = False bo_mixed_2 = choose_generation_strategy(search_space=ss) - self.assertEqual(bo_mixed_2._steps[0].model, Models.SOBOL) + self.assertEqual(bo_mixed_2._steps[0].model, Generators.SOBOL) self.assertEqual(bo_mixed_2._steps[0].num_trials, 5) - self.assertEqual(bo_mixed_2._steps[1].model, Models.BO_MIXED) + self.assertEqual(bo_mixed_2._steps[1].model, Generators.BO_MIXED) expected_model_kwargs = { "torch_device": None, "transforms": [Winsorize] + Mixed_transforms + Y_trans, @@ -230,9 +230,9 @@ def test_choose_generation_strategy(self) -> None: moo_mixed = choose_generation_strategy( search_space=search_space, optimization_config=optimization_config ) - self.assertEqual(moo_mixed._steps[0].model, Models.SOBOL) + self.assertEqual(moo_mixed._steps[0].model, Generators.SOBOL) self.assertEqual(moo_mixed._steps[0].num_trials, 5) - self.assertEqual(moo_mixed._steps[1].model, Models.BO_MIXED) + self.assertEqual(moo_mixed._steps[1].model, Generators.BO_MIXED) model_kwargs = none_throws(moo_mixed._steps[1].model_kwargs) self.assertEqual( set(model_kwargs.keys()), @@ -251,9 +251,9 @@ def test_choose_generation_strategy(self) -> None: num_initialization_trials=3, use_saasbo=True, ) - self.assertEqual(sobol_fullybayesian._steps[0].model, Models.SOBOL) + self.assertEqual(sobol_fullybayesian._steps[0].model, Generators.SOBOL) self.assertEqual(sobol_fullybayesian._steps[0].num_trials, 3) - self.assertEqual(sobol_fullybayesian._steps[1].model, Models.SAASBO) + self.assertEqual(sobol_fullybayesian._steps[1].model, Generators.SAASBO) with self.subTest("SAASBO MOO"): sobol_fullybayesianmoo = choose_generation_strategy( search_space=get_branin_search_space(), @@ -264,11 +264,11 @@ def test_choose_generation_strategy(self) -> None: objective=MultiObjective(objectives=[]) ), ) - self.assertEqual(sobol_fullybayesianmoo._steps[0].model, Models.SOBOL) + self.assertEqual(sobol_fullybayesianmoo._steps[0].model, Generators.SOBOL) self.assertEqual(sobol_fullybayesianmoo._steps[0].num_trials, 3) self.assertEqual( sobol_fullybayesianmoo._steps[1].model, - Models.SAASBO, + Generators.SAASBO, ) with self.subTest("SAASBO"): sobol_fullybayesian_large = choose_generation_strategy( @@ -277,11 +277,13 @@ def test_choose_generation_strategy(self) -> None: ), use_saasbo=True, ) - self.assertEqual(sobol_fullybayesian_large._steps[0].model, Models.SOBOL) + self.assertEqual( + sobol_fullybayesian_large._steps[0].model, Generators.SOBOL + ) self.assertEqual(sobol_fullybayesian_large._steps[0].num_trials, 30) self.assertEqual( sobol_fullybayesian_large._steps[1].model, - Models.SAASBO, + Generators.SAASBO, ) with self.subTest("num_initialization_trials"): ss = get_large_factorial_search_space() @@ -291,32 +293,38 @@ def test_choose_generation_strategy(self) -> None: gs_12_init_trials = choose_generation_strategy( search_space=ss, num_trials=100 ) - self.assertEqual(gs_12_init_trials._steps[0].model, Models.SOBOL) + self.assertEqual(gs_12_init_trials._steps[0].model, Generators.SOBOL) self.assertEqual(gs_12_init_trials._steps[0].num_trials, 12) - self.assertEqual(gs_12_init_trials._steps[1].model, Models.BOTORCH_MODULAR) + self.assertEqual( + gs_12_init_trials._steps[1].model, Generators.BOTORCH_MODULAR + ) # at least 5 initialization trials are performed gs_5_init_trials = choose_generation_strategy(search_space=ss, num_trials=0) - self.assertEqual(gs_5_init_trials._steps[0].model, Models.SOBOL) + self.assertEqual(gs_5_init_trials._steps[0].model, Generators.SOBOL) self.assertEqual(gs_5_init_trials._steps[0].num_trials, 5) - self.assertEqual(gs_5_init_trials._steps[1].model, Models.BOTORCH_MODULAR) + self.assertEqual( + gs_5_init_trials._steps[1].model, Generators.BOTORCH_MODULAR + ) # avoid spending >20% of budget on initialization trials if there are # more than 5 initialization trials gs_6_init_trials = choose_generation_strategy( search_space=ss, num_trials=30 ) - self.assertEqual(gs_6_init_trials._steps[0].model, Models.SOBOL) + self.assertEqual(gs_6_init_trials._steps[0].model, Generators.SOBOL) self.assertEqual(gs_6_init_trials._steps[0].num_trials, 6) - self.assertEqual(gs_6_init_trials._steps[1].model, Models.BOTORCH_MODULAR) + self.assertEqual( + gs_6_init_trials._steps[1].model, Generators.BOTORCH_MODULAR + ) with self.subTest("suggested_model_override"): sobol_gpei = choose_generation_strategy( search_space=get_branin_search_space() ) - self.assertEqual(sobol_gpei._steps[1].model, Models.BOTORCH_MODULAR) + self.assertEqual(sobol_gpei._steps[1].model, Generators.BOTORCH_MODULAR) sobol_saasbo = choose_generation_strategy( search_space=get_branin_search_space(), - suggested_model_override=Models.SAASBO, + suggested_model_override=Generators.SAASBO, ) - self.assertEqual(sobol_saasbo._steps[1].model, Models.SAASBO) + self.assertEqual(sobol_saasbo._steps[1].model, Generators.SAASBO) def test_make_botorch_step_extra(self) -> None: # Test parts of _make_botorch_step that are not directly exposed in @@ -347,12 +355,12 @@ def test_disable_progbar(self) -> None: disable_progbar=disable_progbar, use_saasbo=True, ) - self.assertEqual(sobol_saasbo._steps[0].model, Models.SOBOL) + self.assertEqual(sobol_saasbo._steps[0].model, Generators.SOBOL) self.assertNotIn( "disable_progbar", none_throws(sobol_saasbo._steps[0].model_kwargs), ) - self.assertEqual(sobol_saasbo._steps[1].model, Models.SAASBO) + self.assertEqual(sobol_saasbo._steps[1].model, Generators.SAASBO) self.assertNotIn( "disable_progbar", none_throws(sobol_saasbo._steps[0].model_kwargs), @@ -378,12 +386,12 @@ def test_disable_progbar_for_non_saasbo_discards_the_model_kwarg(self) -> None: use_saasbo=False, ) self.assertEqual(len(gp_saasbo._steps), 2) - self.assertEqual(gp_saasbo._steps[0].model, Models.SOBOL) + self.assertEqual(gp_saasbo._steps[0].model, Generators.SOBOL) self.assertNotIn( "disable_progbar", none_throws(gp_saasbo._steps[0].model_kwargs), ) - self.assertEqual(gp_saasbo._steps[1].model, Models.BOTORCH_MODULAR) + self.assertEqual(gp_saasbo._steps[1].model, Generators.BOTORCH_MODULAR) self.assertNotIn( "disable_progbar", none_throws(gp_saasbo._steps[1].model_kwargs), @@ -547,11 +555,11 @@ def test_num_trials(self) -> None: "with budget that is lower than exhaustive, BayesOpt is used" ): sobol_gpei = choose_generation_strategy(search_space=ss, num_trials=23) - self.assertEqual(sobol_gpei._steps[0].model, Models.SOBOL) - self.assertEqual(sobol_gpei._steps[1].model, Models.BO_MIXED) + self.assertEqual(sobol_gpei._steps[0].model, Generators.SOBOL) + self.assertEqual(sobol_gpei._steps[1].model, Generators.BO_MIXED) with self.subTest("with budget that is exhaustive, Sobol is used"): sobol = choose_generation_strategy(search_space=ss, num_trials=36) - self.assertEqual(sobol._steps[0].model, Models.SOBOL) + self.assertEqual(sobol._steps[0].model, Generators.SOBOL) self.assertEqual(len(sobol._steps), 1) with self.subTest("with budget that is exhaustive and use_saasbo, it warns"): with self.assertLogs( @@ -569,7 +577,7 @@ def test_num_trials(self) -> None: ), logger.output, ) - self.assertEqual(sobol._steps[0].model, Models.SOBOL) + self.assertEqual(sobol._steps[0].model, Generators.SOBOL) self.assertEqual(len(sobol._steps), 1) def test_use_batch_trials(self) -> None: @@ -731,12 +739,12 @@ def test_jit_compile(self) -> None: jit_compile=jit_compile, use_saasbo=True, ) - self.assertEqual(sobol_saasbo._steps[0].model, Models.SOBOL) + self.assertEqual(sobol_saasbo._steps[0].model, Generators.SOBOL) self.assertNotIn( "jit_compile", none_throws(sobol_saasbo._steps[0].model_kwargs), ) - self.assertEqual(sobol_saasbo._steps[1].model, Models.SAASBO) + self.assertEqual(sobol_saasbo._steps[1].model, Generators.SAASBO) self.assertNotIn( "jit_compile", none_throws(sobol_saasbo._steps[0].model_kwargs), @@ -762,12 +770,12 @@ def test_jit_compile_for_non_saasbo_discards_the_model_kwarg(self) -> None: use_saasbo=False, ) self.assertEqual(len(gp_saasbo._steps), 2) - self.assertEqual(gp_saasbo._steps[0].model, Models.SOBOL) + self.assertEqual(gp_saasbo._steps[0].model, Generators.SOBOL) self.assertNotIn( "jit_compile", none_throws(gp_saasbo._steps[0].model_kwargs), ) - self.assertEqual(gp_saasbo._steps[1].model, Models.BOTORCH_MODULAR) + self.assertEqual(gp_saasbo._steps[1].model, Generators.BOTORCH_MODULAR) self.assertNotIn( "jit_compile", none_throws(gp_saasbo._steps[1].model_kwargs), diff --git a/ax/modelbridge/tests/test_external_generation_node.py b/ax/modelbridge/tests/test_external_generation_node.py index bfa446fcb07..caadf25475e 100644 --- a/ax/modelbridge/tests/test_external_generation_node.py +++ b/ax/modelbridge/tests/test_external_generation_node.py @@ -16,7 +16,7 @@ from ax.exceptions.core import UnsupportedError from ax.modelbridge.external_generation_node import ExternalGenerationNode from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.random import RandomModelBridge +from ax.modelbridge.random import RandomAdapter from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_branin_data, @@ -31,7 +31,7 @@ def __init__(self) -> None: super().__init__(node_name="dummy") self.update_count = 0 self.gen_count = 0 - self.generator: RandomModelBridge | None = None + self.generator: RandomAdapter | None = None self.last_pending: list[TParameterization] = [] def update_generator_state(self, experiment: Experiment, data: Data) -> None: diff --git a/ax/modelbridge/tests/test_factory.py b/ax/modelbridge/tests/test_factory.py index dd770025990..6b0375eed5e 100644 --- a/ax/modelbridge/tests/test_factory.py +++ b/ax/modelbridge/tests/test_factory.py @@ -7,7 +7,7 @@ # pyre-strict from ax.core.outcome_constraint import ComparisonOp, ObjectiveThreshold -from ax.modelbridge.discrete import DiscreteModelBridge +from ax.modelbridge.discrete import DiscreteAdapter from ax.modelbridge.factory import ( get_empirical_bayes_thompson, get_factorial, @@ -15,7 +15,7 @@ get_thompson, get_uniform, ) -from ax.modelbridge.random import RandomModelBridge +from ax.modelbridge.random import RandomAdapter from ax.models.discrete.eb_thompson import EmpiricalBayesThompsonSampler from ax.models.discrete.thompson import ThompsonSampler from ax.utils.common.testutils import TestCase @@ -46,14 +46,14 @@ def get_multi_obj_exp_and_opt_config(): return multi_obj_exp, optimization_config -class ModelBridgeFactoryTestSingleObjective(TestCase): +class AdapterFactoryTestSingleObjective(TestCase): def test_model_kwargs(self) -> None: """Tests that model kwargs are passed correctly.""" exp = get_branin_experiment() sobol = get_sobol( search_space=exp.search_space, init_position=2, scramble=False, seed=239 ) - self.assertIsInstance(sobol, RandomModelBridge) + self.assertIsInstance(sobol, RandomAdapter) for _ in range(5): sobol_run = sobol.gen(1) exp.new_batch_trial().add_generator_run(sobol_run).run().mark_completed() @@ -65,7 +65,7 @@ def test_factorial(self) -> None: """Tests factorial instantiation.""" exp = get_factorial_experiment() factorial = get_factorial(exp.search_space) - self.assertIsInstance(factorial, DiscreteModelBridge) + self.assertIsInstance(factorial, DiscreteAdapter) factorial_run = factorial.gen(n=-1) self.assertEqual(len(factorial_run.arms), 24) @@ -73,14 +73,14 @@ def test_empirical_bayes_thompson(self) -> None: """Tests EB/TS instantiation.""" exp = get_factorial_experiment() factorial = get_factorial(exp.search_space) - self.assertIsInstance(factorial, DiscreteModelBridge) + self.assertIsInstance(factorial, DiscreteAdapter) factorial_run = factorial.gen(n=-1) exp.new_batch_trial().add_generator_run(factorial_run).run().mark_completed() data = exp.fetch_data() eb_thompson = get_empirical_bayes_thompson( experiment=exp, data=data, min_weight=0.0 ) - self.assertIsInstance(eb_thompson, DiscreteModelBridge) + self.assertIsInstance(eb_thompson, DiscreteAdapter) self.assertIsInstance(eb_thompson.model, EmpiricalBayesThompsonSampler) thompson_run = eb_thompson.gen(n=5) self.assertEqual(len(thompson_run.arms), 5) @@ -89,7 +89,7 @@ def test_thompson(self) -> None: """Tests TS instantiation.""" exp = get_factorial_experiment() factorial = get_factorial(exp.search_space) - self.assertIsInstance(factorial, DiscreteModelBridge) + self.assertIsInstance(factorial, DiscreteAdapter) factorial_run = factorial.gen(n=-1) exp.new_batch_trial().add_generator_run(factorial_run).run().mark_completed() data = exp.fetch_data() @@ -99,6 +99,6 @@ def test_thompson(self) -> None: def test_uniform(self) -> None: exp = get_branin_experiment() uniform = get_uniform(exp.search_space) - self.assertIsInstance(uniform, RandomModelBridge) + self.assertIsInstance(uniform, RandomAdapter) uniform_run = uniform.gen(n=5) self.assertEqual(len(uniform_run.arms), 5) diff --git a/ax/modelbridge/tests/test_generation_node.py b/ax/modelbridge/tests/test_generation_node.py index 2dc66717281..ccbb702ba70 100644 --- a/ax/modelbridge/tests/test_generation_node.py +++ b/ax/modelbridge/tests/test_generation_node.py @@ -28,8 +28,8 @@ InputConstructorPurpose, NodeInputConstructors, ) -from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec -from ax.modelbridge.registry import Models +from ax.modelbridge.model_spec import FactoryFunctionGeneratorSpec, GeneratorSpec +from ax.modelbridge.registry import Generators from ax.modelbridge.transition_criterion import MinTrials from ax.utils.common.constants import Keys from ax.utils.common.logger import get_logger @@ -44,13 +44,13 @@ class TestGenerationNode(TestCase): def setUp(self) -> None: super().setUp() - self.sobol_model_spec = ModelSpec( - model_enum=Models.SOBOL, + self.sobol_model_spec = GeneratorSpec( + model_enum=Generators.SOBOL, model_kwargs={"init_position": 3}, model_gen_kwargs={"some_gen_kwarg": "some_value"}, ) - self.mbm_model_spec = ModelSpec( - model_enum=Models.BOTORCH_MODULAR, + self.mbm_model_spec = GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, model_kwargs={}, model_gen_kwargs={}, ) @@ -75,8 +75,10 @@ def test_init(self) -> None: model_specs=[self.sobol_model_spec, self.sobol_model_spec], ) mbm_specs = [ - ModelSpec(model_enum=Models.BOTORCH_MODULAR), - ModelSpec(model_enum=Models.BOTORCH_MODULAR, model_key_override="MBM v2"), + GeneratorSpec(model_enum=Generators.BOTORCH_MODULAR), + GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, model_key_override="MBM v2" + ), ] with self.assertRaisesRegex(UserInputError, MISSING_MODEL_SELECTOR_MESSAGE): GenerationNode( @@ -173,8 +175,8 @@ def test_gen_with_trial_type(self) -> None: mbm_short = GenerationNode( node_name="test", model_specs=[ - ModelSpec( - model_enum=Models.BOTORCH_MODULAR, + GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, model_kwargs={}, model_gen_kwargs={ "n": 1, @@ -212,8 +214,8 @@ def test_model_gen_kwargs_deepcopy(self) -> None: node = GenerationNode( node_name="test", model_specs=[ - ModelSpec( - model_enum=Models.BOTORCH_MODULAR, + GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, model_kwargs={}, model_gen_kwargs={ "n": 1, @@ -246,8 +248,8 @@ def test_properties(self) -> None: node = GenerationNode( node_name="test", model_specs=[ - ModelSpec( - model_enum=Models.BOTORCH_MODULAR, + GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, model_kwargs={}, model_gen_kwargs={ "n": 1, @@ -306,7 +308,7 @@ def test_node_string_representation(self) -> None: self.assertEqual( string_rep, - "GenerationNode(model_specs=[ModelSpec(model_enum=BoTorch," + "GenerationNode(model_specs=[GeneratorSpec(model_enum=BoTorch," " model_kwargs={}, model_gen_kwargs={}, model_cv_kwargs={}," " model_key_override=None)], node_name=test, " "transition_criteria=[MinTrials({'threshold': 5, " @@ -322,8 +324,8 @@ def test_single_fixed_features(self) -> None: node = GenerationNode( node_name="test", model_specs=[ - ModelSpec( - model_enum=Models.BOTORCH_MODULAR, + GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, model_kwargs={}, model_gen_kwargs={ "n": 2, @@ -343,13 +345,13 @@ def setUp(self) -> None: super().setUp() self.model_kwargs = {"init_position": 5} self.sobol_generation_step = GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=5, model_kwargs=self.model_kwargs, ) - self.model_spec = ModelSpec( + self.model_spec = GeneratorSpec( # pyre-fixme[6]: For 1st param expected `ModelRegistryBase` but got - # `Union[typing.Callable[..., ModelBridge], ModelRegistryBase]`. + # `Union[typing.Callable[..., Adapter], ModelRegistryBase]`. model_enum=self.sobol_generation_step.model, model_kwargs=self.model_kwargs, ) @@ -362,7 +364,7 @@ def test_init(self) -> None: self.assertEqual(self.sobol_generation_step.model_name, "Sobol") named_generation_step = GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=5, model_kwargs=self.model_kwargs, model_name="Custom Sobol", @@ -372,7 +374,7 @@ def test_init(self) -> None: def test_min_trials_observed(self) -> None: with self.assertRaisesRegex(UserInputError, "min_trials_observed > num_trials"): GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=5, min_trials_observed=10, model_kwargs=self.model_kwargs, @@ -382,7 +384,7 @@ def test_init_factory_function(self) -> None: generation_step = GenerationStep(model=get_sobol, num_trials=-1) self.assertEqual( generation_step.model_specs, - [FactoryFunctionModelSpec(factory_function=get_sobol)], + [FactoryFunctionGeneratorSpec(factory_function=get_sobol)], ) generation_step = GenerationStep( model=get_sobol, num_trials=-1, model_name="test" @@ -390,7 +392,7 @@ def test_init_factory_function(self) -> None: self.assertEqual( generation_step.model_specs, [ - FactoryFunctionModelSpec( + FactoryFunctionGeneratorSpec( factory_function=get_sobol, model_key_override="test" ) ], @@ -415,8 +417,8 @@ def setUp(self) -> None: self.branin_experiment = get_branin_experiment( with_batch=True, with_completed_batch=True ) - ms_mixed = ModelSpec(model_enum=Models.BO_MIXED) - ms_botorch = ModelSpec(model_enum=Models.BOTORCH_MODULAR) + ms_mixed = GeneratorSpec(model_enum=Generators.BO_MIXED) + ms_botorch = GeneratorSpec(model_enum=Generators.BOTORCH_MODULAR) self.mock_aggregation = MagicMock( side_effect=ReductionCriterion.MEAN, spec=ReductionCriterion diff --git a/ax/modelbridge/tests/test_generation_node_input_constructors.py b/ax/modelbridge/tests/test_generation_node_input_constructors.py index 0ca22d7ae25..eccb6577ebb 100644 --- a/ax/modelbridge/tests/test_generation_node_input_constructors.py +++ b/ax/modelbridge/tests/test_generation_node_input_constructors.py @@ -22,8 +22,8 @@ NodeInputConstructors, ) from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.model_spec import ModelSpec -from ax.modelbridge.registry import Models +from ax.modelbridge.model_spec import GeneratorSpec +from ax.modelbridge.registry import Generators from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_branin_experiment @@ -55,8 +55,8 @@ class TestGenerationNodeInputConstructors(TestCase): def setUp(self) -> None: super().setUp() - self.sobol_model_spec = ModelSpec( - model_enum=Models.SOBOL, + self.sobol_model_spec = GeneratorSpec( + model_enum=Generators.SOBOL, model_kwargs={"init_position": 3}, model_gen_kwargs={"some_gen_kwarg": "some_value"}, ) diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index d8efac00486..5604b3f9df4 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -30,7 +30,7 @@ MaxParallelismReachedException, ) from ax.modelbridge.best_model_selector import SingleDiagnosticBestModelSelector -from ax.modelbridge.discrete import DiscreteModelBridge +from ax.modelbridge.discrete import DiscreteAdapter from ax.modelbridge.factory import get_sobol from ax.modelbridge.generation_node import GenerationNode from ax.modelbridge.generation_node_input_constructors import ( @@ -38,16 +38,16 @@ NodeInputConstructors, ) from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.model_spec import ModelSpec -from ax.modelbridge.random import RandomModelBridge +from ax.modelbridge.model_spec import GeneratorSpec +from ax.modelbridge.random import RandomAdapter from ax.modelbridge.registry import ( _extract_model_state_after_gen, Cont_X_trans, + Generators, MBM_MTGP_trans, MODEL_KEY_TO_MODEL_SETUP, - Models, ) -from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.torch import TorchAdapter from ax.modelbridge.transition_criterion import ( AutoTransitionAfterGen, MaxGenerationParallelism, @@ -70,7 +70,7 @@ from pyre_extensions import assert_is_instance, none_throws -class TestGenerationStrategyWithoutModelBridgeMocks(TestCase): +class TestGenerationStrategyWithoutAdapterMocks(TestCase): """The test class above heavily mocks the modelbridge. This makes it difficult to test certain aspects of the GS. This is an alternative test class that makes use of mocking rather sparingly. @@ -93,7 +93,7 @@ def test_with_model_selection(self, mock_model_state: Mock) -> None: nodes=[ GenerationNode( node_name="Sobol", - model_specs=[ModelSpec(model_enum=Models.SOBOL)], + model_specs=[GeneratorSpec(model_enum=Generators.SOBOL)], transition_criteria=[ MinTrials(threshold=2, transition_to="MBM/BO_MIXED") ], @@ -101,8 +101,8 @@ def test_with_model_selection(self, mock_model_state: Mock) -> None: GenerationNode( node_name="MBM/BO_MIXED", model_specs=[ - ModelSpec(model_enum=Models.BOTORCH_MODULAR), - ModelSpec(model_enum=Models.BO_MIXED), + GeneratorSpec(model_enum=Generators.BOTORCH_MODULAR), + GeneratorSpec(model_enum=Generators.BO_MIXED), ], best_model_selector=best_model_selector, ), @@ -150,7 +150,7 @@ def setUp(self) -> None: # Mock out slow model fitting. self.torch_model_bridge_patcher = patch( - f"{TorchModelBridge.__module__}.TorchModelBridge", spec=True + f"{TorchAdapter.__module__}.TorchAdapter", spec=True ) self.mock_torch_model_bridge = self.torch_model_bridge_patcher.start() mock_mb = self.mock_torch_model_bridge.return_value @@ -159,14 +159,14 @@ def setUp(self) -> None: # Mock out slow TS. self.discrete_model_bridge_patcher = patch( - f"{DiscreteModelBridge.__module__}.DiscreteModelBridge", spec=True + f"{DiscreteAdapter.__module__}.DiscreteAdapter", spec=True ) self.mock_discrete_model_bridge = self.discrete_model_bridge_patcher.start() self.mock_discrete_model_bridge.return_value.gen.return_value = self.gr - # Mock in `Models` registry. + # Mock in `Generators` registry. self.registry_setup_dict_patcher = patch.dict( - f"{Models.__module__}.MODEL_KEY_TO_MODEL_SETUP", + f"{Generators.__module__}.MODEL_KEY_TO_MODEL_SETUP", { "Factorial": MODEL_KEY_TO_MODEL_SETUP["Factorial"]._replace( bridge_class=self.mock_discrete_model_bridge @@ -191,7 +191,7 @@ def setUp(self) -> None: self.sobol_GS = GenerationStrategy( steps=[ GenerationStep( - Models.SOBOL, + Generators.SOBOL, num_trials=-1, should_deduplicate=True, ) @@ -228,13 +228,13 @@ def setUp(self) -> None: only_in_statuses=[TrialStatus.RUNNING], ) ] - self.sobol_model_spec = ModelSpec( - model_enum=Models.SOBOL, + self.sobol_model_spec = GeneratorSpec( + model_enum=Generators.SOBOL, model_kwargs=self.step_model_kwargs, model_gen_kwargs={}, ) - self.mbm_model_spec = ModelSpec( - model_enum=Models.BOTORCH_MODULAR, + self.mbm_model_spec = GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, model_kwargs=self.step_model_kwargs, model_gen_kwargs={}, ) @@ -355,12 +355,12 @@ def _get_sobol_mbm_step_gs( name="Sobol+MBM", steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=num_sobol_trials, model_kwargs=self.step_model_kwargs, ), GenerationStep( - model=Models.BOTORCH_MODULAR, + model=Generators.BOTORCH_MODULAR, num_trials=num_mbm_trials, model_kwargs=self.step_model_kwargs, enforce_num_trials=True, @@ -391,8 +391,8 @@ def test_validation(self) -> None: with self.assertRaises(UserInputError): GenerationStrategy( steps=[ - GenerationStep(model=Models.SOBOL, num_trials=5), - GenerationStep(model=Models.BOTORCH_MODULAR, num_trials=-10), + GenerationStep(model=Generators.SOBOL, num_trials=5), + GenerationStep(model=Generators.BOTORCH_MODULAR, num_trials=-10), ] ) @@ -400,8 +400,8 @@ def test_validation(self) -> None: with self.assertRaises(UserInputError): GenerationStrategy( steps=[ - GenerationStep(model=Models.SOBOL, num_trials=-1), - GenerationStep(model=Models.BOTORCH_MODULAR, num_trials=10), + GenerationStep(model=Generators.SOBOL, num_trials=-1), + GenerationStep(model=Generators.BOTORCH_MODULAR, num_trials=10), ] ) @@ -410,8 +410,8 @@ def test_validation(self) -> None: ) factorial_thompson_generation_strategy = GenerationStrategy( steps=[ - GenerationStep(model=Models.FACTORIAL, num_trials=1), - GenerationStep(model=Models.THOMPSON, num_trials=2), + GenerationStep(model=Generators.FACTORIAL, num_trials=1), + GenerationStep(model=Generators.THOMPSON, num_trials=2), ] ) self.assertTrue(factorial_thompson_generation_strategy._uses_registered_models) @@ -425,9 +425,9 @@ def test_validation(self) -> None: GenerationStrategy( steps=[ GenerationStep( - model=Models.SOBOL, num_trials=5, max_parallelism=-1 + model=Generators.SOBOL, num_trials=5, max_parallelism=-1 ), - GenerationStep(model=Models.BOTORCH_MODULAR, num_trials=-1), + GenerationStep(model=Generators.BOTORCH_MODULAR, num_trials=-1), ] ) @@ -451,7 +451,7 @@ def test_string_representation(self) -> None: ), ) gs2 = GenerationStrategy( - steps=[GenerationStep(model=Models.SOBOL, num_trials=-1)] + steps=[GenerationStep(model=Generators.SOBOL, num_trials=-1)] ) self.assertEqual( str(gs2), "GenerationStrategy(name='Sobol', steps=[Sobol for all trials])" @@ -462,8 +462,8 @@ def test_string_representation(self) -> None: GenerationNode( node_name="test", model_specs=[ - ModelSpec( - model_enum=Models.SOBOL, + GeneratorSpec( + model_enum=Generators.SOBOL, model_kwargs={}, model_gen_kwargs={}, ), @@ -474,7 +474,7 @@ def test_string_representation(self) -> None: self.assertEqual( str(gs3), "GenerationStrategy(name='test', nodes=[GenerationNode(" - "model_specs=[ModelSpec(model_enum=Sobol, " + "model_specs=[GeneratorSpec(model_enum=Sobol, " "model_kwargs={}, model_gen_kwargs={}, model_cv_kwargs={}," " model_key_override=None)], node_name=test, " "transition_criteria=[])])", @@ -496,8 +496,10 @@ def test_min_observed(self) -> None: exp = get_branin_experiment(get_branin_experiment()) gs = GenerationStrategy( steps=[ - GenerationStep(model=Models.SOBOL, num_trials=5, min_trials_observed=5), - GenerationStep(model=Models.BOTORCH_MODULAR, num_trials=1), + GenerationStep( + model=Generators.SOBOL, num_trials=5, min_trials_observed=5 + ), + GenerationStep(model=Generators.BOTORCH_MODULAR, num_trials=1), ] ) self.assertFalse(gs.uses_non_registered_models) @@ -514,18 +516,18 @@ def test_do_not_enforce_min_observations(self) -> None: gs = GenerationStrategy( steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=1, min_trials_observed=5, enforce_num_trials=False, ), - GenerationStep(model=Models.BOTORCH_MODULAR, num_trials=1), + GenerationStep(model=Generators.BOTORCH_MODULAR, num_trials=1), ] ) for _ in range(2): gs.gen(exp) # Make sure Sobol is used to generate the 6th point. - self.assertIsInstance(gs._model, RandomModelBridge) + self.assertIsInstance(gs._model, RandomAdapter) def test_sobol_MBM_strategy(self) -> None: exp = get_branin_experiment() @@ -597,14 +599,14 @@ def test_sobol_MBM_strategy_keep_generating(self) -> None: g = self.sobol_MBM_step_GS.gen(exp) exp.new_trial(generator_run=g).run() if i > 4: - self.assertIsInstance(self.sobol_MBM_step_GS.model, TorchModelBridge) + self.assertIsInstance(self.sobol_MBM_step_GS.model, TorchAdapter) def test_sobol_strategy(self) -> None: exp = get_branin_experiment() sobol_generation_strategy = GenerationStrategy( steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=5, max_parallelism=10, enforce_num_trials=False, @@ -623,12 +625,12 @@ def test_factorial_thompson_strategy(self, _: MagicMock) -> None: factorial_thompson_gs = GenerationStrategy( steps=[ GenerationStep( - model=Models.FACTORIAL, + model=Generators.FACTORIAL, num_trials=1, model_kwargs=self.step_model_kwargs, ), GenerationStep( - model=Models.THOMPSON, + model=Generators.THOMPSON, num_trials=-1, model_kwargs=self.step_model_kwargs, ), @@ -658,8 +660,8 @@ def test_factorial_thompson_strategy(self, _: MagicMock) -> None: def test_clone_reset(self) -> None: ftgs = GenerationStrategy( steps=[ - GenerationStep(model=Models.FACTORIAL, num_trials=1), - GenerationStep(model=Models.THOMPSON, num_trials=2), + GenerationStep(model=Generators.FACTORIAL, num_trials=1), + GenerationStep(model=Generators.THOMPSON, num_trials=2), ] ) ftgs._curr = ftgs._steps[1] @@ -670,7 +672,9 @@ def test_kwargs_passed(self) -> None: gs = GenerationStrategy( steps=[ GenerationStep( - model=Models.SOBOL, num_trials=1, model_kwargs={"scramble": False} + model=Generators.SOBOL, + num_trials=1, + model_kwargs={"scramble": False}, ) ] ) @@ -703,14 +707,14 @@ def test_sobol_MBM_strategy_batches(self) -> None: else: grs_2 = sobol_MBM_generation_strategy._gen_with_multiple_nodes(exp, n=2) exp.new_batch_trial(generator_runs=grs_2).run() - self.assertIsInstance(sobol_MBM_generation_strategy.model, TorchModelBridge) + self.assertIsInstance(sobol_MBM_generation_strategy.model, TorchAdapter) def test_with_factory_function(self) -> None: """Checks that generation strategy works with custom factory functions. No information about the model should be saved on generator run.""" - def get_sobol(search_space: SearchSpace) -> RandomModelBridge: - return RandomModelBridge( + def get_sobol(search_space: SearchSpace) -> RandomAdapter: + return RandomAdapter( search_space=search_space, model=SobolGenerator(), transforms=Cont_X_trans, @@ -721,7 +725,7 @@ def get_sobol(search_space: SearchSpace) -> RandomModelBridge: steps=[GenerationStep(model=get_sobol, num_trials=5)] ) g = sobol_generation_strategy.gen(exp) - self.assertIsInstance(sobol_generation_strategy.model, RandomModelBridge) + self.assertIsInstance(sobol_generation_strategy.model, RandomAdapter) self.assertIsNone(g._model_key) self.assertIsNone(g._model_kwargs) self.assertIsNone(g._bridge_kwargs) @@ -729,7 +733,7 @@ def get_sobol(search_space: SearchSpace) -> RandomModelBridge: def test_store_experiment(self) -> None: exp = get_branin_experiment() sobol_generation_strategy = GenerationStrategy( - steps=[GenerationStep(model=Models.SOBOL, num_trials=5)] + steps=[GenerationStep(model=Generators.SOBOL, num_trials=5)] ) self.assertIsNone(sobol_generation_strategy._experiment) sobol_generation_strategy.gen(exp) @@ -739,8 +743,8 @@ def test_trials_as_df(self) -> None: exp = get_branin_experiment() sobol_generation_strategy = GenerationStrategy( steps=[ - GenerationStep(model=Models.SOBOL, num_trials=2), - GenerationStep(model=Models.SOBOL, num_trials=3), + GenerationStep(model=Generators.SOBOL, num_trials=2), + GenerationStep(model=Generators.SOBOL, num_trials=3), ] ) # No experiment attached to the GS, should be None. @@ -755,7 +759,9 @@ def test_trials_as_df(self) -> None: def test_max_parallelism_reached(self) -> None: exp = get_branin_experiment() sobol_generation_strategy = GenerationStrategy( - steps=[GenerationStep(model=Models.SOBOL, num_trials=5, max_parallelism=1)] + steps=[ + GenerationStep(model=Generators.SOBOL, num_trials=5, max_parallelism=1) + ] ) exp.new_trial( generator_run=sobol_generation_strategy.gen(experiment=exp) @@ -784,7 +790,7 @@ def test_deduplication(self) -> None: name="Sobol", steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=-1, # Disable model-level deduplication. model_kwargs={"deduplicate": False}, @@ -815,12 +821,12 @@ def test_current_generator_run_limit(self) -> None: sobol_gs_with_parallelism_limits = GenerationStrategy( steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=NUM_INIT_TRIALS, min_trials_observed=3, ), GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=(NUM_ROUNDS - 1) * SECOND_STEP_PARALLELISM, max_parallelism=SECOND_STEP_PARALLELISM, ), @@ -858,12 +864,12 @@ def test_current_generator_run_limit_unlimited_second_step(self) -> None: sobol_gs_with_parallelism_limits = GenerationStrategy( steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=NUM_INIT_TRIALS, min_trials_observed=3, ), GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=-1, max_parallelism=SECOND_STEP_PARALLELISM, ), @@ -891,9 +897,9 @@ def test_hierarchical_search_space(self) -> None: for _ in range(10): # During each iteration, check that all transformed observation features # contain all parameters of the flat search space. - with patch.object( - RandomModelBridge, "_fit" - ) as mock_model_fit, patch.object(RandomModelBridge, "gen"): + with patch.object(RandomAdapter, "_fit") as mock_model_fit, patch.object( + RandomAdapter, "gen" + ): self.sobol_GS.gen(experiment=experiment) mock_model_fit.assert_called_once() observations = mock_model_fit.call_args[1].get("observations") @@ -933,11 +939,11 @@ def test_gen_multiple(self) -> None: sobol_MBM_gs = self.sobol_MBM_step_GS with mock_patch_method_original( - mock_path=f"{ModelSpec.__module__}.ModelSpec.gen", - original_method=ModelSpec.gen, + mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.gen", + original_method=GeneratorSpec.gen, ) as model_spec_gen_mock, mock_patch_method_original( - mock_path=f"{ModelSpec.__module__}.ModelSpec.fit", - original_method=ModelSpec.fit, + mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.fit", + original_method=GeneratorSpec.fit, ) as model_spec_fit_mock: # Generate first four Sobol GRs (one more to gen after that if # first four become trials. @@ -1018,8 +1024,8 @@ def test_gen_for_multiple_trials_with_multiple_models(self) -> None: sobol_MBM_gs = self.sobol_MBM_step_GS sobol_MBM_gs.experiment = exp with mock_patch_method_original( - mock_path=f"{ModelSpec.__module__}.ModelSpec.gen", - original_method=ModelSpec.gen, + mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.gen", + original_method=GeneratorSpec.gen, ) as model_spec_gen_mock: # Generate first four Sobol GRs (one more to gen after that if # first four become trials. @@ -1091,17 +1097,17 @@ def test_gen_for_multiple_trials_with_multiple_models_with_fixed_features( gs = GenerationStrategy( steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=1, model_kwargs=self.step_model_kwargs, ), GenerationStep( - model=Models.BO_MIXED, + model=Generators.BO_MIXED, num_trials=1, model_kwargs=self.step_model_kwargs, ), GenerationStep( - model=Models.BOTORCH_MODULAR, + model=Generators.BOTORCH_MODULAR, model_kwargs={ # this will cause an error if the model # doesn't get fixed features @@ -1196,12 +1202,12 @@ def test_gs_setup_with_nodes(self) -> None: ], steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=5, model_kwargs=self.step_model_kwargs, ), GenerationStep( - model=Models.BOTORCH_MODULAR, + model=Generators.BOTORCH_MODULAR, num_trials=-1, model_kwargs=self.step_model_kwargs, ), @@ -1254,7 +1260,7 @@ def test_gs_setup_with_nodes(self) -> None: nodes=[ node_1, GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=5, model_kwargs=self.step_model_kwargs, ), @@ -1376,8 +1382,8 @@ def test_gen_with_multiple_nodes_pending_points(self) -> None: "sobol_3": 3, } with mock_patch_method_original( - mock_path=f"{ModelSpec.__module__}.ModelSpec.gen", - original_method=ModelSpec.gen, + mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.gen", + original_method=GeneratorSpec.gen, ) as model_spec_gen_mock: # Generate a trial that should be composed of arms from 3 nodes grs = gs._gen_with_multiple_nodes( @@ -1880,8 +1886,8 @@ def test_gs_with_fixed_features_constructor(self) -> None: trial0.run() # necessary for transition criterion to be met with self.subTest("no passed fixed features gen_with_multiple_nodes"): with mock_patch_method_original( - mock_path=f"{ModelSpec.__module__}.ModelSpec.gen", - original_method=ModelSpec.gen, + mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.gen", + original_method=GeneratorSpec.gen, ) as model_spec_gen_mock: exp.new_batch_trial( generator_runs=gs._gen_with_multiple_nodes(exp, n=9) @@ -1897,8 +1903,8 @@ def test_gs_with_fixed_features_constructor(self) -> None: with self.subTest("passed fixed features gen_with_multiple_nodes"): with mock_patch_method_original( - mock_path=f"{ModelSpec.__module__}.ModelSpec.gen", - original_method=ModelSpec.gen, + mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.gen", + original_method=GeneratorSpec.gen, ) as model_spec_gen_mock: passed_fixed_features = ObservationFeatures( parameters={}, trial_index=4 @@ -1921,8 +1927,8 @@ def test_gs_with_fixed_features_constructor(self) -> None: "no passed fixed features gen_for_multiple_trials_with_multiple_nodes" ): with mock_patch_method_original( - mock_path=f"{ModelSpec.__module__}.ModelSpec.gen", - original_method=ModelSpec.gen, + mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.gen", + original_method=GeneratorSpec.gen, ) as model_spec_gen_mock: exp.new_batch_trial( generator_runs=gs.gen_for_multiple_trials_with_multiple_models( @@ -1942,8 +1948,8 @@ def test_gs_with_fixed_features_constructor(self) -> None: "passed fixed features gen_for_multiple_trials_with_multiple_nodes" ): with mock_patch_method_original( - mock_path=f"{ModelSpec.__module__}.ModelSpec.gen", - original_method=ModelSpec.gen, + mock_path=f"{GeneratorSpec.__module__}.GeneratorSpec.gen", + original_method=GeneratorSpec.gen, ) as model_spec_gen_mock: passed_fixed_features = ObservationFeatures( parameters={}, trial_index=4 diff --git a/ax/modelbridge/tests/test_hierarchical_search_space.py b/ax/modelbridge/tests/test_hierarchical_search_space.py index 4417d29572b..137798c5a26 100644 --- a/ax/modelbridge/tests/test_hierarchical_search_space.py +++ b/ax/modelbridge/tests/test_hierarchical_search_space.py @@ -24,7 +24,7 @@ from ax.core.trial import Trial from ax.metrics.noisy_function import GenericNoisyFunctionMetric from ax.modelbridge.cross_validation import cross_validate -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.runners.synthetic import SyntheticRunner from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase @@ -152,13 +152,13 @@ def _test_gen_base( runner=SyntheticRunner(), ) - sobol = Models.SOBOL(search_space=hss) + sobol = Generators.SOBOL(search_space=hss) for _ in range(num_sobol_trials): trial = experiment.new_trial(generator_run=sobol.gen(n=1)) trial.run().mark_completed() for _ in range(num_bo_trials): - mbm = Models.BOTORCH_MODULAR( + mbm = Generators.BOTORCH_MODULAR( experiment=experiment, data=experiment.fetch_data() ) trial = experiment.new_trial(generator_run=mbm.gen(n=1)) @@ -190,7 +190,7 @@ def _base_test_predict_and_cv( `expect_errors_with_final_parameterization` arg is used to handle the `KeyError` that is expected (but should be fixed) in this setting. """ - mbm = Models.BOTORCH_MODULAR( + mbm = Generators.BOTORCH_MODULAR( experiment=experiment, data=experiment.fetch_data() ) for t in experiment.trials.values(): diff --git a/ax/modelbridge/tests/test_map_torch_modelbridge.py b/ax/modelbridge/tests/test_map_torch_modelbridge.py index 74462063e11..7ade1db5386 100644 --- a/ax/modelbridge/tests/test_map_torch_modelbridge.py +++ b/ax/modelbridge/tests/test_map_torch_modelbridge.py @@ -18,8 +18,8 @@ ObservationFeatures, recombine_observations, ) -from ax.modelbridge.map_torch import MapTorchModelBridge -from ax.models.torch_base import TorchGenResults, TorchModel +from ax.modelbridge.map_torch import MapTorchAdapter +from ax.models.torch_base import TorchGenerator, TorchGenResults from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -28,8 +28,8 @@ ) -class MapTorchModelBridgeTest(TestCase): - def test_TorchModelBridge(self) -> None: +class MapTorchAdapterTest(TestCase): + def test_TorchAdapter(self) -> None: experiment = get_branin_experiment_with_timestamp_map_metric(rate=0.5) for i in range(3): trial = experiment.new_trial().add_arm(arm=get_branin_arms(n=1, seed=i)[0]) @@ -43,8 +43,8 @@ def test_TorchModelBridge(self) -> None: experiment.trials[i].mark_as(status=TrialStatus.COMPLETED) experiment.attach_data(data=experiment.fetch_data()) - model = mock.MagicMock(TorchModel, autospec=True, instance=True) - modelbridge = MapTorchModelBridge( + model = mock.MagicMock(TorchGenerator, autospec=True, instance=True) + modelbridge = MapTorchAdapter( experiment=experiment, search_space=experiment.search_space, data=experiment.lookup_data(), @@ -69,7 +69,7 @@ def test_TorchModelBridge(self) -> None: self.assertEqual(len(modelbridge.get_training_data()), len(objective_df)) # Test _gen - model = mock.MagicMock(TorchModel, autospec=True, instance=True) + model = mock.MagicMock(TorchGenerator, autospec=True, instance=True) model.gen.return_value = TorchGenResults( points=torch.tensor([[0.0, 0.0, 0.0]]), weights=torch.tensor([1.0]), @@ -99,7 +99,7 @@ def test_TorchModelBridge(self) -> None: ) # Test _predict - model = mock.MagicMock(TorchModel, autospec=True, instance=True) + model = mock.MagicMock(TorchGenerator, autospec=True, instance=True) model.predict.return_value = ( torch.tensor([[0.0, 0.0]]), torch.tensor([[[1.0, 0.0], [0.0, 1.0]]]), @@ -154,7 +154,7 @@ def test_TorchModelBridge(self) -> None: ] cv_training_data = recombine_observations(features, data) with mock.patch( - "ax.modelbridge.torch.TorchModelBridge._cross_validate", + "ax.modelbridge.torch.TorchAdapter._cross_validate", return_value=test_data, ): cv_obs_data = modelbridge._cross_validate( diff --git a/ax/modelbridge/tests/test_model_fit_metrics.py b/ax/modelbridge/tests/test_model_fit_metrics.py index 41e35be047d..c425bf925cc 100644 --- a/ax/modelbridge/tests/test_model_fit_metrics.py +++ b/ax/modelbridge/tests/test_model_fit_metrics.py @@ -22,7 +22,7 @@ get_fit_and_std_quality_and_generalization_dict, ) from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.runners.synthetic import SyntheticRunner from ax.service.scheduler import get_fitted_model_bridge, Scheduler, SchedulerOptions from ax.utils.common.constants import Keys @@ -33,7 +33,7 @@ NUM_SOBOL = 5 -class TestModelBridgeFitMetrics(TestCase): +class TestAdapterFitMetrics(TestCase): def setUp(self) -> None: super().setUp() # setting up experiment and generation strategy @@ -56,9 +56,11 @@ def setUp(self) -> None: self.generation_strategy = GenerationStrategy( steps=[ GenerationStep( - model=Models.SOBOL, num_trials=NUM_SOBOL, max_parallelism=NUM_SOBOL + model=Generators.SOBOL, + num_trials=NUM_SOBOL, + max_parallelism=NUM_SOBOL, ), - GenerationStep(model=Models.BOTORCH_MODULAR, num_trials=-1), + GenerationStep(model=Generators.BOTORCH_MODULAR, num_trials=-1), ] ) @@ -68,7 +70,7 @@ def test_model_fit_metrics(self) -> None: generation_strategy=self.generation_strategy, options=SchedulerOptions(), ) - # need to run some trials to initialize the ModelBridge + # need to run some trials to initialize the Adapter scheduler.run_n_trials(max_trials=NUM_SOBOL + 1) model_bridge = get_fitted_model_bridge(scheduler) @@ -148,7 +150,7 @@ class TestGetFitAndStdQualityAndGeneralizationDict(TestCase): def setUp(self) -> None: super().setUp() self.experiment = get_branin_experiment() - self.sobol = Models.SOBOL(search_space=self.experiment.search_space) + self.sobol = Generators.SOBOL(search_space=self.experiment.search_space) def test_it_returns_empty_data_for_sobol(self) -> None: results = get_fit_and_std_quality_and_generalization_dict( @@ -169,7 +171,7 @@ def test_it_returns_float_values_when_fit_can_be_evaluated(self) -> None: sobol_run ).run().mark_completed() data = self.experiment.fetch_data() - botorch_modelbridge = Models.BOTORCH_MODULAR( + botorch_modelbridge = Generators.BOTORCH_MODULAR( experiment=self.experiment, data=data ) diff --git a/ax/modelbridge/tests/test_model_spec.py b/ax/modelbridge/tests/test_model_spec.py index a8690d1d022..15911344919 100644 --- a/ax/modelbridge/tests/test_model_spec.py +++ b/ax/modelbridge/tests/test_model_spec.py @@ -13,20 +13,20 @@ from ax.core.observation import ObservationFeatures from ax.exceptions.core import UserInputError from ax.modelbridge.factory import get_sobol -from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec +from ax.modelbridge.model_spec import FactoryFunctionGeneratorSpec, GeneratorSpec from ax.modelbridge.modelbridge_utils import extract_search_space_digest -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_branin_experiment from ax.utils.testing.mock import mock_botorch_optimize from pyre_extensions import none_throws -class BaseModelSpecTest(TestCase): +class BaseGeneratorSpecTest(TestCase): def setUp(self) -> None: super().setUp() self.experiment = get_branin_experiment() - sobol = Models.SOBOL(search_space=self.experiment.search_space) + sobol = Generators.SOBOL(search_space=self.experiment.search_space) sobol_run = sobol.gen(n=20) self.experiment.new_batch_trial().add_generator_run( sobol_run @@ -34,10 +34,10 @@ def setUp(self) -> None: self.data = self.experiment.fetch_data() -class ModelSpecTest(BaseModelSpecTest): +class GeneratorSpecTest(BaseGeneratorSpecTest): @mock_botorch_optimize def test_construct(self) -> None: - ms = ModelSpec(model_enum=Models.BOTORCH_MODULAR) + ms = GeneratorSpec(model_enum=Generators.BOTORCH_MODULAR) with self.assertRaises(UserInputError): ms.gen(n=1) ms.fit(experiment=self.experiment, data=self.data) @@ -45,13 +45,13 @@ def test_construct(self) -> None: @mock_botorch_optimize # We can use `extract_search_space_digest` as a surrogate for executing - # the full TorchModelBridge._fit. + # the full TorchAdapter._fit. @mock.patch( "ax.modelbridge.torch.extract_search_space_digest", wraps=extract_search_space_digest, ) def test_fit(self, wrapped_extract_ssd: Mock) -> None: - ms = ModelSpec(model_enum=Models.BOTORCH_MODULAR) + ms = GeneratorSpec(model_enum=Generators.BOTORCH_MODULAR) # This should fit the model as usual. ms.fit(experiment=self.experiment, data=self.data) wrapped_extract_ssd.assert_called_once() @@ -67,15 +67,18 @@ def test_fit(self, wrapped_extract_ssd: Mock) -> None: wrapped_extract_ssd.assert_called_once() def test_model_key(self) -> None: - ms = ModelSpec(model_enum=Models.BOTORCH_MODULAR) + ms = GeneratorSpec(model_enum=Generators.BOTORCH_MODULAR) self.assertEqual(ms.model_key, "BoTorch") - ms = ModelSpec( - model_enum=Models.BOTORCH_MODULAR, model_key_override="MBM with defaults" + ms = GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, + model_key_override="MBM with defaults", ) self.assertEqual(ms.model_key, "MBM with defaults") - @patch(f"{ModelSpec.__module__}.compute_diagnostics") - @patch(f"{ModelSpec.__module__}.cross_validate", return_value=["fake-cv-result"]) + @patch(f"{GeneratorSpec.__module__}.compute_diagnostics") + @patch( + f"{GeneratorSpec.__module__}.cross_validate", return_value=["fake-cv-result"] + ) def test_cross_validate_with_GP_model( self, mock_cv: Mock, mock_diagnostics: Mock ) -> None: @@ -83,7 +86,9 @@ def test_cross_validate_with_GP_model( fake_mb = MagicMock() fake_mb._process_and_transform_data = MagicMock(return_value=(None, None)) mock_enum.return_value = fake_mb - ms = ModelSpec(model_enum=mock_enum, model_cv_kwargs={"test_key": "test-value"}) + ms = GeneratorSpec( + model_enum=mock_enum, model_cv_kwargs={"test_key": "test-value"} + ) ms.fit( experiment=self.experiment, data=self.experiment.trials[0].fetch_data(), @@ -135,14 +140,18 @@ def test_cross_validate_with_GP_model( mock_cv.assert_called_with(model=fake_mb, test_key="test-value", test=1) self.assertEqual(ms._last_cv_kwargs, {"test": 1, "test_key": "test-value"}) - @patch(f"{ModelSpec.__module__}.compute_diagnostics") - @patch(f"{ModelSpec.__module__}.cross_validate", side_effect=NotImplementedError) + @patch(f"{GeneratorSpec.__module__}.compute_diagnostics") + @patch( + f"{GeneratorSpec.__module__}.cross_validate", side_effect=NotImplementedError + ) def test_cross_validate_with_non_GP_model( self, mock_cv: Mock, mock_diagnostics: Mock ) -> None: mock_enum = Mock() mock_enum.return_value = "fake-modelbridge" - ms = ModelSpec(model_enum=mock_enum, model_cv_kwargs={"test_key": "test-value"}) + ms = GeneratorSpec( + model_enum=mock_enum, model_cv_kwargs={"test_key": "test-value"} + ) ms.fit( experiment=self.experiment, data=self.experiment.trials[0].fetch_data(), @@ -159,7 +168,7 @@ def test_cross_validate_with_non_GP_model( mock_diagnostics.assert_not_called() def test_fixed_features(self) -> None: - ms = ModelSpec(model_enum=Models.BOTORCH_MODULAR) + ms = GeneratorSpec(model_enum=Generators.BOTORCH_MODULAR) self.assertIsNone(ms.fixed_features) new_features = ObservationFeatures(parameters={"a": 1.0}) ms.fixed_features = new_features @@ -167,7 +176,7 @@ def test_fixed_features(self) -> None: self.assertEqual(ms.model_gen_kwargs["fixed_features"], new_features) def test_gen_attaches_empty_model_fit_metadata_if_fit_not_applicable(self) -> None: - ms = ModelSpec(model_enum=Models.SOBOL) + ms = GeneratorSpec(model_enum=Generators.SOBOL) ms.fit(experiment=self.experiment, data=self.data) gr = ms.gen(n=1) gen_metadata = none_throws(gr.gen_metadata) @@ -177,7 +186,7 @@ def test_gen_attaches_empty_model_fit_metadata_if_fit_not_applicable(self) -> No self.assertEqual(gen_metadata["model_std_generalization"], None) def test_gen_attaches_model_fit_metadata_if_applicable(self) -> None: - ms = ModelSpec(model_enum=Models.BOTORCH_MODULAR) + ms = GeneratorSpec(model_enum=Generators.BOTORCH_MODULAR) ms.fit(experiment=self.experiment, data=self.data) gr = ms.gen(n=1) gen_metadata = none_throws(gr.gen_metadata) @@ -187,8 +196,8 @@ def test_gen_attaches_model_fit_metadata_if_applicable(self) -> None: self.assertIsInstance(gen_metadata["model_std_generalization"], float) def test_spec_string_representation(self) -> None: - ms = ModelSpec( - model_enum=Models.BOTORCH_MODULAR, + ms = GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, model_kwargs={"test_model_kwargs": 1}, model_gen_kwargs={"test_gen_kwargs": 1}, model_cv_kwargs={"test_cv_kwargs": 1}, @@ -204,21 +213,21 @@ def test_spec_string_representation(self) -> None: self.assertIn("test_model_key_override", repr_str) -class FactoryFunctionModelSpecTest(BaseModelSpecTest): +class FactoryFunctionGeneratorSpecTest(BaseGeneratorSpecTest): def test_construct(self) -> None: - ms = FactoryFunctionModelSpec(factory_function=get_sobol) + ms = FactoryFunctionGeneratorSpec(factory_function=get_sobol) with self.assertRaises(UserInputError): ms.gen(n=1) ms.fit(experiment=self.experiment, data=self.data) ms.gen(n=1) def test_model_key(self) -> None: - ms = FactoryFunctionModelSpec(factory_function=get_sobol) + ms = FactoryFunctionGeneratorSpec(factory_function=get_sobol) self.assertEqual(ms.model_key, "get_sobol") with self.assertRaisesRegex(TypeError, "cannot extract name"): # pyre-ignore[6] - Invalid factory function for testing. - FactoryFunctionModelSpec(factory_function="test") - ms = FactoryFunctionModelSpec( + FactoryFunctionGeneratorSpec(factory_function="test") + ms = FactoryFunctionGeneratorSpec( factory_function=get_sobol, model_key_override="fancy sobol" ) self.assertEqual(ms.model_key, "fancy sobol") diff --git a/ax/modelbridge/tests/test_modelbridge_utils.py b/ax/modelbridge/tests/test_modelbridge_utils.py index 8222bc9b73a..410a28b94dc 100644 --- a/ax/modelbridge/tests/test_modelbridge_utils.py +++ b/ax/modelbridge/tests/test_modelbridge_utils.py @@ -38,16 +38,16 @@ from pyre_extensions import none_throws -class TestModelBridgeUtils(TestCase): +class TestAdapterUtils(TestCase): def test__array_to_tensor(self) -> None: - from ax.modelbridge import ModelBridge + from ax.modelbridge import Adapter @dataclass - class MockModelbridge(ModelBridge): + class MockAdapter(Adapter): def _array_to_tensor(self, array: npt.NDArray | list[float]): return _array_to_tensor(array=array) - mock_modelbridge = MockModelbridge() + mock_modelbridge = MockAdapter() arr = [0.0] res = _array_to_tensor(array=arr) self.assertEqual(len(res.size()), 1) @@ -351,7 +351,7 @@ def test_process_contextual_datasets(self) -> None: self.assertIsInstance(d, ContextualDataset) def test_extract_search_space_digest(self) -> None: - # This is also tested as part of broader TorchModelBridge tests. + # This is also tested as part of broader TorchAdapter tests. # Test log & logit scale parameters. for log_scale, logit_scale in [(True, False), (False, True)]: ss = SearchSpace( diff --git a/ax/modelbridge/tests/test_pairwise_modelbridge.py b/ax/modelbridge/tests/test_pairwise_modelbridge.py index c5633c0aa4d..3600cd05288 100644 --- a/ax/modelbridge/tests/test_pairwise_modelbridge.py +++ b/ax/modelbridge/tests/test_pairwise_modelbridge.py @@ -13,9 +13,9 @@ from ax.modelbridge.pairwise import ( _binary_pref_to_comp_pair, _consolidate_comparisons, - PairwiseModelBridge, + PairwiseAdapter, ) -from ax.models.torch.botorch_modular.model import BoTorchModel +from ax.models.torch.botorch_modular.model import BoTorchGenerator from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase @@ -31,7 +31,7 @@ from pyre_extensions import assert_is_instance -class PairwiseModelBridgeTest(TestCase): +class PairwiseAdapterTest(TestCase): def setUp(self) -> None: super().setUp() experiment = get_pbo_experiment() @@ -41,7 +41,7 @@ def setUp(self) -> None: @TestCase.ax_long_test( reason="TODO[T199510629] Fix: break up test into one test per case" ) - def test_PairwiseModelBridge(self) -> None: + def test_PairwiseAdapter(self) -> None: surrogate = Surrogate( botorch_model_class=PairwiseGP, mll_class=PairwiseLaplaceMarginalLogLikelihood, @@ -62,11 +62,11 @@ def test_PairwiseModelBridge(self) -> None: ), ] for botorch_acqf_class, model_gen_options, n in cases: - pmb = PairwiseModelBridge( + pmb = PairwiseAdapter( experiment=self.experiment, search_space=self.experiment.search_space, data=self.data, - model=BoTorchModel( + model=BoTorchGenerator( botorch_acqf_class=botorch_acqf_class, surrogate=surrogate, ), diff --git a/ax/modelbridge/tests/test_prediction_utils.py b/ax/modelbridge/tests/test_prediction_utils.py index 0354cdffe47..9b30fa7d47c 100644 --- a/ax/modelbridge/tests/test_prediction_utils.py +++ b/ax/modelbridge/tests/test_prediction_utils.py @@ -81,8 +81,8 @@ def test_predict_by_features(self) -> None: ) self.assertEqual(len(predictions_map), 3) - @mock.patch("ax.modelbridge.random.RandomModelBridge.predict") - @mock.patch("ax.modelbridge.random.RandomModelBridge") + @mock.patch("ax.modelbridge.random.RandomAdapter.predict") + @mock.patch("ax.modelbridge.random.RandomAdapter") # pyre-fixme[3]: Return type must be annotated. def test_predict_by_features_with_non_predicting_model( self, diff --git a/ax/modelbridge/tests/test_random_modelbridge.py b/ax/modelbridge/tests/test_random_modelbridge.py index b1c890a0796..c8ba5af13dc 100644 --- a/ax/modelbridge/tests/test_random_modelbridge.py +++ b/ax/modelbridge/tests/test_random_modelbridge.py @@ -22,15 +22,15 @@ ) from ax.core.search_space import SearchSpace from ax.exceptions.core import SearchSpaceExhausted -from ax.modelbridge.random import RandomModelBridge +from ax.modelbridge.random import RandomAdapter from ax.modelbridge.registry import Cont_X_trans -from ax.models.random.base import RandomModel +from ax.models.random.base import RandomGenerator from ax.models.random.sobol import SobolGenerator from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_data, get_small_discrete_search_space -class RandomModelBridgeTest(TestCase): +class RandomAdapterTest(TestCase): def setUp(self) -> None: super().setUp() x = RangeParameter("x", ParameterType.FLOAT, lower=0, upper=1) @@ -44,41 +44,41 @@ def setUp(self) -> None: self.search_space = SearchSpace(self.parameters, parameter_constraints) self.model_gen_options = {"option": "yes"} - @mock.patch("ax.modelbridge.random.RandomModelBridge.__init__", return_value=None) + @mock.patch("ax.modelbridge.random.RandomAdapter.__init__", return_value=None) def test_Fit(self, mock_init: mock.Mock) -> None: # pyre-fixme[20]: Argument `model` expected. - modelbridge = RandomModelBridge() - model = mock.create_autospec(RandomModel, instance=True) + modelbridge = RandomAdapter() + model = mock.create_autospec(RandomGenerator, instance=True) modelbridge._fit(model, self.search_space, None) self.assertEqual(modelbridge.parameters, ["x", "y", "z"]) - self.assertTrue(isinstance(modelbridge.model, RandomModel)) + self.assertTrue(isinstance(modelbridge.model, RandomGenerator)) - @mock.patch("ax.modelbridge.random.RandomModelBridge.__init__", return_value=None) + @mock.patch("ax.modelbridge.random.RandomAdapter.__init__", return_value=None) def test_Predict(self, mock_init: mock.Mock) -> None: # pyre-fixme[20]: Argument `model` expected. - modelbridge = RandomModelBridge() + modelbridge = RandomAdapter() modelbridge.transforms = OrderedDict() modelbridge.parameters = ["x", "y", "z"] with self.assertRaises(NotImplementedError): modelbridge._predict([]) - @mock.patch("ax.modelbridge.random.RandomModelBridge.__init__", return_value=None) + @mock.patch("ax.modelbridge.random.RandomAdapter.__init__", return_value=None) def test_CrossValidate(self, mock_init: mock.Mock) -> None: # pyre-fixme[20]: Argument `model` expected. - modelbridge = RandomModelBridge() + modelbridge = RandomAdapter() modelbridge.transforms = OrderedDict() modelbridge.parameters = ["x", "y", "z"] with self.assertRaises(NotImplementedError): modelbridge._cross_validate(self.search_space, [], []) - @mock.patch("ax.modelbridge.random.RandomModelBridge.__init__", return_value=None) + @mock.patch("ax.modelbridge.random.RandomAdapter.__init__", return_value=None) def test_Gen(self, mock_init: mock.Mock) -> None: # Test with constraints # pyre-fixme[20]: Argument `model` expected. - modelbridge = RandomModelBridge(model=RandomModel()) + modelbridge = RandomAdapter(model=RandomGenerator()) modelbridge.parameters = ["x", "y", "z"] modelbridge.transforms = OrderedDict() - modelbridge.model = RandomModel() + modelbridge.model = RandomGenerator() with mock.patch.object( modelbridge.model, "gen", @@ -144,7 +144,7 @@ def test_Gen(self, mock_init: mock.Mock) -> None: self.assertIsNone(gen_args["fixed_features"]) def test_deduplicate(self) -> None: - sobol = RandomModelBridge( + sobol = RandomAdapter( search_space=get_small_discrete_search_space(), model=SobolGenerator(deduplicate=True), transforms=Cont_X_trans, @@ -166,7 +166,7 @@ def test_search_space_not_expanded(self) -> None: trial.mark_running(no_runner_required=True) trial.mark_completed() experiment.add_tracking_metric(metric=Metric("ax_test_metric")) - sobol = RandomModelBridge( + sobol = RandomAdapter( search_space=self.search_space, model=SobolGenerator(), experiment=experiment, diff --git a/ax/modelbridge/tests/test_registry.py b/ax/modelbridge/tests/test_registry.py index d6afccebbd1..11a8e3fc6f6 100644 --- a/ax/modelbridge/tests/test_registry.py +++ b/ax/modelbridge/tests/test_registry.py @@ -10,24 +10,25 @@ from ax.core.observation import ObservationFeatures from ax.core.optimization_config import MultiObjectiveOptimizationConfig -from ax.modelbridge.discrete import DiscreteModelBridge -from ax.modelbridge.random import RandomModelBridge +from ax.modelbridge.discrete import DiscreteAdapter +from ax.modelbridge.random import RandomAdapter from ax.modelbridge.registry import ( _extract_model_state_after_gen, Cont_X_trans, + Generators, get_model_from_generator_run, MODEL_KEY_TO_MODEL_SETUP, Models, Y_trans, ) -from ax.modelbridge.torch import TorchModelBridge -from ax.models.base import Model +from ax.modelbridge.torch import TorchAdapter +from ax.models.base import Generator from ax.models.discrete.eb_thompson import EmpiricalBayesThompsonSampler from ax.models.discrete.thompson import ThompsonSampler from ax.models.random.sobol import SobolGenerator from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel -from ax.models.torch.botorch_modular.model import BoTorchModel +from ax.models.torch.botorch_modular.model import BoTorchGenerator from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.utils.common.kwargs import get_function_argument_names from ax.utils.common.testutils import TestCase @@ -62,7 +63,7 @@ def setUp(self) -> None: def test_botorch_modular(self) -> None: exp = get_branin_experiment(with_batch=True) exp.trials[0].run() - gpei = Models.BOTORCH_MODULAR( + gpei = Generators.BOTORCH_MODULAR( # Model kwargs acquisition_class=Acquisition, botorch_acqf_class=qExpectedImprovement, @@ -71,8 +72,8 @@ def test_botorch_modular(self) -> None: experiment=exp, data=exp.fetch_data(), ) - self.assertIsInstance(gpei, TorchModelBridge) - self.assertIsInstance(gpei.model, BoTorchModel) + self.assertIsInstance(gpei, TorchAdapter) + self.assertIsInstance(gpei.model, BoTorchGenerator) self.assertEqual(gpei.model.botorch_acqf_class, qExpectedImprovement) self.assertEqual(gpei.model.acquisition_class, Acquisition) self.assertEqual(gpei.model.acquisition_options, {"best_f": 0.0}) @@ -86,16 +87,16 @@ def test_botorch_modular(self) -> None: @mock_botorch_optimize def test_SAASBO(self) -> None: exp = get_branin_experiment() - sobol = Models.SOBOL(search_space=exp.search_space) - self.assertIsInstance(sobol, RandomModelBridge) + sobol = Generators.SOBOL(search_space=exp.search_space) + self.assertIsInstance(sobol, RandomAdapter) for _ in range(5): sobol_run = sobol.gen(n=1) self.assertEqual(sobol_run._model_key, "Sobol") exp.new_batch_trial().add_generator_run(sobol_run).run() - saasbo = Models.SAASBO(experiment=exp, data=exp.fetch_data()) - self.assertIsInstance(saasbo, TorchModelBridge) + saasbo = Generators.SAASBO(experiment=exp, data=exp.fetch_data()) + self.assertIsInstance(saasbo, TorchAdapter) self.assertEqual(saasbo._model_key, "SAASBO") - self.assertIsInstance(saasbo.model, BoTorchModel) + self.assertIsInstance(saasbo.model, BoTorchGenerator) surrogate_spec = saasbo.model.surrogate_spec self.assertEqual( surrogate_spec, @@ -108,19 +109,19 @@ def test_SAASBO(self) -> None: @mock_botorch_optimize def test_enum_sobol_legacy_GPEI(self) -> None: - """Tests Sobol and Legacy GPEI instantiation through the Models enum.""" + """Tests Sobol and Legacy GPEI instantiation through the Generators enum.""" exp = get_branin_experiment() # Check that factory generates a valid sobol modelbridge. - sobol = Models.SOBOL(search_space=exp.search_space) - self.assertIsInstance(sobol, RandomModelBridge) + sobol = Generators.SOBOL(search_space=exp.search_space) + self.assertIsInstance(sobol, RandomAdapter) for _ in range(5): sobol_run = sobol.gen(n=1) self.assertEqual(sobol_run._model_key, "Sobol") exp.new_batch_trial().add_generator_run(sobol_run).run() # Check that factory generates a valid GP+EI modelbridge. exp.optimization_config = get_branin_optimization_config() - gpei = Models.LEGACY_BOTORCH(experiment=exp, data=exp.fetch_data()) - self.assertIsInstance(gpei, TorchModelBridge) + gpei = Generators.LEGACY_BOTORCH(experiment=exp, data=exp.fetch_data()) + self.assertIsInstance(gpei, TorchAdapter) self.assertEqual(gpei._model_key, "Legacy_GPEI") botorch_defaults = "ax.models.torch.botorch_defaults" # Check that the callable kwargs and the torch kwargs were recorded. @@ -173,13 +174,13 @@ def test_enum_sobol_legacy_GPEI(self) -> None: }, ) prior_kwargs = {"lengthscale_prior": GammaPrior(6.0, 6.0)} - gpei = Models.LEGACY_BOTORCH( + gpei = Generators.LEGACY_BOTORCH( experiment=exp, data=exp.fetch_data(), search_space=exp.search_space, prior=prior_kwargs, ) - self.assertIsInstance(gpei, TorchModelBridge) + self.assertIsInstance(gpei, TorchAdapter) self.assertEqual( gpei._model_kwargs["prior"], # pyre-ignore prior_kwargs, @@ -187,56 +188,56 @@ def test_enum_sobol_legacy_GPEI(self) -> None: def test_enum_model_kwargs(self) -> None: """Tests that kwargs are passed correctly when instantiating through the - Models enum.""" + Generators enum.""" exp = get_branin_experiment() - sobol = Models.SOBOL( + sobol = Generators.SOBOL( search_space=exp.search_space, init_position=2, scramble=False, seed=239 ) - self.assertIsInstance(sobol, RandomModelBridge) + self.assertIsInstance(sobol, RandomAdapter) for _ in range(5): sobol_run = sobol.gen(1) exp.new_batch_trial().add_generator_run(sobol_run).run() def test_enum_factorial(self) -> None: - """Tests factorial instantiation through the Models enum.""" + """Tests factorial instantiation through the Generators enum.""" exp = get_factorial_experiment() - factorial = Models.FACTORIAL(exp.search_space) - self.assertIsInstance(factorial, DiscreteModelBridge) + factorial = Generators.FACTORIAL(exp.search_space) + self.assertIsInstance(factorial, DiscreteAdapter) factorial_run = factorial.gen(n=-1) self.assertEqual(len(factorial_run.arms), 24) def test_enum_empirical_bayes_thompson(self) -> None: - """Tests EB/TS instantiation through the Models enum.""" + """Tests EB/TS instantiation through the Generators enum.""" exp = get_factorial_experiment() - factorial = Models.FACTORIAL(exp.search_space) - self.assertIsInstance(factorial, DiscreteModelBridge) + factorial = Generators.FACTORIAL(exp.search_space) + self.assertIsInstance(factorial, DiscreteAdapter) factorial_run = factorial.gen(n=-1) exp.new_batch_trial().add_generator_run(factorial_run).run().mark_completed() data = exp.fetch_data() - eb_thompson = Models.EMPIRICAL_BAYES_THOMPSON( + eb_thompson = Generators.EMPIRICAL_BAYES_THOMPSON( experiment=exp, data=data, min_weight=0.0 ) - self.assertIsInstance(eb_thompson, DiscreteModelBridge) + self.assertIsInstance(eb_thompson, DiscreteAdapter) self.assertIsInstance(eb_thompson.model, EmpiricalBayesThompsonSampler) thompson_run = eb_thompson.gen(n=5) self.assertEqual(len(thompson_run.arms), 5) def test_enum_thompson(self) -> None: - """Tests TS instantiation through the Models enum.""" + """Tests TS instantiation through the Generators enum.""" exp = get_factorial_experiment() - factorial = Models.FACTORIAL(exp.search_space) - self.assertIsInstance(factorial, DiscreteModelBridge) + factorial = Generators.FACTORIAL(exp.search_space) + self.assertIsInstance(factorial, DiscreteAdapter) factorial_run = factorial.gen(n=-1) exp.new_batch_trial().add_generator_run(factorial_run).run().mark_completed() data = exp.fetch_data() - thompson = Models.THOMPSON(experiment=exp, data=data) + thompson = Generators.THOMPSON(experiment=exp, data=data) self.assertIsInstance(thompson.model, ThompsonSampler) def test_enum_uniform(self) -> None: - """Tests uniform random instantiation through the Models enum.""" + """Tests uniform random instantiation through the Generators enum.""" exp = get_branin_experiment() - uniform = Models.UNIFORM(exp.search_space) - self.assertIsInstance(uniform, RandomModelBridge) + uniform = Generators.UNIFORM(exp.search_space) + self.assertIsInstance(uniform, RandomAdapter) uniform_run = uniform.gen(n=5) self.assertEqual(len(uniform_run.arms), 5) @@ -244,7 +245,7 @@ def test_view_defaults(self) -> None: """Checks that kwargs are correctly constructed from default kwargs + standard kwargs.""" self.assertEqual( - Models.SOBOL.view_defaults(), + Generators.SOBOL.view_defaults(), ( { "seed": None, @@ -269,11 +270,11 @@ def test_view_defaults(self) -> None: ) self.assertTrue( all( - kw in Models.SOBOL.view_kwargs()[0] + kw in Generators.SOBOL.view_kwargs()[0] for kw in ["seed", "deduplicate", "init_position", "scramble"] ), all( - kw in Models.SOBOL.view_kwargs()[1] + kw in Generators.SOBOL.view_kwargs()[1] for kw in [ "search_space", "model", @@ -295,17 +296,17 @@ def test_view_defaults(self) -> None: @mock_botorch_optimize def test_get_model_from_generator_run(self) -> None: """Tests that it is possible to restore a model from a generator run it - produced, if `Models` registry was used. + produced, if `Generators` registry was used. """ exp = get_branin_experiment() - initial_sobol = Models.SOBOL(experiment=exp, seed=239) + initial_sobol = Generators.SOBOL(experiment=exp, seed=239) gr = initial_sobol.gen(n=1) # Restore the model as it was before generation. sobol = get_model_from_generator_run( generator_run=gr, experiment=exp, data=exp.fetch_data(), - models_enum=Models, + models_enum=Generators, after_gen=False, ) self.assertEqual(sobol.model.init_position, 0) @@ -315,14 +316,14 @@ def test_get_model_from_generator_run(self) -> None: generator_run=gr, experiment=exp, data=exp.fetch_data(), - models_enum=Models, + models_enum=Generators, ) self.assertEqual(sobol_after_gen.model.init_position, 1) self.assertEqual(sobol_after_gen.model.seed, 239) self.assertEqual(initial_sobol.gen(n=1).arms, sobol_after_gen.gen(n=1).arms) exp.new_trial(generator_run=gr) # Check restoration of GPEI, to ensure proper restoration of callable kwargs - gpei = Models.LEGACY_BOTORCH(experiment=exp, data=get_branin_data()) + gpei = Generators.LEGACY_BOTORCH(experiment=exp, data=get_branin_data()) # Punch GPEI model + bridge kwargs into the Sobol generator run, to avoid # a slow call to `gpei.gen`, and remove Sobol's model state. gr._model_key = "Legacy_GPEI" @@ -330,7 +331,7 @@ def test_get_model_from_generator_run(self) -> None: gr._bridge_kwargs = gpei._bridge_kwargs gr._model_state_after_gen = {} gpei_restored = get_model_from_generator_run( - gr, experiment=exp, data=get_branin_data(), models_enum=Models + gr, experiment=exp, data=get_branin_data(), models_enum=Generators ) for key in gpei.__dict__: self.assertIn(key, gpei_restored.__dict__) @@ -340,7 +341,7 @@ def test_get_model_from_generator_run(self) -> None: continue # Fit times are set in instantiation so won't be same. if isinstance(original, OrderedDict) and isinstance(restored, OrderedDict): original, restored = list(original.keys()), list(restored.keys()) - if isinstance(original, Model) and isinstance(restored, Model): + if isinstance(original, Generator) and isinstance(restored, Generator): continue # Model equality is tough to compare. self.assertEqual(original, restored) @@ -401,15 +402,15 @@ def test_ST_MTGP(self, use_saas: bool = False) -> None: ) for surrogate, default_model in zip(surrogates, (False, True)): - constructor = Models.SAAS_MTGP if use_saas else Models.ST_MTGP + constructor = Generators.SAAS_MTGP if use_saas else Generators.ST_MTGP mtgp = constructor( experiment=exp, data=exp.fetch_data(), status_quo_features=status_quo_features, surrogate=surrogate, ) - self.assertIsInstance(mtgp, TorchModelBridge) - self.assertIsInstance(mtgp.model, BoTorchModel) + self.assertIsInstance(mtgp, TorchAdapter) + self.assertIsInstance(mtgp.model, BoTorchGenerator) self.assertEqual(mtgp.model.acquisition_class, Acquisition) is_moo = isinstance( exp.optimization_config, MultiObjectiveOptimizationConfig @@ -451,7 +452,7 @@ def test_SAAS_MTGP(self) -> None: def test_extract_model_state_after_gen(self) -> None: # Test with actual state. exp = get_branin_experiment() - sobol = Models.SOBOL(search_space=exp.search_space) + sobol = Generators.SOBOL(search_space=exp.search_space) gr = sobol.gen(n=1) expected_state = sobol.model._get_state() self.assertEqual(gr._model_state_after_gen, expected_state) @@ -465,3 +466,12 @@ def test_extract_model_state_after_gen(self) -> None: generator_run=gr, model_class=SobolGenerator ) self.assertEqual(extracted, {}) + + def test_deprecation_warning(self) -> None: + """Tests deprecation warning""" + with self.assertRaisesRegex( + DeprecationWarning, + r"Models is deprecated, use \`ax.modelbridge.registry.Generators\`" + r" instead.", + ): + Models.BOTORCH_MODULAR diff --git a/ax/modelbridge/tests/test_robust_modelbridge.py b/ax/modelbridge/tests/test_robust_modelbridge.py index 925da78821b..cb2986ec95e 100644 --- a/ax/modelbridge/tests/test_robust_modelbridge.py +++ b/ax/modelbridge/tests/test_robust_modelbridge.py @@ -18,7 +18,7 @@ from ax.core.types import ComparisonOp from ax.exceptions.core import UnsupportedError from ax.metrics.branin import BraninMetric -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_robust_branin_experiment @@ -44,7 +44,7 @@ def test_robust( ) for _ in range(5): - modelbridge = Models.BOTORCH_MODULAR( + modelbridge = Generators.BOTORCH_MODULAR( experiment=exp, data=exp.fetch_data(), surrogate=Surrogate(botorch_model_class=SingleTaskGP), @@ -129,7 +129,7 @@ def test_mars(self) -> None: def test_unsupported_model(self) -> None: exp = get_robust_branin_experiment() with self.assertRaisesRegex(UnsupportedError, "support robust"): - Models.LEGACY_BOTORCH( + Generators.LEGACY_BOTORCH( experiment=exp, data=exp.fetch_data(), ).gen(n=1) diff --git a/ax/modelbridge/tests/test_torch_modelbridge.py b/ax/modelbridge/tests/test_torch_modelbridge.py index 7f12794fd20..df0b8ae14eb 100644 --- a/ax/modelbridge/tests/test_torch_modelbridge.py +++ b/ax/modelbridge/tests/test_torch_modelbridge.py @@ -29,14 +29,14 @@ from ax.core.parameter import ChoiceParameter, ParameterType, RangeParameter from ax.core.search_space import SearchSpace, SearchSpaceDigest from ax.core.types import ComparisonOp -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.registry import MBM_X_trans -from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.torch import TorchAdapter from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.standardize_y import StandardizeY from ax.modelbridge.transforms.unit_x import UnitX -from ax.models.torch.botorch_modular.model import BoTorchModel -from ax.models.torch_base import TorchGenResults, TorchModel +from ax.models.torch.botorch_modular.model import BoTorchGenerator +from ax.models.torch_base import TorchGenerator, TorchGenResults from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -63,21 +63,21 @@ def _get_modelbridge_from_experiment( transforms: list[type[Transform]] | None = None, device: torch.device | None = None, fit_on_init: bool = True, -) -> TorchModelBridge: - return TorchModelBridge( +) -> TorchAdapter: + return TorchAdapter( experiment=experiment, search_space=experiment.search_space, data=experiment.lookup_data(), - model=BoTorchModel(), + model=BoTorchGenerator(), transforms=transforms or [], torch_device=device, fit_on_init=fit_on_init, ) -class TorchModelBridgeTest(TestCase): +class TorchAdapterTest(TestCase): @mock_botorch_optimize - def test_TorchModelBridge(self, device: torch.device | None = None) -> None: + def test_TorchAdapter(self, device: torch.device | None = None) -> None: feature_names = ["x1", "x2", "x3"] search_space = get_search_space_for_range_values( min=0.0, max=5.0, parameter_names=feature_names @@ -130,7 +130,7 @@ def test_TorchModelBridge(self, device: torch.device | None = None) -> None: ] observations = recombine_observations(observation_features, observation_data) - model = BoTorchModel() + model = BoTorchGenerator() with mock.patch.object(model, "fit", wraps=model.fit) as mock_fit: model_bridge._fit( model=model, search_space=search_space, observations=observations @@ -145,7 +145,7 @@ def test_TorchModelBridge(self, device: torch.device | None = None) -> None: self.assertIsNone(model_fit_args["candidate_metadata"]) self.assertEqual(model_bridge._last_observations, observations) - with mock.patch(f"{TorchModelBridge.__module__}.logger.debug") as mock_logger: + with mock.patch(f"{TorchAdapter.__module__}.logger.debug") as mock_logger: model_bridge._fit( model=model, search_space=search_space, @@ -201,14 +201,14 @@ def test_TorchModelBridge(self, device: torch.device | None = None) -> None: ) es.enter_context( mock.patch( - f"{TorchModelBridge.__module__}.TorchModelBridge." + f"{TorchAdapter.__module__}.TorchAdapter." "_array_callable_to_tensor_callable", return_value=torch.round, ) ) es.enter_context( # silence a warning about inability to generate unique candidates - mock.patch(f"{ModelBridge.__module__}.logger.warning") + mock.patch(f"{Adapter.__module__}.logger.warning") ) gen_run = model_bridge.gen( n=3, @@ -285,19 +285,19 @@ def test_TorchModelBridge(self, device: torch.device | None = None) -> None: X = model_bridge._transform_observation_features(observation_features=obsf) self.assertTrue(torch.equal(X, torch.tensor([[1.0, 2.0]], **tkwargs))) - def _test_TorchModelBridge_torch_dtype_deprecated( + def _test_TorchAdapter_torch_dtype_deprecated( self, torch_dtype: torch.dtype ) -> None: search_space = get_search_space_for_range_values( min=0.0, max=5.0, parameter_names=["x1", "x2", "x3"] ) - model = mock.MagicMock(TorchModel, autospec=True, instance=True) + model = mock.MagicMock(TorchGenerator, autospec=True, instance=True) experiment = Experiment(search_space=search_space, name="test") with self.assertWarnsRegex( DeprecationWarning, - "The `torch_dtype` argument to `TorchModelBridge` is deprecated", + "The `torch_dtype` argument to `TorchAdapter` is deprecated", ): - TorchModelBridge( + TorchAdapter( experiment=experiment, search_space=search_space, data=experiment.lookup_data(), @@ -307,15 +307,15 @@ def _test_TorchModelBridge_torch_dtype_deprecated( torch_dtype=torch_dtype, ) - def test_TorchModelBridge_float(self) -> None: - self._test_TorchModelBridge_torch_dtype_deprecated(torch_dtype=torch.float32) + def test_TorchAdapter_float(self) -> None: + self._test_TorchAdapter_torch_dtype_deprecated(torch_dtype=torch.float32) - def test_TorchModelBridge_float64(self) -> None: - self._test_TorchModelBridge_torch_dtype_deprecated(torch_dtype=torch.float64) + def test_TorchAdapter_float64(self) -> None: + self._test_TorchAdapter_torch_dtype_deprecated(torch_dtype=torch.float64) - def test_TorchModelBridge_cuda(self) -> None: + def test_TorchAdapter_cuda(self) -> None: if torch.cuda.is_available(): - self.test_TorchModelBridge(device=torch.device("cuda")) + self.test_TorchAdapter(device=torch.device("cuda")) @mock_botorch_optimize def test_evaluate_acquisition_function(self) -> None: @@ -406,9 +406,9 @@ def test_best_point(self) -> None: objective=Objective(metric=Metric("a"), minimize=False), outcome_constraints=[], ) - modelbridge = TorchModelBridge( + modelbridge = TorchAdapter( search_space=search_space, - model=TorchModel(), + model=TorchGenerator(), transforms=[transform_1, transform_2], experiment=exp, data=Data(), @@ -432,7 +432,7 @@ def test_best_point(self) -> None: points=torch.tensor([[1.0]]), weights=torch.tensor([1.0]) ) with mock.patch( - f"{TorchModel.__module__}.TorchModel.best_point", + f"{TorchGenerator.__module__}.TorchGenerator.best_point", return_value=torch.tensor([best_point_value]), autospec=True, ), mock.patch.object(modelbridge, "predict", return_value=predict_return_value): @@ -474,7 +474,7 @@ def test_best_point(self) -> None: ) with mock.patch( - f"{TorchModel.__module__}.TorchModel.best_point", + f"{TorchGenerator.__module__}.TorchGenerator.best_point", side_effect=NotImplementedError, autospec=True, ): @@ -506,9 +506,9 @@ def test_candidate_metadata_propagation(self) -> None: "preexisting_batch_cand_metadata": "some_value" } } - model = TorchModel() + model = TorchGenerator() with mock.patch.object(model, "fit", wraps=model.fit) as mock_model_fit: - modelbridge = TorchModelBridge( + modelbridge = TorchAdapter( experiment=exp, search_space=exp.search_space, model=model, @@ -562,13 +562,12 @@ def test_candidate_metadata_propagation(self) -> None: # Check that no candidate metadata is handled correctly. exp = get_branin_experiment(with_status_quo=True) - model = TorchModel() + model = TorchGenerator() with mock.patch( - f"{TorchModelBridge.__module__}." - "TorchModelBridge._validate_observation_data", + f"{TorchAdapter.__module__}." "TorchAdapter._validate_observation_data", autospec=True, ), mock.patch.object(model, "fit", wraps=model.fit) as mock_model_fit: - modelbridge = TorchModelBridge( + modelbridge = TorchAdapter( search_space=exp.search_space, experiment=exp, model=model, @@ -591,9 +590,9 @@ def test_fit_tracking_metrics(self) -> None: with_tracking_metrics=True, ) for fit_tracking_metrics in (True, False): - model = TorchModel() + model = TorchGenerator() with mock.patch.object(model, "fit", wraps=model.fit) as mock_model_fit: - modelbridge = TorchModelBridge( + modelbridge = TorchAdapter( experiment=exp, search_space=exp.search_space, data=exp.lookup_data(), @@ -807,8 +806,8 @@ def test_gen_metadata_untransform(self) -> None: experiment = get_experiment_with_observations( observations=[[0.0, 1.0], [2.0, 3.0]] ) - model = BoTorchModel() - mb = TorchModelBridge( + model = BoTorchGenerator() + mb = TorchAdapter( experiment=experiment, search_space=experiment.search_space, data=experiment.lookup_data(), diff --git a/ax/modelbridge/tests/test_torch_moo_modelbridge.py b/ax/modelbridge/tests/test_torch_moo_modelbridge.py index 8e6cdd2c4b6..5f68380f677 100644 --- a/ax/modelbridge/tests/test_torch_moo_modelbridge.py +++ b/ax/modelbridge/tests/test_torch_moo_modelbridge.py @@ -30,9 +30,9 @@ predicted_pareto_frontier, ) from ax.modelbridge.registry import Cont_X_trans, ST_MTGP_trans, Y_trans -from ax.modelbridge.torch import TorchModelBridge -from ax.models.torch.botorch_modular.model import BoTorchModel -from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel +from ax.modelbridge.torch import TorchAdapter +from ax.models.torch.botorch_modular.model import BoTorchGenerator +from ax.models.torch.botorch_moo import MultiObjectiveBotorchGenerator from ax.models.torch.botorch_moo_defaults import ( infer_objective_thresholds, pareto_frontier_evaluator, @@ -58,7 +58,7 @@ STUBS_PATH: str = get_branin_experiment_with_multi_objective.__module__ -class MultiObjectiveTorchModelBridgeTest(TestCase): +class MultiObjectiveTorchAdapterTest(TestCase): @patch( # Mocking `BraninMetric` as not available while running, so it will # be grabbed from cache during `fetch_data`. @@ -104,9 +104,9 @@ def helper_test_pareto_frontier( trial_indices=exp.trials.keys(), num_objectives=n_outcomes ), ) - modelbridge = TorchModelBridge( + modelbridge = TorchAdapter( search_space=exp.search_space, - model=MultiObjectiveBotorchModel(), + model=MultiObjectiveBotorchGenerator(), optimization_config=exp.optimization_config, transforms=[transform_1, transform_2], experiment=exp, @@ -276,11 +276,11 @@ def test_get_pareto_frontier_and_configs_input_validation(self) -> None: trial_indices=exp.trials.keys(), num_objectives=2 ), ) - modelbridge = TorchModelBridge( + modelbridge = TorchAdapter( experiment=exp, search_space=exp.search_space, data=exp.fetch_data(), - model=MultiObjectiveBotorchModel(), + model=MultiObjectiveBotorchGenerator(), transforms=[], ) observation_features = [ @@ -357,9 +357,9 @@ def test_hypervolume(self, _, cuda: bool = False) -> None: trial_indices=exp.trials.keys(), num_objectives=num_objectives ) ) - modelbridge = TorchModelBridge( + modelbridge = TorchAdapter( search_space=exp.search_space, - model=MultiObjectiveBotorchModel(), + model=MultiObjectiveBotorchGenerator(), optimization_config=optimization_config, transforms=[], experiment=exp, @@ -443,9 +443,9 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None: get_branin_data_multi_objective(trial_indices=exp.trials.keys()) ) data = exp.fetch_data() - modelbridge = TorchModelBridge( + modelbridge = TorchAdapter( search_space=exp.search_space, - model=MultiObjectiveBotorchModel(), + model=MultiObjectiveBotorchGenerator(), optimization_config=exp.optimization_config, transforms=Cont_X_trans + Y_trans, torch_device=torch.device("cuda" if cuda else "cpu"), @@ -573,9 +573,9 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None: trial.mark_running(no_runner_required=True).mark_completed() data = exp.fetch_data() set_rng_seed(0) # make model fitting deterministic - modelbridge = TorchModelBridge( + modelbridge = TorchAdapter( search_space=exp.search_space, - model=MultiObjectiveBotorchModel(), + model=MultiObjectiveBotorchGenerator(), optimization_config=exp.optimization_config, transforms=ST_MTGP_trans, experiment=exp, @@ -635,9 +635,9 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None: # Update trials to match the search space. exp._search_space = hss exp._trials = get_hss_trials_with_fixed_parameter(exp=exp) - modelbridge = TorchModelBridge( + modelbridge = TorchAdapter( search_space=hss, - model=MultiObjectiveBotorchModel(), + model=MultiObjectiveBotorchGenerator(), optimization_config=exp.optimization_config, transforms=Cont_X_trans + Y_trans, torch_device=torch.device("cuda" if cuda else "cpu"), @@ -682,9 +682,9 @@ def test_infer_objective_thresholds(self, _, cuda: bool = False) -> None: get_branin_data_multi_objective(trial_indices=exp.trials.keys()) ) data = exp.fetch_data() - modelbridge = TorchModelBridge( + modelbridge = TorchAdapter( search_space=exp.search_space, - model=BoTorchModel(), + model=BoTorchGenerator(), optimization_config=exp.optimization_config, transforms=Cont_X_trans + Y_trans, torch_device=torch.device("cuda" if cuda else "cpu"), @@ -719,9 +719,9 @@ def test_status_quo_for_non_monolithic_data(self) -> None: # create data where metrics vary in start and end times data = get_non_monolithic_branin_moo_data() - bridge = TorchModelBridge( + bridge = TorchAdapter( search_space=exp.search_space, - model=MultiObjectiveBotorchModel(), + model=MultiObjectiveBotorchGenerator(), optimization_config=exp.optimization_config, experiment=exp, data=data, @@ -739,9 +739,9 @@ def test_best_point(self) -> None: exp.attach_data( get_branin_data_multi_objective(trial_indices=exp.trials.keys()) ) - bridge = TorchModelBridge( + bridge = TorchAdapter( search_space=exp.search_space, - model=MultiObjectiveBotorchModel(), + model=MultiObjectiveBotorchGenerator(), optimization_config=exp.optimization_config, transforms=[], experiment=exp, diff --git a/ax/modelbridge/tests/test_transform_utils.py b/ax/modelbridge/tests/test_transform_utils.py index 9184a5096fe..7f104cfd0fc 100644 --- a/ax/modelbridge/tests/test_transform_utils.py +++ b/ax/modelbridge/tests/test_transform_utils.py @@ -14,7 +14,7 @@ from ax.core.observation import Observation, ObservationData, ObservationFeatures from ax.core.parameter import ParameterType, RangeParameter from ax.core.search_space import SearchSpace -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.transforms.utils import ( ClosestLookupDict, derelativize_optimization_config_with_raw_status_quo, @@ -67,7 +67,7 @@ def test_derelativize_optimization_config_with_raw_status_quo(self, _) -> None: RangeParameter("y", ParameterType.FLOAT, 0, 20), ] ) - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=dummy_search_space, model=None, transforms=[], diff --git a/ax/modelbridge/tests/test_transition_criterion.py b/ax/modelbridge/tests/test_transition_criterion.py index c3b4ba043a0..92b0709143e 100644 --- a/ax/modelbridge/tests/test_transition_criterion.py +++ b/ax/modelbridge/tests/test_transition_criterion.py @@ -20,8 +20,8 @@ GenerationStep, GenerationStrategy, ) -from ax.modelbridge.model_spec import ModelSpec -from ax.modelbridge.registry import Models +from ax.modelbridge.model_spec import GeneratorSpec +from ax.modelbridge.registry import Generators from ax.modelbridge.transition_criterion import ( AutoTransitionAfterGen, AuxiliaryExperimentCheck, @@ -51,8 +51,8 @@ class TestAuxiliaryExperimentPurpose(AuxiliaryExperimentPurpose): class TestTransitionCriterion(TestCase): def setUp(self) -> None: super().setUp() - self.sobol_model_spec = ModelSpec( - model_enum=Models.SOBOL, + self.sobol_model_spec = GeneratorSpec( + model_enum=Generators.SOBOL, model_kwargs={"init_position": 3}, model_gen_kwargs={"some_gen_kwarg": "some_value"}, ) @@ -66,12 +66,12 @@ def test_minimum_preference_criterion(self) -> None: name="SOBOL::default", steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=-1, completion_criteria=[criterion], ), GenerationStep( - model=Models.BOTORCH_MODULAR, + model=Generators.BOTORCH_MODULAR, num_trials=-1, max_parallelism=1, ), @@ -106,7 +106,7 @@ def test_minimum_preference_criterion(self) -> None: ) self.assertEqual( generation_strategy._curr.model_spec_to_gen_from.model_enum, - Models.BOTORCH_MODULAR, + Generators.BOTORCH_MODULAR, ) def test_aux_experiment_check(self) -> None: @@ -197,18 +197,18 @@ def test_default_step_criterion_setup(self) -> None: name="SOBOL+MBM::default", steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=3, ), GenerationStep( - model=Models.BOTORCH_MODULAR, + model=Generators.BOTORCH_MODULAR, num_trials=4, max_parallelism=1, min_trials_observed=2, enforce_num_trials=False, ), GenerationStep( - model=Models.BOTORCH_MODULAR, + model=Generators.BOTORCH_MODULAR, num_trials=-1, ), ], @@ -262,13 +262,13 @@ def test_min_trials_is_met(self) -> None: name="SOBOL::default", steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=4, min_trials_observed=2, enforce_num_trials=True, ), GenerationStep( - Models.SOBOL, + Generators.SOBOL, num_trials=-1, max_parallelism=1, ), @@ -426,19 +426,19 @@ def test_max_trials_is_met(self) -> None: name="SOBOL::default", steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=4, min_trials_observed=0, enforce_num_trials=True, ), GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=4, min_trials_observed=0, enforce_num_trials=False, ), GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=-1, max_parallelism=1, ), @@ -513,7 +513,7 @@ def test_trials_from_node_empty(self) -> None: name="SOBOL::default", steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=4, min_trials_observed=2, enforce_num_trials=True, diff --git a/ax/modelbridge/tests/test_utils.py b/ax/modelbridge/tests/test_utils.py index 1be57ddff77..640ece86460 100644 --- a/ax/modelbridge/tests/test_utils.py +++ b/ax/modelbridge/tests/test_utils.py @@ -30,7 +30,7 @@ observation_data_to_array, pending_observations_as_array_list, ) -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( @@ -43,7 +43,7 @@ TEST_PARAMETERIZATON_LIST = ["5", "foo", "True", "5"] -class TestModelbridgeUtils(TestCase): +class TestAdapterUtils(TestCase): def setUp(self) -> None: super().setUp() self.experiment = get_experiment() @@ -56,7 +56,7 @@ def setUp(self) -> None: arm=self.trial.arm, trial_index=self.trial.index ) self.hss_exp = get_hierarchical_search_space_experiment() - self.hss_sobol = Models.SOBOL(search_space=self.hss_exp.search_space) + self.hss_sobol = Generators.SOBOL(search_space=self.hss_exp.search_space) self.hss_gr = self.hss_sobol.gen(n=1) self.hss_trial = self.hss_exp.new_trial(self.hss_gr) self.hss_arm = none_throws(self.hss_trial.arm) diff --git a/ax/modelbridge/torch.py b/ax/modelbridge/torch.py index efc6f730581..318c2532c0b 100644 --- a/ax/modelbridge/torch.py +++ b/ax/modelbridge/torch.py @@ -42,7 +42,7 @@ from ax.core.types import TCandidateMetadata, TModelPredictArm from ax.exceptions.core import DataRequiredError, UnsupportedError from ax.exceptions.generation_strategy import OptimizationConfigRequired -from ax.modelbridge.base import gen_arms, GenResults, ModelBridge +from ax.modelbridge.base import Adapter, gen_arms, GenResults from ax.modelbridge.modelbridge_utils import ( array_to_observation_data, extract_objective_thresholds, @@ -63,10 +63,10 @@ ) from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.cast import Cast -from ax.models.torch.botorch_modular.model import BoTorchModel -from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel +from ax.models.torch.botorch_modular.model import BoTorchGenerator +from ax.models.torch.botorch_moo import MultiObjectiveBotorchGenerator from ax.models.torch.botorch_moo_defaults import infer_objective_thresholds -from ax.models.torch_base import TorchModel, TorchOptConfig +from ax.models.torch_base import TorchGenerator, TorchOptConfig from ax.models.types import TConfig from ax.utils.common.logger import get_logger from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset @@ -78,11 +78,11 @@ FIT_MODEL_ERROR = "Model must be fit before {action}." -class TorchModelBridge(ModelBridge): +class TorchAdapter(Adapter): """A model bridge for using torch-based models. - Specifies an interface that is implemented by TorchModel. In particular, - model should have methods fit, predict, and gen. See TorchModel for the + Specifies an interface that is implemented by TorchGenerator. In particular, + model should have methods fit, predict, and gen. See TorchGenerator for the API for each of these methods. Requires that all parameters have been transformed to RangeParameters @@ -92,7 +92,7 @@ class TorchModelBridge(ModelBridge): them to the model. """ - model: TorchModel | None = None + model: TorchGenerator | None = None # pyre-fixme[13]: Attribute `outcomes` is never initialized. outcomes: list[str] # pyre-ignore[13]: These are initialized in _fit. # pyre-fixme[13]: Attribute `parameters` is never initialized. @@ -105,7 +105,7 @@ def __init__( experiment: Experiment, search_space: SearchSpace, data: Data, - model: TorchModel, + model: TorchGenerator, transforms: list[type[Transform]], transform_configs: dict[str, TConfig] | None = None, torch_dtype: torch.dtype | None = None, @@ -123,10 +123,10 @@ def __init__( # This warning is being added while we are on 0.4.3, so it will be # released in 0.4.4 or 0.5.0. The `torch_dtype` argument can be removed # in the subsequent minor version. It should also be removed from - # `TorchModelBridge` subclasses. + # `TorchAdapter` subclasses. if torch_dtype is not None: warn( - "The `torch_dtype` argument to `TorchModelBridge` is deprecated" + "The `torch_dtype` argument to `TorchAdapter` is deprecated" " and will be ignored; data will be in double precision.", DeprecationWarning, ) @@ -135,7 +135,7 @@ def __init__( self.dtype: torch.dtype = torch.double self.device = torch_device # pyre-ignore [4]: Attribute `_default_model_gen_options` of class - # `TorchModelBridge` must have a type that does not contain `Any`. + # `TorchAdapter` must have a type that does not contain `Any`. self._default_model_gen_options = default_model_gen_options or {} # Handle init for multi-objective optimization. @@ -198,7 +198,7 @@ def infer_objective_thresholds( optimization_config=optimization_config, fixed_features=fixed_features, ) - # Get transformed args from TorchModelbridge. + # Get transformed args from TorchAdapter. search_space_digest, torch_opt_config = self._get_transformed_model_gen_args( search_space=base_gen_args.search_space, fixed_features=base_gen_args.fixed_features, @@ -210,16 +210,16 @@ def infer_objective_thresholds( "`infer_objective_thresholds` does not support risk measures." ) # Infer objective thresholds. - if isinstance(self.model, MultiObjectiveBotorchModel): + if isinstance(self.model, MultiObjectiveBotorchGenerator): model = self.model.model Xs = self.model.Xs - elif isinstance(self.model, BoTorchModel): + elif isinstance(self.model, BoTorchGenerator): model = self.model.surrogate.model Xs = self.model.surrogate.Xs else: raise UnsupportedError( - "Model must be a MultiObjectiveBotorchModel or an appropriate Modular " - "Botorch Model to infer_objective_thresholds. Found " + "Model must be a MultiObjectiveBotorchGenerator or an appropriate " + "Modular Botorch Model to infer_objective_thresholds. Found " f"{type(self.model)}." ) @@ -526,7 +526,7 @@ def evaluate_acquisition_function( if optimization_config is None: raise ValueError( "The `optimization_config` must be specified either while initializing " - "the ModelBridge or to the `evaluate_acquisition_function` call." + "the Adapter or to the `evaluate_acquisition_function` call." ) # pyre-ignore Incompatible parameter type [9] obs_feats: list[list[ObservationFeatures]] = deepcopy(observation_features) @@ -624,7 +624,7 @@ def _get_fit_args( observation_features, observation_data = separate_observations(observations) # Only update outcomes if fitting a model on tracking metrics. Otherwise, # we will only fit models to the outcomes that are extracted from optimization - # config in ModelBridge.__init__. + # config in Adapter.__init__. if update_outcomes_and_parameters and self._fit_tracking_metrics: for od in observation_data: all_metric_names.update(od.metric_names) @@ -651,7 +651,7 @@ def _get_fit_args( def _fit( self, - model: TorchModel, + model: TorchGenerator, search_space: SearchSpace, observations: list[Observation], parameters: list[str] | None = None, @@ -777,7 +777,7 @@ def _predict( f"predictions of shape {f.shape} for inputs of shape {X.shape}. " "This was likely due to the use of one-to-many input transforms -- " "typically used for robust optimization -- which are not supported in" - "TorchModelBridge.predict." + "TorchAdapter.predict." ) # Convert resulting arrays to observations return array_to_observation_data(f=f, cov=cov, outcomes=self.outcomes) diff --git a/ax/modelbridge/transforms/base.py b/ax/modelbridge/transforms/base.py index 01ebe75a21f..caa5b2c9817 100644 --- a/ax/modelbridge/transforms/base.py +++ b/ax/modelbridge/transforms/base.py @@ -34,7 +34,7 @@ class Transform: Transforms are used to adapt the search space and data into the types and structures expected by the model. When Transforms are used (for - instance, in ModelBridge), it is always assumed that they may potentially + instance, in Adapter), it is always assumed that they may potentially mutate the transformed object in-place. Forward transforms are defined for all four of those quantities. Reverse @@ -53,13 +53,13 @@ class Transform: """ config: TConfig - modelbridge: modelbridge_module.base.ModelBridge | None + modelbridge: modelbridge_module.base.Adapter | None def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: modelbridge_module.base.ModelBridge | None = None, + modelbridge: modelbridge_module.base.Adapter | None = None, config: TConfig | None = None, ) -> None: """Do any initial computations for preparing the transform. @@ -69,7 +69,7 @@ def __init__( Args: search_space: The search space observations: Observations - modelbridge: ModelBridge for referencing experiment, status quo, etc... + modelbridge: Adapter for referencing experiment, status quo, etc... config: A dictionary of options specific to each transform """ if config is None: @@ -109,7 +109,7 @@ def _transform_search_space(self, search_space: SearchSpace) -> SearchSpace: def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: modelbridge_module.base.ModelBridge | None = None, + modelbridge: modelbridge_module.base.Adapter | None = None, fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: """Transform optimization config. diff --git a/ax/modelbridge/transforms/cast.py b/ax/modelbridge/transforms/cast.py index 53baa0f10e4..7bae8ca19a8 100644 --- a/ax/modelbridge/transforms/cast.py +++ b/ax/modelbridge/transforms/cast.py @@ -44,7 +44,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: self.search_space: SearchSpace = none_throws(search_space).clone() diff --git a/ax/modelbridge/transforms/choice_encode.py b/ax/modelbridge/transforms/choice_encode.py index 2aaa19de870..5f30d5c5261 100644 --- a/ax/modelbridge/transforms/choice_encode.py +++ b/ax/modelbridge/transforms/choice_encode.py @@ -55,7 +55,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: assert search_space is not None, "ChoiceToNumericChoice requires search space" @@ -152,7 +152,7 @@ def __init__( self, search_space: SearchSpace, observations: list[Observation], - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: # Identify parameters that should be transformed diff --git a/ax/modelbridge/transforms/convert_metric_names.py b/ax/modelbridge/transforms/convert_metric_names.py index 1fd875ad423..cb2e3dd7b08 100644 --- a/ax/modelbridge/transforms/convert_metric_names.py +++ b/ax/modelbridge/transforms/convert_metric_names.py @@ -41,7 +41,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: assert observations is not None, "ConvertMetricNames requires observations" diff --git a/ax/modelbridge/transforms/derelativize.py b/ax/modelbridge/transforms/derelativize.py index a197ce96609..efb1a5591b1 100644 --- a/ax/modelbridge/transforms/derelativize.py +++ b/ax/modelbridge/transforms/derelativize.py @@ -47,7 +47,7 @@ class Derelativize(Transform): def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: use_raw_sq = self.config.get("use_raw_status_quo", False) @@ -59,7 +59,7 @@ def transform_optimization_config( # Else, we have at least one relative constraint. # Estimate the value at the status quo. if modelbridge is None: - raise ValueError("ModelBridge not supplied to transform.") + raise ValueError("Adapter not supplied to transform.") # Unobserved status quo corresponds to a modelbridge.status_quo of None. if modelbridge.status_quo is None: raise DataRequiredError( diff --git a/ax/modelbridge/transforms/fill_missing_parameters.py b/ax/modelbridge/transforms/fill_missing_parameters.py index ac09b812d8d..79b8d120cd7 100644 --- a/ax/modelbridge/transforms/fill_missing_parameters.py +++ b/ax/modelbridge/transforms/fill_missing_parameters.py @@ -36,7 +36,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: config = config or {} diff --git a/ax/modelbridge/transforms/int_range_to_choice.py b/ax/modelbridge/transforms/int_range_to_choice.py index 1ec3959f984..64868dce6aa 100644 --- a/ax/modelbridge/transforms/int_range_to_choice.py +++ b/ax/modelbridge/transforms/int_range_to_choice.py @@ -31,7 +31,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: assert search_space is not None, "IntRangeToChoice requires search space" diff --git a/ax/modelbridge/transforms/int_to_float.py b/ax/modelbridge/transforms/int_to_float.py index d4495f5549a..2651f02ffd5 100644 --- a/ax/modelbridge/transforms/int_to_float.py +++ b/ax/modelbridge/transforms/int_to_float.py @@ -51,7 +51,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: self.search_space: SearchSpace = none_throws( @@ -200,7 +200,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: if config is not None and "min_choices" in config: diff --git a/ax/modelbridge/transforms/log.py b/ax/modelbridge/transforms/log.py index ee8043cdb6e..f87e0cc2b77 100644 --- a/ax/modelbridge/transforms/log.py +++ b/ax/modelbridge/transforms/log.py @@ -30,7 +30,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: assert search_space is not None, "Log requires search space" diff --git a/ax/modelbridge/transforms/log_y.py b/ax/modelbridge/transforms/log_y.py index a0731aec93a..1376b270262 100644 --- a/ax/modelbridge/transforms/log_y.py +++ b/ax/modelbridge/transforms/log_y.py @@ -51,7 +51,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: base_modelbridge.ModelBridge | None = None, + modelbridge: base_modelbridge.Adapter | None = None, config: TConfig | None = None, ) -> None: if config is None: @@ -83,7 +83,7 @@ def __init__( def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: base_modelbridge.ModelBridge | None = None, + modelbridge: base_modelbridge.Adapter | None = None, fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: for c in optimization_config.all_constraints: diff --git a/ax/modelbridge/transforms/logit.py b/ax/modelbridge/transforms/logit.py index 3c3e4ddb416..d9e6ed674b9 100644 --- a/ax/modelbridge/transforms/logit.py +++ b/ax/modelbridge/transforms/logit.py @@ -30,7 +30,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: assert search_space is not None, "Logit requires search space" diff --git a/ax/modelbridge/transforms/map_key_to_float.py b/ax/modelbridge/transforms/map_key_to_float.py index 759eee41777..b16bdbd2e9e 100644 --- a/ax/modelbridge/transforms/map_key_to_float.py +++ b/ax/modelbridge/transforms/map_key_to_float.py @@ -44,7 +44,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: config = config or {} diff --git a/ax/modelbridge/transforms/map_unit_x.py b/ax/modelbridge/transforms/map_unit_x.py index 881e44c23cb..732502ed464 100644 --- a/ax/modelbridge/transforms/map_unit_x.py +++ b/ax/modelbridge/transforms/map_unit_x.py @@ -36,7 +36,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: modelbridge_module.base.ModelBridge | None = None, + modelbridge: modelbridge_module.base.Adapter | None = None, config: TConfig | None = None, ) -> None: assert observations is not None, "MapUnitX requires observations" diff --git a/ax/modelbridge/transforms/merge_repeated_measurements.py b/ax/modelbridge/transforms/merge_repeated_measurements.py index f406520a800..af2a3d39246 100644 --- a/ax/modelbridge/transforms/merge_repeated_measurements.py +++ b/ax/modelbridge/transforms/merge_repeated_measurements.py @@ -15,7 +15,7 @@ from ax.core.arm import Arm from ax.core.observation import Observation, ObservationData, separate_observations from ax.core.search_space import SearchSpace -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.transforms.base import Transform from ax.models.types import TConfig @@ -36,7 +36,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: ModelBridge | None = None, + modelbridge: Adapter | None = None, config: TConfig | None = None, ) -> None: if observations is None: diff --git a/ax/modelbridge/transforms/metadata_to_float.py b/ax/modelbridge/transforms/metadata_to_float.py index 9ff66f5cd73..0e5ed2d585b 100644 --- a/ax/modelbridge/transforms/metadata_to_float.py +++ b/ax/modelbridge/transforms/metadata_to_float.py @@ -51,7 +51,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: if observations is None or not observations: diff --git a/ax/modelbridge/transforms/metrics_as_task.py b/ax/modelbridge/transforms/metrics_as_task.py index 3e162d6cc99..fc10cfe7738 100644 --- a/ax/modelbridge/transforms/metrics_as_task.py +++ b/ax/modelbridge/transforms/metrics_as_task.py @@ -43,7 +43,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: # Use config to specify metric task map diff --git a/ax/modelbridge/transforms/one_hot.py b/ax/modelbridge/transforms/one_hot.py index 13812806b20..5c6a0d9fcd4 100644 --- a/ax/modelbridge/transforms/one_hot.py +++ b/ax/modelbridge/transforms/one_hot.py @@ -89,7 +89,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: assert search_space is not None, "OneHot requires search space" diff --git a/ax/modelbridge/transforms/power_transform_y.py b/ax/modelbridge/transforms/power_transform_y.py index 3049092a48e..9e73e5ff790 100644 --- a/ax/modelbridge/transforms/power_transform_y.py +++ b/ax/modelbridge/transforms/power_transform_y.py @@ -56,7 +56,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: modelbridge_module.base.ModelBridge | None = None, + modelbridge: modelbridge_module.base.Adapter | None = None, config: TConfig | None = None, ) -> None: """Initialize the ``PowerTransformY`` transform. @@ -64,7 +64,7 @@ def __init__( Args: search_space: The search space of the experiment. Unused. observations: A list of observations from the experiment. - modelbridge: The `ModelBridge` within which the transform is used. Unused. + modelbridge: The `Adapter` within which the transform is used. Unused. config: A dictionary of options to control the behavior of the transform. Can contain the following keys: - "metrics": A list of metric names to apply the transform to. If @@ -132,7 +132,7 @@ def _untransform_observation_data( def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: modelbridge_module.base.ModelBridge | None = None, + modelbridge: modelbridge_module.base.Adapter | None = None, fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: for c in optimization_config.all_constraints: diff --git a/ax/modelbridge/transforms/relativize.py b/ax/modelbridge/transforms/relativize.py index bed943d81f1..b644649e062 100644 --- a/ax/modelbridge/transforms/relativize.py +++ b/ax/modelbridge/transforms/relativize.py @@ -23,7 +23,7 @@ ) from ax.core.outcome_constraint import OutcomeConstraint from ax.core.search_space import SearchSpace -from ax.modelbridge import ModelBridge +from ax.modelbridge import Adapter from ax.modelbridge.transforms.base import Transform from ax.models.types import TConfig from ax.utils.stats.statstools import relativize, unrelativize @@ -37,7 +37,7 @@ class BaseRelativize(Transform, ABC): """ Change the relative flag of the given relative optimization configuration - to False. This is needed in order for the new opt config to pass ModelBridge + to False. This is needed in order for the new opt config to pass Adapter that requires non-relativized opt config. Also transforms absolute data and opt configs to relative. @@ -53,7 +53,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: modelbridge_module.base.ModelBridge | None = None, + modelbridge: modelbridge_module.base.Adapter | None = None, config: TConfig | None = None, ) -> None: cls_name = self.__class__.__name__ @@ -65,7 +65,7 @@ def __init__( config=config, ) # self.modelbridge should NOT be modified - self.modelbridge: ModelBridge = none_throws( + self.modelbridge: Adapter = none_throws( modelbridge, f"{cls_name} transform requires a modelbridge" ) @@ -86,12 +86,12 @@ def control_as_constant(self) -> bool: def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: modelbridge_module.base.ModelBridge | None = None, + modelbridge: modelbridge_module.base.Adapter | None = None, fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: r""" Change the relative flag of the given relative optimization configuration - to False. This is needed in order for the new opt config to pass ModelBridge + to False. This is needed in order for the new opt config to pass Adapter that requires non-relativized opt config. Args: diff --git a/ax/modelbridge/transforms/remove_fixed.py b/ax/modelbridge/transforms/remove_fixed.py index acb0c7aebbc..51d3a256ca8 100644 --- a/ax/modelbridge/transforms/remove_fixed.py +++ b/ax/modelbridge/transforms/remove_fixed.py @@ -33,7 +33,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: assert search_space is not None, "RemoveFixed requires search space" diff --git a/ax/modelbridge/transforms/search_space_to_choice.py b/ax/modelbridge/transforms/search_space_to_choice.py index db1549fa6ef..10bf8e81e1d 100644 --- a/ax/modelbridge/transforms/search_space_to_choice.py +++ b/ax/modelbridge/transforms/search_space_to_choice.py @@ -38,7 +38,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: assert search_space is not None, "SearchSpaceToChoice requires search space" diff --git a/ax/modelbridge/transforms/standardize_y.py b/ax/modelbridge/transforms/standardize_y.py index 88261c672cb..ae384eb34c0 100644 --- a/ax/modelbridge/transforms/standardize_y.py +++ b/ax/modelbridge/transforms/standardize_y.py @@ -41,7 +41,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["base_modelbridge.ModelBridge"] = None, + modelbridge: Optional["base_modelbridge.Adapter"] = None, config: TConfig | None = None, ) -> None: if observations is None or len(observations) == 0: @@ -68,7 +68,7 @@ def _transform_observation_data( def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: Optional["base_modelbridge.ModelBridge"] = None, + modelbridge: Optional["base_modelbridge.Adapter"] = None, fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: for c in optimization_config.all_constraints: diff --git a/ax/modelbridge/transforms/stratified_standardize_y.py b/ax/modelbridge/transforms/stratified_standardize_y.py index ea6d1b31a84..76fd1e776c6 100644 --- a/ax/modelbridge/transforms/stratified_standardize_y.py +++ b/ax/modelbridge/transforms/stratified_standardize_y.py @@ -51,7 +51,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: """Initialize StratifiedStandardizeY. @@ -148,7 +148,7 @@ def transform_observations( def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: if len(optimization_config.all_constraints) == 0: diff --git a/ax/modelbridge/transforms/task_encode.py b/ax/modelbridge/transforms/task_encode.py index 2b43a952a69..3ad85835e94 100644 --- a/ax/modelbridge/transforms/task_encode.py +++ b/ax/modelbridge/transforms/task_encode.py @@ -41,7 +41,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: assert ( diff --git a/ax/modelbridge/transforms/tests/test_derelativize_transform.py b/ax/modelbridge/transforms/tests/test_derelativize_transform.py index ae876398641..20dc614ccd7 100644 --- a/ax/modelbridge/transforms/tests/test_derelativize_transform.py +++ b/ax/modelbridge/transforms/tests/test_derelativize_transform.py @@ -23,7 +23,7 @@ from ax.core.search_space import SearchSpace from ax.core.types import ComparisonOp from ax.exceptions.core import DataRequiredError -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.transforms.derelativize import Derelativize from ax.utils.common.testutils import TestCase @@ -31,7 +31,7 @@ class DerelativizeTransformTest(TestCase): def setUp(self) -> None: super().setUp() - m = mock.patch.object(ModelBridge, "__abstractmethods__", frozenset()) + m = mock.patch.object(Adapter, "__abstractmethods__", frozenset()) self.addCleanup(m.stop) m.start() @@ -84,13 +84,13 @@ def test_DerelativizeTransform(self) -> None: with ExitStack() as es: mock_predict = es.enter_context( mock.patch( - "ax.modelbridge.base.ModelBridge._predict", + "ax.modelbridge.base.Adapter._predict", autospec=True, return_value=predict_return_value, ) ) mock_fit = es.enter_context( - mock.patch("ax.modelbridge.base.ModelBridge._fit", autospec=True) + mock.patch("ax.modelbridge.base.Adapter._fit", autospec=True) ) mock_observations_from_data = es.enter_context( mock.patch( @@ -117,14 +117,14 @@ def _test_DerelativizeTransform( ) -> None: t = Derelativize(search_space=None, observations=[]) - # ModelBridge with in-design status quo + # Adapter with in-design status quo search_space = SearchSpace( parameters=[ RangeParameter("x", ParameterType.FLOAT, 0, 20), RangeParameter("y", ParameterType.FLOAT, 0, 20), ] ) - g = ModelBridge( + g = Adapter( search_space=search_space, model=None, transforms=[], @@ -207,7 +207,7 @@ def _test_DerelativizeTransform( # Test with relative constraint, out-of-design status quo mock_predict.side_effect = RuntimeError() - g = ModelBridge( + g = Adapter( search_space=search_space, model=None, transforms=[], @@ -257,7 +257,7 @@ def _test_DerelativizeTransform( self.assertEqual(mock_predict.call_count, 1) # Raises error if predict fails with in-design status quo - g = ModelBridge( + g = Adapter( search_space=search_space, model=None, transforms=[], @@ -317,7 +317,7 @@ def _test_DerelativizeTransform( t2.transform_optimization_config(deepcopy(oc_scalarized_only), g, None) # Raises error with relative constraint, no status quo. - g = ModelBridge( + g = Adapter( search_space=search_space, model=None, transforms=[], @@ -345,7 +345,7 @@ def test_Errors(self) -> None: search_space = SearchSpace( parameters=[RangeParameter("x", ParameterType.FLOAT, 0, 20)] ) - g = ModelBridge(search_space, None, []) + g = Adapter(search_space, None, []) with self.assertRaises(ValueError): t.transform_optimization_config(oc, None, None) with self.assertRaises(DataRequiredError): diff --git a/ax/modelbridge/transforms/tests/test_relativize_transform.py b/ax/modelbridge/transforms/tests/test_relativize_transform.py index a3dc4210ae3..8ef808060fc 100644 --- a/ax/modelbridge/transforms/tests/test_relativize_transform.py +++ b/ax/modelbridge/transforms/tests/test_relativize_transform.py @@ -21,15 +21,15 @@ from ax.core.outcome_constraint import OutcomeConstraint from ax.core.types import ComparisonOp from ax.metrics.branin import BraninMetric -from ax.modelbridge import ModelBridge -from ax.modelbridge.registry import Models +from ax.modelbridge import Adapter +from ax.modelbridge.registry import Generators from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.relativize import ( BaseRelativize, Relativize, RelativizeWithConstantControl, ) -from ax.models.base import Model +from ax.models.base import Generator from ax.utils.common.testutils import TestCase from ax.utils.stats.statstools import relativize_data from ax.utils.testing.core_stubs import ( @@ -92,7 +92,7 @@ def test_relativize_transform_requires_a_modelbridge_to_have_status_quo_data( ) -> None: for relativize_cls in self.relativize_classes: # modelbridge has no status quo - sobol = Models.SOBOL(search_space=get_search_space()) + sobol = Generators.SOBOL(search_space=get_search_space()) self.assertIsNone(sobol.status_quo) with self.assertRaisesRegex( AssertionError, f"{relativize_cls.__name__} requires status quo data." @@ -129,9 +129,9 @@ def test_relativize_transform_requires_a_modelbridge_to_have_status_quo_data( ) t.mark_completed() data = exp.fetch_data() - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=exp.search_space, - model=Model(), + model=Generator(), transforms=[relativize_cls], experiment=exp, data=data, @@ -146,9 +146,9 @@ def test_relativize_transform_requires_a_modelbridge_to_have_status_quo_data( ) # reset SQ none_throws(exp._status_quo)._parameters["x1"] = 0.0 - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=exp.search_space, - model=Model(), + model=Generator(), transforms=[relativize_cls], experiment=exp, data=data, @@ -467,7 +467,7 @@ class RelativizeDataOptConfigTest(TestCase): def setUp(self) -> None: super().setUp() search_space = get_search_space() - gr = Models.SOBOL(search_space=search_space).gen(n=1) + gr = Generators.SOBOL(search_space=search_space).gen(n=1) self.model = Mock( search_space=search_space, status_quo=Mock( diff --git a/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py b/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py index 2cdc6ec44fd..a6fc7daad5b 100644 --- a/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py +++ b/ax/modelbridge/transforms/tests/test_transform_to_new_sq.py @@ -12,11 +12,11 @@ import numpy.typing as npt from ax.core.batch_trial import BatchTrial from ax.core.observation import observations_from_data -from ax.modelbridge import ModelBridge +from ax.modelbridge import Adapter from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.tests.test_relativize_transform import RelativizeDataTest from ax.modelbridge.transforms.transform_to_new_sq import TransformToNewSQ -from ax.models.base import Model +from ax.models.base import Generator from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import ( get_branin_data_batch, @@ -73,9 +73,9 @@ def setUp(self) -> None: self._refresh_modelbridge() def _refresh_modelbridge(self) -> None: - self.modelbridge = ModelBridge( + self.modelbridge = Adapter( search_space=self.exp.search_space, - model=Model(), + model=Generator(), experiment=self.exp, data=self.exp.lookup_data(), status_quo_name="status_quo", diff --git a/ax/modelbridge/transforms/tests/test_winsorize_transform.py b/ax/modelbridge/transforms/tests/test_winsorize_transform.py index 205fdf83bd5..20045b8ed76 100644 --- a/ax/modelbridge/transforms/tests/test_winsorize_transform.py +++ b/ax/modelbridge/transforms/tests/test_winsorize_transform.py @@ -30,7 +30,7 @@ from ax.core.parameter import ParameterType, RangeParameter from ax.core.search_space import SearchSpace from ax.exceptions.core import DataRequiredError, UnsupportedError, UserInputError -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.transforms.winsorize import ( _get_auto_winsorization_cutoffs_outcome_constraint, _get_auto_winsorization_cutoffs_single_objective, @@ -572,7 +572,7 @@ def test_relative_constraints( self, mock_observations_from_data: mock.Mock, ) -> None: - # ModelBridge with in-design status quo + # Adapter with in-design status quo search_space = SearchSpace( parameters=[ RangeParameter("x", ParameterType.FLOAT, 0, 20), @@ -600,7 +600,7 @@ def test_relative_constraints( ), ], ) - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=search_space, model=None, transforms=[], @@ -618,7 +618,7 @@ def test_relative_constraints( config={"derelativize_with_raw_status_quo": True}, ) - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=search_space, model=None, transforms=[], @@ -706,8 +706,8 @@ def get_default_transform_cutoffs( def _wrap_optimization_config_in_modelbridge( optimization_config: OptimizationConfig, -) -> ModelBridge: - return ModelBridge( +) -> Adapter: + return Adapter( search_space=SearchSpace(parameters=[]), model=1, optimization_config=optimization_config, diff --git a/ax/modelbridge/transforms/time_as_feature.py b/ax/modelbridge/transforms/time_as_feature.py index 669bc2f0728..cb1053dc51d 100644 --- a/ax/modelbridge/transforms/time_as_feature.py +++ b/ax/modelbridge/transforms/time_as_feature.py @@ -48,7 +48,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: assert observations is not None, "TimeAsFeature requires observations" diff --git a/ax/modelbridge/transforms/transform_to_new_sq.py b/ax/modelbridge/transforms/transform_to_new_sq.py index cfa3dae5ce9..b9ea3b76c9e 100644 --- a/ax/modelbridge/transforms/transform_to_new_sq.py +++ b/ax/modelbridge/transforms/transform_to_new_sq.py @@ -37,7 +37,7 @@ class TransformToNewSQ(BaseRelativize): """Map relative values of one batch to SQ of another. Will compute the relative metrics for each arm in each batch, and will then turn - those back into raw metrics but using the status quo values set on the Modelbridge. + those back into raw metrics but using the status quo values set on the Adapter. This is useful if batches are comparable on a relative scale, but have offset in their status quo. This is often approximately true for online @@ -50,7 +50,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: modelbridge_module.base.ModelBridge | None = None, + modelbridge: modelbridge_module.base.Adapter | None = None, config: TConfig | None = None, ) -> None: super().__init__( @@ -98,7 +98,7 @@ def control_as_constant(self) -> bool: def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: modelbridge_module.base.ModelBridge | None = None, + modelbridge: modelbridge_module.base.Adapter | None = None, fixed_features: ObservationFeatures | None = None, ) -> OptimizationConfig: return optimization_config diff --git a/ax/modelbridge/transforms/trial_as_task.py b/ax/modelbridge/transforms/trial_as_task.py index 24c115b2574..638dfe7ece2 100644 --- a/ax/modelbridge/transforms/trial_as_task.py +++ b/ax/modelbridge/transforms/trial_as_task.py @@ -56,7 +56,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: assert observations is not None, "TrialAsTask requires observations" diff --git a/ax/modelbridge/transforms/unit_x.py b/ax/modelbridge/transforms/unit_x.py index 8ce91d91b32..fbfb63248ef 100644 --- a/ax/modelbridge/transforms/unit_x.py +++ b/ax/modelbridge/transforms/unit_x.py @@ -40,7 +40,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: assert search_space is not None, "UnitX requires search space" diff --git a/ax/modelbridge/transforms/utils.py b/ax/modelbridge/transforms/utils.py index e3db346eecb..3d79a59f75d 100644 --- a/ax/modelbridge/transforms/utils.py +++ b/ax/modelbridge/transforms/utils.py @@ -162,7 +162,7 @@ def construct_new_search_space( def derelativize_optimization_config_with_raw_status_quo( optimization_config: OptimizationConfig, - modelbridge: modelbridge_module.base.ModelBridge, + modelbridge: modelbridge_module.base.Adapter, observations: list[Observation] | None, ) -> OptimizationConfig: """Derelativize optimization_config using raw status-quo values""" diff --git a/ax/modelbridge/transforms/winsorize.py b/ax/modelbridge/transforms/winsorize.py index e4cd851ad9a..fd340dd895b 100644 --- a/ax/modelbridge/transforms/winsorize.py +++ b/ax/modelbridge/transforms/winsorize.py @@ -95,7 +95,7 @@ def __init__( self, search_space: SearchSpace | None = None, observations: list[Observation] | None = None, - modelbridge: Optional["modelbridge_module.base.ModelBridge"] = None, + modelbridge: Optional["modelbridge_module.base.Adapter"] = None, config: TConfig | None = None, ) -> None: if observations is None or len(observations) == 0: @@ -173,7 +173,7 @@ def _get_cutoffs( metric_name: str, metric_values: list[float], winsorization_config: WinsorizationConfig | dict[str, WinsorizationConfig], - modelbridge: Optional["modelbridge_module.base.ModelBridge"], + modelbridge: Optional["modelbridge_module.base.Adapter"], observations: list[Observation] | None, optimization_config: OptimizationConfig | None, use_raw_sq: bool, diff --git a/ax/models/base.py b/ax/models/base.py index 735fe730e51..63174288138 100644 --- a/ax/models/base.py +++ b/ax/models/base.py @@ -9,7 +9,7 @@ from typing import Any -class Model: +class Generator: """Base class for an Ax model. Note: the core methods each model has: `fit`, `predict`, `gen`, diff --git a/ax/models/discrete/eb_ashr.py b/ax/models/discrete/eb_ashr.py index 6909b15b7d9..14200c1f29b 100644 --- a/ax/models/discrete/eb_ashr.py +++ b/ax/models/discrete/eb_ashr.py @@ -15,7 +15,7 @@ from ax.core.types import TGenMetadata, TParamValue, TParamValueList from ax.models.discrete.ashr_utils import Ashr, GaussianMixture from ax.models.discrete.thompson import ThompsonSampler -from ax.models.discrete_base import DiscreteModel +from ax.models.discrete_base import DiscreteGenerator from ax.models.types import TConfig from ax.utils.common.docutils import copy_doc from pyre_extensions import none_throws @@ -272,7 +272,7 @@ def _get_success_measurement(self, objective_weights: npt.NDArray) -> npt.NDArra success = np.dot(posterior_means, objective_weights) return success - @copy_doc(DiscreteModel.gen) + @copy_doc(DiscreteGenerator.gen) def gen( self, n: int, diff --git a/ax/models/discrete/full_factorial.py b/ax/models/discrete/full_factorial.py index 5c9c4384d64..4a1bb346d47 100644 --- a/ax/models/discrete/full_factorial.py +++ b/ax/models/discrete/full_factorial.py @@ -14,7 +14,7 @@ import numpy.typing as npt from ax.core.types import TGenMetadata, TParamValue, TParamValueList -from ax.models.discrete_base import DiscreteModel +from ax.models.discrete_base import DiscreteGenerator from ax.models.types import TConfig from ax.utils.common.docutils import copy_doc from ax.utils.common.logger import get_logger @@ -23,7 +23,7 @@ logger: logging.Logger = get_logger(__name__) -class FullFactorialGenerator(DiscreteModel): +class FullFactorialGenerator(DiscreteGenerator): """Generator for full factorial designs. Generates arms for all possible combinations of parameter values, @@ -48,7 +48,7 @@ def __init__( self.max_cardinality = max_cardinality self.check_cardinality = check_cardinality - @copy_doc(DiscreteModel.gen) + @copy_doc(DiscreteGenerator.gen) # pyre-fixme[15]: Inconsistent override in return def gen( self, diff --git a/ax/models/discrete/thompson.py b/ax/models/discrete/thompson.py index d5a079d1154..1db31fcdec3 100644 --- a/ax/models/discrete/thompson.py +++ b/ax/models/discrete/thompson.py @@ -16,14 +16,14 @@ from ax.core.types import TGenMetadata, TParamValue, TParamValueList from ax.exceptions.constants import TS_MIN_WEIGHT_ERROR, TS_NO_FEASIBLE_ARMS_ERROR from ax.exceptions.model import ModelError -from ax.models.discrete_base import DiscreteModel +from ax.models.discrete_base import DiscreteGenerator from ax.models.types import TConfig from ax.utils.common.docutils import copy_doc from pyre_extensions import assert_is_instance, none_throws -class ThompsonSampler(DiscreteModel): +class ThompsonSampler(DiscreteGenerator): """Generator for Thompson sampling. The generator performs Thompson sampling on the data passed in via `fit`. @@ -57,7 +57,7 @@ def __init__( list[dict[TParamValueList, tuple[float, float]]] | None ) = None - @copy_doc(DiscreteModel.fit) + @copy_doc(DiscreteGenerator.fit) def fit( self, Xs: Sequence[Sequence[Sequence[TParamValue]]], @@ -76,7 +76,7 @@ def fit( Yvars=none_throws(self.Yvars), ) - @copy_doc(DiscreteModel.gen) + @copy_doc(DiscreteGenerator.gen) def gen( self, n: int, @@ -140,7 +140,7 @@ def gen( }, ) - @copy_doc(DiscreteModel.predict) + @copy_doc(DiscreteGenerator.predict) def predict( self, X: Sequence[Sequence[TParamValue]] ) -> tuple[npt.NDArray, npt.NDArray]: diff --git a/ax/models/discrete_base.py b/ax/models/discrete_base.py index 565ae612241..8e3e6f88d86 100644 --- a/ax/models/discrete_base.py +++ b/ax/models/discrete_base.py @@ -10,11 +10,11 @@ import numpy.typing as npt from ax.core.types import TGenMetadata, TParamValue, TParamValueList -from ax.models.base import Model +from ax.models.base import Generator from ax.models.types import TConfig -class DiscreteModel(Model): +class DiscreteGenerator(Generator): """This class specifies the interface for a model based on discrete parameters. These methods should be implemented to have access to all of the features diff --git a/ax/models/model_utils.py b/ax/models/model_utils.py index 664a6e99ea5..04bcee34230 100644 --- a/ax/models/model_utils.py +++ b/ax/models/model_utils.py @@ -30,7 +30,7 @@ TTensoray = TypeVar("TTensoray", bound=Tensoray) -class TorchModelLike(Protocol): +class TorchGeneratorLike(Protocol): """A protocol that stands in for ``TorchModel`` like objects that have a ``predict`` method. """ @@ -68,13 +68,13 @@ def rejection_sample( """Rejection sample in parameter space. Parameter space is typically [0, 1] for all tunable parameters. - Models must implement a `gen_unconstrained` method in order to support + Generators must implement a `gen_unconstrained` method in order to support rejection sampling via this utility. Args: gen_unconstrained: A callable that generates unconstrained points in the parameter space. This is typically the `_gen_unconstrained` method - of a `RandomModel`. + of a `RandomGenerator`. n: Number of samples to generate. d: Dimensionality of the parameter space. tunable_feature_indices: Indices of the tunable features in the @@ -269,13 +269,13 @@ def validate_bounds( if bound[0] != 0 or bound[1] != 1: raise ValueError( "This generator operates on [0,1]^d. Please make use " - "of the UnitX transform in the ModelBridge, and ensure " + "of the UnitX transform in the Adapter, and ensure " "task features are fixed." ) def best_observed_point( - model: TorchModelLike, + model: TorchGeneratorLike, bounds: Sequence[tuple[float, float]], objective_weights: TTensoray | None, outcome_constraints: tuple[TTensoray, TTensoray] | None = None, @@ -351,7 +351,7 @@ def best_observed_point( def best_in_sample_point( Xs: Sequence[TTensoray], - model: TorchModelLike, + model: TorchGeneratorLike, bounds: Sequence[tuple[float, float]], objective_weights: TTensoray | None, outcome_constraints: tuple[TTensoray, TTensoray] | None = None, diff --git a/ax/models/random/base.py b/ax/models/random/base.py index 06c5bc53be0..e06219d6c61 100644 --- a/ax/models/random/base.py +++ b/ax/models/random/base.py @@ -14,7 +14,7 @@ import numpy.typing as npt import torch from ax.exceptions.core import SearchSpaceExhausted -from ax.models.base import Model +from ax.models.base import Generator from ax.models.model_utils import ( add_fixed_features, rejection_sample, @@ -33,7 +33,7 @@ logger: Logger = get_logger(__name__) -class RandomModel(Model): +class RandomGenerator(Generator): """This class specifies the basic skeleton for a random model. As random generators do not make use of models, they do not implement @@ -186,7 +186,7 @@ def gen( self.generated_points = np.vstack([self.generated_points, points]) return points, np.ones(len(points)) - @copy_doc(Model._get_state) + @copy_doc(Generator._get_state) def _get_state(self) -> dict[str, Any]: state = super()._get_state() state.update( @@ -237,7 +237,7 @@ def _gen_samples(self, n: int, tunable_d: int) -> npt.NDArray: (n x d) array of generated points. """ - raise NotImplementedError("Base RandomModel can't generate samples.") + raise NotImplementedError("Base RandomGenerator can't generate samples.") def _convert_inequality_constraints( self, diff --git a/ax/models/random/sobol.py b/ax/models/random/sobol.py index c8a54821006..bc73ea9a10e 100644 --- a/ax/models/random/sobol.py +++ b/ax/models/random/sobol.py @@ -12,13 +12,13 @@ import numpy.typing as npt import torch from ax.models.model_utils import tunable_feature_indices -from ax.models.random.base import RandomModel +from ax.models.random.base import RandomGenerator from ax.models.types import TConfig from pyre_extensions import none_throws from torch.quasirandom import SobolEngine -class SobolGenerator(RandomModel): +class SobolGenerator(RandomGenerator): """This class specifies the generation algorithm for a Sobol generator. As Sobol does not make use of a model, it does not implement @@ -27,7 +27,7 @@ class SobolGenerator(RandomModel): Attributes: scramble: If True, permutes the parameter values among the elements of the Sobol sequence. Default is True. - See base `RandomModel` for a description of remaining attributes. + See base `RandomGenerator` for a description of remaining attributes. """ def __init__( diff --git a/ax/models/random/uniform.py b/ax/models/random/uniform.py index ceb4631527d..d86a3e68faa 100644 --- a/ax/models/random/uniform.py +++ b/ax/models/random/uniform.py @@ -9,16 +9,16 @@ import numpy as np import numpy.typing as npt -from ax.models.random.base import RandomModel +from ax.models.random.base import RandomGenerator -class UniformGenerator(RandomModel): +class UniformGenerator(RandomGenerator): """This class specifies a uniform random generation algorithm. As a uniform generator does not make use of a model, it does not implement the fit or predict methods. - See base `RandomModel` for a description of model attributes. + See base `RandomGenerator` for a description of model attributes. """ def __init__( diff --git a/ax/models/tests/test_base.py b/ax/models/tests/test_base.py index 334f7f54b83..57cb059a536 100644 --- a/ax/models/tests/test_base.py +++ b/ax/models/tests/test_base.py @@ -6,13 +6,13 @@ # pyre-strict -from ax.models.base import Model +from ax.models.base import Generator from ax.utils.common.testutils import TestCase class BaseModelTest(TestCase): def test_base_model(self) -> None: - model = Model() + model = Generator() raw_state = {"foo": "bar", "two": 3.0} self.assertEqual(model.serialize_state(raw_state), raw_state) self.assertEqual(model.deserialize_state(raw_state), raw_state) diff --git a/ax/models/tests/test_botorch_model.py b/ax/models/tests/test_botorch_model.py index be28b262b13..2b90d70d22b 100644 --- a/ax/models/tests/test_botorch_model.py +++ b/ax/models/tests/test_botorch_model.py @@ -16,7 +16,7 @@ from ax.core.search_space import SearchSpaceDigest from ax.exceptions.core import DataRequiredError from ax.models.torch.botorch import ( - BotorchModel, + BotorchGenerator, get_feature_importances_from_botorch_model, get_rounding_func, ) @@ -56,9 +56,9 @@ def dummy_func(X: torch.Tensor) -> torch.Tensor: return X -class BotorchModelTest(TestCase): +class BotorchGeneratorTest(TestCase): @mock_botorch_optimize - def test_fixed_rank_BotorchModel( + def test_fixed_rank_BotorchGenerator( self, dtype: torch.dtype = torch.float, cuda: bool = False ) -> None: Xs1, Ys1, Yvars1, bounds, _, feature_names, __ = get_torch_test_data( @@ -67,7 +67,7 @@ def test_fixed_rank_BotorchModel( Xs2, Ys2, Yvars2, _, _, _, _ = get_torch_test_data( dtype=dtype, cuda=cuda, constant_noise=True ) - model = BotorchModel(multitask_gp_ranks={"y": 2, "w": 1}) + model = BotorchGenerator(multitask_gp_ranks={"y": 2, "w": 1}) datasets = [ SupervisedDataset( X=Xs1[0], @@ -114,7 +114,7 @@ def test_fixed_rank_BotorchModel( self.assertEqual(model_list[1]._rank, 1) @mock_botorch_optimize - def test_fixed_prior_BotorchModel( + def test_fixed_prior_BotorchGenerator( self, dtype: torch.dtype = torch.float, cuda: bool = False ) -> None: Xs1, Ys1, Yvars1, bounds, _, feature_names, metric_names = get_torch_test_data( @@ -134,7 +134,7 @@ def test_fixed_prior_BotorchModel( "eta": 0.6, } } - model = BotorchModel(**kwargs) + model = BotorchGenerator(**kwargs) datasets = [ SupervisedDataset( X=Xs1[0], @@ -194,7 +194,7 @@ def test_fixed_prior_BotorchModel( ) @mock_botorch_optimize - def test_BotorchModel( + def test_BotorchGenerator( self, dtype: torch.dtype = torch.float, cuda: bool = False ) -> None: ( @@ -211,7 +211,7 @@ def test_BotorchModel( ) for use_input_warping in (True, False): for use_loocv_pseudo_likelihood in (True, False): - model = BotorchModel( + model = BotorchGenerator( use_input_warping=use_input_warping, use_loocv_pseudo_likelihood=use_loocv_pseudo_likelihood, ) @@ -576,7 +576,7 @@ def test_BotorchModel( self.assertEqual(importances.shape, torch.Size([2, 1, 3])) # test unfit model CV and feature_importances - unfit_model = BotorchModel() + unfit_model = BotorchGenerator() with self.assertRaisesRegex( RuntimeError, r"Cannot cross-validate model that has not been fitted" ): @@ -633,18 +633,18 @@ def test_BotorchModel( ) ) - def test_BotorchModel_cuda(self) -> None: + def test_BotorchGenerator_cuda(self) -> None: if torch.cuda.is_available(): - self.test_BotorchModel(cuda=True) + self.test_BotorchGenerator(cuda=True) - def test_BotorchModel_double(self) -> None: - self.test_BotorchModel(dtype=torch.double) + def test_BotorchGenerator_double(self) -> None: + self.test_BotorchGenerator(dtype=torch.double) - def test_BotorchModel_double_cuda(self) -> None: + def test_BotorchGenerator_double_cuda(self) -> None: if torch.cuda.is_available(): - self.test_BotorchModel(dtype=torch.double, cuda=True) + self.test_BotorchGenerator(dtype=torch.double, cuda=True) - def test_BotorchModelOneOutcome(self) -> None: + def test_BotorchGeneratorOneOutcome(self) -> None: ( Xs1, Ys1, @@ -657,7 +657,7 @@ def test_BotorchModelOneOutcome(self) -> None: for use_input_warping, use_loocv_pseudo_likelihood in product( (True, False), (True, False) ): - model = BotorchModel( + model = BotorchGenerator( use_input_warping=use_input_warping, use_loocv_pseudo_likelihood=use_loocv_pseudo_likelihood, ) @@ -697,7 +697,7 @@ def test_BotorchModelOneOutcome(self) -> None: else: self.assertFalse(hasattr(model.model, "input_transform")) - def test_BotorchModelConstraints(self) -> None: + def test_BotorchGeneratorConstraints(self) -> None: ( Xs1, Ys1, @@ -716,7 +716,7 @@ def test_BotorchModelConstraints(self) -> None: [-1.0, 1.0], dtype=torch.float, device=torch.device("cpu") ) n = 3 - model = BotorchModel() + model = BotorchGenerator() search_space_digest = SearchSpaceDigest( feature_names=feature_names, bounds=bounds, @@ -761,9 +761,9 @@ def test_botorchmodel_raises_when_no_data(self) -> None: bounds=bounds, task_features=tfs, ) - model = BotorchModel() + model = BotorchGenerator() with self.assertRaisesRegex( - DataRequiredError, "BotorchModel.fit requires non-empty data sets." + DataRequiredError, "BotorchGenerator.fit requires non-empty data sets." ): model.fit( datasets=[], diff --git a/ax/models/tests/test_botorch_moo_defaults.py b/ax/models/tests/test_botorch_moo_defaults.py index 11cda031ced..11fcb038354 100644 --- a/ax/models/tests/test_botorch_moo_defaults.py +++ b/ax/models/tests/test_botorch_moo_defaults.py @@ -15,8 +15,8 @@ import torch from ax.core.search_space import SearchSpaceDigest from ax.models.torch.botorch_defaults import NO_OBSERVED_POINTS_MESSAGE -from ax.models.torch.botorch_modular.model import BoTorchModel -from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel +from ax.models.torch.botorch_modular.model import BoTorchGenerator +from ax.models.torch.botorch_moo import MultiObjectiveBotorchGenerator from ax.models.torch.botorch_moo_defaults import ( get_outcome_constraint_transforms, get_qLogEHVI, @@ -26,7 +26,7 @@ pareto_frontier_evaluator, ) from ax.models.torch.utils import _get_X_pending_and_observed -from ax.models.torch_base import TorchModel +from ax.models.torch_base import TorchGenerator from ax.utils.common.random import with_rng_seed from ax.utils.common.testutils import TestCase from ax.utils.testing.mock import mock_botorch_optimize_context_manager @@ -48,7 +48,7 @@ def _fit_model( - model: TorchModel, X: torch.Tensor, Y: torch.Tensor, Yvar: torch.Tensor + model: TorchGenerator, X: torch.Tensor, Y: torch.Tensor, Yvar: torch.Tensor ) -> None: bounds = [(0.0, 4.0), (0.0, 4.0)] datasets = [ @@ -88,14 +88,14 @@ def setUp(self) -> None: def test_pareto_frontier_raise_error_when_missing_data(self) -> None: with self.assertRaises(ValueError): pareto_frontier_evaluator( - model=MultiObjectiveBotorchModel(), + model=MultiObjectiveBotorchGenerator(), objective_thresholds=self.objective_thresholds, objective_weights=self.objective_weights, Yvar=self.Yvar, ) def test_pareto_frontier_evaluator_raw(self) -> None: - model = BoTorchModel() + model = BoTorchGenerator() _fit_model(model=model, X=self.X, Y=self.Y, Yvar=self.Yvar) Yvar = torch.diag_embed(self.Yvar) Y, cov, indx = pareto_frontier_evaluator( @@ -153,7 +153,7 @@ def test_pareto_frontier_evaluator_raw(self) -> None: def test_pareto_frontier_evaluator_predict(self) -> None: def dummy_predict( - model: MultiObjectiveBotorchModel, + model: MultiObjectiveBotorchGenerator, X: Tensor, use_posterior_predictive: bool = False, ) -> tuple[Tensor, Tensor]: @@ -163,10 +163,10 @@ def dummy_predict( return mean, cov # pyre-fixme: Incompatible parameter type [6]: In call - # `MultiObjectiveBotorchModel.__init__`, for argument `model_predictor`, + # `MultiObjectiveBotorchGenerator.__init__`, for argument `model_predictor`, # expected `typing.Callable[[Model, Tensor, bool], Tuple[Tensor, # Tensor]]` but got named arguments - model = MultiObjectiveBotorchModel(model_predictor=dummy_predict) + model = MultiObjectiveBotorchGenerator(model_predictor=dummy_predict) _fit_model(model=model, X=self.X, Y=self.Y, Yvar=self.Yvar) Y, _, indx = pareto_frontier_evaluator( @@ -182,7 +182,7 @@ def dummy_predict( self.assertTrue(torch.equal(torch.arange(2, 4), indx)) def test_pareto_frontier_evaluator_with_outcome_constraints(self) -> None: - model = MultiObjectiveBotorchModel() + model = MultiObjectiveBotorchGenerator() Y, _, indx = pareto_frontier_evaluator( model=model, objective_weights=self.objective_weights, diff --git a/ax/models/tests/test_botorch_moo_model.py b/ax/models/tests/test_botorch_moo_model.py index b6d875d4ca3..296afa776c1 100644 --- a/ax/models/tests/test_botorch_moo_model.py +++ b/ax/models/tests/test_botorch_moo_model.py @@ -18,7 +18,7 @@ from ax.core.search_space import SearchSpaceDigest from ax.exceptions.core import AxError from ax.models.torch.botorch_defaults import get_qLogNEI -from ax.models.torch.botorch_moo import MultiObjectiveBotorchModel +from ax.models.torch.botorch_moo import MultiObjectiveBotorchGenerator from ax.models.torch.botorch_moo_defaults import ( get_EHVI, get_NEHVI, @@ -180,7 +180,7 @@ def test_BotorchMOOModel_with_random_scalarization( bounds=bounds, task_features=tfs, ) - model = MultiObjectiveBotorchModel(acqf_constructor=get_qLogNEI) + model = MultiObjectiveBotorchGenerator(acqf_constructor=get_qLogNEI) with mock.patch(FIT_MODEL_MO_PATH) as _mock_fit_model: model.fit( datasets=training_data, @@ -234,7 +234,7 @@ def test_BotorchMOOModel_with_random_scalarization( # test input warping self.assertFalse(model.use_input_warping) - model = MultiObjectiveBotorchModel( + model = MultiObjectiveBotorchGenerator( acqf_constructor=get_qLogNEI, use_input_warping=True, ) @@ -252,7 +252,7 @@ def test_BotorchMOOModel_with_random_scalarization( # test loocv pseudo likelihood self.assertFalse(model.use_loocv_pseudo_likelihood) - model = MultiObjectiveBotorchModel( + model = MultiObjectiveBotorchGenerator( acqf_constructor=get_qLogNEI, use_loocv_pseudo_likelihood=True, ) @@ -310,7 +310,7 @@ def test_BotorchMOOModel_with_chebyshev_scalarization( bounds=bounds, task_features=tfs, ) - model = MultiObjectiveBotorchModel(acqf_constructor=get_qLogNEI) + model = MultiObjectiveBotorchGenerator(acqf_constructor=get_qLogNEI) with mock.patch(FIT_MODEL_MO_PATH) as _mock_fit_model: model.fit( datasets=training_data, @@ -421,7 +421,7 @@ def test_BotorchMOOModel_with_qehvi( objective_weights = torch.tensor([1.0, 1.0, 0.0], **tkwargs) obj_t = torch.tensor([1.0, 1.0, float("nan")], **tkwargs) # pyre-fixme[6]: For 1st param expected `(Model, Tensor, Optional[Tuple[Tenso... - model = MultiObjectiveBotorchModel(acqf_constructor=acqf_constructor) + model = MultiObjectiveBotorchGenerator(acqf_constructor=acqf_constructor) X_dummy = torch.tensor([[[1.0, 2.0, 3.0]]], **tkwargs) acqfv_dummy = torch.tensor([[[1.0, 2.0, 3.0]]], **tkwargs) @@ -723,7 +723,7 @@ def test_BotorchMOOModel_with_random_scalarization_and_outcome_constraints( n = 2 objective_weights = torch.tensor([1.0, 1.0], **tkwargs) obj_t = torch.tensor([1.0, 1.0], **tkwargs) - model = MultiObjectiveBotorchModel(acqf_constructor=get_qLogNEI) + model = MultiObjectiveBotorchGenerator(acqf_constructor=get_qLogNEI) search_space_digest = SearchSpaceDigest( feature_names=feature_names, @@ -801,7 +801,7 @@ def test_BotorchMOOModel_with_chebyshev_scalarization_and_outcome_constraints( n = 2 objective_weights = torch.tensor([1.0, 1.0], **tkwargs) obj_t = torch.tensor([1.0, 1.0], **tkwargs) - model = MultiObjectiveBotorchModel(acqf_constructor=get_qLogNEI) + model = MultiObjectiveBotorchGenerator(acqf_constructor=get_qLogNEI) search_space_digest = SearchSpaceDigest( feature_names=feature_names, @@ -905,7 +905,7 @@ def test_BotorchMOOModel_with_qehvi_and_outcome_constraints( objective_weights = torch.tensor([1.0, 1.0, 0.0], **tkwargs) obj_t = torch.tensor([1.0, 1.0, 1.0], **tkwargs) # pyre-fixme[6]: For 1st param expected `(Model, Tensor, Optional[Tuple[Tenso... - model = MultiObjectiveBotorchModel(acqf_constructor=acqf_constructor) + model = MultiObjectiveBotorchGenerator(acqf_constructor=acqf_constructor) search_space_digest = SearchSpaceDigest( feature_names=feature_names, diff --git a/ax/models/tests/test_discrete.py b/ax/models/tests/test_discrete.py index bc4d7869db4..20ee7ed0f56 100644 --- a/ax/models/tests/test_discrete.py +++ b/ax/models/tests/test_discrete.py @@ -7,22 +7,22 @@ # pyre-strict import numpy as np -from ax.models.discrete_base import DiscreteModel +from ax.models.discrete_base import DiscreteGenerator from ax.utils.common.testutils import TestCase -class DiscreteModelTest(TestCase): +class DiscreteGeneratorTest(TestCase): def test_discrete_model_get_state(self) -> None: - discrete_model = DiscreteModel() + discrete_model = DiscreteGenerator() self.assertEqual(discrete_model._get_state(), {}) def test_discrete_model_feature_importances(self) -> None: - discrete_model = DiscreteModel() + discrete_model = DiscreteGenerator() with self.assertRaises(NotImplementedError): discrete_model.feature_importances() - def test_DiscreteModelFit(self) -> None: - discrete_model = DiscreteModel() + def test_DiscreteGeneratorFit(self) -> None: + discrete_model = DiscreteGenerator() discrete_model.fit( Xs=[[[0]]], Ys=[[0]], @@ -32,19 +32,19 @@ def test_DiscreteModelFit(self) -> None: ) def test_discreteModelPredict(self) -> None: - discrete_model = DiscreteModel() + discrete_model = DiscreteGenerator() with self.assertRaises(NotImplementedError): discrete_model.predict([[0]]) def test_discreteModelGen(self) -> None: - discrete_model = DiscreteModel() + discrete_model = DiscreteGenerator() with self.assertRaises(NotImplementedError): discrete_model.gen( n=1, parameter_values=[[0, 1]], objective_weights=np.array([1]) ) def test_discreteModelCrossValidate(self) -> None: - discrete_model = DiscreteModel() + discrete_model = DiscreteGenerator() with self.assertRaises(NotImplementedError): discrete_model.cross_validate( Xs_train=[[[0]]], Ys_train=[[1]], Yvars_train=[[1]], X_test=[[1]] diff --git a/ax/models/tests/test_random.py b/ax/models/tests/test_random.py index af76bcda4db..3ece5bcb4a3 100644 --- a/ax/models/tests/test_random.py +++ b/ax/models/tests/test_random.py @@ -8,34 +8,34 @@ import numpy as np import torch -from ax.models.random.base import RandomModel +from ax.models.random.base import RandomGenerator from ax.utils.common.testutils import TestCase from pyre_extensions import none_throws -class RandomModelTest(TestCase): +class RandomGeneratorTest(TestCase): def setUp(self) -> None: super().setUp() - self.random_model = RandomModel() + self.random_model = RandomGenerator() def test_seed(self) -> None: # With manual seed. - random_model = RandomModel(seed=5) + random_model = RandomGenerator(seed=5) self.assertEqual(random_model.seed, 5) # With no seed. self.assertIsInstance(self.random_model.seed, int) def test_state(self) -> None: - for model in (self.random_model, RandomModel(seed=5)): + for model in (self.random_model, RandomGenerator(seed=5)): state = model._get_state() self.assertEqual(state["seed"], model.seed) self.assertEqual(state["generated_points"], model.generated_points) - def test_RandomModelGenSamples(self) -> None: + def test_RandomGeneratorGenSamples(self) -> None: with self.assertRaises(NotImplementedError): self.random_model._gen_samples(n=1, tunable_d=1) - def test_RandomModelGenUnconstrained(self) -> None: + def test_RandomGeneratorGenUnconstrained(self) -> None: with self.assertRaises(NotImplementedError): self.random_model._gen_unconstrained( n=1, d=2, tunable_feature_indices=np.array([]) @@ -82,8 +82,8 @@ def test_ConvertBounds(self) -> None: def test_GetLastPoint(self) -> None: generated_points = np.array([[1, 2, 3], [4, 5, 6]]) - RandomModelWithPoints = RandomModel(generated_points=generated_points) - result = RandomModelWithPoints._get_last_point() + RandomGeneratorWithPoints = RandomGenerator(generated_points=generated_points) + result = RandomGeneratorWithPoints._get_last_point() expected = torch.tensor([[4], [5], [6]]) comparison = result == expected # pyre-fixme[16]: `bool` has no attribute `any`. diff --git a/ax/models/tests/test_torch.py b/ax/models/tests/test_torch.py index 8c921d56a7c..0aed5671e28 100644 --- a/ax/models/tests/test_torch.py +++ b/ax/models/tests/test_torch.py @@ -8,12 +8,12 @@ import torch from ax.core.search_space import SearchSpaceDigest -from ax.models.torch_base import TorchModel, TorchOptConfig +from ax.models.torch_base import TorchGenerator, TorchOptConfig from ax.utils.common.testutils import TestCase from botorch.utils.datasets import SupervisedDataset -class TorchModelTest(TestCase): +class TorchGeneratorTest(TestCase): def setUp(self) -> None: super().setUp() self.dataset = SupervisedDataset( @@ -30,7 +30,7 @@ def setUp(self) -> None: self.torch_opt_config = TorchOptConfig(objective_weights=torch.ones(1)) def test_TorchModelFit(self) -> None: - torch_model = TorchModel() + torch_model = TorchGenerator() torch_model.fit( datasets=[self.dataset], search_space_digest=SearchSpaceDigest( @@ -40,12 +40,12 @@ def test_TorchModelFit(self) -> None: ) def test_TorchModelPredict(self) -> None: - torch_model = TorchModel() + torch_model = TorchGenerator() with self.assertRaises(NotImplementedError): torch_model.predict(torch.zeros(1)) def test_TorchModelGen(self) -> None: - torch_model = TorchModel() + torch_model = TorchGenerator() with self.assertRaises(NotImplementedError): torch_model.gen( n=1, @@ -54,7 +54,7 @@ def test_TorchModelGen(self) -> None: ) def test_NumpyTorchBestPoint(self) -> None: - torch_model = TorchModel() + torch_model = TorchGenerator() x = torch_model.best_point( search_space_digest=self.search_space_digest, torch_opt_config=self.torch_opt_config, @@ -62,7 +62,7 @@ def test_NumpyTorchBestPoint(self) -> None: self.assertIsNone(x) def test_TorchModelCrossValidate(self) -> None: - torch_model = TorchModel() + torch_model = TorchGenerator() with self.assertRaises(NotImplementedError): torch_model.cross_validate( datasets=[self.dataset], diff --git a/ax/models/torch/botorch.py b/ax/models/torch/botorch.py index 133dbe22c9d..ca1e031594f 100644 --- a/ax/models/torch/botorch.py +++ b/ax/models/torch/botorch.py @@ -34,7 +34,7 @@ predict_from_model, subset_model, ) -from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig +from ax.models.torch_base import TorchGenerator, TorchGenResults, TorchOptConfig from ax.models.types import TConfig from ax.utils.common.constants import Keys from ax.utils.common.docutils import copy_doc @@ -84,7 +84,7 @@ ] TBestPointRecommender = Callable[ [ - TorchModel, + TorchGenerator, list[tuple[float, float]], Tensor, Optional[tuple[Tensor, Tensor]], @@ -97,7 +97,7 @@ ] -class BotorchModel(TorchModel): +class BotorchGenerator(TorchGenerator): r""" Customizable botorch model. @@ -158,7 +158,7 @@ class BotorchModel(TorchModel): `fidelity_features` is a list of ints that specify the positions of fidelity parameters in 'Xs', `metric_names` provides the names of each `Y` in `Ys`, `state_dict` is a pytorch module state dict, and `model` is a BoTorch `Model`. - Optional kwargs are being passed through from the `BotorchModel` constructor. + Optional kwargs are being passed through from the `BotorchGenerator` constructor. This callable is assumed to return a fitted BoTorch model that has the same dtype and lives on the same device as the input tensors. @@ -222,7 +222,7 @@ class BotorchModel(TorchModel): target_fidelities, ) -> candidates - Here `model` is a TorchModel, `bounds` is a list of tuples containing bounds + Here `model` is a TorchGenerator, `bounds` is a list of tuples containing bounds on the parameters, `objective_weights` is a tensor of weights for the model outputs, `outcome_constraints` is a tuple of tensors describing the (linear) outcome constraints, `linear_constraints` is a tuple of tensors describing constraints @@ -257,12 +257,12 @@ def __init__( **kwargs: Any, ) -> None: warnings.warn( - "The legacy `BotorchModel` and its subclasses, including the current" + "The legacy `BotorchGenerator` and its subclasses, including the current" f"class `{self.__class__.__name__}`, slated for deprecation. " "These models will not be supported going forward and may be " "fully removed in a future release. Please consider using the " - "Modular BoTorch Model (MBM) setup (ax/models/torch/botorch_modular) " - "instead. If you run into a use case that is not supported by MBM, " + "Modular BoTorch Generator (MBG) setup (ax/models/torch/botorch_modular) " + "instead. If you run into a use case that is not supported by MBG, " "please raise this with an issue at https://github.com/facebook/Ax", DeprecationWarning, stacklevel=2, @@ -289,7 +289,7 @@ def __init__( self.fidelity_features: list[int] = [] self.metric_names: list[str] = [] - @copy_doc(TorchModel.fit) + @copy_doc(TorchGenerator.fit) def fit( self, datasets: list[SupervisedDataset], @@ -297,7 +297,9 @@ def fit( candidate_metadata: list[list[TCandidateMetadata]] | None = None, ) -> None: if len(datasets) == 0: - raise DataRequiredError("BotorchModel.fit requires non-empty data sets.") + raise DataRequiredError( + "BotorchGenerator.fit requires non-empty data sets." + ) self.Xs, self.Ys, self.Yvars = _datasets_to_legacy_inputs(datasets=datasets) self.metric_names = sum((ds.outcome_names for ds in datasets), []) # Store search space info for later use (e.g. during generation) @@ -324,11 +326,11 @@ def fit( **self._kwargs, ) - @copy_doc(TorchModel.predict) + @copy_doc(TorchGenerator.predict) def predict(self, X: Tensor) -> tuple[Tensor, Tensor]: return self.model_predictor(model=self.model, X=X) # pyre-ignore [28] - @copy_doc(TorchModel.gen) + @copy_doc(TorchGenerator.gen) def gen( self, n: int, @@ -341,7 +343,7 @@ def gen( if search_space_digest.fidelity_features: raise NotImplementedError( - "Base BotorchModel does not support fidelity_features." + "Base BotorchGenerator does not support fidelity_features." ) X_pending, X_observed = _get_X_pending_and_observed( Xs=self.Xs, @@ -439,7 +441,7 @@ def make_and_optimize_acqf(override_qmc: bool = False) -> tuple[Tensor, Tensor]: gen_metadata=gen_metadata, ) - @copy_doc(TorchModel.best_point) + @copy_doc(TorchGenerator.best_point) def best_point( self, search_space_digest: SearchSpaceDigest, @@ -465,7 +467,7 @@ def best_point( target_fidelities=target_fidelities, ) - @copy_doc(TorchModel.cross_validate) + @copy_doc(TorchGenerator.cross_validate) def cross_validate( # pyre-ignore [14]: `search_space_digest` arg not needed here self, datasets: list[SupervisedDataset], diff --git a/ax/models/torch/botorch_defaults.py b/ax/models/torch/botorch_defaults.py index d93d81161df..159f0e6266f 100644 --- a/ax/models/torch/botorch_defaults.py +++ b/ax/models/torch/botorch_defaults.py @@ -14,7 +14,7 @@ import torch from ax.models.model_utils import best_observed_point -from ax.models.torch_base import TorchModel +from ax.models.torch_base import TorchGenerator from ax.models.types import TConfig from botorch.acquisition import get_acquisition_function from botorch.acquisition.acquisition import AcquisitionFunction @@ -520,7 +520,7 @@ def scipy_optimizer( def recommend_best_observed_point( - model: TorchModel, + model: TorchGenerator, bounds: list[tuple[float, float]], objective_weights: Tensor, outcome_constraints: tuple[Tensor, Tensor] | None = None, @@ -530,12 +530,12 @@ def recommend_best_observed_point( target_fidelities: dict[int, float] | None = None, ) -> Tensor | None: """ - A wrapper around `ax.models.model_utils.best_observed_point` for TorchModel + A wrapper around `ax.models.model_utils.best_observed_point` for TorchGenerator that recommends a best point from previously observed points using either a "max_utility" or "feasible_threshold" strategy. Args: - model: A TorchModel. + model: A TorchGenerator. bounds: A list of (lower, upper) tuples for each column of X. objective_weights: The objective is to maximize a weighted sum of the columns of f(x). These are the weights. @@ -558,7 +558,7 @@ def recommend_best_observed_point( """ if target_fidelities: raise NotImplementedError( - "target_fidelities not implemented for base BotorchModel" + "target_fidelities not implemented for base BotorchGenerator" ) x_best = best_observed_point( diff --git a/ax/models/torch/botorch_modular/acquisition.py b/ax/models/torch/botorch_modular/acquisition.py index 789fcd0c4e3..dfc9d758003 100644 --- a/ax/models/torch/botorch_modular/acquisition.py +++ b/ax/models/torch/botorch_modular/acquisition.py @@ -71,7 +71,7 @@ class Acquisition(Base): versions only.** Ax wrapper for BoTorch `AcquisitionFunction`, subcomponent - of `BoTorchModel` and is not meant to be used outside of it. + of `BoTorchGenerator` and is not meant to be used outside of it. Args: surrogate: The Surrogate model, with which this acquisition @@ -275,7 +275,7 @@ def optimize( should be fixed to a particular value during generation. rounding_func: A function that post-processes an optimization result appropriately. This is typically passed down from - `ModelBridge` to ensure compatibility of the candidates with + `Adapter` to ensure compatibility of the candidates with with Ax transforms. For additional post processing, use `post_processing_func` option in `optimizer_options`. optimizer_options: Options for the optimizer function, e.g. ``sequential`` diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index 060cd9cb144..0166ac58888 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -30,7 +30,7 @@ ModelConfig, ) from ax.models.torch.utils import _to_inequality_constraints -from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig +from ax.models.torch_base import TorchGenerator, TorchGenResults, TorchOptConfig from ax.utils.common.base import Base from ax.utils.common.constants import Keys from ax.utils.common.docutils import copy_doc @@ -41,7 +41,7 @@ from torch import Tensor -class BoTorchModel(TorchModel, Base): +class BoTorchGenerator(TorchGenerator, Base): """**All classes in 'botorch_modular' directory are under construction, incomplete, and should be treated as alpha versions only.** @@ -171,7 +171,7 @@ def fit( candidate_metadata: Model-produced metadata for candidates, in the order corresponding to the Xs. state_dict: An optional model statedict for the underlying ``Surrogate``. - Primarily used in ``BoTorchModel.cross_validate``. + Primarily used in ``BoTorchGenerator.cross_validate``. refit: Whether to re-optimize model parameters. additional_model_inputs: Additional kwargs to pass to the model input constructor in ``Surrogate.fit``. @@ -228,7 +228,7 @@ def predict( X=X, use_posterior_predictive=use_posterior_predictive ) - @copy_doc(TorchModel.gen) + @copy_doc(TorchGenerator.gen) def gen( self, n: int, @@ -306,7 +306,7 @@ def _get_gen_metadata_from_acqf( gen_metadata["outcome_model_fixed_draw_weights"] = outcome_model.w return gen_metadata - @copy_doc(TorchModel.best_point) + @copy_doc(TorchGenerator.best_point) def best_point( self, search_space_digest: SearchSpaceDigest, @@ -320,7 +320,7 @@ def best_point( except ValueError: return None - @copy_doc(TorchModel.evaluate_acquisition_function) + @copy_doc(TorchGenerator.evaluate_acquisition_function) def evaluate_acquisition_function( self, X: Tensor, @@ -335,7 +335,7 @@ def evaluate_acquisition_function( ) return acqf.evaluate(X=X) - @copy_doc(TorchModel.cross_validate) + @copy_doc(TorchGenerator.cross_validate) def cross_validate( self, datasets: Sequence[SupervisedDataset], diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 7e7109ca1de..30f4ef30200 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -441,7 +441,7 @@ class SurrogateSpec: Fields in the SurrogateSpec dataclass correspond to arguments in ``Surrogate.__init__``, except for ``outcomes`` which is used to specify which outcomes the Surrogate is responsible for modeling. - When ``BotorchModel.fit`` is called, these fields will be used to construct the + When ``BotorchGenerator.fit`` is called, these fields will be used to construct the requisite Surrogate objects. If ``outcomes`` is left empty then no outcomes will be fit to the Surrogate. @@ -454,7 +454,7 @@ class SurrogateSpec: model_options: Dictionary of options / kwargs for the BoTorch ``Model`` constructed during ``Surrogate.fit``. Note that the corresponding attribute will later be updated to include any - additional kwargs passed into ``BoTorchModel.fit``. + additional kwargs passed into ``BoTorchGenerator.fit``. This argument is deprecated in favor of model_configs. mll_class: ``MarginalLogLikelihood`` class to use for model-fitting. This argument is deprecated in favor of model_configs. @@ -608,7 +608,7 @@ class Surrogate(Base): construction, incomplete, and should be treated as alpha versions only.** - Ax wrapper for BoTorch ``Model``, subcomponent of ``BoTorchModel`` + Ax wrapper for BoTorch ``Model``, subcomponent of ``BoTorchGenerator`` and is not meant to be used outside of it. Args: @@ -620,7 +620,7 @@ class Surrogate(Base): model_options: Dictionary of options / kwargs for the BoTorch ``Model`` constructed during ``Surrogate.fit``. Note that the corresponding attribute will later be updated to include any - additional kwargs passed into ``BoTorchModel.fit``. + additional kwargs passed into ``BoTorchGenerator.fit``. This argument is deprecated in favor of model_configs. mll_class: ``MarginalLogLikelihood`` class to use for model-fitting. This argument is deprecated in favor of model_configs. @@ -674,7 +674,7 @@ class string names and the values are dictionaries of input transform Set to false to fit individual models to each metric in a loop. refit_on_cv: Whether to refit the model on the cross-validation folds. metric_to_best_model_config: Dictionary mapping a metric name to the best - model config. This is only used by BotorchModel.cross_validate and for + model config. This is only used by BotorchGenerator.cross_validate and for logging what model was used. """ @@ -773,7 +773,7 @@ def model(self) -> Model: if self._model is None: raise ValueError( "BoTorch `Model` has not yet been constructed, please fit the " - "surrogate first (done via `BoTorchModel.fit`)." + "surrogate first (done via `BoTorchGenerator.fit`)." ) return self._model @@ -1015,7 +1015,7 @@ def model_selection( based on the SurrogateSpec's eval_criteria. The eval_criteria is computed using LOOCV on the provided dataset. The best model config is saved in self.metric_to_best_model_config for future use (e.g. for using cross- - validation at the Modelbridge level). + validation at the Adapter level). Args: dataset: Training data for the model diff --git a/ax/models/torch/botorch_modular/utils.py b/ax/models/torch/botorch_modular/utils.py index 5b8f04acf1b..2a13ca7e2e2 100644 --- a/ax/models/torch/botorch_modular/utils.py +++ b/ax/models/torch/botorch_modular/utils.py @@ -62,7 +62,7 @@ class ModelConfig: model_options: Dictionary of options / kwargs for the BoTorch ``Model`` constructed during ``Surrogate.fit``. Note that the corresponding attribute will later be updated to include any - additional kwargs passed into ``BoTorchModel.fit``. + additional kwargs passed into ``BoTorchGenerator.fit``. mll_class: ``MarginalLogLikelihood`` class to use for model-fitting. mll_options: Dictionary of options / kwargs for the MLL. outcome_transform_classes: List of BoTorch outcome transforms classes. Passed diff --git a/ax/models/torch/botorch_moo.py b/ax/models/torch/botorch_moo.py index 4524b33cb48..fa273389726 100644 --- a/ax/models/torch/botorch_moo.py +++ b/ax/models/torch/botorch_moo.py @@ -14,7 +14,7 @@ from ax.core.search_space import SearchSpaceDigest from ax.exceptions.core import AxError from ax.models.torch.botorch import ( - BotorchModel, + BotorchGenerator, get_rounding_func, TBestPointRecommender, TModelConstructor, @@ -41,7 +41,7 @@ randomize_objective_weights, subset_model, ) -from ax.models.torch_base import TorchGenResults, TorchModel, TorchOptConfig +from ax.models.torch_base import TorchGenerator, TorchGenResults, TorchOptConfig from ax.utils.common.constants import Keys from ax.utils.common.docutils import copy_doc from ax.utils.common.logger import get_logger @@ -67,7 +67,7 @@ ] -class MultiObjectiveBotorchModel(BotorchModel): +class MultiObjectiveBotorchGenerator(BotorchGenerator): r""" Customizable multi-objective model. @@ -113,7 +113,7 @@ class MultiObjectiveBotorchModel(BotorchModel): `fidelity_features` is a list of ints that specify the positions of fidelity parameters in 'Xs', `metric_names` provides the names of each `Y` in `Ys`, `state_dict` is a pytorch module state dict, and `model` is a BoTorch `Model`. - Optional kwargs are being passed through from the `BotorchModel` constructor. + Optional kwargs are being passed through from the `BotorchGenerator` constructor. This callable is assumed to return a fitted BoTorch model that has the same dtype and lives on the same device as the input tensors. @@ -236,7 +236,7 @@ def __init__( self.fidelity_features: list[int] = [] self.metric_names: list[str] = [] - @copy_doc(TorchModel.gen) + @copy_doc(TorchGenerator.gen) def gen( self, n: int, @@ -249,7 +249,7 @@ def gen( if search_space_digest.fidelity_features: # untested raise NotImplementedError( - "fidelity_features not implemented for base BotorchModel" + "fidelity_features not implemented for base BotorchGenerator" ) if ( torch_opt_config.objective_thresholds is not None diff --git a/ax/models/torch/botorch_moo_defaults.py b/ax/models/torch/botorch_moo_defaults.py index 90543e5bf2b..345af487af4 100644 --- a/ax/models/torch/botorch_moo_defaults.py +++ b/ax/models/torch/botorch_moo_defaults.py @@ -39,7 +39,7 @@ get_outcome_constraint_transforms, subset_model, ) -from ax.models.torch_base import TorchModel +from ax.models.torch_base import TorchGenerator from botorch.acquisition import get_acquisition_function from botorch.acquisition.acquisition import AcquisitionFunction from botorch.acquisition.multi_objective.logei import ( @@ -70,7 +70,7 @@ # along with their covariances and their index in the input observations. TFrontierEvaluator = Callable[ [ - TorchModel, + TorchGenerator, Tensor, Optional[Tensor], Optional[Tensor], @@ -550,7 +550,7 @@ def scipy_optimizer_list( def pareto_frontier_evaluator( - model: TorchModel | None, + model: TorchGenerator | None, objective_weights: Tensor, objective_thresholds: Tensor | None = None, X: Tensor | None = None, diff --git a/ax/models/torch/cbo_lcea.py b/ax/models/torch/cbo_lcea.py index c3d2f667847..3d045776162 100644 --- a/ax/models/torch/cbo_lcea.py +++ b/ax/models/torch/cbo_lcea.py @@ -11,10 +11,10 @@ from ax.core.search_space import SearchSpaceDigest from ax.core.types import TCandidateMetadata -from ax.models.torch.botorch import BotorchModel +from ax.models.torch.botorch import BotorchGenerator from ax.models.torch.botorch_defaults import get_qLogNEI from ax.models.torch.cbo_sac import generate_model_space_decomposition -from ax.models.torch_base import TorchModel, TorchOptConfig +from ax.models.torch_base import TorchGenerator, TorchOptConfig from ax.utils.common.docutils import copy_doc from ax.utils.common.logger import get_logger from botorch.fit import fit_gpytorch_mll @@ -66,7 +66,7 @@ def get_map_model( return model, mll -class LCEABO(BotorchModel): +class LCEABO(BotorchGenerator): r"""Does Bayesian optimization with Latent Context Embedding Additive (LCE-A) GP. The parameter space decomposition must be provided. @@ -114,7 +114,7 @@ def __init__( model_constructor=self.get_and_fit_model, acqf_constructor=get_qLogNEI ) - @copy_doc(TorchModel.fit) + @copy_doc(TorchGenerator.fit) def fit( self, datasets: list[SupervisedDataset], @@ -129,7 +129,7 @@ def fit( search_space_digest=search_space_digest, ) - @copy_doc(TorchModel.best_point) + @copy_doc(TorchGenerator.best_point) def best_point( self, search_space_digest: SearchSpaceDigest, diff --git a/ax/models/torch/cbo_lcem.py b/ax/models/torch/cbo_lcem.py index eb320c9ddc6..0157776328b 100644 --- a/ax/models/torch/cbo_lcem.py +++ b/ax/models/torch/cbo_lcem.py @@ -9,7 +9,7 @@ from typing import Any import torch -from ax.models.torch.botorch import BotorchModel +from ax.models.torch.botorch import BotorchGenerator from botorch.fit import fit_gpytorch_mll from botorch.models.contextual_multioutput import LCEMGP from botorch.models.model_list_gp_regression import ModelListGP @@ -20,7 +20,7 @@ MIN_OBSERVED_NOISE_LEVEL = 1e-7 -class LCEMBO(BotorchModel): +class LCEMBO(BotorchGenerator): r"""Does Bayesian optimization with LCE-M GP.""" def __init__( diff --git a/ax/models/torch/cbo_sac.py b/ax/models/torch/cbo_sac.py index e0276369a0d..43d7efb65c8 100644 --- a/ax/models/torch/cbo_sac.py +++ b/ax/models/torch/cbo_sac.py @@ -11,8 +11,8 @@ from ax.core.search_space import SearchSpaceDigest from ax.core.types import TCandidateMetadata -from ax.models.torch.botorch import BotorchModel -from ax.models.torch_base import TorchModel +from ax.models.torch.botorch import BotorchGenerator +from ax.models.torch_base import TorchGenerator from ax.utils.common.docutils import copy_doc from ax.utils.common.logger import get_logger from botorch.fit import fit_gpytorch_mll @@ -28,7 +28,7 @@ logger: Logger = get_logger(__name__) -class SACBO(BotorchModel): +class SACBO(BotorchGenerator): """Does Bayesian optimization with structural additive contextual GP (SACGP). The parameter space decomposition must be provided. @@ -48,7 +48,7 @@ def __init__(self, decomposition: dict[str, list[str]]) -> None: self.feature_names: list[str] = [] super().__init__(model_constructor=self.get_and_fit_model) - @copy_doc(TorchModel.fit) + @copy_doc(TorchGenerator.fit) def fit( self, datasets: list[SupervisedDataset], diff --git a/ax/models/torch/randomforest.py b/ax/models/torch/randomforest.py index 4ac405bffcb..233fbd7a815 100644 --- a/ax/models/torch/randomforest.py +++ b/ax/models/torch/randomforest.py @@ -14,7 +14,7 @@ from ax.core.search_space import SearchSpaceDigest from ax.core.types import TCandidateMetadata from ax.models.torch.utils import _datasets_to_legacy_inputs -from ax.models.torch_base import TorchModel +from ax.models.torch_base import TorchGenerator from ax.utils.common.docutils import copy_doc from botorch.utils.datasets import SupervisedDataset from sklearn.ensemble import RandomForestRegressor @@ -22,7 +22,7 @@ from torch import Tensor -class RandomForest(TorchModel): +class RandomForest(TorchGenerator): """A Random Forest model. Uses a parametric bootstrap to handle uncertainty in Y. @@ -42,7 +42,7 @@ def __init__(self, max_features: str | None = "sqrt", num_trees: int = 500) -> N self.num_trees = num_trees self.models: list[RandomForestRegressor] = [] - @copy_doc(TorchModel.fit) + @copy_doc(TorchGenerator.fit) def fit( self, datasets: list[SupervisedDataset], @@ -61,11 +61,11 @@ def fit( ) ) - @copy_doc(TorchModel.predict) + @copy_doc(TorchGenerator.predict) def predict(self, X: Tensor) -> tuple[Tensor, Tensor]: return _rf_predict(self.models, X) - @copy_doc(TorchModel.cross_validate) + @copy_doc(TorchGenerator.cross_validate) def cross_validate( # pyre-ignore [14]: not using metric_names or ssd self, datasets: list[SupervisedDataset], diff --git a/ax/models/torch/tests/test_model.py b/ax/models/torch/tests/test_model.py index ca002b2bed4..f65cb96ba84 100644 --- a/ax/models/torch/tests/test_model.py +++ b/ax/models/torch/tests/test_model.py @@ -20,7 +20,7 @@ from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel from ax.models.torch.botorch_modular.model import ( - BoTorchModel, + BoTorchGenerator, choose_botorch_acqf_class, ) from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec @@ -61,7 +61,7 @@ CURRENT_PATH: str = __name__ -MODEL_PATH: str = BoTorchModel.__module__ +MODEL_PATH: str = BoTorchGenerator.__module__ SURROGATE_PATH: str = Surrogate.__module__ ACQUISITION_PATH: str = Acquisition.__module__ @@ -70,7 +70,7 @@ } -class BoTorchModelTest(TestCase): +class BoTorchGeneratorTest(TestCase): def setUp(self) -> None: super().setUp() self.botorch_model_class = SingleTaskGP @@ -78,7 +78,7 @@ def setUp(self) -> None: self.acquisition_class = Acquisition self.botorch_acqf_class = qExpectedImprovement self.acquisition_options = ACQ_OPTIONS - self.model = BoTorchModel( + self.model = BoTorchGenerator( surrogate=self.surrogate, acquisition_class=self.acquisition_class, botorch_acqf_class=self.botorch_acqf_class, @@ -182,10 +182,10 @@ def setUp(self) -> None: def test_init(self) -> None: # Default model with no specifications. - model = BoTorchModel() + model = BoTorchGenerator() self.assertEqual(model.acquisition_class, Acquisition) # Model that specifies `botorch_acqf_class`. - model = BoTorchModel(botorch_acqf_class=qExpectedImprovement) + model = BoTorchGenerator(botorch_acqf_class=qExpectedImprovement) self.assertEqual(model.acquisition_class, Acquisition) self.assertEqual(model.botorch_acqf_class, qExpectedImprovement) @@ -194,7 +194,7 @@ def test_init(self) -> None: self.assertTrue(model.warm_start_refit) # Check setting non-default refitting settings - mdl2 = BoTorchModel( + mdl2 = BoTorchGenerator( surrogate=self.surrogate, acquisition_class=self.acquisition_class, acquisition_options=self.acquisition_options, @@ -374,10 +374,10 @@ def test_with_surrogate_specs_input(self) -> None: ), } with self.assertRaisesRegex(DeprecationWarning, "Support for multiple"): - BoTorchModel(surrogate_specs=surrogate_specs) + BoTorchGenerator(surrogate_specs=surrogate_specs) with self.assertWarnsRegex(DeprecationWarning, "surrogate_specs"): - model = BoTorchModel(surrogate_specs={"s": spec1}) + model = BoTorchGenerator(surrogate_specs={"s": spec1}) self.assertIs(model.surrogate_spec, spec1) @mock_botorch_optimize @@ -452,7 +452,7 @@ def test_cross_validate_multiple_configs(self) -> None: """Test cross-validation with multiple configs.""" for refit_on_cv in (True, False): with self.subTest(refit_on_cv=refit_on_cv): - self.model = BoTorchModel( + self.model = BoTorchGenerator( surrogate_spec=SurrogateSpec( model_configs=[ ModelConfig(), @@ -517,7 +517,7 @@ def _test_gen( torch.tensor([1.0]), ) surrogate = Surrogate(botorch_model_class=botorch_model_class) - model = BoTorchModel( + model = BoTorchGenerator( surrogate=surrogate, acquisition_class=Acquisition, acquisition_options=self.acquisition_options, @@ -536,7 +536,7 @@ def _test_gen( with ExitStack() as es: mock_init_acqf = es.enter_context( mock.patch.object( - BoTorchModel, + BoTorchGenerator, "_instantiate_acquisition", wraps=model._instantiate_acquisition, ) @@ -657,7 +657,7 @@ def test_gen_SingleTaskMultiFidelityGP(self) -> None: def test_feature_importances(self) -> None: for botorch_model_class in [SingleTaskGP, SaasFullyBayesianSingleTaskGP]: surrogate = Surrogate(botorch_model_class=botorch_model_class) - model = BoTorchModel( + model = BoTorchGenerator( surrogate=surrogate, acquisition_class=Acquisition, acquisition_options=self.acquisition_options, @@ -756,7 +756,7 @@ def test_best_point(self) -> None: @mock_botorch_optimize def test_evaluate_acquisition_function(self) -> None: - model = BoTorchModel( + model = BoTorchGenerator( surrogate=self.surrogate, acquisition_class=Acquisition, acquisition_options=self.acquisition_options, @@ -778,7 +778,7 @@ def test_evaluate_acquisition_function(self) -> None: @mock_botorch_optimize def test_surrogate_model_options_propagation(self) -> None: surrogate_spec = SurrogateSpec() - model = BoTorchModel(surrogate_spec=surrogate_spec) + model = BoTorchGenerator(surrogate_spec=surrogate_spec) with mock.patch(f"{MODEL_PATH}.Surrogate", wraps=Surrogate) as mock_init: model.fit( datasets=self.non_block_design_training_data, @@ -790,7 +790,7 @@ def test_surrogate_model_options_propagation(self) -> None: @mock_botorch_optimize def test_surrogate_options_propagation(self) -> None: surrogate_spec = SurrogateSpec(allow_batched_models=False) - model = BoTorchModel(surrogate_spec=surrogate_spec) + model = BoTorchGenerator(surrogate_spec=surrogate_spec) with mock.patch(f"{MODEL_PATH}.Surrogate", wraps=Surrogate) as mock_init: model.fit( datasets=self.non_block_design_training_data, @@ -801,7 +801,7 @@ def test_surrogate_options_propagation(self) -> None: @mock_botorch_optimize def test_model_list_choice(self) -> None: - model = BoTorchModel() + model = BoTorchGenerator() model.fit( datasets=self.non_block_design_training_data, search_space_digest=self.mf_search_space_digest, @@ -828,7 +828,7 @@ def test_MOO(self) -> None: input_constructor=mock_input_constructor, ) - model = BoTorchModel() + model = BoTorchGenerator() model.fit( datasets=self.moo_training_data, search_space_digest=self.search_space_digest, diff --git a/ax/models/torch/tests/test_utils.py b/ax/models/torch/tests/test_utils.py index cae5903c8e9..035b5c52e34 100644 --- a/ax/models/torch/tests/test_utils.py +++ b/ax/models/torch/tests/test_utils.py @@ -44,7 +44,7 @@ from pyre_extensions import assert_is_instance, none_throws -class BoTorchModelUtilsTest(TestCase): +class BoTorchGeneratorUtilsTest(TestCase): def setUp(self) -> None: super().setUp() self.dtype = torch.float diff --git a/ax/models/torch_base.py b/ax/models/torch_base.py index 32e94e393b1..7cc3ff23d00 100644 --- a/ax/models/torch_base.py +++ b/ax/models/torch_base.py @@ -17,7 +17,7 @@ from ax.core.metric import Metric from ax.core.search_space import SearchSpaceDigest from ax.core.types import TCandidateMetadata -from ax.models.base import Model as BaseModel +from ax.models.base import Generator as BaseGenerator from ax.models.types import TConfig from botorch.acquisition.risk_measures import RiskMeasureMCObjective from botorch.utils.datasets import SupervisedDataset @@ -110,7 +110,7 @@ class TorchGenResults: candidate_metadata: Sequence[TCandidateMetadata] | None = None -class TorchModel(BaseModel): +class TorchGenerator(BaseGenerator): """This class specifies the interface for a torch-based model. These methods should be implemented to have access to all of the features diff --git a/ax/plot/contour.py b/ax/plot/contour.py index 9970b635a3c..c5987170da0 100644 --- a/ax/plot/contour.py +++ b/ax/plot/contour.py @@ -14,7 +14,7 @@ import numpy.typing as npt import plotly.graph_objs as go from ax.core.observation import ObservationFeatures -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.plot.base import AxPlotConfig, AxPlotTypes, PlotData from ax.plot.color import BLUE_SCALE, GREEN_PINK_SCALE, GREEN_SCALE from ax.plot.helper import ( @@ -54,7 +54,7 @@ def short_name(param_name: str) -> str: def _get_contour_predictions( - model: ModelBridge, + model: Adapter, x_param_name: str, y_param_name: str, metric: str, @@ -109,7 +109,7 @@ def _get_contour_predictions( def plot_contour_plotly( - model: ModelBridge, + model: Adapter, param_x: str, param_y: str, metric_name: str, @@ -124,7 +124,7 @@ def plot_contour_plotly( """Plot predictions for a 2-d slice of the parameter space. Args: - model: ModelBridge that contains model for predictions + model: Adapter that contains model for predictions param_x: Name of parameter that will be sliced on x-axis param_y: Name of parameter that will be sliced on y-axis metric_name: Name of metric to plot @@ -282,7 +282,7 @@ def plot_contour_plotly( def plot_contour( - model: ModelBridge, + model: Adapter, param_x: str, param_y: str, metric_name: str, @@ -297,7 +297,7 @@ def plot_contour( """Plot predictions for a 2-d slice of the parameter space. Args: - model: ModelBridge that contains model for predictions + model: Adapter that contains model for predictions param_x: Name of parameter that will be sliced on x-axis param_y: Name of parameter that will be sliced on y-axis metric_name: Name of metric to plot @@ -338,7 +338,7 @@ def plot_contour( def interact_contour_plotly( - model: ModelBridge, + model: Adapter, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, @@ -353,7 +353,7 @@ def interact_contour_plotly( space. Args: - model: ModelBridge that contains model for predictions + model: Adapter that contains model for predictions metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. @@ -912,7 +912,7 @@ def interact_contour_plotly( def interact_contour( - model: ModelBridge, + model: Adapter, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, @@ -927,7 +927,7 @@ def interact_contour( space. Args: - model: ModelBridge that contains model for predictions + model: Adapter that contains model for predictions metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. diff --git a/ax/plot/feature_importances.py b/ax/plot/feature_importances.py index fe2c0503f2e..1d06dd1f746 100644 --- a/ax/plot/feature_importances.py +++ b/ax/plot/feature_importances.py @@ -15,7 +15,7 @@ import plotly.graph_objs as go from ax.core.parameter import ChoiceParameter from ax.exceptions.core import NoDataError -from ax.modelbridge import ModelBridge +from ax.modelbridge import Adapter from ax.plot.base import AxPlotConfig, AxPlotTypes from ax.plot.helper import compose_annotation from ax.utils.common.logger import get_logger @@ -61,7 +61,7 @@ def plot_feature_importance(df: pd.DataFrame, title: str) -> AxPlotConfig: ) -def plot_feature_importance_by_metric_plotly(model: ModelBridge) -> go.Figure: +def plot_feature_importance_by_metric_plotly(model: Adapter) -> go.Figure: """One plot per feature, showing importances by metric.""" importances = [] for metric_name in sorted(model.metric_names): @@ -90,7 +90,7 @@ def plot_feature_importance_by_metric_plotly(model: ModelBridge) -> go.Figure: return plot_fi -def plot_feature_importance_by_metric(model: ModelBridge) -> AxPlotConfig: +def plot_feature_importance_by_metric(model: Adapter) -> AxPlotConfig: """Wrapper method to convert plot_feature_importance_by_metric_plotly to AxPlotConfig""" return AxPlotConfig( @@ -102,7 +102,7 @@ def plot_feature_importance_by_metric(model: ModelBridge) -> AxPlotConfig: def plot_feature_importance_by_feature_plotly( - model: ModelBridge | None = None, + model: Adapter | None = None, sensitivity_values: dict[str, dict[str, float | npt.NDArray]] | None = None, relative: bool = False, caption: str = "", @@ -277,7 +277,7 @@ def plot_feature_importance_by_feature_plotly( def plot_feature_importance_by_feature( - model: ModelBridge | None = None, + model: Adapter | None = None, sensitivity_values: dict[str, dict[str, float | npt.NDArray]] | None = None, relative: bool = False, caption: str = "", @@ -301,7 +301,7 @@ def plot_feature_importance_by_feature( ) -def plot_relative_feature_importance_plotly(model: ModelBridge) -> go.Figure: +def plot_relative_feature_importance_plotly(model: Adapter) -> go.Figure: """Create a stacked bar chart of feature importances per metric""" importances = [] for metric_name in sorted(model.metric_names): @@ -331,7 +331,7 @@ def plot_relative_feature_importance_plotly(model: ModelBridge) -> go.Figure: return go.Figure(data=data, layout=layout) -def plot_relative_feature_importance(model: ModelBridge) -> AxPlotConfig: +def plot_relative_feature_importance(model: Adapter) -> AxPlotConfig: """Wrapper method to convert plot_relative_feature_importance_plotly to AxPlotConfig""" return AxPlotConfig( diff --git a/ax/plot/helper.py b/ax/plot/helper.py index 6a50ca8e8bc..113697625f4 100644 --- a/ax/plot/helper.py +++ b/ax/plot/helper.py @@ -19,7 +19,7 @@ from ax.core.observation import Observation, ObservationFeatures from ax.core.parameter import ChoiceParameter, FixedParameter, Parameter, RangeParameter from ax.core.types import TParameterization -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.prediction_utils import ( _compute_scalarized_outcome, predict_at_point, @@ -149,7 +149,7 @@ def _filter_dict( def _get_in_sample_arms( - model: ModelBridge, + model: Adapter, metric_names: set[str], fixed_features: ObservationFeatures | None = None, data_selector: Callable[[Observation], bool] | None = None, @@ -282,7 +282,7 @@ def _get_in_sample_arms( def _get_out_of_sample_arms( - model: ModelBridge, + model: Adapter, generator_runs_dict: dict[str, GeneratorRun], metric_names: set[str], fixed_features: ObservationFeatures | None = None, @@ -336,7 +336,7 @@ def _get_out_of_sample_arms( def get_plot_data( - model: ModelBridge, + model: Adapter, generator_runs_dict: dict[str, GeneratorRun], metric_names: set[str] | None = None, fixed_features: ObservationFeatures | None = None, @@ -401,7 +401,7 @@ def get_plot_data( return plot_data, raw_data, cond_name_to_parameters -def get_range_parameter(model: ModelBridge, param_name: str) -> RangeParameter: +def get_range_parameter(model: Adapter, param_name: str) -> RangeParameter: """ Get the range parameter with the given name from the model. @@ -444,7 +444,7 @@ def get_range_parameters_from_list( def get_range_parameters( - model: ModelBridge, min_num_values: int = 0 + model: Adapter, min_num_values: int = 0 ) -> list[RangeParameter]: """ Get a list of range parameters from a model. @@ -482,7 +482,7 @@ def get_grid_for_parameter(parameter: RangeParameter, density: int) -> npt.NDArr def get_fixed_values( - model: ModelBridge, + model: Adapter, slice_values: dict[str, Any] | None = None, trial_index: int | None = None, ) -> TParameterization: @@ -494,7 +494,7 @@ def get_fixed_values( Any value in slice_values will override the above. Args: - model: ModelBridge being used for plotting + model: Adapter being used for plotting slice_values: Map from parameter name to value at which is should be fixed. @@ -774,7 +774,7 @@ def rgb(arr: list[int]) -> str: def infer_is_relative( - model: ModelBridge, metrics: list[str], non_constraint_rel: bool + model: Adapter, metrics: list[str], non_constraint_rel: bool ) -> dict[str, bool]: """Determine whether or not to relativize a metric. diff --git a/ax/plot/marginal_effects.py b/ax/plot/marginal_effects.py index 84d439bc789..a52c1002efe 100644 --- a/ax/plot/marginal_effects.py +++ b/ax/plot/marginal_effects.py @@ -10,14 +10,14 @@ import pandas as pd import plotly.graph_objs as go -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.plot.base import AxPlotConfig, AxPlotTypes, DECIMALS from ax.plot.helper import get_plot_data from ax.utils.stats.statstools import marginal_effects from plotly import subplots -def plot_marginal_effects(model: ModelBridge, metric: str) -> AxPlotConfig: +def plot_marginal_effects(model: Adapter, metric: str) -> AxPlotConfig: """ Calculates and plots the marginal effects -- the effect of changing one factor away from the randomized distribution of the experiment and fixing it diff --git a/ax/plot/pareto_utils.py b/ax/plot/pareto_utils.py index 0115d96361a..34cf2669a2d 100644 --- a/ax/plot/pareto_utils.py +++ b/ax/plot/pareto_utils.py @@ -34,11 +34,11 @@ get_pareto_frontier_and_configs, observed_pareto_frontier, ) -from ax.modelbridge.registry import Models -from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.registry import Generators +from ax.modelbridge.torch import TorchAdapter from ax.modelbridge.transforms.derelativize import derelativize_bound from ax.modelbridge.transforms.search_space_to_float import SearchSpaceToFloat -from ax.models.torch_base import TorchModel +from ax.models.torch_base import TorchGenerator from ax.utils.common.logger import get_logger from ax.utils.stats.statstools import relativize from botorch.acquisition.monte_carlo import qSimpleRegret @@ -313,7 +313,7 @@ def to_nonrobust_search_space(search_space: SearchSpace) -> SearchSpace: return search_space -def get_tensor_converter_model(experiment: Experiment, data: Data) -> TorchModelBridge: +def get_tensor_converter_model(experiment: Experiment, data: Data) -> TorchAdapter: """ Constructs a minimal model for converting things to tensors. @@ -331,11 +331,11 @@ def get_tensor_converter_model(experiment: Experiment, data: Data) -> TorchModel """ # Transforms is the minimal set that will work for converting any search # space to tensors. - return TorchModelBridge( + return TorchAdapter( experiment=experiment, search_space=to_nonrobust_search_space(experiment.search_space), data=data, - model=TorchModel(), + model=TorchGenerator(), transforms=[SearchSpaceToFloat], fit_out_of_design=True, ) @@ -422,7 +422,7 @@ def compute_posterior_pareto_frontier( secondary_objective=secondary_objective, outcome_constraints=outcome_constraints, ) - model = Models.BOTORCH_MODULAR( + model = Generators.BOTORCH_MODULAR( experiment=experiment, data=data, optimization_config=oc, diff --git a/ax/plot/scatter.py b/ax/plot/scatter.py index f9b9db4c44c..9d6b4c38388 100644 --- a/ax/plot/scatter.py +++ b/ax/plot/scatter.py @@ -19,8 +19,8 @@ from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.observation import Observation, ObservationFeatures -from ax.modelbridge.base import ModelBridge -from ax.modelbridge.registry import Models +from ax.modelbridge.base import Adapter +from ax.modelbridge.registry import Generators from ax.plot.base import ( AxPlotConfig, AxPlotTypes, @@ -299,7 +299,7 @@ def _error_scatter_trace( def _multiple_metric_traces( - model: ModelBridge, + model: Adapter, metric_x: str, metric_y: str, generator_runs_dict: TNullableGeneratorRunsDict, @@ -384,7 +384,7 @@ def _multiple_metric_traces( def plot_multiple_metrics( - model: ModelBridge, + model: Adapter, metric_x: str, metric_y: str, generator_runs_dict: TNullableGeneratorRunsDict = None, @@ -541,7 +541,7 @@ def plot_multiple_metrics( def plot_objective_vs_constraints( - model: ModelBridge, + model: Adapter, objective: str, subset_metrics: list[str] | None = None, generator_runs_dict: TNullableGeneratorRunsDict = None, @@ -807,7 +807,7 @@ def _check_label_lengths(labels: list[str]) -> None: def lattice_multiple_metrics( - model: ModelBridge, + model: Adapter, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, show_arm_details_on_hover: bool = False, @@ -1077,7 +1077,7 @@ def lattice_multiple_metrics( # Single metric fitted values def _single_metric_traces( - model: ModelBridge, + model: Adapter, metric: str, generator_runs_dict: TNullableGeneratorRunsDict, rel: bool, @@ -1164,7 +1164,7 @@ def _single_metric_traces( def plot_fitted( - model: ModelBridge, + model: Adapter, metric: str, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, @@ -1300,7 +1300,7 @@ def plot_fitted( def tile_fitted( - model: ModelBridge, + model: Adapter, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, show_arm_details_on_hover: bool = False, @@ -1478,7 +1478,7 @@ def tile_fitted( def interact_fitted_plotly( - model: ModelBridge, + model: Adapter, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, show_arm_details_on_hover: bool = True, @@ -1626,7 +1626,7 @@ def interact_fitted_plotly( def interact_fitted( - model: ModelBridge, + model: Adapter, generator_runs_dict: TNullableGeneratorRunsDict = None, rel: bool = True, show_arm_details_on_hover: bool = True, @@ -1715,7 +1715,7 @@ def tile_observations( data = experiment.fetch_data() if arm_names is not None: data = Data(data.df[data.df["arm_name"].isin(arm_names)]) - m_ts = Models.THOMPSON( + m_ts = Generators.THOMPSON( data=data, search_space=experiment.search_space, experiment=experiment, diff --git a/ax/plot/slice.py b/ax/plot/slice.py index cb3c7623108..5015bdca7a6 100644 --- a/ax/plot/slice.py +++ b/ax/plot/slice.py @@ -12,7 +12,7 @@ import numpy as np import numpy.typing as npt from ax.core.observation import ObservationFeatures -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.plot.base import AxPlotConfig, AxPlotTypes, PlotData from ax.plot.helper import ( axis_range, @@ -46,7 +46,7 @@ def _get_slice_predictions( - model: ModelBridge, + model: Adapter, param_name: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, @@ -59,7 +59,7 @@ def _get_slice_predictions( """Computes slice prediction configuration values for a single metric name. Args: - model: ModelBridge that contains model for predictions + model: Adapter that contains model for predictions param_name: Name of parameter that will be sliced metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs @@ -128,7 +128,7 @@ def _get_slice_predictions( def plot_slice_plotly( - model: ModelBridge, + model: Adapter, param_name: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, @@ -141,7 +141,7 @@ def plot_slice_plotly( """Plot predictions for a 1-d slice of the parameter space. Args: - model: ModelBridge that contains model for predictions + model: Adapter that contains model for predictions param_name: Name of parameter that will be sliced metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs @@ -242,7 +242,7 @@ def plot_slice_plotly( def plot_slice( - model: ModelBridge, + model: Adapter, param_name: str, metric_name: str, generator_runs_dict: TNullableGeneratorRunsDict = None, @@ -255,7 +255,7 @@ def plot_slice( """Plot predictions for a 1-d slice of the parameter space. Args: - model: ModelBridge that contains model for predictions + model: Adapter that contains model for predictions param_name: Name of parameter that will be sliced metric_name: Name of metric to plot generator_runs_dict: A dictionary {name: generator run} of generator runs @@ -293,7 +293,7 @@ def plot_slice( def interact_slice_plotly( - model: ModelBridge, + model: Adapter, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, @@ -305,7 +305,7 @@ def interact_slice_plotly( space. Args: - model: ModelBridge that contains model for predictions + model: Adapter that contains model for predictions generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo @@ -548,7 +548,7 @@ def interact_slice_plotly( def interact_slice( - model: ModelBridge, + model: Adapter, generator_runs_dict: TNullableGeneratorRunsDict = None, relative: bool = False, density: int = 50, @@ -560,7 +560,7 @@ def interact_slice( space. Args: - model: ModelBridge that contains model for predictions + model: Adapter that contains model for predictions generator_runs_dict: A dictionary {name: generator run} of generator runs whose arms will be plotted, if they lie in the slice. relative: Predictions relative to status quo diff --git a/ax/plot/tests/long_running/test_pareto_utils.py b/ax/plot/tests/long_running/test_pareto_utils.py index 3b8854c7251..1456ff7be35 100644 --- a/ax/plot/tests/long_running/test_pareto_utils.py +++ b/ax/plot/tests/long_running/test_pareto_utils.py @@ -8,7 +8,7 @@ from ax.exceptions.core import UnsupportedError from ax.metrics.branin import BraninMetric -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.plot.pareto_utils import compute_posterior_pareto_frontier from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_branin_experiment @@ -26,7 +26,7 @@ def setUp(self) -> None: experiment.add_tracking_metric( BraninMetric(name="m2", param_names=["x1", "x2"]) ) - sobol = Models.SOBOL(experiment.search_space) + sobol = Generators.SOBOL(experiment.search_space) a = sobol.gen(5) experiment.new_batch_trial(generator_run=a).run() self.experiment = experiment diff --git a/ax/plot/tests/test_contours.py b/ax/plot/tests/test_contours.py index d854a71476e..79485473b8b 100644 --- a/ax/plot/tests/test_contours.py +++ b/ax/plot/tests/test_contours.py @@ -7,7 +7,7 @@ # pyre-strict import plotly.graph_objects as go -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.plot.base import AxPlotConfig from ax.plot.contour import ( interact_contour, @@ -28,7 +28,7 @@ class ContoursTest(TestCase): def test_Contours(self) -> None: exp = get_branin_experiment(with_str_choice_param=True, with_batch=True) exp.trials[0].run() - model = Models.BOTORCH_MODULAR( + model = Generators.BOTORCH_MODULAR( # Model bridge kwargs experiment=exp, data=exp.fetch_data(), @@ -36,7 +36,7 @@ def test_Contours(self) -> None: # Assert that each type of plot can be constructed successfully plot = plot_contour_plotly( model, - # pyre-fixme[16]: `ModelBridge` has no attribute `parameters`. + # pyre-fixme[16]: `Adapter` has no attribute `parameters`. model.parameters[0], model.parameters[1], list(model.metric_names)[0], @@ -63,7 +63,7 @@ def test_Contours(self) -> None: exp = get_high_dimensional_branin_experiment(with_batch=True) exp.trials[0].run() - model = Models.BOTORCH_MODULAR( + model = Generators.BOTORCH_MODULAR( experiment=exp, data=exp.fetch_data(), ) diff --git a/ax/plot/tests/test_diagnostic.py b/ax/plot/tests/test_diagnostic.py index 622c68eb0d4..4a5c3e14e03 100644 --- a/ax/plot/tests/test_diagnostic.py +++ b/ax/plot/tests/test_diagnostic.py @@ -8,7 +8,7 @@ import plotly.graph_objects as go from ax.modelbridge.cross_validation import cross_validate -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.plot.base import AxPlotConfig from ax.plot.diagnostic import ( interact_cross_validation, @@ -25,7 +25,7 @@ def setUp(self) -> None: super().setUp() exp = get_branin_experiment(with_batch=True) exp.trials[0].run() - self.model = Models.BOTORCH_MODULAR( + self.model = Generators.BOTORCH_MODULAR( # Model bridge kwargs experiment=exp, data=exp.fetch_data(), diff --git a/ax/plot/tests/test_feature_importances.py b/ax/plot/tests/test_feature_importances.py index f8697cadf0a..69aa5a7c508 100644 --- a/ax/plot/tests/test_feature_importances.py +++ b/ax/plot/tests/test_feature_importances.py @@ -9,8 +9,8 @@ import json import torch -from ax.modelbridge.base import ModelBridge -from ax.modelbridge.registry import Models +from ax.modelbridge.base import Adapter +from ax.modelbridge.registry import Generators from ax.plot.base import AxPlotConfig from ax.plot.feature_importances import ( plot_feature_importance_by_feature, @@ -28,10 +28,10 @@ DUMMY_CAPTION = "test_caption" -def get_modelbridge() -> ModelBridge: +def get_modelbridge() -> Adapter: exp = get_branin_experiment(with_batch=True) exp.trials[0].run() - return Models.LEGACY_BOTORCH( + return Generators.LEGACY_BOTORCH( # Model bridge kwargs experiment=exp, data=exp.fetch_data(), @@ -40,7 +40,7 @@ def get_modelbridge() -> ModelBridge: # pyre-fixme[24]: Generic type `dict` expects 2 type parameters, use `typing.Dict` # to avoid runtime subscripting errors. -def get_sensitivity_values(ax_model: ModelBridge) -> dict: +def get_sensitivity_values(ax_model: Adapter) -> dict: """ Compute lengscale sensitivity value for on an ax model. @@ -58,7 +58,7 @@ def get_sensitivity_values(ax_model: ModelBridge) -> dict: res = {} for metric_name in ax_model.outcomes: importances_arr = importances_dict[metric_name].numpy() - # pyre-fixme[16]: `ModelBridge` has no attribute `parameters`. + # pyre-fixme[16]: `Adapter` has no attribute `parameters`. res[metric_name] = dict(zip(ax_model.parameters, importances_arr)) return res diff --git a/ax/plot/tests/test_fitted_scatter.py b/ax/plot/tests/test_fitted_scatter.py index 1a6db33e8f5..fa44d0a3718 100644 --- a/ax/plot/tests/test_fitted_scatter.py +++ b/ax/plot/tests/test_fitted_scatter.py @@ -10,7 +10,7 @@ import plotly.graph_objects as go from ax.core.data import Data -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.plot.base import AxPlotConfig from ax.plot.scatter import interact_fitted, interact_fitted_plotly from ax.utils.common.testutils import TestCase @@ -29,7 +29,7 @@ def test_fitted_scatter(self) -> None: df["metric_name"] = "branin_dup" exp.add_tracking_metric(get_branin_metric(name="branin_dup")) - model = Models.BOTORCH_MODULAR( + model = Generators.BOTORCH_MODULAR( # Model bridge kwargs experiment=exp, data=Data.from_multiple_data([data, Data(df)]), diff --git a/ax/plot/tests/test_pareto_utils.py b/ax/plot/tests/test_pareto_utils.py index 08d425a171f..dabae1947d4 100644 --- a/ax/plot/tests/test_pareto_utils.py +++ b/ax/plot/tests/test_pareto_utils.py @@ -21,7 +21,7 @@ from ax.core.types import ComparisonOp from ax.exceptions.core import UserInputError from ax.metrics.branin import BraninMetric, NegativeBraninMetric -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.plot.pareto_frontier import ( interact_multiple_pareto_frontier, interact_pareto_frontier, @@ -51,7 +51,7 @@ def setUp(self) -> None: experiment.add_tracking_metric( BraninMetric(name="m2", param_names=["x1", "x2"]) ) - sobol = Models.SOBOL(experiment.search_space) + sobol = Generators.SOBOL(experiment.search_space) a = sobol.gen(5) experiment.new_batch_trial(generator_run=a).run() self.experiment = experiment @@ -179,7 +179,7 @@ def test_PlotParetoFrontiers(self) -> None: experiment = get_branin_experiment_with_multi_objective( has_objective_thresholds=True, ) - sobol = Models.SOBOL(experiment.search_space) + sobol = Generators.SOBOL(experiment.search_space) a = sobol.gen(5) experiment.new_batch_trial(generator_run=a).run() experiment.fetch_data() diff --git a/ax/plot/tests/test_slices.py b/ax/plot/tests/test_slices.py index 68d3d3e5c55..561e6a159b7 100644 --- a/ax/plot/tests/test_slices.py +++ b/ax/plot/tests/test_slices.py @@ -7,7 +7,7 @@ # pyre-strict import plotly.graph_objects as go -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.plot.base import AxPlotConfig from ax.plot.slice import ( interact_slice, @@ -25,7 +25,7 @@ class SlicesTest(TestCase): def test_Slices(self) -> None: exp = get_branin_experiment(with_batch=True) exp.trials[0].run() - model = Models.BOTORCH_MODULAR( + model = Generators.BOTORCH_MODULAR( # Model bridge kwargs experiment=exp, data=exp.fetch_data(), @@ -33,7 +33,7 @@ def test_Slices(self) -> None: # Assert that each type of plot can be constructed successfully plot = plot_slice_plotly( model, - # pyre-fixme[16]: `ModelBridge` has no attribute `parameters`. + # pyre-fixme[16]: `Adapter` has no attribute `parameters`. model.parameters[0], list(model.metric_names)[0], ) diff --git a/ax/plot/tests/test_tile_fitted.py b/ax/plot/tests/test_tile_fitted.py index 885d3b6dc44..3397e7074bb 100644 --- a/ax/plot/tests/test_tile_fitted.py +++ b/ax/plot/tests/test_tile_fitted.py @@ -11,7 +11,7 @@ from ax.core.arm import Arm from ax.core.metric import Metric from ax.core.search_space import SearchSpace -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.models.discrete.full_factorial import FullFactorialGenerator from ax.plot.scatter import tile_fitted, tile_observations from ax.utils.common.testutils import TestCase @@ -38,9 +38,9 @@ def get_modelbridge( # pyre-fixme[2]: Parameter must be annotated. mock_observations_from_data, status_quo_name: str | None = None, -) -> ModelBridge: +) -> Adapter: exp = get_experiment() - modelbridge = ModelBridge( + modelbridge = Adapter( search_space=get_search_space(), model=FullFactorialGenerator(), experiment=exp, @@ -48,7 +48,7 @@ def get_modelbridge( status_quo_name=status_quo_name, ) modelbridge._predict = mock.MagicMock( - "ax.modelbridge.base.ModelBridge._predict", + "ax.modelbridge.base.Adapter._predict", autospec=True, return_value=[get_observation().data], ) diff --git a/ax/plot/tests/test_traces.py b/ax/plot/tests/test_traces.py index c69c7b9417a..edd7b240a02 100644 --- a/ax/plot/tests/test_traces.py +++ b/ax/plot/tests/test_traces.py @@ -8,7 +8,7 @@ import numpy as np import plotly.graph_objects as go -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.plot.base import AxPlotConfig from ax.plot.trace import ( optimization_trace_single_method, @@ -27,7 +27,7 @@ def setUp(self) -> None: super().setUp() self.exp = get_branin_experiment(with_batch=True) self.exp.trials[0].run() - self.model = Models.BOTORCH_MODULAR( + self.model = Generators.BOTORCH_MODULAR( # Model bridge kwargs experiment=self.exp, data=self.exp.fetch_data(), @@ -73,11 +73,11 @@ def test_plot_objective_value_vs_trial_index(self) -> None: # Generate some trials with different model types, including batch trial. exp = get_branin_experiment(with_batch=True) exp.trials[0].mark_completed(unsafe=True) - sobol = Models.SOBOL(search_space=exp.search_space) + sobol = Generators.SOBOL(search_space=exp.search_space) for _ in range(2): t = exp.new_trial(sobol.gen(1)).run() t.mark_completed() - model = Models.BOTORCH_MODULAR( + model = Generators.BOTORCH_MODULAR( experiment=exp, data=exp.fetch_data(), ) diff --git a/ax/preview/modelbridge/dispatch_utils.py b/ax/preview/modelbridge/dispatch_utils.py index e99b0cc39bb..47e743cd3ed 100644 --- a/ax/preview/modelbridge/dispatch_utils.py +++ b/ax/preview/modelbridge/dispatch_utils.py @@ -10,8 +10,8 @@ from ax.core.base_trial import TrialStatus from ax.exceptions.core import UnsupportedError from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy -from ax.modelbridge.model_spec import ModelSpec -from ax.modelbridge.registry import Models +from ax.modelbridge.model_spec import GeneratorSpec +from ax.modelbridge.registry import Generators from ax.modelbridge.transition_criterion import MinTrials from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec from ax.preview.api.configs import GenerationMethod, GenerationStrategyConfig @@ -72,8 +72,8 @@ def _get_sobol_node( return GenerationNode( node_name="Sobol", model_specs=[ - ModelSpec( - model_enum=Models.SOBOL, + GeneratorSpec( + model_enum=Generators.SOBOL, model_kwargs={"seed": gs_config.initialization_random_seed}, ) ], @@ -113,8 +113,8 @@ def _get_mbm_node( return GenerationNode( node_name="MBM", model_specs=[ - ModelSpec( - model_enum=Models.BOTORCH_MODULAR, + GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, model_kwargs={ "surrogate_spec": SurrogateSpec(model_configs=model_configs), "torch_device": torch_device, @@ -149,8 +149,8 @@ def choose_generation_strategy( GenerationNode( node_name="Sobol", model_specs=[ - ModelSpec( - model_enum=Models.SOBOL, + GeneratorSpec( + model_enum=Generators.SOBOL, model_kwargs={"seed": gs_config.initialization_random_seed}, ) ], diff --git a/ax/preview/modelbridge/tests/test_dispatch_utils.py b/ax/preview/modelbridge/tests/test_dispatch_utils.py index 805a5d8aa74..fb13853d0a3 100644 --- a/ax/preview/modelbridge/tests/test_dispatch_utils.py +++ b/ax/preview/modelbridge/tests/test_dispatch_utils.py @@ -8,7 +8,7 @@ import torch from ax.core.base_trial import TrialStatus from ax.core.trial import Trial -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.modelbridge.transition_criterion import MinTrials from ax.models.torch.botorch_modular.surrogate import ModelConfig, SurrogateSpec from ax.preview.api.configs import GenerationMethod, GenerationStrategyConfig @@ -35,7 +35,7 @@ def test_choose_gs_random_search(self) -> None: sobol_node = gs._nodes[0] self.assertEqual(len(sobol_node.model_specs), 1) sobol_spec = sobol_node.model_specs[0] - self.assertEqual(sobol_spec.model_enum, Models.SOBOL) + self.assertEqual(sobol_spec.model_enum, Generators.SOBOL) self.assertEqual(sobol_spec.model_kwargs, {"seed": None}) self.assertEqual(sobol_node._transition_criteria, []) # Make sure it generates. @@ -59,7 +59,7 @@ def test_choose_gs_fast_with_options(self) -> None: self.assertTrue(sobol_node.should_deduplicate) self.assertEqual(len(sobol_node.model_specs), 1) sobol_spec = sobol_node.model_specs[0] - self.assertEqual(sobol_spec.model_enum, Models.SOBOL) + self.assertEqual(sobol_spec.model_enum, Generators.SOBOL) self.assertEqual(sobol_spec.model_kwargs, {"seed": 0}) expected_tc = [ MinTrials( @@ -85,7 +85,7 @@ def test_choose_gs_fast_with_options(self) -> None: self.assertTrue(mbm_node.should_deduplicate) self.assertEqual(len(mbm_node.model_specs), 1) mbm_spec = mbm_node.model_specs[0] - self.assertEqual(mbm_spec.model_enum, Models.BOTORCH_MODULAR) + self.assertEqual(mbm_spec.model_enum, Generators.BOTORCH_MODULAR) expected_ss = SurrogateSpec(model_configs=[ModelConfig(name="MBM defaults")]) self.assertEqual( mbm_spec.model_kwargs, @@ -122,7 +122,7 @@ def test_choose_gs_defaults(self) -> None: self.assertTrue(sobol_node.should_deduplicate) self.assertEqual(len(sobol_node.model_specs), 1) sobol_spec = sobol_node.model_specs[0] - self.assertEqual(sobol_spec.model_enum, Models.SOBOL) + self.assertEqual(sobol_spec.model_enum, Generators.SOBOL) self.assertEqual(sobol_spec.model_kwargs, {"seed": None}) expected_tc = [ MinTrials( @@ -148,7 +148,7 @@ def test_choose_gs_defaults(self) -> None: self.assertTrue(mbm_node.should_deduplicate) self.assertEqual(len(mbm_node.model_specs), 1) mbm_spec = mbm_node.model_specs[0] - self.assertEqual(mbm_spec.model_enum, Models.BOTORCH_MODULAR) + self.assertEqual(mbm_spec.model_enum, Generators.BOTORCH_MODULAR) expected_ss = SurrogateSpec( model_configs=[ ModelConfig(name="MBM defaults"), diff --git a/ax/service/managed_loop.py b/ax/service/managed_loop.py index c6ba280bbeb..ee1f9733fe4 100644 --- a/ax/service/managed_loop.py +++ b/ax/service/managed_loop.py @@ -27,10 +27,10 @@ from ax.core.utils import get_pending_observation_features from ax.exceptions.constants import CHOLESKY_ERROR_ANNOTATION from ax.exceptions.core import SearchSpaceExhausted, UserInputError -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.service.utils.best_point import ( get_best_parameters_from_model_predictions_with_trial_index, get_best_raw_objective_point, @@ -251,7 +251,7 @@ def get_best_point(self) -> tuple[TParameterization, TModelPredictArm | None]: of this optimization.""" # Find latest trial which has a generator_run attached and get its predictions best_point = get_best_parameters_from_model_predictions_with_trial_index( - experiment=self.experiment, models_enum=Models + experiment=self.experiment, models_enum=Generators ) if best_point is not None: _, parameterizations, predictions = best_point @@ -270,7 +270,7 @@ def get_best_point(self) -> tuple[TParameterization, TModelPredictArm | None]: ), ) - def get_current_model(self) -> ModelBridge | None: + def get_current_model(self) -> Adapter | None: """Obtain the most recently used model in optimization.""" return self.generation_strategy.model @@ -287,7 +287,7 @@ def optimize( arms_per_trial: int = 1, random_seed: int | None = None, generation_strategy: GenerationStrategy | None = None, -) -> tuple[TParameterization, TModelPredictArm | None, Experiment, ModelBridge | None]: +) -> tuple[TParameterization, TModelPredictArm | None, Experiment, Adapter | None]: """Construct and run a full optimization loop.""" loop = OptimizationLoop.with_evaluation_function( parameters=parameters, diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index de879ce6d42..df1fe924441 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -48,7 +48,7 @@ MaxParallelismReachedException, OptimizationConfigRequired, ) -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.generation_strategy import GenerationStrategy from ax.modelbridge.modelbridge_utils import get_fixed_features_from_experiment from ax.service.utils.analysis_base import AnalysisBase @@ -2142,10 +2142,8 @@ def _should_stop_due_to_total_trials(self) -> tuple[bool, str]: ) -def get_fitted_model_bridge( - scheduler: Scheduler, force_refit: bool = False -) -> ModelBridge: - """Returns a fitted ModelBridge object. If the model is fit already, directly +def get_fitted_model_bridge(scheduler: Scheduler, force_refit: bool = False) -> Adapter: + """Returns a fitted Adapter object. If the model is fit already, directly returns the already fitted model. Otherwise, fits and returns a new one. Args: @@ -2153,11 +2151,11 @@ def get_fitted_model_bridge( force_refit: If True, will force a data lookup and a refit of the model. Returns: - A ModelBridge object fitted to the observations of the scheduler's experiment. + A Adapter object fitted to the observations of the scheduler's experiment. """ gs = scheduler.standard_generation_strategy - model_bridge = gs.model # Optional[ModelBridge] + model_bridge = gs.model # Optional[Adapter] if model_bridge is None or force_refit: # Need to re-fit the model. gs._fit_current_model(data=None) # Will lookup_data if none is provided. - model_bridge = cast(ModelBridge, gs.model) + model_bridge = cast(Adapter, gs.model) return model_bridge diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 388e59e008e..2c758413ac1 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -49,7 +49,7 @@ from ax.modelbridge.cross_validation import compute_model_fit_metrics_from_modelbridge from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.registry import MBM_MTGP_trans, Models +from ax.modelbridge.registry import Generators, MBM_MTGP_trans from ax.runners.single_running_trial_mixin import SingleRunningTrialMixin from ax.runners.synthetic import SyntheticRunner from ax.service.scheduler import ( @@ -360,17 +360,21 @@ def setUp(self) -> None: self.two_sobol_steps_GS = GenerationStrategy( # Contrived GS to ensure steps=[ # that `DataRequiredError` is property handled in scheduler. GenerationStep( # This error is raised when not enough trials - model=Models.SOBOL, # have been observed to proceed to next + model=Generators.SOBOL, # have been observed to proceed to next num_trials=5, # geneneration step. min_trials_observed=3, max_parallelism=2, ), - GenerationStep(model=Models.SOBOL, num_trials=-1, max_parallelism=3), + GenerationStep( + model=Generators.SOBOL, num_trials=-1, max_parallelism=3 + ), ] ) # GS to force the scheduler to poll completed trials after each ran trial. self.sobol_GS_no_parallelism = GenerationStrategy( - steps=[GenerationStep(model=Models.SOBOL, num_trials=-1, max_parallelism=1)] + steps=[ + GenerationStep(model=Generators.SOBOL, num_trials=-1, max_parallelism=1) + ] ) self.scheduler_options_kwargs = {} @@ -2167,9 +2171,11 @@ def test_get_fitted_model_bridge(self) -> None: generation_strategy = GenerationStrategy( steps=[ GenerationStep( - model=Models.SOBOL, num_trials=NUM_SOBOL, max_parallelism=NUM_SOBOL + model=Generators.SOBOL, + num_trials=NUM_SOBOL, + max_parallelism=NUM_SOBOL, ), - GenerationStep(model=Models.BOTORCH_MODULAR, num_trials=-1), + GenerationStep(model=Generators.BOTORCH_MODULAR, num_trials=-1), ] ) gs = self._get_generation_strategy_strategy_for_test( @@ -2184,7 +2190,7 @@ def test_get_fitted_model_bridge(self) -> None: ), db_settings=self.db_settings_if_always_needed, ) - # need to run some trials to initialize the ModelBridge + # need to run some trials to initialize the Adapter scheduler.run_n_trials(max_trials=NUM_SOBOL + 1) self._helper_path_that_refits_the_model_if_it_is_not_already_initialized( scheduler=scheduler, @@ -2475,15 +2481,15 @@ def test_it_works_with_multitask_models( generation_strategy=GenerationStrategy( steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=1, ), GenerationStep( - model=Models.BOTORCH_MODULAR, + model=Generators.BOTORCH_MODULAR, num_trials=1, ), GenerationStep( - model=Models.BOTORCH_MODULAR, + model=Generators.BOTORCH_MODULAR, model_kwargs={ # this will cause and error if the model # doesn't get fixed features @@ -2589,7 +2595,7 @@ def test_generate_candidates_works_for_sobol(self) -> None: self.assertEqual(len(candidate_trial.generator_runs), 1) self.assertEqual( candidate_trial.generator_runs[0]._model_key, - Models.SOBOL.value, + Generators.SOBOL.value, ) self.assertEqual( len(candidate_trial.arms), @@ -2823,7 +2829,7 @@ def test_generate_candidates_works_for_iteration(self) -> None: self.assertEqual(len(candidate_trial.generator_runs), 1) self.assertEqual( candidate_trial.generator_runs[0]._model_key, - Models.BOTORCH_MODULAR.value, + Generators.BOTORCH_MODULAR.value, ) # MBM may generate less than the requested batch size. self.assertLessEqual( diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 52cffb01008..2d217a5c096 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -55,9 +55,9 @@ GenerationStep, GenerationStrategy, ) -from ax.modelbridge.model_spec import ModelSpec -from ax.modelbridge.random import RandomModelBridge -from ax.modelbridge.registry import Cont_X_trans, Models +from ax.modelbridge.model_spec import GeneratorSpec +from ax.modelbridge.random import RandomAdapter +from ax.modelbridge.registry import Cont_X_trans, Generators from ax.runners.synthetic import SyntheticRunner from ax.service.ax_client import AxClient, ObjectiveProperties from ax.service.utils.best_point import ( @@ -217,9 +217,9 @@ def get_client_with_simple_discrete_moo_problem( ) -> AxClient: gs = GenerationStrategy( steps=[ - GenerationStep(model=Models.SOBOL, num_trials=3), + GenerationStep(model=Generators.SOBOL, num_trials=3), GenerationStep( - model=Models.BOTORCH_MODULAR, + model=Generators.BOTORCH_MODULAR, num_trials=-1, model_kwargs={ # To avoid search space exhausted errors. @@ -484,17 +484,17 @@ def test_set_optimization_config_without_objectives_raises_error(self) -> None: return_value=([get_observation1(first_metric_name="branin")]), ) @patch( - "ax.modelbridge.random.RandomModelBridge.get_training_data", + "ax.modelbridge.random.RandomAdapter.get_training_data", autospec=True, return_value=([get_observation1(first_metric_name="branin")]), ) @patch( - "ax.modelbridge.random.RandomModelBridge._predict", + "ax.modelbridge.random.RandomAdapter._predict", autospec=True, return_value=[get_observation1trans(first_metric_name="branin").data], ) @patch( - "ax.modelbridge.random.RandomModelBridge.feature_importances", + "ax.modelbridge.random.RandomAdapter.feature_importances", autospec=True, return_value={"x": 0.9, "y": 1.1}, ) @@ -506,7 +506,7 @@ def test_default_generation_strategy_continuous(self, _a, _b, _c, _d) -> None: ax_client = get_branin_optimization() self.assertEqual( [s.model for s in none_throws(ax_client.generation_strategy)._steps], - [Models.SOBOL, Models.BOTORCH_MODULAR], + [Generators.SOBOL, Generators.BOTORCH_MODULAR], ) with self.assertRaisesRegex(ValueError, ".* no trials"): ax_client.get_optimization_trace(objective_optimum=branin.fmin) @@ -605,7 +605,7 @@ def test_optimization_complete(self, _mock_gen) -> None: def test_sobol_generation_strategy_completion(self) -> None: ax_client = get_branin_optimization( generation_strategy=GenerationStrategy( - [GenerationStep(Models.SOBOL, num_trials=3)] + [GenerationStep(Generators.SOBOL, num_trials=3)] ) ) # All Sobol trials should be able to be generated at once and optimization @@ -625,7 +625,7 @@ def test_save_and_load_generation_strategy(self) -> None: decoder = Decoder(config=config) db_settings = DBSettings(encoder=encoder, decoder=decoder) generation_strategy = GenerationStrategy( - [GenerationStep(Models.SOBOL, num_trials=3)] + [GenerationStep(Generators.SOBOL, num_trials=3)] ) ax_client = AxClient( db_settings=db_settings, generation_strategy=generation_strategy @@ -692,17 +692,17 @@ def test_db_write_failure_on_create_experiment(self, _mock_save_experiment) -> N return_value=([get_observation1(first_metric_name="branin")]), ) @patch( - "ax.modelbridge.random.RandomModelBridge.get_training_data", + "ax.modelbridge.random.RandomAdapter.get_training_data", autospec=True, return_value=([get_observation1(first_metric_name="branin")]), ) @patch( - "ax.modelbridge.random.RandomModelBridge._predict", + "ax.modelbridge.random.RandomAdapter._predict", autospec=True, return_value=[get_observation1trans(first_metric_name="branin").data], ) @patch( - "ax.modelbridge.random.RandomModelBridge.feature_importances", + "ax.modelbridge.random.RandomAdapter.feature_importances", autospec=True, return_value={"x": 0.9, "y": 1.1}, ) @@ -724,7 +724,7 @@ def test_default_generation_strategy_continuous_for_moo( ) self.assertEqual( [s.model for s in none_throws(ax_client.generation_strategy)._steps], - [Models.SOBOL, Models.BOTORCH_MODULAR], + [Generators.SOBOL, Generators.BOTORCH_MODULAR], ) with self.assertRaisesRegex(ValueError, ".* no trials"): ax_client.get_optimization_trace(objective_optimum=branin.fmin) @@ -789,7 +789,7 @@ def test_create_experiment(self) -> None: """Test basic experiment creation.""" ax_client = AxClient( GenerationStrategy( - steps=[GenerationStep(model=Models.SOBOL, num_trials=30)] + steps=[GenerationStep(model=Generators.SOBOL, num_trials=30)] ) ) with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"): @@ -929,7 +929,7 @@ def test_create_multitype_experiment(self) -> None: """ ax_client = AxClient( GenerationStrategy( - steps=[GenerationStep(model=Models.SOBOL, num_trials=30)] + steps=[GenerationStep(model=Generators.SOBOL, num_trials=30)] ) ) ax_client.create_experiment( @@ -1026,7 +1026,7 @@ def test_create_multitype_experiment(self) -> None: def test_create_single_objective_experiment_with_objectives_dict(self) -> None: ax_client = AxClient( GenerationStrategy( - steps=[GenerationStep(model=Models.SOBOL, num_trials=30)] + steps=[GenerationStep(model=Generators.SOBOL, num_trials=30)] ) ) with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"): @@ -1354,7 +1354,7 @@ def test_create_moo_experiment(self) -> None: """Test basic experiment creation.""" ax_client = AxClient( GenerationStrategy( - steps=[GenerationStep(model=Models.SOBOL, num_trials=30)] + steps=[GenerationStep(model=Generators.SOBOL, num_trials=30)] ) ) with self.assertRaisesRegex(AssertionError, "Experiment not set on Ax client"): @@ -1518,7 +1518,7 @@ def test_constraint_same_as_objective(self) -> None: """Check that we do not allow constraints on the objective metric.""" ax_client = AxClient( GenerationStrategy( - steps=[GenerationStep(model=Models.SOBOL, num_trials=30)] + steps=[GenerationStep(model=Generators.SOBOL, num_trials=30)] ) ) with self.assertRaises(ValueError): @@ -1775,7 +1775,7 @@ def test_trial_completion_with_metadata_with_iso_times(self) -> None: }, ) with patch.object( - RandomModelBridge, "_fit", autospec=True, side_effect=RandomModelBridge._fit + RandomAdapter, "_fit", autospec=True, side_effect=RandomAdapter._fit ) as mock_fit: ax_client.get_next_trial() mock_fit.assert_called_once() @@ -1797,7 +1797,7 @@ def test_trial_completion_with_metadata_millisecond_times(self) -> None: }, ) with patch.object( - RandomModelBridge, "_fit", autospec=True, side_effect=RandomModelBridge._fit + RandomAdapter, "_fit", autospec=True, side_effect=RandomAdapter._fit ) as mock_fit: ax_client.get_next_trial() mock_fit.assert_called_once() @@ -2283,7 +2283,7 @@ def test_unnamed_experiment_snapshot(self) -> None: self.assertIsNone(ax_client.experiment._name) @patch( - "ax.modelbridge.random.RandomModelBridge._predict", + "ax.modelbridge.random.RandomAdapter._predict", autospec=True, return_value=[get_observation1trans(first_metric_name="branin").data], ) @@ -2804,7 +2804,7 @@ def test_get_hypervolume(self) -> None: ) # Cannot get predicted hypervolume with sobol model - with self.assertRaisesRegex(ValueError, "is not of type TorchModelBridge"): + with self.assertRaisesRegex(ValueError, "is not of type TorchAdapter"): ax_client.get_hypervolume(use_model_predictions=True) # Run one more trial and check predicted hypervolume gets returned @@ -2979,7 +2979,7 @@ def test_torch_device(self) -> None: with self.assertWarnsRegex(RuntimeWarning, "a `torch_device` were specified."): AxClient( generation_strategy=GenerationStrategy( - [GenerationStep(Models.SOBOL, num_trials=3)] + [GenerationStep(Generators.SOBOL, num_trials=3)] ), torch_device=device, ) @@ -3080,7 +3080,7 @@ def test_with_node_based_gs(self) -> None: nodes=[ GenerationNode( node_name="Sobol", - model_specs=[ModelSpec(model_enum=Models.SOBOL)], + model_specs=[GeneratorSpec(model_enum=Generators.SOBOL)], ) ], ) diff --git a/ax/service/tests/test_best_point_utils.py b/ax/service/tests/test_best_point_utils.py index fb4d40b49aa..8ce7d3cfe73 100644 --- a/ax/service/tests/test_best_point_utils.py +++ b/ax/service/tests/test_best_point_utils.py @@ -25,8 +25,8 @@ from ax.core.types import ComparisonOp from ax.exceptions.core import UserInputError from ax.modelbridge.cross_validation import AssessModelFitResult -from ax.modelbridge.registry import Models -from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.registry import Generators +from ax.modelbridge.torch import TorchAdapter from ax.plot.pareto_utils import get_tensor_converter_model from ax.service.ax_client import AxClient from ax.service.utils.best_point import ( @@ -64,21 +64,21 @@ def test_best_from_model_prediction(self) -> None: exp = get_branin_experiment() for _ in range(3): - sobol = Models.SOBOL(search_space=exp.search_space) + sobol = Generators.SOBOL(search_space=exp.search_space) generator_run = sobol.gen(n=1) trial = exp.new_trial(generator_run=generator_run) trial.run() trial.mark_completed() exp.attach_data(exp.fetch_data()) - model = Models.BOTORCH_MODULAR(experiment=exp, data=exp.lookup_data()) + model = Generators.BOTORCH_MODULAR(experiment=exp, data=exp.lookup_data()) generator_run = model.gen(n=1) trial = exp.new_trial(generator_run=generator_run) trial.run() trial.mark_completed() with patch.object( - TorchModelBridge, + TorchAdapter, "model_best_point", return_value=( ( @@ -105,7 +105,7 @@ def test_best_from_model_prediction(self) -> None: }, ), ): - self.assertIsNotNone(get_best_parameters(exp, Models)) + self.assertIsNotNone(get_best_parameters(exp, Generators)) self.assertTrue( any("Model fit is poor" in warning for warning in lg.output), msg=lg.output, @@ -122,24 +122,24 @@ def test_best_from_model_prediction(self) -> None: bad_fit_metrics_to_fisher_score={}, ), ): - self.assertIsNotNone(get_best_parameters(exp, Models)) + self.assertIsNotNone(get_best_parameters(exp, Generators)) mock_model_best_point.assert_called() # Assert the non-mocked method works correctly as well - best_params = get_best_parameters(exp, Models) + best_params = get_best_parameters(exp, Generators) self.assertIsNotNone(best_params) # It works even when there are no predictions already stored on the # GeneratorRun for trial in exp.trials.values(): trial.generator_run._best_arm_predictions = None - best_params_no_gr = get_best_parameters(exp, Models) + best_params_no_gr = get_best_parameters(exp, Generators) self.assertEqual(best_params, best_params_no_gr) def test_best_raw_objective_point(self) -> None: exp = get_branin_experiment() with self.assertRaisesRegex(ValueError, "Cannot identify best "): get_best_raw_objective_point(exp) - self.assertEqual(get_best_parameters(exp, Models), None) + self.assertEqual(get_best_parameters(exp, Generators), None) exp.new_trial( generator_run=GeneratorRun(arms=[Arm(parameters={"x1": 5.0, "x2": 5.0})]) ).run().complete() @@ -157,7 +157,7 @@ def test_best_raw_objective_point(self) -> None: constrained=True, minimize=False, ) - _, best_prediction = none_throws(get_best_parameters(exp, Models)) + _, best_prediction = none_throws(get_best_parameters(exp, Generators)) best_metrics = none_throws(best_prediction)[0] self.assertDictEqual(best_metrics, {"m1": 3.0, "m2": 4.0}) @@ -166,7 +166,7 @@ def test_best_raw_objective_point(self) -> None: # pyre-fixme[8]: Attribute `bound` declared in class `OutcomeConstraint` # has type `float` but is used as type `Tensor`. constraint.bound = torch.tensor(constraint.bound) - _, best_prediction = none_throws(get_best_parameters(exp, Models)) + _, best_prediction = none_throws(get_best_parameters(exp, Generators)) best_metrics = none_throws(best_prediction)[0] self.assertDictEqual(best_metrics, {"m1": 3.0, "m2": 4.0}) @@ -219,7 +219,7 @@ def test_best_raw_objective_point_scalarized(self) -> None: ) with self.assertRaisesRegex(ValueError, "Cannot identify best "): get_best_raw_objective_point(exp) - self.assertEqual(get_best_parameters(exp, Models), None) + self.assertEqual(get_best_parameters(exp, Generators), None) exp.new_trial( generator_run=GeneratorRun(arms=[Arm(parameters={"x1": 5.0, "x2": 5.0})]) ).run().complete() @@ -237,7 +237,7 @@ def test_best_raw_objective_point_scalarized_multi(self) -> None: ) with self.assertRaisesRegex(ValueError, "Cannot identify best "): get_best_raw_objective_point(exp) - self.assertEqual(get_best_parameters(exp, Models), None) + self.assertEqual(get_best_parameters(exp, Generators), None) exp.new_trial( generator_run=GeneratorRun(arms=[Arm(parameters={"x1": 5.0, "x2": 5.0})]) ).run().complete() @@ -267,7 +267,7 @@ def test_derel_opt_config_wrapper(self, mock_derelativize: MagicMock) -> None: # Check errors. with self.assertRaisesRegex( ValueError, - "Must specify ModelBridge or Experiment when calling " + "Must specify Adapter or Experiment when calling " "`_derel_opt_config_wrapper`.", ): _derel_opt_config_wrapper(optimization_config=input_optimization_config) @@ -283,7 +283,7 @@ def test_derel_opt_config_wrapper(self, mock_derelativize: MagicMock) -> None: # Set status quo. exp.status_quo = exp.trials[0].arms[0] - # ModelBridges will have specific addresses and so must be self-same to + # Adapters will have specific addresses and so must be self-same to # pass equality checks. test_modelbridge_1 = get_tensor_converter_model( experiment=none_throws(exp), @@ -304,7 +304,7 @@ def test_derel_opt_config_wrapper(self, mock_derelativize: MagicMock) -> None: f"{best_point_module}.get_tensor_converter_model", return_value=test_modelbridge_1, ), patch( - f"{best_point_module}.ModelBridge.get_training_data", + f"{best_point_module}.Adapter.get_training_data", return_value=test_observations_1, ): returned_value = _derel_opt_config_wrapper( @@ -317,7 +317,7 @@ def test_derel_opt_config_wrapper(self, mock_derelativize: MagicMock) -> None: observations=test_observations_1, ) - # Observations and ModelBridge are not constructed from other inputs when + # Observations and Adapter are not constructed from other inputs when # provided. test_modelbridge_2 = get_tensor_converter_model( experiment=none_throws(exp), @@ -328,7 +328,7 @@ def test_derel_opt_config_wrapper(self, mock_derelativize: MagicMock) -> None: f"{best_point_module}.get_tensor_converter_model", return_value=test_modelbridge_2, ), patch( - f"{best_point_module}.ModelBridge.get_training_data", + f"{best_point_module}.Adapter.get_training_data", return_value=test_observations_2, ): returned_value = _derel_opt_config_wrapper( @@ -339,7 +339,7 @@ def test_derel_opt_config_wrapper(self, mock_derelativize: MagicMock) -> None: ) self.assertTrue( any( - "ModelBridge and Experiment provided to " + "Adapter and Experiment provided to " "`_derel_opt_config_wrapper`. Ignoring the latter." in warning for warning in lg.output ), diff --git a/ax/service/tests/test_interactive_loop.py b/ax/service/tests/test_interactive_loop.py index 78e0516f559..8a734ebf8f5 100644 --- a/ax/service/tests/test_interactive_loop.py +++ b/ax/service/tests/test_interactive_loop.py @@ -17,7 +17,7 @@ from ax.core.types import TEvaluationOutcome, TParameterization from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.service.ax_client import AxClient from ax.service.interactive_loop import ( ax_client_data_attacher, @@ -32,7 +32,9 @@ class TestInteractiveLoop(TestCase): def setUp(self) -> None: generation_strategy = GenerationStrategy( - steps=[GenerationStep(model=Models.SOBOL, max_parallelism=1, num_trials=-1)] + steps=[ + GenerationStep(model=Generators.SOBOL, max_parallelism=1, num_trials=-1) + ] ) self.ax_client = AxClient(generation_strategy=generation_strategy) self.ax_client.create_experiment( diff --git a/ax/service/tests/test_managed_loop.py b/ax/service/tests/test_managed_loop.py index 99896cf138c..7a0cc14d82f 100644 --- a/ax/service/tests/test_managed_loop.py +++ b/ax/service/tests/test_managed_loop.py @@ -13,7 +13,7 @@ from ax.exceptions.core import UserInputError from ax.metrics.branin import branin from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.service.managed_loop import OptimizationLoop, optimize from ax.utils.common.testutils import TestCase from ax.utils.testing.mock import mock_botorch_optimize @@ -373,7 +373,7 @@ def test_optimize_search_space_exhausted(self) -> None: def test_custom_gs(self) -> None: """Managed loop with custom generation strategy""" strategy0 = GenerationStrategy( - name="Sobol", steps=[GenerationStep(model=Models.SOBOL, num_trials=-1)] + name="Sobol", steps=[GenerationStep(model=Generators.SOBOL, num_trials=-1)] ) loop = OptimizationLoop.with_evaluation_function( parameters=[ @@ -414,7 +414,8 @@ def test_optimize_graceful_exit_on_exception(self) -> None: minimize=True, total_trials=6, generation_strategy=GenerationStrategy( - name="Sobol", steps=[GenerationStep(model=Models.SOBOL, num_trials=3)] + name="Sobol", + steps=[GenerationStep(model=Generators.SOBOL, num_trials=3)], ), ) self.assertEqual(len(exp.trials), 3) # Check that we stopped at 3 trials. @@ -436,7 +437,7 @@ def test_optimize_graceful_exit_on_exception(self) -> None: # pyre-fixme[3]: Return type must be annotated. def test_annotate_exception(self, _): strategy0 = GenerationStrategy( - name="Sobol", steps=[GenerationStep(model=Models.SOBOL, num_trials=-1)] + name="Sobol", steps=[GenerationStep(model=Generators.SOBOL, num_trials=-1)] ) loop = OptimizationLoop.with_evaluation_function( parameters=[ diff --git a/ax/service/tests/test_report_utils.py b/ax/service/tests/test_report_utils.py index b09ad2a241e..6245fd6ff7b 100644 --- a/ax/service/tests/test_report_utils.py +++ b/ax/service/tests/test_report_utils.py @@ -24,7 +24,7 @@ from ax.core.types import ComparisonOp from ax.modelbridge.generation_node import GenerationStep from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.service.scheduler import Scheduler from ax.service.utils.report_utils import ( _format_comparison_string, @@ -380,7 +380,7 @@ def test_get_standard_plots(self) -> None: exp = get_branin_experiment(with_batch=True, minimize=True) exp.trials[0].run() exp.trials[0].mark_completed() - model = Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()) + model = Generators.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()) for gsa, true_objective_metric_name in itertools.product( [False, True], ["branin", None] ): @@ -440,7 +440,7 @@ def test_get_standard_plots_moo(self) -> None: with self.assertLogs(logger="ax", level=INFO) as log: plots = get_standard_plots( experiment=exp, - model=Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()), + model=Generators.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()), ) self.assertEqual(len(log.output), 3) self.assertIn( @@ -479,7 +479,7 @@ def test_get_standard_plots_moo_relative_constraints(self) -> None: ot.relative = False plots = get_standard_plots( experiment=exp, - model=Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()), + model=Generators.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()), ) self.assertEqual(len(plots), 8) @@ -491,7 +491,7 @@ def test_get_standard_plots_moo_no_objective_thresholds(self) -> None: exp.trials[0].run() plots = get_standard_plots( experiment=exp, - model=Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()), + model=Generators.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()), ) self.assertEqual(len(plots), 8) @@ -501,14 +501,14 @@ def test_get_standard_plots_map_data(self) -> None: exp.new_trial().add_arm(exp.status_quo) exp.trials[0].run() exp.new_trial( - generator_run=Models.SOBOL(search_space=exp.search_space).gen(n=1) + generator_run=Generators.SOBOL(search_space=exp.search_space).gen(n=1) ) exp.trials[1].run() for t in exp.trials.values(): t.mark_completed() plots = get_standard_plots( experiment=exp, - model=Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()), + model=Generators.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()), true_objective_metric_name="branin", ) @@ -524,7 +524,7 @@ def test_get_standard_plots_map_data(self) -> None: ): plots = get_standard_plots( experiment=exp, - model=Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()), + model=Generators.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()), true_objective_metric_name="not_present", ) @@ -532,10 +532,10 @@ def test_get_standard_plots_map_data(self) -> None: def test_skip_contour_high_dimensional(self) -> None: exp = get_high_dimensional_branin_experiment() # Initial Sobol points - sobol = Models.SOBOL(search_space=exp.search_space) + sobol = Generators.SOBOL(search_space=exp.search_space) for _ in range(1): exp.new_trial(sobol.gen(1)).run() - model = Models.BOTORCH_MODULAR( + model = Generators.BOTORCH_MODULAR( experiment=exp, data=exp.fetch_data(), ) @@ -1241,13 +1241,13 @@ def test_warn_if_unpredictable_metrics(self) -> None: gs = GenerationStrategy( steps=[ GenerationStep( - model=Models.SOBOL, + model=Generators.SOBOL, num_trials=3, min_trials_observed=3, max_parallelism=3, ), GenerationStep( - model=Models.BOTORCH_MODULAR, num_trials=-1, max_parallelism=3 + model=Generators.BOTORCH_MODULAR, num_trials=-1, max_parallelism=3 ), ] ) diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index a9bb36f2922..95cbd3cb916 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -12,7 +12,7 @@ from ax.metrics.branin import BraninMetric from ax.modelbridge.dispatch_utils import choose_generation_strategy from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.service.tests.scheduler_test_utils import ( AxSchedulerTestCase, BrokenRunnerRuntimeError, @@ -106,17 +106,21 @@ def setUp(self) -> None: self.two_sobol_steps_GS = GenerationStrategy( # Contrived GS to ensure steps=[ # that `DataRequiredError` is property handled in scheduler. GenerationStep( # This error is raised when not enough trials - model=Models.SOBOL, # have been observed to proceed to next + model=Generators.SOBOL, # have been observed to proceed to next num_trials=5, # geneneration step. min_trials_observed=3, max_parallelism=2, ), - GenerationStep(model=Models.SOBOL, num_trials=-1, max_parallelism=3), + GenerationStep( + model=Generators.SOBOL, num_trials=-1, max_parallelism=3 + ), ] ) # GS to force the scheduler to poll completed trials after each ran trial. self.sobol_GS_no_parallelism = GenerationStrategy( - steps=[GenerationStep(model=Models.SOBOL, num_trials=-1, max_parallelism=1)] + steps=[ + GenerationStep(model=Generators.SOBOL, num_trials=-1, max_parallelism=1) + ] ) self.scheduler_options_kwargs: dict[str, str | None] = { "mt_experiment_trial_type": "type1" diff --git a/ax/service/utils/best_point.py b/ax/service/utils/best_point.py index 28057692f3e..d50aa1fbfc3 100644 --- a/ax/service/utils/best_point.py +++ b/ax/service/utils/best_point.py @@ -30,7 +30,7 @@ from ax.core.trial import Trial from ax.core.types import ComparisonOp, TModelPredictArm, TParameterization from ax.exceptions.core import UnsupportedError, UserInputError -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.cross_validation import ( assess_model_fit, compute_diagnostics, @@ -42,11 +42,11 @@ predicted_pareto_frontier as predicted_pareto, ) from ax.modelbridge.registry import ( + Generators, get_model_from_generator_run, ModelRegistryBase, - Models, ) -from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.torch import TorchAdapter from ax.modelbridge.transforms.utils import ( derelativize_optimization_config_with_raw_status_quo, ) @@ -257,9 +257,9 @@ def get_best_parameters_from_model_predictions_with_trial_index( except ValueError: return _extract_best_arm_from_gr(gr=gr, trials=experiment.trials) - # If model is not TorchModelBridge, just use the best arm from the + # If model is not TorchAdapter, just use the best arm from the # last good generator run - if not isinstance(model, TorchModelBridge): + if not isinstance(model, TorchAdapter): return _extract_best_arm_from_gr(gr=gr, trials=experiment.trials) # Check to see if the model is worth using @@ -543,16 +543,16 @@ def get_pareto_optimal_parameters( modelbridge = generation_strategy.model is_moo_modelbridge = ( modelbridge - and isinstance(modelbridge, TorchModelBridge) + and isinstance(modelbridge, TorchAdapter) and assert_is_instance( modelbridge, - TorchModelBridge, + TorchAdapter, ).is_moo_problem ) if is_moo_modelbridge: generation_strategy._fit_current_model(data=None) else: - modelbridge = Models.BOTORCH_MODULAR( + modelbridge = Generators.BOTORCH_MODULAR( experiment=experiment, data=assert_is_instance( experiment.lookup_data(trial_indices=trial_indices), @@ -561,7 +561,7 @@ def get_pareto_optimal_parameters( ) modelbridge = assert_is_instance( modelbridge, - TorchModelBridge, + TorchAdapter, ) # If objective thresholds are not specified in optimization config, extract @@ -739,7 +739,7 @@ def _is_all_noiseless(df: pd.DataFrame, metric_name: str) -> bool: def _derel_opt_config_wrapper( optimization_config: OptimizationConfig, - modelbridge: ModelBridge | None = None, + modelbridge: Adapter | None = None, experiment: Experiment | None = None, observations: list[Observation] | None = None, ) -> OptimizationConfig: @@ -751,7 +751,7 @@ def _derel_opt_config_wrapper( if modelbridge is None and experiment is None: raise ValueError( - "Must specify ModelBridge or Experiment when calling " + "Must specify Adapter or Experiment when calling " "`_derel_opt_config_wrapper`." ) elif not modelbridge: @@ -761,7 +761,7 @@ def _derel_opt_config_wrapper( ) else: # Both modelbridge and experiment specified. logger.warning( - "ModelBridge and Experiment provided to `_derel_opt_config_wrapper`. " + "Adapter and Experiment provided to `_derel_opt_config_wrapper`. " "Ignoring the latter." ) if not modelbridge.status_quo: diff --git a/ax/service/utils/best_point_mixin.py b/ax/service/utils/best_point_mixin.py index f2ab5b65491..4090e026e8d 100644 --- a/ax/service/utils/best_point_mixin.py +++ b/ax/service/utils/best_point_mixin.py @@ -33,7 +33,7 @@ validate_and_apply_final_transform, ) from ax.modelbridge.registry import ModelRegistryBase -from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.torch import TorchAdapter from ax.modelbridge.transforms.derelativize import Derelativize from ax.models.torch.botorch_moo_defaults import ( get_outcome_constraint_transforms, @@ -273,8 +273,8 @@ def _get_best_trial( # calculation of best parameters. if use_model_predictions: current_model = generation_strategy._curr.model_spec_to_gen_from.model_enum - # Cover for the case where source of `self._curr.model` was not a `Models` - # enum but a factory function, in which case we cannot do + # Cover for the case where source of `self._curr.model` was not a + # `Generators` enum but a factory function, in which case we cannot do # `get_model_from_generator_run` (since we don't have model type and inputs # recorded on the generator run. models_enum = ( @@ -391,9 +391,9 @@ def _get_hypervolume( # this should be a no-op. generation_strategy._fit_current_model(data=None) model = generation_strategy.model - if not isinstance(model, TorchModelBridge): + if not isinstance(model, TorchAdapter): raise ValueError( - f"Model {model} is not of type TorchModelBridge, cannot " + f"Model {model} is not of type TorchAdapter, cannot " "calculate predicted hypervolume." ) return predicted_hypervolume( diff --git a/ax/service/utils/report_utils.py b/ax/service/utils/report_utils.py index 63a5412af02..a9e7cb6b47c 100644 --- a/ax/service/utils/report_utils.py +++ b/ax/service/utils/report_utils.py @@ -35,14 +35,14 @@ from ax.core.trial import BaseTrial from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy from ax.exceptions.core import DataRequiredError, UserInputError -from ax.modelbridge import ModelBridge +from ax.modelbridge import Adapter from ax.modelbridge.cross_validation import ( compute_model_fit_metrics_from_modelbridge, cross_validate, ) from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.random import RandomModelBridge -from ax.modelbridge.torch import TorchModelBridge +from ax.modelbridge.random import RandomAdapter +from ax.modelbridge.torch import TorchAdapter from ax.plot.contour import interact_contour_plotly from ax.plot.diagnostic import interact_cross_validation_plotly from ax.plot.feature_importances import plot_feature_importance_by_feature_plotly @@ -91,7 +91,7 @@ ) -def _get_cross_validation_plots(model: ModelBridge) -> list[go.Figure]: +def _get_cross_validation_plots(model: Adapter) -> list[go.Figure]: cv = cross_validate(model=model) return [ interact_cross_validation_plotly( @@ -149,7 +149,7 @@ def _get_objective_trace_plot( def _get_objective_v_param_plots( experiment: Experiment, - model: ModelBridge, + model: Adapter, importance: None | (dict[str, dict[str, npt.NDArray]] | dict[str, dict[str, float]]) = None, # Chosen to take ~1min on local benchmarks. @@ -316,7 +316,7 @@ def _get_shortest_unique_suffix_dict( def get_standard_plots( experiment: Experiment, - model: ModelBridge | None, + model: Adapter | None, data: Data | None = None, true_objective_metric_name: str | None = None, early_stopping_strategy: BaseEarlyStoppingStrategy | None = None, @@ -325,14 +325,14 @@ def get_standard_plots( ) -> list[go.Figure]: """Extract standard plots for single-objective optimization. - Extracts a list of plots from an ``Experiment`` and ``ModelBridge`` of general + Extracts a list of plots from an ``Experiment`` and ``Adapter`` of general interest to an Ax user. Currently not supported are - TODO: multi-objective optimization - TODO: ChoiceParameter plots Args: - experiment: The ``Experiment`` from which to obtain standard plots. - - model: The ``ModelBridge`` used to suggest trial parameters. + - model: The ``Adapter`` used to suggest trial parameters. - true_objective_metric_name: Name of the metric to use as the true objective. - early_stopping_strategy: Early stopping strategy used throughout the experiment; used for visualizing when curves are stopped. @@ -394,8 +394,8 @@ def get_standard_plots( # Objective vs. parameter plot requires a `Model`, so add it only if model # is alrady available. In cases where initially custom trials are attached, # model might not yet be set on the generation strategy. Additionally, if - # the model is a RandomModelBridge, skip plots that require predictions. - if model is not None and not isinstance(model, RandomModelBridge): + # the model is a RandomAdapter, skip plots that require predictions. + if model is not None and not isinstance(model, RandomAdapter): try: if true_objective_metric_name is not None: logger.debug("Starting objective vs. true objective scatter plot.") @@ -414,7 +414,7 @@ def get_standard_plots( # features to plot. sens = None importance_measure = "" - if global_sensitivity_analysis and isinstance(model, TorchModelBridge): + if global_sensitivity_analysis and isinstance(model, TorchAdapter): try: logger.debug("Starting global sensitivity analysis.") sens = ax_parameter_sens(model, order="total") @@ -1163,7 +1163,7 @@ def pareto_frontier_scatter_2d_plotly( def _objective_vs_true_objective_scatter( - model: ModelBridge, + model: Adapter, objective_metric_name: str, true_objective_metric_name: str, ) -> go.Figure: @@ -1502,13 +1502,13 @@ def warn_if_unpredictable_metrics( A string warning the user about unpredictable metrics, if applicable. """ # Get fit quality dict. - model_bridge = generation_strategy.model # Optional[ModelBridge] + model_bridge = generation_strategy.model # Optional[Adapter] if model_bridge is None: # Need to re-fit the model. generation_strategy._fit_current_model(data=None) - model_bridge = cast(ModelBridge, generation_strategy.model) - if isinstance(model_bridge, RandomModelBridge): + model_bridge = cast(Adapter, generation_strategy.model) + if isinstance(model_bridge, RandomAdapter): logger.debug( - "Current modelbridge on GenerationStrategy is RandomModelBridge. " + "Current modelbridge on GenerationStrategy is RandomAdapter. " "Not checking metric predictability." ) return None diff --git a/ax/storage/json_store/decoder.py b/ax/storage/json_store/decoder.py index f2af62835a7..fa5d1e69471 100644 --- a/ax/storage/json_store/decoder.py +++ b/ax/storage/json_store/decoder.py @@ -39,7 +39,7 @@ GenerationStep, GenerationStrategy, ) -from ax.modelbridge.model_spec import ModelSpec +from ax.modelbridge.model_spec import GeneratorSpec from ax.modelbridge.registry import _decode_callables_from_references, ModelRegistryBase from ax.modelbridge.transition_criterion import ( AuxiliaryExperimentCheck, @@ -183,7 +183,7 @@ def object_from_json( return generation_node_from_json( generation_node_json=object_json, **vars(registry_kwargs) ) - elif _class == ModelSpec: + elif _class == GeneratorSpec: return model_spec_from_json( model_spec_json=object_json, **vars(registry_kwargs) ) @@ -696,7 +696,7 @@ def _extract_surrogate_spec_from_surrogate_specs( key with the value of that element. This helper will keep deserialization of MBM models backwards compatible - even after we remove the ``surrogate_specs`` argument from ``BoTorchModel``. + even after we remove the ``surrogate_specs`` argument from ``BoTorchGenerator``. Args: model_kwargs: A dictionary of model kwargs to update. @@ -783,14 +783,14 @@ def model_spec_from_json( model_spec_json: dict[str, Any], decoder_registry: TDecoderRegistry = CORE_DECODER_REGISTRY, class_decoder_registry: TClassDecoderRegistry = CORE_CLASS_DECODER_REGISTRY, -) -> ModelSpec: - """Load ModelSpec from JSON.""" +) -> GeneratorSpec: + """Load GeneratorSpec from JSON.""" kwargs = model_spec_json.pop("model_kwargs", None) kwargs.pop("fit_on_update", None) # Remove deprecated fit_on_update. if kwargs is not None: kwargs = _extract_surrogate_spec_from_surrogate_specs(kwargs) gen_kwargs = model_spec_json.pop("model_gen_kwargs", None) - return ModelSpec( + return GeneratorSpec( model_enum=object_from_json( model_spec_json.pop("model_enum"), decoder_registry=decoder_registry, @@ -1086,15 +1086,15 @@ def _update_deprecated_model_registry(name: str) -> str: will error out while looking it up in the corresponding enum. Args: - name: The name of the ``Models`` enum. + name: The name of the ``Generators`` enum. Returns: - Either the given name or the name of a replacement ``Models`` enum. + Either the given name or the name of a replacement ``Generators`` enum. """ if name in _DEPRECATED_MODEL_TO_REPLACEMENT: new_name = _DEPRECATED_MODEL_TO_REPLACEMENT[name] logger.exception( - f"{name} model is deprecated and replaced by Models.{new_name}. " + f"{name} model is deprecated and replaced by Generators.{new_name}. " f"Please use {new_name} in the future. Note that this warning only " "enables deserialization of experiments with deprecated models. " "Model fitting with the loaded experiment may still fail. " diff --git a/ax/storage/json_store/encoders.py b/ax/storage/json_store/encoders.py index 85d43fba87c..7279021cc4c 100644 --- a/ax/storage/json_store/encoders.py +++ b/ax/storage/json_store/encoders.py @@ -47,11 +47,11 @@ from ax.modelbridge.best_model_selector import BestModelSelector from ax.modelbridge.generation_node import GenerationNode from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.model_spec import FactoryFunctionModelSpec, ModelSpec +from ax.modelbridge.model_spec import FactoryFunctionGeneratorSpec, GeneratorSpec from ax.modelbridge.registry import _encode_callables_as_references from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transition_criterion import TransitionCriterion -from ax.models.torch.botorch_modular.model import BoTorchModel +from ax.models.torch.botorch_modular.model import BoTorchGenerator from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.models.winsorization_config import WinsorizationConfig from ax.storage.botorch_modular_registry import CLASS_TO_REGISTRY @@ -490,12 +490,12 @@ def transition_criterion_to_dict(criterion: TransitionCriterion) -> dict[str, An return properties -def model_spec_to_dict(model_spec: ModelSpec) -> dict[str, Any]: +def model_spec_to_dict(model_spec: GeneratorSpec) -> dict[str, Any]: """Convert Ax model spec to a dictionary.""" - if isinstance(model_spec, FactoryFunctionModelSpec): + if isinstance(model_spec, FactoryFunctionGeneratorSpec): raise NotImplementedError( f"JSON serialization not yet implemented for model spec: {model_spec}" - " because it leverages a factory function instead of `Models` registry." + " because it leverages a factory function instead of `Generators` registry." ) return { "__type": model_spec.__class__.__name__, @@ -528,7 +528,7 @@ def observation_features_to_dict(obs_features: ObservationFeatures) -> dict[str, } -def botorch_model_to_dict(model: BoTorchModel) -> dict[str, Any]: +def botorch_model_to_dict(model: BoTorchGenerator) -> dict[str, Any]: """Convert Ax model to a dictionary.""" return { "__type": model.__class__.__name__, diff --git a/ax/storage/json_store/registry.py b/ax/storage/json_store/registry.py index c89e0ca7cf2..149a3712dbb 100644 --- a/ax/storage/json_store/registry.py +++ b/ax/storage/json_store/registry.py @@ -80,14 +80,14 @@ ReductionCriterion, SingleDiagnosticBestModelSelector, ) -from ax.modelbridge.factory import Models +from ax.modelbridge.factory import Generators from ax.modelbridge.generation_node import GenerationNode, GenerationStep from ax.modelbridge.generation_node_input_constructors import ( InputConstructorPurpose, NodeInputConstructors, ) from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.model_spec import ModelSpec +from ax.modelbridge.model_spec import GeneratorSpec from ax.modelbridge.registry import ModelRegistryBase from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transition_criterion import ( @@ -102,7 +102,7 @@ TransitionCriterion, ) from ax.models.torch.botorch_modular.acquisition import Acquisition -from ax.models.torch.botorch_modular.model import BoTorchModel +from ax.models.torch.botorch_modular.model import BoTorchGenerator from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.torch.botorch_modular.utils import ModelConfig from ax.models.winsorization_config import WinsorizationConfig @@ -199,7 +199,7 @@ BenchmarkMapMetric: metric_to_dict, BenchmarkTimeVaryingMetric: metric_to_dict, BenchmarkMapUnavailableWhileRunningMetric: metric_to_dict, - BoTorchModel: botorch_model_to_dict, + BoTorchGenerator: botorch_model_to_dict, BraninMetric: metric_to_dict, BraninTimestampMapMetric: metric_to_dict, ChainedInputTransform: botorch_component_to_dict, @@ -229,7 +229,7 @@ MinimumTrialsInStatus: transition_criterion_to_dict, MinimumPreferenceOccurances: transition_criterion_to_dict, AuxiliaryExperimentCheck: transition_criterion_to_dict, - ModelSpec: model_spec_to_dict, + GeneratorSpec: model_spec_to_dict, MultiObjective: multi_objective_to_dict, MultiObjectiveOptimizationConfig: multi_objective_optimization_config_to_dict, MultiTypeExperiment: multi_type_experiment_to_dict, @@ -297,6 +297,7 @@ "AndEarlyStoppingStrategy": AndEarlyStoppingStrategy, "AutoTransitionAfterGen": AutoTransitionAfterGen, "AuxiliaryExperiment": AuxiliaryExperiment, + "AuxiliaryExperimentCheck": AuxiliaryExperimentCheck, "AuxiliaryExperimentPurpose": AuxiliaryExperimentPurpose, "Arm": Arm, "AggregatedBenchmarkResult": AggregatedBenchmarkResult, @@ -312,7 +313,8 @@ ), "BenchmarkResult": BenchmarkResult, "BenchmarkTrialMetadata": BenchmarkTrialMetadata, - "BoTorchModel": BoTorchModel, + "BoTorchGenerator": BoTorchGenerator, + "BoTorchModel": BoTorchGenerator, "BraninMetric": BraninMetric, "BraninTimestampMapMetric": BraninTimestampMapMetric, "ChainedInputTransform": ChainedInputTransform, @@ -332,6 +334,8 @@ "GenerationStep": GenerationStep, "GeneratorRun": GeneratorRun, "GeneratorRunStruct": GeneratorRunStruct, + "Generators": Generators, + "GeneratorSpec": GeneratorSpec, "Hartmann6Metric": Hartmann6Metric, "HierarchicalSearchSpace": HierarchicalSearchSpace, "ImprovementGlobalStoppingStrategy": ImprovementGlobalStoppingStrategy, @@ -352,11 +356,10 @@ "MinTrials": MinTrials, "MinimumTrialsInStatus": MinimumTrialsInStatus, "MinimumPreferenceOccurances": MinimumPreferenceOccurances, - "AuxiliaryExperimentCheck": AuxiliaryExperimentCheck, - "Models": Models, "ModelRegistryBase": ModelRegistryBase, "ModelConfig": ModelConfig, - "ModelSpec": ModelSpec, + "Models": Generators, + "ModelSpec": GeneratorSpec, "MultiObjective": MultiObjective, "MultiObjectiveOptimizationConfig": MultiObjectiveOptimizationConfig, "MultiTypeExperiment": MultiTypeExperiment, diff --git a/ax/storage/json_store/tests/test_json_store.py b/ax/storage/json_store/tests/test_json_store.py index eab276439c0..295ff949365 100644 --- a/ax/storage/json_store/tests/test_json_store.py +++ b/ax/storage/json_store/tests/test_json_store.py @@ -21,7 +21,7 @@ from ax.exceptions.storage import JSONDecodeError, JSONEncodeError from ax.modelbridge.generation_node import GenerationStep from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel from ax.models.torch.botorch_modular.surrogate import SurrogateSpec from ax.models.torch.botorch_modular.utils import ModelConfig @@ -162,10 +162,10 @@ get_benchmark_map_unavailable_while_running_metric, ), ("BenchmarkResult", get_benchmark_result), - ("BoTorchModel", get_botorch_model), - ("BoTorchModel", get_botorch_model_with_default_acquisition_class), - ("BoTorchModel", get_botorch_model_with_surrogate_spec), - ("BoTorchModel", get_botorch_model_with_surrogate_specs), + ("BoTorchGenerator", get_botorch_model), + ("BoTorchGenerator", get_botorch_model_with_default_acquisition_class), + ("BoTorchGenerator", get_botorch_model_with_surrogate_spec), + ("BoTorchGenerator", get_botorch_model_with_surrogate_specs), ("BraninMetric", get_branin_metric), ("ChainedInputTransform", get_chained_input_transform), ("ChoiceParameter", get_choice_parameter), @@ -523,7 +523,7 @@ def test_DecodeGenerationStrategy(self) -> None: generation_strategy._unset_non_persistent_state_fields() self.assertEqual(generation_strategy, new_generation_strategy) self.assertGreater(len(new_generation_strategy._steps), 0) - self.assertIsInstance(new_generation_strategy._steps[0].model, Models) + self.assertIsInstance(new_generation_strategy._steps[0].model, Generators) # Model has not yet been initialized on this GS since it hasn't generated # anything yet. self.assertIsNone(new_generation_strategy.model) @@ -549,7 +549,7 @@ def test_DecodeGenerationStrategy(self) -> None: # well. generation_strategy._unset_non_persistent_state_fields() self.assertEqual(generation_strategy, new_generation_strategy) - self.assertIsInstance(new_generation_strategy._steps[0].model, Models) + self.assertIsInstance(new_generation_strategy._steps[0].model, Generators) # Check that we can encode and decode the generation strategy after # it has generated some trials and been updated with some data. @@ -572,7 +572,7 @@ def test_DecodeGenerationStrategy(self) -> None: # well. generation_strategy._unset_non_persistent_state_fields() self.assertEqual(generation_strategy, new_generation_strategy) - self.assertIsInstance(new_generation_strategy._steps[0].model, Models) + self.assertIsInstance(new_generation_strategy._steps[0].model, Generators) def test_EncodeDecodeNumpy(self) -> None: arr = np.array([[1, 2, 3], [4, 5, 6]]) @@ -747,7 +747,7 @@ def test_DecodeUnknownClassFromJson(self) -> None: class_from_json({"index": 0, "class": "unknown_path"}) def test_unregistered_model_not_supported_in_nodes(self) -> None: - """Support for callables within model kwargs on ModelSpecs stored on + """Support for callables within model kwargs on GeneratorSpecs stored on GenerationNodes is currently not supported. This is supported for GenerationSteps due to legacy compatibility. """ @@ -797,7 +797,7 @@ def test_generation_step_backwards_compatibility(self) -> None: # Test that we can load a generation step with fit_on_update. json = { "__type": "GenerationStep", - "model": {"__type": "Models", "name": "BOTORCH_MODULAR"}, + "model": {"__type": "Generators", "name": "BOTORCH_MODULAR"}, "num_trials": 5, "min_trials_observed": 0, "completion_criteria": [], @@ -970,11 +970,11 @@ def test_model_registry_backwards_compatibility(self) -> None: # Check for models with listed replacements. for name, replacement in _DEPRECATED_MODEL_TO_REPLACEMENT.items(): with self.assertLogs(logger="ax", level="ERROR"): - from_json = object_from_json({"__type": "Models", "name": name}) - self.assertEqual(from_json, Models[replacement]) + from_json = object_from_json({"__type": "Generators", "name": name}) + self.assertEqual(from_json, Generators[replacement]) # Check for non-deprecated models. from_json = object_from_json({"__type": "Models", "name": "BO_MIXED"}) - self.assertEqual(from_json, Models.BO_MIXED) + self.assertEqual(from_json, Generators.BO_MIXED) # Check for models with no replacement. with self.assertRaisesRegex(KeyError, "nonexistent"): object_from_json({"__type": "Models", "name": "nonexistent_model"}) diff --git a/ax/storage/sqa_store/tests/test_sqa_store.py b/ax/storage/sqa_store/tests/test_sqa_store.py index 8da1a6412b6..d98bf51c1e2 100644 --- a/ax/storage/sqa_store/tests/test_sqa_store.py +++ b/ax/storage/sqa_store/tests/test_sqa_store.py @@ -35,7 +35,7 @@ from ax.exceptions.storage import JSONDecodeError, SQADecodeError, SQAEncodeError from ax.metrics.branin import BraninMetric from ax.modelbridge.dispatch_utils import choose_generation_strategy -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.runners.synthetic import SyntheticRunner from ax.storage.metric_registry import CORE_METRIC_REGISTRY, register_metrics @@ -216,7 +216,7 @@ def test_SaveExperimentWithSurrogateAsModelKwarg(self) -> None: experiment = get_branin_experiment( with_batch=True, num_batch_trial=1, with_completed_batch=True ) - model = Models.BOTORCH_MODULAR( + model = Generators.BOTORCH_MODULAR( experiment=experiment, data=experiment.lookup_data(), surrogate=Surrogate(surrogate_spec=SurrogateSpec()), @@ -499,7 +499,7 @@ def test_ExperimentSaveAndLoadReducedState( # 3. Try case with model state and search space + opt.config on a # generator run in the experiment. - gr = Models.SOBOL(experiment=exp).gen(1) + gr = Generators.SOBOL(experiment=exp).gen(1) # Expecting model kwargs to have 6 fields (seed, deduplicate, init_position, # scramble, generated_points, fallback_to_sample_polytope) # and the rest of model-state info on generator run to have values too. @@ -564,7 +564,7 @@ def test_load_and_save_reduced_state_does_not_lose_abandoned_arms(self) -> None: def test_ExperimentSaveAndLoadGRWithOptConfig(self) -> None: exp = get_experiment_with_batch_trial(constrain_search_space=False) - gr = Models.SOBOL(experiment=exp).gen( + gr = Generators.SOBOL(experiment=exp).gen( n=1, optimization_config=exp.optimization_config ) exp.new_trial(generator_run=gr) @@ -715,7 +715,7 @@ def test_ExperimentSaveAndUpdateTrials(self) -> None: exp = get_branin_experiment_with_timestamp_map_metric() save_experiment(exp) - generator_run = Models.SOBOL(search_space=exp.search_space).gen(n=1) + generator_run = Generators.SOBOL(search_space=exp.search_space).gen(n=1) trial = exp.new_trial(generator_run=generator_run) exp.attach_data(trial.run().fetch_data()) save_or_update_trials( @@ -1568,7 +1568,7 @@ def test_EncodeDecodeGenerationStrategy(self) -> None: # well. generation_strategy._unset_non_persistent_state_fields() self.assertEqual(generation_strategy, new_generation_strategy) - self.assertIsInstance(new_generation_strategy._steps[0].model, Models) + self.assertIsInstance(new_generation_strategy._steps[0].model, Generators) self.assertEqual(len(new_generation_strategy._generator_runs), 2) self.assertEqual( none_throws(new_generation_strategy._experiment)._name, experiment._name @@ -1626,7 +1626,8 @@ def test_EncodeDecodeGenerationNodeGSWithAdvancedSettings(self) -> None: generation_strategy._unset_non_persistent_state_fields() self.assertEqual(generation_strategy, new_generation_strategy) self.assertIsInstance( - new_generation_strategy._nodes[0].model_spec_to_gen_from.model_enum, Models + new_generation_strategy._nodes[0].model_spec_to_gen_from.model_enum, + Generators, ) self.assertEqual(len(new_generation_strategy._generator_runs), 2) self.assertEqual( @@ -1683,7 +1684,8 @@ def test_EncodeDecodeGenerationNodeBasedGenerationStrategy(self) -> None: generation_strategy._unset_non_persistent_state_fields() self.assertEqual(generation_strategy, new_generation_strategy) self.assertIsInstance( - new_generation_strategy._nodes[0].model_spec_to_gen_from.model_enum, Models + new_generation_strategy._nodes[0].model_spec_to_gen_from.model_enum, + Generators, ) self.assertEqual(len(new_generation_strategy._generator_runs), 2) self.assertEqual( @@ -1731,7 +1733,7 @@ def test_EncodeDecodeGenerationStrategyReducedState(self) -> None: self.assertEqual(new_generation_strategy, generation_strategy) # Model should be successfully restored in generation strategy even with # the reduced state. - self.assertIsInstance(new_generation_strategy._steps[0].model, Models) + self.assertIsInstance(new_generation_strategy._steps[0].model, Generators) self.assertEqual(len(new_generation_strategy._generator_runs), 2) self.assertEqual( none_throws(new_generation_strategy._experiment)._name, experiment._name @@ -1791,7 +1793,7 @@ def test_EncodeDecodeGenerationStrategyReducedStateLoadExperiment(self) -> None: self.assertEqual(new_generation_strategy, generation_strategy) # Model should be successfully restored in generation strategy even with # the reduced state. - self.assertIsInstance(new_generation_strategy._steps[0].model, Models) + self.assertIsInstance(new_generation_strategy._steps[0].model, Generators) self.assertEqual(len(new_generation_strategy._generator_runs), 2) self.assertEqual( none_throws(new_generation_strategy._experiment)._name, experiment._name diff --git a/ax/utils/common/equality.py b/ax/utils/common/equality.py index ce4956117aa..3f8f5facdab 100644 --- a/ax/utils/common/equality.py +++ b/ax/utils/common/equality.py @@ -214,7 +214,7 @@ def object_attribute_dicts_find_unequal_fields( elif field == "_db_id": equal = skip_db_id_check or one_val == other_val elif field == "_model": - # TODO[T52643706]: replace with per-`ModelBridge` method like + # TODO[T52643706]: replace with per-`Adapter` method like # `equivalent_models`, to compare models more meaningfully. if not hasattr(one_val, "model") or not hasattr(other_val, "model"): equal = not hasattr(other_val, "model") and not hasattr( diff --git a/ax/utils/sensitivity/sobol_measures.py b/ax/utils/sensitivity/sobol_measures.py index 65126cfe4ff..aa22a6707dc 100644 --- a/ax/utils/sensitivity/sobol_measures.py +++ b/ax/utils/sensitivity/sobol_measures.py @@ -12,9 +12,11 @@ import numpy.typing as npt import torch -from ax.modelbridge.torch import TorchModelBridge -from ax.models.torch.botorch import BotorchModel -from ax.models.torch.botorch_modular.model import BoTorchModel as ModularBoTorchModel +from ax.modelbridge.torch import TorchAdapter +from ax.models.torch.botorch import BotorchGenerator +from ax.models.torch.botorch_modular.model import ( + BoTorchGenerator as ModularBoTorchGenerator, +) from ax.utils.sensitivity.derivative_measures import ( compute_derivatives_from_model_list, sample_discrete_parameters, @@ -834,14 +836,14 @@ def compute_sobol_indices_from_model_list( def ax_parameter_sens( - model_bridge: TorchModelBridge, + model_bridge: TorchAdapter, metrics: list[str] | None = None, order: str = "first", signed: bool = True, **sobol_kwargs: Any, ) -> dict[str, dict[str, npt.NDArray]]: """ - Compute sensitivity for all metrics on an TorchModelBridge. + Compute sensitivity for all metrics on an TorchAdapter. Sobol measures are always positive regardless of the direction in which the parameter influences f. If `signed` is set to True, then the Sobol measure for each @@ -851,7 +853,7 @@ def ax_parameter_sens( will have values close to 0. Args: - model_bridge: A ModelBridge object with models that were fit. + model_bridge: A Adapter object with models that were fit. metrics: The names of the metrics and outcomes for which to compute sensitivities. This should preferably be metrics with a good model fit. Defaults to model_bridge.outcomes. @@ -938,32 +940,35 @@ def ax_parameter_sens( def _get_torch_model( - model_bridge: TorchModelBridge, -) -> BotorchModel | ModularBoTorchModel: - """Returns the TorchModel of the model_bridge, if it is a type that stores - SearchSpaceDigest during model fitting. At this point, this is BotorchModel, and - ModularBoTorchModel. + model_bridge: TorchAdapter, +) -> BotorchGenerator | ModularBoTorchGenerator: + """Returns the TorchGenerator of the model_bridge, if it is a type that stores + SearchSpaceDigest during model fitting. At this point, this is BotorchGenerator, and + ModularBoTorchGenerator. """ - if not isinstance(model_bridge, TorchModelBridge): + if not isinstance(model_bridge, TorchAdapter): raise NotImplementedError( - f"{type(model_bridge)=}, but only TorchModelBridge is supported." + f"{type(model_bridge)=}, but only TorchAdapter is supported." ) - model = model_bridge.model # should be of type TorchModel - if not (isinstance(model, BotorchModel) or isinstance(model, ModularBoTorchModel)): + model = model_bridge.model # should be of typeTorchGenerator + if not ( + isinstance(model, BotorchGenerator) + or isinstance(model, ModularBoTorchGenerator) + ): raise NotImplementedError( f"{type(model_bridge.model)=}, but only " - "Union[BotorchModel, ModularBoTorchModel] is supported." + "Union[BotorchGenerator, ModularBoTorchGenerator] is supported." ) return model def _get_model_per_metric( - model: BotorchModel | ModularBoTorchModel, metrics: list[str] + model: BotorchGenerator | ModularBoTorchGenerator, metrics: list[str] ) -> list[Model]: - """For a given TorchModel model, returns a list of botorch.models.model.Model + """For a given TorchGenerator model, returns a list of botorch.models.model.Model objects corresponding to - and in the same order as - the given metrics. """ - if isinstance(model, BotorchModel): + if isinstance(model, BotorchGenerator): # guaranteed not to be None after accessing search_space_digest gp_model = model.model model_idx = [model.metric_names.index(m) for m in metrics] @@ -975,14 +980,14 @@ def _get_model_per_metric( "but only ModelList is supported." ) return [gp_model.models[i] for i in model_idx] - else: # isinstance(model, ModularBoTorchModel): + else: # isinstance(model, ModularBoTorchGenerator): surrogate = model.surrogate outcomes = surrogate.outcomes model_list = [] for m in metrics: # for each metric, find a corresponding surrogate i = outcomes.index(m) metric_model = surrogate.model - # since model is a ModularBoTorchModel, metric_model will be a + # since model is a ModularBoTorchGenerator, metric_model will be a # `botorch.models.model.Model` object, which have the `num_outputs` # property and `subset_outputs` method. if metric_model.num_outputs > 1: # subset to relevant output diff --git a/ax/utils/sensitivity/tests/test_sensitivity.py b/ax/utils/sensitivity/tests/test_sensitivity.py index e6c3baf7e80..435e4f28823 100644 --- a/ax/utils/sensitivity/tests/test_sensitivity.py +++ b/ax/utils/sensitivity/tests/test_sensitivity.py @@ -12,10 +12,10 @@ from unittest.mock import patch, PropertyMock import torch -from ax.modelbridge.base import ModelBridge -from ax.modelbridge.registry import Models -from ax.modelbridge.torch import TorchModelBridge -from ax.models.torch.botorch import BotorchModel +from ax.modelbridge.base import Adapter +from ax.modelbridge.registry import Generators +from ax.modelbridge.torch import TorchAdapter +from ax.models.torch.botorch import BotorchGenerator from ax.utils.common.random import set_rng_seed from ax.utils.common.testutils import TestCase from ax.utils.sensitivity.derivative_gp import posterior_derivative @@ -44,21 +44,21 @@ @mock_botorch_optimize -def get_modelbridge(modular: bool = False, saasbo: bool = False) -> ModelBridge: +def get_modelbridge(modular: bool = False, saasbo: bool = False) -> Adapter: exp = get_branin_experiment(with_batch=True) exp.trials[0].run() if modular: - return Models.BOTORCH_MODULAR( + return Generators.BOTORCH_MODULAR( experiment=exp, data=exp.fetch_data(), ) if saasbo: - return Models.SAASBO( + return Generators.SAASBO( experiment=exp, data=exp.fetch_data(), ) else: - return Models.LEGACY_BOTORCH(experiment=exp, data=exp.fetch_data()) + return Generators.LEGACY_BOTORCH(experiment=exp, data=exp.fetch_data()) class SensitivityAnalysisTest(TestCase): @@ -241,12 +241,12 @@ def test_SobolGPMean_SAASBO(self) -> None: ) # testing ax sensitivity utils - # model_bridge = cast(TorchModelBridge, get_modelbridge()) + # model_bridge = cast(TorchAdapter, get_modelbridge()) for modular in [False, True]: - model_bridge = cast(TorchModelBridge, get_modelbridge(modular=modular)) + model_bridge = cast(TorchAdapter, get_modelbridge(modular=modular)) with self.assertRaisesRegex( NotImplementedError, - "but only TorchModelBridge is supported", + "but only TorchAdapter is supported", ): # pyre-ignore ax_parameter_sens(1, model_bridge.outcomes) @@ -254,11 +254,14 @@ def test_SobolGPMean_SAASBO(self) -> None: with patch.object(model_bridge, "model", return_value=None): with self.assertRaisesRegex( NotImplementedError, - r"but only Union\[BotorchModel, ModularBoTorchModel\] is supported", + ( + r"but only Union\[BotorchGenerator, ModularBoTorchGenerator\] " + r"is supported" + ), ): ax_parameter_sens(model_bridge, model_bridge.outcomes) - torch_model = cast(BotorchModel, model_bridge.model) + torch_model = cast(BotorchGenerator, model_bridge.model) if not modular: with self.assertRaisesRegex( NotImplementedError, @@ -273,7 +276,7 @@ def test_SobolGPMean_SAASBO(self) -> None: mock.return_value = 2 ax_parameter_sens(model_bridge, model_bridge.outcomes) - # since only ModelList is supported for BotorchModel: + # since only ModelList is supported for BotorchGenerator: gpytorch_model = ModelListGP(cast(GPyTorchModel, torch_model.model)) torch_model.model = gpytorch_model diff --git a/ax/utils/testing/benchmark_stubs.py b/ax/utils/testing/benchmark_stubs.py index e50dcfa0e3c..faf5a0053ae 100644 --- a/ax/utils/testing/benchmark_stubs.py +++ b/ax/utils/testing/benchmark_stubs.py @@ -40,8 +40,8 @@ from ax.early_stopping.strategies.base import BaseEarlyStoppingStrategy from ax.modelbridge.external_generation_node import ExternalGenerationNode from ax.modelbridge.generation_strategy import GenerationStrategy -from ax.modelbridge.torch import TorchModelBridge -from ax.models.torch.botorch_modular.model import BoTorchModel +from ax.modelbridge.torch import TorchAdapter +from ax.models.torch.botorch_modular.model import BoTorchGenerator from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.utils.testing.core_stubs import ( get_branin_experiment, @@ -92,10 +92,10 @@ def get_multi_objective_benchmark_problem( def get_soo_surrogate_test_function(lazy: bool = True) -> SurrogateTestFunction: experiment = get_branin_experiment(with_completed_trial=True) - surrogate = TorchModelBridge( + surrogate = TorchAdapter( experiment=experiment, search_space=experiment.search_space, - model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)), + model=BoTorchGenerator(surrogate=Surrogate(botorch_model_class=SingleTaskGP)), data=experiment.lookup_data(), transforms=[], ) @@ -134,10 +134,10 @@ def get_soo_surrogate() -> BenchmarkProblem: def get_moo_surrogate() -> BenchmarkProblem: experiment = get_branin_experiment_with_multi_objective(with_completed_trial=True) - surrogate = TorchModelBridge( + surrogate = TorchAdapter( experiment=experiment, search_space=experiment.search_space, - model=BoTorchModel(surrogate=Surrogate(botorch_model_class=SingleTaskGP)), + model=BoTorchGenerator(surrogate=Surrogate(botorch_model_class=SingleTaskGP)), data=experiment.lookup_data(), transforms=[], ) diff --git a/ax/utils/testing/core_stubs.py b/ax/utils/testing/core_stubs.py index 85f9201c354..922693619d1 100644 --- a/ax/utils/testing/core_stubs.py +++ b/ax/utils/testing/core_stubs.py @@ -84,13 +84,13 @@ from ax.metrics.branin_map import BraninTimestampMapMetric from ax.metrics.factorial import FactorialMetric from ax.metrics.hartmann6 import Hartmann6Metric -from ax.modelbridge.factory import Cont_X_trans, get_factorial, get_sobol, Models +from ax.modelbridge.factory import Cont_X_trans, Generators, get_factorial, get_sobol from ax.modelbridge.generation_node_input_constructors import ( InputConstructorPurpose, NodeInputConstructors, ) from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy -from ax.modelbridge.model_spec import ModelSpec +from ax.modelbridge.model_spec import GeneratorSpec from ax.modelbridge.transition_criterion import ( MaxGenerationParallelism, MinTrials, @@ -98,7 +98,7 @@ ) from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.kernels import ScaleMaternKernel -from ax.models.torch.botorch_modular.model import BoTorchModel +from ax.models.torch.botorch_modular.model import BoTorchGenerator from ax.models.torch.botorch_modular.sebo import SEBOAcquisition from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec from ax.models.winsorization_config import WinsorizationConfig @@ -2238,30 +2238,30 @@ def get_model_predictions_per_arm() -> dict[str, TModelPredictArm]: ############################## -def get_botorch_model() -> BoTorchModel: - return BoTorchModel( +def get_botorch_model() -> BoTorchGenerator: + return BoTorchGenerator( surrogate=get_surrogate(), acquisition_class=get_acquisition_type() ) -def get_botorch_model_with_default_acquisition_class() -> BoTorchModel: - return BoTorchModel( +def get_botorch_model_with_default_acquisition_class() -> BoTorchGenerator: + return BoTorchGenerator( surrogate=get_surrogate(), acquisition_class=Acquisition, botorch_acqf_class=get_acquisition_function_type(), ) -def get_botorch_model_with_surrogate_specs() -> BoTorchModel: - return BoTorchModel( +def get_botorch_model_with_surrogate_specs() -> BoTorchGenerator: + return BoTorchGenerator( surrogate_specs={ "name": SurrogateSpec(botorch_model_kwargs={"some_option": "some_value"}) } ) -def get_botorch_model_with_surrogate_spec() -> BoTorchModel: - return BoTorchModel( +def get_botorch_model_with_surrogate_spec() -> BoTorchGenerator: + return BoTorchGenerator( surrogate_spec=SurrogateSpec(botorch_model_kwargs={"some_option": "some_value"}) ) @@ -2449,13 +2449,13 @@ def get_online_sobol_mbm_generation_strategy( ], ), ] - sobol_model_spec = ModelSpec( - model_enum=Models.SOBOL, + sobol_model_spec = GeneratorSpec( + model_enum=Generators.SOBOL, model_kwargs=step_model_kwargs, model_gen_kwargs={}, ) - mbm_model_spec = ModelSpec( - model_enum=Models.BOTORCH_MODULAR, + mbm_model_spec = GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, model_kwargs=step_model_kwargs, model_gen_kwargs={}, ) diff --git a/ax/utils/testing/modeling_stubs.py b/ax/utils/testing/modeling_stubs.py index 5689c5030ce..d117afa0265 100644 --- a/ax/utils/testing/modeling_stubs.py +++ b/ax/utils/testing/modeling_stubs.py @@ -17,7 +17,7 @@ from ax.core.parameter import FixedParameter, RangeParameter from ax.core.search_space import SearchSpace from ax.exceptions.core import UserInputError -from ax.modelbridge.base import ModelBridge +from ax.modelbridge.base import Adapter from ax.modelbridge.best_model_selector import ( ReductionCriterion, SingleDiagnosticBestModelSelector, @@ -32,8 +32,8 @@ NodeInputConstructors, ) from ax.modelbridge.generation_strategy import GenerationStep, GenerationStrategy -from ax.modelbridge.model_spec import ModelSpec -from ax.modelbridge.registry import Models +from ax.modelbridge.model_spec import GeneratorSpec +from ax.modelbridge.registry import Generators from ax.modelbridge.transforms.base import Transform from ax.modelbridge.transforms.int_to_float import IntToFloat from ax.modelbridge.transforms.transform_to_new_sq import TransformToNewSQ @@ -189,8 +189,8 @@ def get_generation_strategy( ) -> GenerationStrategy: if with_generation_nodes: gs = sobol_gpei_generation_node_gs() - gs._nodes[0]._model_spec_to_gen_from = ModelSpec( - model_enum=Models.SOBOL, + gs._nodes[0]._model_spec_to_gen_from = GeneratorSpec( + model_enum=Generators.SOBOL, model_kwargs={"init_position": 3}, model_gen_kwargs={"some_gen_kwarg": "some_value"}, ) @@ -235,7 +235,7 @@ def sobol_gpei_generation_node_gs( """Returns a basic SOBOL+MBM GS using GenerationNodes for testing. Args: - with_model_selection: If True, will add a second ModelSpec in the MBM node. + with_model_selection: If True, will add a second GeneratorSpec in the MBM node. This can be used for testing model selection. """ if sum([with_auto_transition, with_unlimited_gen_mbm, with_is_SOO_transition]) > 1: @@ -298,14 +298,14 @@ def sobol_gpei_generation_node_gs( auto_mbm_criterion = [AutoTransitionAfterGen(transition_to="MBM_node")] is_SOO_mbm_criterion = [IsSingleObjective(transition_to="MBM_node")] step_model_kwargs = {"silently_filter_kwargs": True} - sobol_model_spec = ModelSpec( - model_enum=Models.SOBOL, + sobol_model_spec = GeneratorSpec( + model_enum=Generators.SOBOL, model_kwargs=step_model_kwargs, model_gen_kwargs={}, ) mbm_model_specs = [ - ModelSpec( - model_enum=Models.BOTORCH_MODULAR, + GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, model_kwargs=step_model_kwargs, model_gen_kwargs={}, ) @@ -317,7 +317,7 @@ def sobol_gpei_generation_node_gs( ) if with_model_selection: # This is just MBM with different transforms. - mbm_model_specs.append(ModelSpec(model_enum=Models.BO_MIXED)) + mbm_model_specs.append(GeneratorSpec(model_enum=Generators.BO_MIXED)) best_model_selector = SingleDiagnosticBestModelSelector( diagnostic=FISHER_EXACT_TEST_P, metric_aggregation=ReductionCriterion.MEAN, @@ -403,7 +403,7 @@ def get_sobol_MBM_MTGP_gs() -> GenerationStrategy: nodes=[ GenerationNode( node_name="Sobol", - model_specs=[ModelSpec(model_enum=Models.SOBOL)], + model_specs=[GeneratorSpec(model_enum=Generators.SOBOL)], transition_criteria=[ MinTrials( threshold=1, @@ -414,8 +414,8 @@ def get_sobol_MBM_MTGP_gs() -> GenerationStrategy: GenerationNode( node_name="MBM", model_specs=[ - ModelSpec( - model_enum=Models.BOTORCH_MODULAR, + GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, ), ], transition_criteria=[ @@ -433,8 +433,8 @@ def get_sobol_MBM_MTGP_gs() -> GenerationStrategy: GenerationNode( node_name="MTGP", model_specs=[ - ModelSpec( - model_enum=Models.ST_MTGP, + GeneratorSpec( + model_enum=Generators.ST_MTGP, ), ], ), @@ -471,7 +471,7 @@ def get_legacy_list_surrogate_generation_step_as_dict() -> dict[str, Any]: # before new multi-Surrogate Model and new Surrogate diffs D42013742 return { "__type": "GenerationStep", - "model": {"__type": "Models", "name": "BOTORCH_MODULAR"}, + "model": {"__type": "Generators", "name": "BOTORCH_MODULAR"}, "num_trials": -1, "min_trials_observed": 0, "completion_criteria": [], @@ -558,7 +558,7 @@ def get_legacy_list_surrogate_generation_step_as_dict() -> dict[str, Any]: def get_surrogate_generation_step() -> GenerationStep: return GenerationStep( - model=Models.BOTORCH_MODULAR, + model=Generators.BOTORCH_MODULAR, num_trials=-1, max_parallelism=1, model_kwargs={ @@ -691,7 +691,7 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: ModelBridge | None, + modelbridge: Adapter | None, fixed_features: ObservationFeatures | None, ) -> OptimizationConfig: return ( # pyre-ignore[7]: pyre is right, this is a hack for testing. @@ -750,7 +750,7 @@ def transform_search_space(self, search_space: SearchSpace) -> SearchSpace: def transform_optimization_config( self, optimization_config: OptimizationConfig, - modelbridge: ModelBridge | None, + modelbridge: Adapter | None, fixed_features: ObservationFeatures | None, ) -> OptimizationConfig: return ( diff --git a/ax/utils/testing/tests/test_mock.py b/ax/utils/testing/tests/test_mock.py index 2c84e03375e..5766cb16812 100644 --- a/ax/utils/testing/tests/test_mock.py +++ b/ax/utils/testing/tests/test_mock.py @@ -9,7 +9,7 @@ from unittest.mock import patch import torch -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators from ax.modelbridge.transforms.choice_encode import OrderedChoiceToIntegerRange from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_branin_experiment @@ -42,7 +42,7 @@ def test_fully_bayesian_mocks(self) -> None: experiment = get_branin_experiment(with_completed_batch=True) with patch("botorch.fit.MCMC", wraps=MCMC) as mock_mcmc: with mock_botorch_optimize_context_manager(): - Models.SAASBO(experiment=experiment, data=experiment.lookup_data()) + Generators.SAASBO(experiment=experiment, data=experiment.lookup_data()) mock_mcmc.assert_called_once() kwargs = mock_mcmc.call_args.kwargs self.assertEqual(kwargs["num_samples"], 16) @@ -57,7 +57,7 @@ def test_mixed_optimizer_mocks(self) -> None: wraps=generate_starting_points, ) as mock_gen: with mock_botorch_optimize_context_manager(): - Models.BOTORCH_MODULAR( + Generators.BOTORCH_MODULAR( experiment=experiment, data=experiment.lookup_data(), transforms=[OrderedChoiceToIntegerRange], diff --git a/ax/utils/testing/tests/test_utils.py b/ax/utils/testing/tests/test_utils.py index 4f7f7ee4a60..edf936545ac 100644 --- a/ax/utils/testing/tests/test_utils.py +++ b/ax/utils/testing/tests/test_utils.py @@ -9,8 +9,8 @@ import numpy as np import torch from ax.modelbridge.generation_strategy import GenerationNode, GenerationStrategy -from ax.modelbridge.model_spec import ModelSpec -from ax.modelbridge.registry import Models +from ax.modelbridge.model_spec import GeneratorSpec +from ax.modelbridge.registry import Generators from ax.utils.common.testutils import TestCase from ax.utils.testing.core_stubs import get_experiment_with_observations from ax.utils.testing.mock import mock_botorch_optimize @@ -64,8 +64,8 @@ def test_run_trials_with_gs(self) -> None: GenerationNode( node_name="MBM", model_specs=[ - ModelSpec( - model_enum=Models.BOTORCH_MODULAR, + GeneratorSpec( + model_enum=Generators.BOTORCH_MODULAR, ) ], ) diff --git a/docs/api.md b/docs/api.md index 973408c88a9..d5e3a79e34c 100644 --- a/docs/api.md +++ b/docs/api.md @@ -158,7 +158,7 @@ exp = Experiment( runner=MockRunner(), ) -sobol = Models.SOBOL(exp.search_space) +sobol = Generators.SOBOL(exp.search_space) for i in range(5): trial = exp.new_trial(generator_run=sobol.gen(1)) trial.run() @@ -166,7 +166,7 @@ for i in range(5): best_arm = None for i in range(15): - gpei = Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()) + gpei = Generators.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()) generator_run = gpei.gen(1) best_arm, _ = generator_run.best_arm_predictions trial = exp.new_trial(generator_run=generator_run) diff --git a/docs/glossary.md b/docs/glossary.md index 3cfa7b6eeea..39a451021a7 100644 --- a/docs/glossary.md +++ b/docs/glossary.md @@ -45,7 +45,7 @@ Algorithm that can be used to generate new points in a [search space](glossary.m ### Model bridge -Adapter for interactions with a [model](glossary.md#model) within the Ax ecosystem. [`[ModelBridge]`](https://ax.readthedocs.io/en/latest/modelbridge.html) +Adapter for interactions with a [model](glossary.md#model) within the Ax ecosystem. [`[Adapter]`](https://ax.readthedocs.io/en/latest/modelbridge.html) ### Objective diff --git a/docs/models.md b/docs/models.md index 800d88eec7e..cef513f978e 100644 --- a/docs/models.md +++ b/docs/models.md @@ -1,14 +1,14 @@ --- id: models -title: Models +title: Generators --- ## Using models in Ax -In the optimization algorithms implemented by Ax, models predict the outcomes of metrics within an experiment evaluated at a parameterization, and are used to predict metrics or suggest new parameterizations for trials. Models in Ax are created using factory functions from the [`ax.modelbridge.factory`](https://ax.readthedocs.io/en/latest/modelbridge.html#module-ax.modelbridge.factory). All of these models share a common API with [`predict()`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge.predict) to make predictions at new points and [`gen()`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge.gen) to generate new candidates to be tested. There are a variety of models available in the factory; here we describe the usage patterns for the primary model types and show how the various Ax utilities can be used with models. +In the optimization algorithms implemented by Ax, models predict the outcomes of metrics within an experiment evaluated at a parameterization, and are used to predict metrics or suggest new parameterizations for trials. Generators in Ax are created using factory functions from the [`ax.modelbridge.factory`](https://ax.readthedocs.io/en/latest/modelbridge.html#module-ax.modelbridge.factory). All of these models share a common API with [`predict()`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter.predict) to make predictions at new points and [`gen()`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter.gen) to generate new candidates to be tested. There are a variety of models available in the factory; here we describe the usage patterns for the primary model types and show how the various Ax utilities can be used with models. #### Sobol sequence -The [`get_sobol`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.factory.get_sobol) function is used to construct a model that produces a quasirandom Sobol sequence when[`gen`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge.gen) is called. This code generates a scrambled Sobol sequence of 10 points: +The [`get_sobol`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.factory.get_sobol) function is used to construct a model that produces a quasirandom Sobol sequence when[`gen`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter.gen) is called. This code generates a scrambled Sobol sequence of 10 points: ```python from ax.modelbridge.factory import get_sobol @@ -17,19 +17,19 @@ m = get_sobol(search_space) gr = m.gen(n=10) ``` -The output of [`gen`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge.gen) is a [`GeneratorRun`](https://ax.readthedocs.io/en/latest/core.html#ax.core.generator_run.GeneratorRun) object that contains the generated points, along with metadata about the generation process. The generated arms can be accessed at [`GeneratorRun.arms`](https://ax.readthedocs.io/en/latest/core.html#ax.core.generator_run.GeneratorRun.arms). +The output of [`gen`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter.gen) is a [`GeneratorRun`](https://ax.readthedocs.io/en/latest/core.html#ax.core.generator_run.GeneratorRun) object that contains the generated points, along with metadata about the generation process. The generated arms can be accessed at [`GeneratorRun.arms`](https://ax.readthedocs.io/en/latest/core.html#ax.core.generator_run.GeneratorRun.arms). Additional arguments can be passed to [`get_sobol`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.factory.get_sobol) such as `scramble=False` to disable scrambling, and `seed` to set a seed (see [model API](https://ax.readthedocs.io/en/latest/models.html#ax.models.random.sobol.SobolGenerator)). -Sobol sequences are typically used to select initialization points, and this model does not implement [`predict`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge.predict). It can be used on search spaces with any combination of discrete and continuous parameters. +Sobol sequences are typically used to select initialization points, and this model does not implement [`predict`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter.predict). It can be used on search spaces with any combination of discrete and continuous parameters. #### Gaussian Process with EI -Gaussian Processes (GPs) are used for [Bayesian Optimization](bayesopt.md) in Ax, the [`Models.BOTORCH_MODULAR`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.registry.Models) registry entry constructs a modular BoTorch model that fits a GP to the data, and uses qLogNEI (or qLogNEHVI for MOO) acquisition function to generate new points on calls to [`gen`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge.gen). This code fits a GP and generates a batch of 5 points which maximizes EI: +Gaussian Processes (GPs) are used for [Bayesian Optimization](bayesopt.md) in Ax, the [`Generators.BOTORCH_MODULAR`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.registry.Generators) registry entry constructs a modular BoTorch model that fits a GP to the data, and uses qLogNEI (or qLogNEHVI for MOO) acquisition function to generate new points on calls to [`gen`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter.gen). This code fits a GP and generates a batch of 5 points which maximizes EI: ```Python -from ax.modelbridge.registry import Models +from ax.modelbridge.registry import Generators -m = Models.BOTORCH_MODULAR(experiment=experiment, data=data) +m = Generators.BOTORCH_MODULAR(experiment=experiment, data=data) gr = m.gen(n=5, optimization_config=optimization_config) ``` @@ -45,9 +45,9 @@ obs_feats = [ f, cov = m.predict(obs_feats) ``` -The output of [`predict`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge.predict) is the mean estimate of each metric and the covariance (across metrics) for each point. +The output of [`predict`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter.predict) is the mean estimate of each metric and the covariance (across metrics) for each point. -All Ax models that implement [`predict`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge.predict) can be used with the built-in plotting utilities, which can produce plots of model predictions on 1-d or 2-d slices of the parameter space: +All Ax models that implement [`predict`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter.predict) can be used with the built-in plotting utilities, which can produce plots of model predictions on 1-d or 2-d slices of the parameter space: ```python from ax.plot.slice import plot_slice @@ -96,7 +96,7 @@ render(interact_cross_validation(cv))
-If the model fits the data well, the values will lie along the diagonal. Poor GP fits tend to produce cross validation plots that are flat with high predictive uncertainty - such fits are unlikely to produce good candidates in [`gen`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge.gen). +If the model fits the data well, the values will lie along the diagonal. Poor GP fits tend to produce cross validation plots that are flat with high predictive uncertainty - such fits are unlikely to produce good candidates in [`gen`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter.gen). By default, this model will apply a number of transformations to the feature space, such as one-hot encoding of [`ChoiceParameters`](https://ax.readthedocs.io/en/latest/core.html#ax.core.parameter.ChoiceParameter) and log transformation of [`RangeParameters`](https://ax.readthedocs.io/en/latest/core.html#ax.core.parameter.RangeParameter) which have `log_scale` set to `True`. Transforms are also applied to the observed outcomes, such as standardizing the data for each metric. See [the section below on Transforms](/docs/models#transforms) for a description of the default transforms, and how new transforms can be implemented and included. @@ -127,7 +127,7 @@ gr = m.gen(n=10, optimization_config=optimization_config) The arms and their corresponding weights can be accessed as `gr.arm_weights`. -As with the GP, we can use [`predict`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge.predict) to evaluate the model at points of our choosing. However, because this is a purely in-sample model, those points should correspond to arms that were in the data. The model prediction will return the estimate at that point after applying the empirical Bayes shrinkage: +As with the GP, we can use [`predict`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter.predict) to evaluate the model at points of our choosing. However, because this is a purely in-sample model, those points should correspond to arms that were in the data. The model prediction will return the estimate at that point after applying the empirical Bayes shrinkage: ```python f, cov = m.predict([ObservationFeatures(parameters={'x1': 3.14, 'x2': 2.72})]) @@ -158,25 +158,25 @@ Like the Sobol sequence, the factorial model is only used to generate points and ## Deeper dive: organization of the modeling stack -Ax uses a bridge design to provide a unified interface for models, while still allowing for modularity in how different types of models are implemented. The modeling stack consists of two layers: the [`ModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge) and the Model. +Ax uses a bridge design to provide a unified interface for models, while still allowing for modularity in how different types of models are implemented. The modeling stack consists of two layers: the [`Adapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter) and the Model. -The [`ModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge) is the object that is directly used in Ax: model factories return [`ModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge) objects, and plotting and cross validation tools operate on a [`ModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge). The [`ModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge) defines a unified API for all of the models used in Ax via methods like [`predict`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge.predict) and [`gen`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge.gen). Internally, it is responsible for transforming Ax objects like [`Arm`](https://ax.readthedocs.io/en/latest/core.html#ax.core.arm.Arm) and [`Data`](https://ax.readthedocs.io/en/latest/core.html#ax.core.data.Data) into objects which are then consumed downstream by a Model. +The [`Adapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter) is the object that is directly used in Ax: model factories return [`Adapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter) objects, and plotting and cross validation tools operate on a [`Adapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter). The [`Adapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter) defines a unified API for all of the models used in Ax via methods like [`predict`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter.predict) and [`gen`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter.gen). Internally, it is responsible for transforming Ax objects like [`Arm`](https://ax.readthedocs.io/en/latest/core.html#ax.core.arm.Arm) and [`Data`](https://ax.readthedocs.io/en/latest/core.html#ax.core.data.Data) into objects which are then consumed downstream by a Model. -Model objects are only used in Ax via a [`ModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge). Each Model object defines an API which does not use Ax objects, allowing for modularity of different model types and making it easy to implement new models. For example, the TorchModel defines an API for a model that operates on torch tensors. There is a 1-to-1 link between [`ModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge) objects and Model objects. For instance, the TorchModelBridge takes in Ax objects, converts them to torch tensors, and sends them along to the TorchModel. Similar pairings exist for all of the different model types: +Model objects are only used in Ax via a [`Adapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter). Each Model object defines an API which does not use Ax objects, allowing for modularity of different model types and making it easy to implement new models. For example, the TorchGenerator defines an API for a model that operates on torch tensors. There is a 1-to-1 link between [`Adapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter) objects and Model objects. For instance, the TorchAdapter takes in Ax objects, converts them to torch tensors, and sends them along to the TorchGenerator. Similar pairings exist for all of the different model types: -| ModelBridge | Model | Example implementation | | +| Adapter | Model | Example implementation | | | -------------------------------------------------------------------------------------- | ---------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------ | - | -| [`TorchModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#module-ax.modelbridge.torch) | [`TorchModel`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel) | [`BotorchModel`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch.botorch.BotorchModel) | | -| [`DiscreteModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#module-ax.modelbridge.discrete) | [`DiscreteModel`](https://ax.readthedocs.io/en/latest/models.html#ax.models.discrete_base.DiscreteModel) | [`ThompsonSampler`](https://ax.readthedocs.io/en/latest/models.html#ax.models.discrete.thompson.ThompsonSampler) | | -| [`RandomModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#module-ax.modelbridge.random) | [`RandomModel`](https://ax.readthedocs.io/en/latest/models.html#ax.models.random.base.RandomModel) | [`SobolGenerator`](https://ax.readthedocs.io/en/latest/models.html#ax.models.random.sobol.SobolGenerator) | | +| [`TorchAdapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#module-ax.modelbridge.torch) | [`TorchGenerator`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel) | [`BotorchGenerator`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch.botorch.BotorchModel) | | +| [`DiscreteAdapter](https://ax.readthedocs.io/en/latest/modelbridge.html#module-ax.modelbridge.discrete) | [`DiscreteGenerator`](https://ax.readthedocs.io/en/latest/models.html#ax.models.discrete_base.DiscreteModel) | [`ThompsonSampler`](https://ax.readthedocs.io/en/latest/models.html#ax.models.discrete.thompson.ThompsonSampler) | | +| [`RandomAdapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#module-ax.modelbridge.random) | [`RandomGenerator`](https://ax.readthedocs.io/en/latest/models.html#ax.models.random.base.RandomModel) | [`SobolGenerator`](https://ax.readthedocs.io/en/latest/models.html#ax.models.random.sobol.SobolGenerator) | | -This structure allows for different models like the GP in BotorchModel and the Random Forest in RandomForest to share an interface and use common plotting tools at the level of the ModelBridge, while each is implemented using its own torch or numpy structures. +This structure allows for different models like the GP in BotorchGenerator and the Random Forest in RandomForest to share an interface and use common plotting tools at the level of the Adapter, while each is implemented using its own torch or numpy structures. -The primary role of the [`ModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge) is to act as a transformation layer. This includes transformations to the data, search space, and optimization config such as standardization and log transforms, as well as the final transform from Ax objects into the objects consumed by the Model. We now describe how transforms are implemented and used in the ModelBridge. +The primary role of the [`Adapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter) is to act as a transformation layer. This includes transformations to the data, search space, and optimization config such as standardization and log transforms, as well as the final transform from Ax objects into the objects consumed by the Model. We now describe how transforms are implemented and used in the Adapter. ## Transforms -The transformations in the [`ModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge) are done by chaining together a set of individual Transform objects. For continuous space models obtained via factory functions ([`get_sobol`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.factory.get_sobol) and [`Models.BOTORCH_MODULAR`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.registry.Models)), the following transforms will be applied by default, in this sequence: +The transformations in the [`Adapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter) are done by chaining together a set of individual Transform objects. For continuous space models obtained via factory functions ([`get_sobol`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.factory.get_sobol) and [`Generators.BOTORCH_MODULAR`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.registry.Generators)), the following transforms will be applied by default, in this sequence: * [`RemoveFixed`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.transforms.remove_fixed.RemoveFixed): Remove [`FixedParameters`](https://ax.readthedocs.io/en/latest/core.html#ax.core.parameter.FixedParameter) from the search space. * [`OrderedChoiceEncode`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.transforms.choice_encode.OrderedChoiceEncode): [`ChoiceParameters`](https://ax.readthedocs.io/en/latest/core.html#ax.core.parameter.ChoiceParameter) with `is_ordered` set to `True` are encoded as a sequence of integers. * [`OneHot`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.transforms.one_hot.OneHot): [`ChoiceParameters`](https://ax.readthedocs.io/en/latest/core.html#ax.core.parameter.ChoiceParameter) with `is_ordered` set to `False` are one-hot encoded. @@ -186,11 +186,11 @@ The transformations in the [`ModelBridge`](https://ax.readthedocs.io/en/latest/m * [`Derelativize`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.transforms.derelativize.Derelativize): Constraints relative to status quo are converted to constraints on raw values. * [`StandardizeY`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.transforms.standardize_y.StandardizeY): The Y values for each metric are standardized (subtract mean, divide by standard deviation). -Each transform defines both a forward and backwards transform. Arm parameters are passed through the forward transform before being sent along to the Model. The Model works entirely in the transformed space, and when new candidates are generated, they are passed through all of the backwards transforms so the [`ModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge) returns points in the original space. +Each transform defines both a forward and backwards transform. Arm parameters are passed through the forward transform before being sent along to the Model. The Model works entirely in the transformed space, and when new candidates are generated, they are passed through all of the backwards transforms so the [`Adapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter) returns points in the original space. New transforms can be implemented by creating a subclass of [`Transform`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.transforms.base.Transform), which defines the interface for all transforms. There are separate methods for transforming the search space, optimization config, observation features, and observation data. Transforms that operate on only some aspects of the problem do not need to implement all methods, for instance, [`Log`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.transforms.log.Log) implements only [`transform_observation_features`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.transforms.log.Log.transform_observation_features) (to log transform the parameters), [`transform_search_space`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.transforms.log.Log.transform_search_space) (to log transform the search space bounds), and [`untransform_observation_features`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.transforms.log.Log.untransform_observation_features) (to apply the inverse transform). -The (ordered) list of transforms to apply is an input to the ModelBridge, and so can easily be altered to add new transforms. It is important that transforms be applied in the right order. For instance, `StandardizeY` and `Winsorize` both transform the observed metric values. Applying them in the order `[StandardizeY, Winsorize]` could produce very different results than `[Winsorize, StandardizeY]`. In the former case, outliers would have already been included in the standardization (a procedure sensitive to outliers), and so the second approach that winsorizes first is preferred. +The (ordered) list of transforms to apply is an input to the Adapter, and so can easily be altered to add new transforms. It is important that transforms be applied in the right order. For instance, `StandardizeY` and `Winsorize` both transform the observed metric values. Applying them in the order `[StandardizeY, Winsorize]` could produce very different results than `[Winsorize, StandardizeY]`. In the former case, outliers would have already been included in the standardization (a procedure sensitive to outliers), and so the second approach that winsorizes first is preferred. See [the API reference](https://ax.readthedocs.io/en/latest/modelbridge.html#transforms) for the full collection of implemented transforms. @@ -200,13 +200,13 @@ The structure of the modeling stack makes it easy to implement new models and us ### Using an existing Model interface -The easiest way to implement a new model is if it can be adapted to one of the existing Model interfaces: ([`TorchModel`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel), [`DiscreteModel`](https://ax.readthedocs.io/en/latest/models.html#ax.models.discrete_base.DiscreteModel), or [`RandomModel`](https://ax.readthedocs.io/en/latest/models.html#ax.models.random.base.RandomModel)). The class definition provides the interface for each of the methods that should be implemented in order for Ax to be able to fully use the new model. Note however that not all methods must need be implemented to use some Ax functionality. For instance, an implementation of [`TorchModel`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel) that implements only [`fit`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel.fit) and [`predict`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel.predict) can be used to fit data and make plots in Ax; however, it will not be able to generate new candidates (requires implementing [`gen`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel.gen)) or be used with Ax's cross validation utility (requires implementing [`cross_validate`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel.cross_validate)). +The easiest way to implement a new model is if it can be adapted to one of the existing Model interfaces: ([`TorchModel`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel), [`DiscreteGenerator`](https://ax.readthedocs.io/en/latest/models.html#ax.models.discrete_base.DiscreteGenerator), or [`RandomGenerator`](https://ax.readthedocs.io/en/latest/models.html#ax.models.random.base.RandomGenerator)). The class definition provides the interface for each of the methods that should be implemented in order for Ax to be able to fully use the new model. Note however that not all methods must need be implemented to use some Ax functionality. For instance, an implementation of [`TorchModel`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel) that implements only [`fit`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel.fit) and [`predict`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel.predict) can be used to fit data and make plots in Ax; however, it will not be able to generate new candidates (requires implementing [`gen`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel.gen)) or be used with Ax's cross validation utility (requires implementing [`cross_validate`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel.cross_validate)). -Once the new model has been implemented, it can be used in Ax with the corresponding [`ModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge) from the table above. For instance, suppose a new torch-based model was implemented as a subclass of [`TorchModel`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel). We can use that model in Ax like: +Once the new model has been implemented, it can be used in Ax with the corresponding [`Adapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter) from the table above. For instance, suppose a new torch-based model was implemented as a subclass of [`TorchModel`](https://ax.readthedocs.io/en/latest/models.html#ax.models.torch_base.TorchModel). We can use that model in Ax like: ```python new_model_obj = NewModel(init_args) # An instance of the new model class -m = TorchModelBridge( +m = TorchAdapter( experiment=experiment, search_space=search_space, data=data, @@ -215,11 +215,11 @@ m = TorchModelBridge( ) ``` -The [`ModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge) object `m` can then be used with plotting and cross validation utilities exactly the same way as the built-in models. +The [`Adapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter) object `m` can then be used with plotting and cross validation utilities exactly the same way as the built-in models. ### Creating a new Model interface -If none of the existing Model interfaces work are suitable for the new model type, then a new interface will have to be created. This involves two steps: creating the new model interface and creating the new model bridge. The new model bridge must be a subclass of [`ModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge) that implements `ModelBridge._fit`, `ModelBridge._predict`, `ModelBridge._gen`, and `ModelBridge._cross_validate`. The implementation of each of these methods will transform the Ax objects in the inputs into objects required for the interface with the new model type. The model bridge will then call out to the new model interface to do the actual modeling work. All of the ModelBridge/Model pairs in the table above provide examples of how this interface can be defined. The main key is that the inputs on the [`ModelBridge`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.ModelBridge) side are fixed, but those inputs can then be transformed in whatever way is desired for the downstream Model interface to be that which is most convenient for implementing the model. +If none of the existing Model interfaces work are suitable for the new model type, then a new interface will have to be created. This involves two steps: creating the new model interface and creating the new model bridge. The new model bridge must be a subclass of [`Adapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter) that implements `Adapter._fit`, `Adapter._predict`, `Adapter._gen`, and `Adapter._cross_validate`. The implementation of each of these methods will transform the Ax objects in the inputs into objects required for the interface with the new model type. The model bridge will then call out to the new model interface to do the actual modeling work. All of the Adapter/Model pairs in the table above provide examples of how this interface can be defined. The main key is that the inputs on the [`Adapter`](https://ax.readthedocs.io/en/latest/modelbridge.html#ax.modelbridge.base.Adapter) side are fixed, but those inputs can then be transformed in whatever way is desired for the downstream Model interface to be that which is most convenient for implementing the model. diff --git a/docs/trial-evaluation.md b/docs/trial-evaluation.md index 1b64cd7bf7e..5a9960a747a 100644 --- a/docs/trial-evaluation.md +++ b/docs/trial-evaluation.md @@ -114,14 +114,14 @@ and `mark_complete` methods. ```python ... -sobol = Models.SOBOL(exp.search_space) +sobol = Generators.SOBOL(exp.search_space) for i in range(5): trial = exp.new_trial(generator_run=sobol.gen(1)) trial.run() trial.mark_completed() for i in range(15): - gpei = Models.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()) + gpei = Generators.BOTORCH_MODULAR(experiment=exp, data=exp.fetch_data()) generator_run = gpei.gen(1) trial = exp.new_trial(generator_run=generator_run) trial.run() diff --git a/sphinx/source/modelbridge.rst b/sphinx/source/modelbridge.rst index c35831f380a..9c647a4caf1 100644 --- a/sphinx/source/modelbridge.rst +++ b/sphinx/source/modelbridge.rst @@ -58,7 +58,7 @@ Factory :undoc-members: :show-inheritance: -ModelSpec +GeneratorSpec ~~~~~~~~~ .. automodule:: ax.modelbridge.model_spec :members: diff --git a/sphinx/source/models.rst b/sphinx/source/models.rst index 113632d5bfe..e7cb41d595e 100644 --- a/sphinx/source/models.rst +++ b/sphinx/source/models.rst @@ -8,7 +8,7 @@ ax.models .. currentmodule:: ax.models -Base Models & Utilities +Base Generators & Utilities ----------------------- ax.models.base @@ -60,7 +60,7 @@ ax.models.winsorization\_config module :show-inheritance: -Discrete Models +Discrete Generators --------------- ax.models.discrete.eb\_thompson module @@ -104,7 +104,7 @@ ax.models.discrete.ashr\_utils module :show-inheritance: -Random Models +Random Generators ------------- ax.models.random.base module @@ -131,7 +131,7 @@ ax.models.random.sobol module :undoc-members: :show-inheritance: -Torch Models & Utilities +Torch Generators & Utilities ------------------------ ax.models.torch.botorch module diff --git a/sphinx/source/preview.rst b/sphinx/source/preview.rst index 333876c2dc8..fcd1deb41e4 100644 --- a/sphinx/source/preview.rst +++ b/sphinx/source/preview.rst @@ -79,7 +79,7 @@ From String :show-inheritance: -ModelBridge +Adapter ~~~~~~~~~~~ .. automodule:: ax.preview.modelbridge diff --git a/tutorials/external_generation_node/external_generation_node.ipynb b/tutorials/external_generation_node/external_generation_node.ipynb index 2e9b6690937..c580ad891ad 100644 --- a/tutorials/external_generation_node/external_generation_node.ipynb +++ b/tutorials/external_generation_node/external_generation_node.ipynb @@ -63,8 +63,8 @@ "from ax.modelbridge.external_generation_node import ExternalGenerationNode\n", "from ax.modelbridge.generation_node import GenerationNode\n", "from ax.modelbridge.generation_strategy import GenerationStrategy\n", - "from ax.modelbridge.model_spec import ModelSpec\n", - "from ax.modelbridge.registry import Models\n", + "from ax.modelbridge.model_spec import GeneratorSpec\n", + "from ax.modelbridge.registry import Generators\n", "from ax.modelbridge.transition_criterion import MaxTrials\n", "from ax.plot.trace import plot_objective_value_vs_trial_index\n", "from ax.service.ax_client import AxClient, ObjectiveProperties\n", @@ -233,10 +233,10 @@ " nodes=[\n", " GenerationNode(\n", " node_name=\"Sobol\",\n", - " model_specs=[ModelSpec(Models.SOBOL)],\n", + " model_specs=[GeneratorSpec(Generators.SOBOL)],\n", " transition_criteria=[\n", " MaxTrials(\n", - " # This specifies the maximum number of trials to generate from this node, \n", + " # This specifies the maximum number of trials to generate from this node,\n", " # and the next node in the strategy.\n", " threshold=5,\n", " block_transition_if_unmet=True,\n", @@ -403,7 +403,5 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3" } - }, - "nbformat": 4, - "nbformat_minor": 2 + } } diff --git a/tutorials/factorial/factorial.ipynb b/tutorials/factorial/factorial.ipynb index 29f54eff6bc..a67cf1a3e97 100644 --- a/tutorials/factorial/factorial.ipynb +++ b/tutorials/factorial/factorial.ipynb @@ -60,7 +60,7 @@ "from ax import (\n", " Arm,\n", " ChoiceParameter,\n", - " Models,\n", + " Generators,\n", " ParameterType,\n", " SearchSpace,\n", " Experiment,\n", @@ -324,7 +324,7 @@ }, "outputs": [], "source": [ - "factorial = Models.FACTORIAL(search_space=exp.search_space)\n", + "factorial = Generators.FACTORIAL(search_space=exp.search_space)\n", "factorial_run = factorial.gen(\n", " n=-1\n", ") # Number of arms to generate is derived from the search space.\n", @@ -425,7 +425,7 @@ " print(f\"Running trial {i+1}...\")\n", " trial.run()\n", " trial.mark_completed()\n", - " thompson = Models.THOMPSON(experiment=exp, data=trial.fetch_data(), min_weight=0.01)\n", + " thompson = Generators.THOMPSON(experiment=exp, data=trial.fetch_data(), min_weight=0.01)\n", " models.append(thompson)\n", " thompson_run = thompson.gen(n=-1)\n", " trial = exp.new_batch_trial(optimize_for_power=True).add_generator_run(thompson_run)" @@ -624,6 +624,9 @@ } ], "metadata": { + "fileHeader": "", + "fileUid": "cc062d11-ffb1-4f11-8391-855d8000639b", + "isAdHoc": false, "kernelspec": { "display_name": "python3", "language": "python", @@ -641,7 +644,5 @@ "pygments_lexer": "ipython3", "version": "3.9.15" } - }, - "nbformat": 4, - "nbformat_minor": 2 + } } diff --git a/tutorials/human_in_the_loop/human_in_the_loop.ipynb b/tutorials/human_in_the_loop/human_in_the_loop.ipynb index 1f0d274c1ec..921577d84ee 100644 --- a/tutorials/human_in_the_loop/human_in_the_loop.ipynb +++ b/tutorials/human_in_the_loop/human_in_the_loop.ipynb @@ -69,7 +69,7 @@ " json_load,\n", ")\n", "from ax.modelbridge.cross_validation import cross_validate\n", - "from ax.modelbridge.registry import Models\n", + "from ax.modelbridge.registry import Generators\n", "from ax.plot.diagnostic import tile_cross_validation\n", "from ax.plot.scatter import plot_multiple_metrics, tile_fitted\n", "from ax.utils.notebook.plotting import render, init_notebook_plotting\n", @@ -112,7 +112,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Bayesian Optimization experiments almost always begin with a set of random points. In this experiment, these points were chosen via a Sobol sequence, accessible via the `ModelBridge` factory.\n", + "Bayesian Optimization experiments almost always begin with a set of random points. In this experiment, these points were chosen via a Sobol sequence, accessible via the `Adapter` factory.\n", "\n", "A collection of points run and analyzed together form a `BatchTrial`. A `Trial` object provides metadata pertaining to the deployment of these points, including details such as when they were deployed, and the current status of their experiment. \n", "\n", @@ -281,7 +281,7 @@ "### Model Fit\n", "\n", "Fitting a Modular BoTorch Model will allow us to predict new candidates based on our first Sobol batch. \n", - "Here, we make use of the default settings for `BOTORCH_MODULAR` defined in the ModelBridge registry (uses BoTorch's `SingleTaskGP` and `qLogNoisyExpectedImprovement` by default for single objective optimization)." + "Here, we make use of the default settings for `BOTORCH_MODULAR` defined in the Adapter registry (uses BoTorch's `SingleTaskGP` and `qLogNoisyExpectedImprovement` by default for single objective optimization)." ] }, { @@ -290,7 +290,7 @@ "metadata": {}, "outputs": [], "source": [ - "gp = Models.BOTORCH_MODULAR(\n", + "gp = Generators.BOTORCH_MODULAR(\n", " search_space=experiment.search_space,\n", " experiment=experiment,\n", " data=data,\n", diff --git a/tutorials/modular_botax/modular_botax.ipynb b/tutorials/modular_botax/modular_botax.ipynb index 444c47fb1f7..fa37c3e2989 100644 --- a/tutorials/modular_botax/modular_botax.ipynb +++ b/tutorials/modular_botax/modular_botax.ipynb @@ -52,13 +52,13 @@ "source": [ "from typing import Any, Dict, Optional, Tuple, Type\n", "\n", - "from ax.modelbridge.registry import Models\n", + "from ax.modelbridge.registry import Generators\n", "\n", "# Ax data tranformation layer\n", "from ax.models.torch.botorch_modular.acquisition import Acquisition\n", "\n", "# Ax wrappers for BoTorch components\n", - "from ax.models.torch.botorch_modular.model import BoTorchModel\n", + "from ax.models.torch.botorch_modular.model import BoTorchGenerator\n", "from ax.models.torch.botorch_modular.surrogate import Surrogate, SurrogateSpec\n", "from ax.models.torch.botorch_modular.utils import ModelConfig\n", "\n", @@ -97,22 +97,22 @@ "showInput": false }, "source": [ - "# Setup and Usage of BoTorch Models in Ax\n", + "# Setup and Usage of BoTorch Generators in Ax\n", "\n", - "Ax provides a set of flexible wrapper abstractions to mix-and-match BoTorch components like `Model` and `AcquisitionFunction` and combine them into a single `Model` object in Ax. The wrapper abstractions: `Surrogate`, `Acquisition`, and `BoTorchModel` – are located in `ax/models/torch/botorch_modular` directory and aim to encapsulate boilerplate code that interfaces between Ax and BoTorch. This functionality is in beta-release and still evolving.\n", + "Ax provides a set of flexible wrapper abstractions to mix-and-match BoTorch components like `Model` and `AcquisitionFunction` and combine them into a single `Model` object in Ax. The wrapper abstractions: `Surrogate`, `Acquisition`, and `BoTorchGenerator` – are located in `ax/models/torch/botorch_modular` directory and aim to encapsulate boilerplate code that interfaces between Ax and BoTorch. This functionality is in beta-release and still evolving.\n", "\n", "This tutorial walks through setting up a custom combination of BoTorch components in Ax in following steps:\n", "\n", - "1. **Quick-start example of `BoTorchModel` use**\n", - "1. **`BoTorchModel` = `Surrogate` + `Acquisition` (overview)**\n", + "1. **Quick-start example of `BoTorchGenerator` use**\n", + "1. **`BoTorchGenerator` = `Surrogate` + `Acquisition` (overview)**\n", " 1. Example with minimal options that uses the defaults\n", " 2. Example showing all possible options\n", " 3. Surrogate and Acquisition Q&A\n", "2. **I know which Botorch Model and AcquisitionFunction I'd like to combine in Ax. How do set this up?**\n", " 1. Making a `Surrogate` from BoTorch `Model`\n", " 2. Using an arbitrary BoTorch `AcquisitionFunction` in Ax\n", - "3. **Using `Models.BOTORCH_MODULAR`** (convenience wrapper that enables storage and resumability)\n", - "4. **Utilizing `BoTorchModel` in generation strategies** (abstraction that allows to chain models together and use them in Ax Service API etc.)\n", + "3. **Using `Generators.BOTORCH_MODULAR`** (convenience wrapper that enables storage and resumability)\n", + "4. **Utilizing `BoTorchGenerator` in generation strategies** (abstraction that allows to chain models together and use them in Ax Service API etc.)\n", " 1. Specifying `pending_observations` to avoid the model re-suggesting points that are part of `RUNNING` or `ABANDONED` trials.\n", "5. **Customizing a `Surrogate` or `Acquisition`** (for cases where existing subcomponent classes are not sufficient)" ] @@ -133,7 +133,7 @@ "source": [ "## 1. Quick-start example\n", "\n", - "Here we set up a `BoTorchModel` with `SingleTaskGP` with `qLogNoisyExpectedImprovement`, one of the most popular combinations in Ax:" + "Here we set up a `BoTorchGenerator` with `SingleTaskGP` with `qLogNoisyExpectedImprovement`, one of the most popular combinations in Ax:" ] }, { @@ -197,9 +197,9 @@ } ], "source": [ - "# `Models` automatically selects a model + model bridge combination.\n", - "# For `BOTORCH_MODULAR`, it will select `BoTorchModel` and `TorchModelBridge`.\n", - "model_bridge_with_GPEI = Models.BOTORCH_MODULAR(\n", + "# `Generators` automatically selects a model + model bridge combination.\n", + "# For `BOTORCH_MODULAR`, it will select `BoTorchGenerator` and `TorchModelBridge`.\n", + "model_bridge_with_GPEI = Generators.BOTORCH_MODULAR(\n", " experiment=experiment,\n", " data=data,\n", " surrogate_spec=SurrogateSpec(\n", @@ -278,7 +278,7 @@ "-----\n", "Before you read the rest of this tutorial:\n", "\n", - "- Note that the concept of ‘model’ is Ax is somewhat a misnomer; we use ['model'](https://ax.dev/docs/glossary.html#model) to refer to an optimization setup capable of producing candidate points for optimization (and often capable of being fit to data, with exception for quasi-random generators). See [Models documentation page](https://ax.dev/docs/models.html) for more information.\n", + "- Note that the concept of ‘model’ is Ax is somewhat a misnomer; we use ['model'](https://ax.dev/docs/glossary.html#model) to refer to an optimization setup capable of producing candidate points for optimization (and often capable of being fit to data, with exception for quasi-random generators). See [Generators documentation page](https://ax.dev/docs/models.html) for more information.\n", "- Learn about `ModelBridge` in Ax, as users should rarely be interacting with a `Model` object directly (more about ModelBridge, a data transformation layer in Ax, [here](https://ax.dev/docs/models.html#deeper-dive-organization-of-the-modeling-stack))." ] }, @@ -296,9 +296,9 @@ "showInput": false }, "source": [ - "## 2. BoTorchModel = Surrogate + Acquisition\n", + "## 2. BoTorchGenerator = Surrogate + Acquisition\n", "\n", - "A `BoTorchModel` in Ax consists of two main subcomponents: a surrogate model and an acquisition function. A surrogate model is represented as an instance of Ax’s `Surrogate` class, which is a wrapper around BoTorch's `Model` class. The Surrogate is defined by a `SurrogateSpec`. The acquisition function is represented as an instance of Ax’s `Acquisition` class, a wrapper around BoTorch's `AcquisitionFunction` class." + "A `BoTorchGenerator` in Ax consists of two main subcomponents: a surrogate model and an acquisition function. A surrogate model is represented as an instance of Ax’s `Surrogate` class, which is a wrapper around BoTorch's `Model` class. The Surrogate is defined by a `SurrogateSpec`. The acquisition function is represented as an instance of Ax’s `Acquisition` class, a wrapper around BoTorch's `AcquisitionFunction` class." ] }, { @@ -317,7 +317,7 @@ "source": [ "### 2A. Example that uses defaults and requires no options\n", "\n", - "BoTorchModel does not always require surrogate and acquisition specification. If instantiated without one or both components specified, defaults are selected based on properties of experiment and data (see Appendix 2 for auto-selection logic)." + "BoTorchGenerator does not always require surrogate and acquisition specification. If instantiated without one or both components specified, defaults are selected based on properties of experiment and data (see Appendix 2 for auto-selection logic)." ] }, { @@ -342,18 +342,18 @@ "source": [ "# The surrogate is not specified, so it will be auto-selected\n", "# during `model.fit`.\n", - "GPEI_model = BoTorchModel(botorch_acqf_class=qLogExpectedImprovement)\n", + "GPEI_model = BoTorchGenerator(botorch_acqf_class=qLogExpectedImprovement)\n", "\n", "# The acquisition class is not specified, so it will be\n", "# auto-selected during `model.gen` or `model.evaluate_acquisition`\n", - "GPEI_model = BoTorchModel(\n", + "GPEI_model = BoTorchGenerator(\n", " surrogate_spec=SurrogateSpec(\n", " model_configs=[ModelConfig(botorch_model_class=SingleTaskGP)]\n", " )\n", ")\n", "\n", "# Both the surrogate and acquisition class will be auto-selected.\n", - "GPEI_model = BoTorchModel()" + "GPEI_model = BoTorchGenerator()" ] }, { @@ -371,7 +371,7 @@ }, "source": [ "### 2B. Example with all the options\n", - "Below are the full set of configurable settings of a `BoTorchModel` with their descriptions:" + "Below are the full set of configurable settings of a `BoTorchGenerator` with their descriptions:" ] }, { @@ -394,7 +394,7 @@ }, "outputs": [], "source": [ - "model = BoTorchModel(\n", + "model = BoTorchGenerator(\n", " # Optional `Surrogate` specification to use instead of default\n", " surrogate_spec=SurrogateSpec(\n", " model_configs=[\n", @@ -418,7 +418,7 @@ " # `AcquisitionFunction` requires one, which is rare)\n", " acquisition_class=None,\n", " # Less common model settings shown with default values, refer\n", - " # to `BoTorchModel` documentation for detail\n", + " # to `BoTorchGenerator` documentation for detail\n", " refit_on_cv=False,\n", " warm_start_refit=True,\n", ")" @@ -440,7 +440,7 @@ "source": [ "## 2C. `Surrogate` and `Acquisition` Q&A\n", "\n", - "**Why is the `surrogate` argument expected to be an instance, but `botorch_acqf_class` –– a class?** Because a BoTorch `AcquisitionFunction` object (and therefore its Ax wrapper, `Acquisition`) is ephemeral: it is constructed, immediately used, and destroyed during `BoTorchModel.gen`, so there is no reason to keep around an `Acquisition` instance. A `Surrogate`, on another hand, is kept in memory as long as its parent `BoTorchModel` is.\n", + "**Why is the `surrogate` argument expected to be an instance, but `botorch_acqf_class` –– a class?** Because a BoTorch `AcquisitionFunction` object (and therefore its Ax wrapper, `Acquisition`) is ephemeral: it is constructed, immediately used, and destroyed during `BoTorchGenerator.gen`, so there is no reason to keep around an `Acquisition` instance. A `Surrogate`, on another hand, is kept in memory as long as its parent `BoTorchGenerator` is.\n", "\n", "**How to know when to use specify acquisition_class (and thereby a non-default Acquisition type) instead of just passing in botorch_acqf_class?** In short, custom `Acquisition` subclasses are needed when a given `AcquisitionFunction` in BoTorch needs some non-standard subcomponents or inputs (e.g. a custom BoTorch `MCAcquisitionObjective`).