Skip to content

Commit

Permalink
Rename modeling layer components (#3280)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #3280

Rename `Model` -> `Generator`, `ModelSpec` -> `GeneratorSpec`, `Modelbridge` -> `Adapter`.

This also updates the decoders so that we can load objects stored with the previous names: e.g. decode `Models` as `Generators`.

Differential Revision: D68735059
  • Loading branch information
sdaulton authored and facebook-github-bot committed Jan 30, 2025
1 parent d059120 commit 76cfab0
Show file tree
Hide file tree
Showing 197 changed files with 7,399 additions and 7,290 deletions.
4 changes: 2 additions & 2 deletions ax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
SumConstraint,
Trial,
)
from ax.modelbridge import Models
from ax.modelbridge import Generators
from ax.service import OptimizationLoop, optimize
from ax.storage import json_load, json_save

Expand All @@ -52,7 +52,7 @@
"FixedParameter",
"GeneratorRun",
"Metric",
"Models",
"Generators",
"MultiObjective",
"MultiObjectiveOptimizationConfig",
"Objective",
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/healthcheck/constraints_feasibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.optimization_config import OptimizationConfig
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.base import Adapter
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.transforms.derelativize import Derelativize
from pyre_extensions import assert_is_instance, none_throws
Expand Down Expand Up @@ -155,7 +155,7 @@ def compute(

def constraints_feasibility(
optimization_config: OptimizationConfig,
model: ModelBridge,
model: Adapter,
prob_threshold: float = 0.99,
) -> Tuple[bool, pd.DataFrame]:
r"""
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/healthcheck/regression_detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from ax.core.observation import observations_from_data

from ax.exceptions.core import DataRequiredError, UserInputError
from ax.modelbridge.discrete import DiscreteModelBridge
from ax.modelbridge.discrete import DiscreteAdapter
from ax.modelbridge.registry import rel_EB_ashr_trans
from ax.models.discrete.eb_ashr import EBAshr
from pyre_extensions import assert_is_instance
Expand Down Expand Up @@ -101,7 +101,7 @@ def compute_regression_probabilities_single_trial(

target_data = Data(df=data.df[data.df["metric_name"].isin(metric_names)])

modelbridge = DiscreteModelBridge(
modelbridge = DiscreteAdapter(
experiment=experiment,
search_space=experiment.search_space,
data=target_data,
Expand Down
8 changes: 4 additions & 4 deletions ax/analysis/healthcheck/tests/test_constraints_feasibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from ax.modelbridge.factory import get_sobol
from ax.modelbridge.generation_node import GenerationNode
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.model_spec import ModelSpec
from ax.modelbridge.registry import Models
from ax.modelbridge.model_spec import GeneratorSpec
from ax.modelbridge.registry import Generators
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_branin_experiment,
Expand Down Expand Up @@ -96,8 +96,8 @@ def setUp(self) -> None:
GenerationNode(
node_name="gn",
model_specs=[
ModelSpec(
model_enum=Models.BOTORCH_MODULAR,
GeneratorSpec(
model_enum=Generators.BOTORCH_MODULAR,
)
],
)
Expand Down
16 changes: 8 additions & 8 deletions ax/analysis/plotly/arm_effects/insample_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
from ax.core.generator_run import GeneratorRun
from ax.core.outcome_constraint import OutcomeConstraint
from ax.exceptions.core import DataRequiredError, UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.base import Adapter
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.registry import Models
from ax.modelbridge.registry import Generators
from ax.modelbridge.transforms.derelativize import Derelativize
from ax.utils.common.logger import get_logger
from pyre_extensions import none_throws
Expand Down Expand Up @@ -157,7 +157,7 @@ def _plot_type_string(self) -> str:
return "Modeled" if self.use_modeled_effects else "Observed"


def _get_max_observed_trial_index(model: ModelBridge) -> int | None:
def _get_max_observed_trial_index(model: Adapter) -> int | None:
"""Returns the max observed trial index to appease multitask models for prediction
by giving fixed features. This is not necessarily accurate and should eventually
come from the generation strategy.
Expand All @@ -178,7 +178,7 @@ def _get_model(
use_modeled_effects: bool,
trial_index: int,
metric_name: str,
) -> ModelBridge:
) -> Adapter:
"""Get a model for predictions.
Args:
Expand Down Expand Up @@ -213,14 +213,14 @@ def _get_model(

if model is None or not is_predictive(model=model):
logger.info("Using empirical Bayes for predictions.")
return Models.EMPIRICAL_BAYES_THOMPSON(
return Generators.EMPIRICAL_BAYES_THOMPSON(
experiment=experiment, data=trial_data
)

return model
else:
# This model just predicts observed data
return Models.THOMPSON(
return Generators.THOMPSON(
data=trial_data,
search_space=experiment.search_space,
experiment=experiment,
Expand All @@ -229,7 +229,7 @@ def _get_model(

def _prepare_data(
experiment: Experiment,
model: ModelBridge,
model: Adapter,
outcome_constraints: list[OutcomeConstraint],
metric_name: str,
trial_index: int,
Expand All @@ -249,7 +249,7 @@ def _prepare_data(
Args:
experiment: Experiment to plot
model: ModelBridge being used for prediction
model: Adapter being used for prediction
outcome_constraints: Derelatives outcome constraints used for
assessing feasibility
metric_name: Name of metric to plot
Expand Down
6 changes: 3 additions & 3 deletions ax/analysis/plotly/arm_effects/predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.base import Adapter
from ax.modelbridge.generation_strategy import GenerationStrategy
from ax.modelbridge.transforms.derelativize import Derelativize
from pyre_extensions import assert_is_instance, none_throws
Expand Down Expand Up @@ -149,7 +149,7 @@ def compute(


def _prepare_data(
model: ModelBridge,
model: Adapter,
metric_name: str,
candidate_trial: BaseTrial,
outcome_constraints: list[OutcomeConstraint],
Expand All @@ -167,7 +167,7 @@ def _prepare_data(
candidate trial.
Args:
model: ModelBridge being used for prediction
model: Adapter being used for prediction
metric_name: Name of metric to plot
candidate_trial: Trial to plot candidates for by generator run
"""
Expand Down
6 changes: 3 additions & 3 deletions ax/analysis/plotly/arm_effects/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ax.core.outcome_constraint import OutcomeConstraint
from ax.core.types import TParameterization
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.base import Adapter
from ax.modelbridge.prediction_utils import predict_at_point
from plotly import express as px, graph_objects as go
from pyre_extensions import none_throws
Expand Down Expand Up @@ -203,7 +203,7 @@ def _add_style_to_effects_by_arm_plot(
)


def _get_trial_index_for_predictions(model: ModelBridge) -> int | None:
def _get_trial_index_for_predictions(model: Adapter) -> int | None:
"""Returns status quo features index if defined on the model. Otherwise, returns
the max observed trial index to appease multitask models for prediction
by giving fixed features. The max index is not necessarily accurate and should
Expand All @@ -224,7 +224,7 @@ def _get_trial_index_for_predictions(model: ModelBridge) -> int | None:


def get_predictions_by_arm(
model: ModelBridge,
model: Adapter,
metric_name: str,
outcome_constraints: list[OutcomeConstraint],
gr: GeneratorRun | None = None,
Expand Down
2 changes: 1 addition & 1 deletion ax/analysis/plotly/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
folds: Number of subsamples to partition observations into. Use -1 for
leave-one-out cross validation.
untransform: Whether to untransform the model predictions before cross
validating. Models are trained on transformed data, and candidate
validating. Generators are trained on transformed data, and candidate
generation is performed in the transformed space. Computing the model
quality metric based on the cross-validation results in the
untransformed space may not be representative of the model that
Expand Down
14 changes: 6 additions & 8 deletions ax/analysis/plotly/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from ax.core.experiment import Experiment
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.exceptions.core import UserInputError
from ax.modelbridge.registry import Models
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.registry import Generators
from ax.modelbridge.torch import TorchAdapter
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.utils.common.logger import get_logger
from ax.utils.sensitivity.sobol_measures import ax_parameter_sens
Expand Down Expand Up @@ -261,9 +261,7 @@ def compute(
fig=fig,
)

def _get_oak_model(
self, experiment: Experiment, metric_name: str
) -> TorchModelBridge:
def _get_oak_model(self, experiment: Experiment, metric_name: str) -> TorchAdapter:
"""
Retrieves the modelbridge used for the analysis. The model uses an OAK
(Orthogonal Additive Kernel) with a sparsity-inducing prior,
Expand All @@ -275,7 +273,7 @@ def _get_oak_model(
lengthscales being fit.
"""
data = experiment.lookup_data().filter(metric_names=[metric_name])
model_bridge = Models.BOTORCH_MODULAR(
model_bridge = Generators.BOTORCH_MODULAR(
search_space=experiment.search_space,
experiment=experiment,
data=data,
Expand Down Expand Up @@ -304,12 +302,12 @@ def _get_oak_model(
),
)

return assert_is_instance(model_bridge, TorchModelBridge)
return assert_is_instance(model_bridge, TorchAdapter)


def _prepare_surface_plot(
experiment: Experiment,
model: TorchModelBridge,
model: TorchAdapter,
feature_name: str,
metric_name: str,
) -> go.Figure:
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/plotly/surface/contour.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.observation import ObservationFeatures
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.base import Adapter
from ax.modelbridge.generation_strategy import GenerationStrategy
from plotly import graph_objects as go
from pyre_extensions import none_throws
Expand Down Expand Up @@ -113,7 +113,7 @@ def compute(

def _prepare_data(
experiment: Experiment,
model: ModelBridge,
model: Adapter,
x_parameter_name: str,
y_parameter_name: str,
metric_name: str,
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/plotly/surface/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from ax.core.generation_strategy_interface import GenerationStrategyInterface
from ax.core.observation import ObservationFeatures
from ax.exceptions.core import UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.base import Adapter
from ax.modelbridge.generation_strategy import GenerationStrategy
from plotly import express as px, graph_objects as go
from pyre_extensions import none_throws
Expand Down Expand Up @@ -100,7 +100,7 @@ def compute(

def _prepare_data(
experiment: Experiment,
model: ModelBridge,
model: Adapter,
parameter_name: str,
metric_name: str,
) -> pd.DataFrame:
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/plotly/tests/test_predicted_effects.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from ax.exceptions.core import UserInputError
from ax.modelbridge.dispatch_utils import choose_generation_strategy
from ax.modelbridge.prediction_utils import predict_at_point
from ax.modelbridge.registry import Models
from ax.modelbridge.registry import Generators
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_branin_experiment,
Expand Down Expand Up @@ -315,7 +315,7 @@ def test_it_works_for_non_batch_experiments(self) -> None:
experiment=experiment,
)
# AND GIVEN we generate all Sobol trials and one GPEI trial
sobol_key = Models.SOBOL.value
sobol_key = Generators.SOBOL.value
last_model_key = sobol_key
while last_model_key == sobol_key:
trial = experiment.new_trial(
Expand Down
6 changes: 3 additions & 3 deletions ax/analysis/plotly/tests/test_scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ax.analysis.analysis import AnalysisCardLevel
from ax.analysis.plotly.scatter import _prepare_data, ScatterPlot
from ax.exceptions.core import DataRequiredError, UserInputError
from ax.modelbridge.registry import Models
from ax.modelbridge.registry import Generators
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import (
get_branin_experiment_with_multi_objective,
Expand Down Expand Up @@ -85,7 +85,7 @@ def test_prepare_data(self) -> None:
def test_it_only_has_observations_with_data_for_both_metrics(self) -> None:
# GIVEN an experiment with multiple trials and metrics
experiment = get_branin_experiment_with_multi_objective()
sobol = Models.SOBOL(search_space=experiment.search_space)
sobol = Generators.SOBOL(search_space=experiment.search_space)

t0 = experiment.new_batch_trial(generator_run=sobol.gen(3)).mark_completed(
unsafe=True
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_it_only_has_observations_with_data_for_both_metrics(self) -> None:
def test_it_must_have_some_observations_with_data_for_both_metrics(self) -> None:
# GIVEN an experiment with multiple trials and metrics
experiment = get_branin_experiment_with_multi_objective()
sobol = Models.SOBOL(search_space=experiment.search_space)
sobol = Generators.SOBOL(search_space=experiment.search_space)

t0 = experiment.new_batch_trial(generator_run=sobol.gen(3)).mark_completed(
unsafe=True
Expand Down
4 changes: 2 additions & 2 deletions ax/analysis/plotly/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ax.core.objective import MultiObjective, ScalarizedObjective
from ax.core.outcome_constraint import ComparisonOp, OutcomeConstraint
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.modelbridge.base import ModelBridge
from ax.modelbridge.base import Adapter
from botorch.utils.probability.utils import compute_log_prob_feas_from_bounds
from numpy.typing import NDArray

Expand Down Expand Up @@ -126,7 +126,7 @@ def format_constraint_violated_probabilities(
return constraints_violated_str


def is_predictive(model: ModelBridge) -> bool:
def is_predictive(model: Adapter) -> bool:
"""Check if a model is predictive. Basically, we're checking if
predict() is implemented.
Expand Down
10 changes: 5 additions & 5 deletions ax/benchmark/benchmark_test_functions/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ax.benchmark.benchmark_test_function import BenchmarkTestFunction
from ax.core.observation import ObservationFeatures
from ax.core.types import TParamValue
from ax.modelbridge.torch import TorchModelBridge
from ax.modelbridge.torch import TorchAdapter
from ax.utils.common.base import Base
from ax.utils.common.equality import equality_typechecker
from pyre_extensions import none_throws
Expand All @@ -28,7 +28,7 @@ class SurrogateTestFunction(BenchmarkTestFunction):
name: The name of the runner.
outcome_names: Names of outcomes to return in `evaluate_true`, if the
surrogate produces more outcomes than are needed.
_surrogate: Either `None`, or a `TorchModelBridge` surrogate to use
_surrogate: Either `None`, or a `TorchAdapter` surrogate to use
for generating observations. If `None`, `get_surrogate`
must not be None and will be used to generate the surrogate when it
is needed.
Expand All @@ -39,8 +39,8 @@ class SurrogateTestFunction(BenchmarkTestFunction):

name: str
outcome_names: Sequence[str]
_surrogate: TorchModelBridge | None = None
get_surrogate: None | Callable[[], TorchModelBridge] = None
_surrogate: TorchAdapter | None = None
get_surrogate: None | Callable[[], TorchAdapter] = None

def __post_init__(self) -> None:
if self.get_surrogate is None and self._surrogate is None:
Expand All @@ -50,7 +50,7 @@ def __post_init__(self) -> None:
)

@property
def surrogate(self) -> TorchModelBridge:
def surrogate(self) -> TorchAdapter:
if self._surrogate is None:
self._surrogate = none_throws(self.get_surrogate)()
return none_throws(self._surrogate)
Expand Down
Loading

0 comments on commit 76cfab0

Please sign in to comment.