diff --git a/ax/modelbridge/generation_strategy.py b/ax/modelbridge/generation_strategy.py index b9e18af41ce..bfb98fb9841 100644 --- a/ax/modelbridge/generation_strategy.py +++ b/ax/modelbridge/generation_strategy.py @@ -185,24 +185,12 @@ def name(self, name: str) -> None: @property @step_based_gs_only def model_transitions(self) -> List[int]: - """List of trial indices where a transition happened from one model to + """[DEPRECATED]List of trial indices where a transition happened from one model to another.""" - # TODO @mgarrard to support GenerationNodes here, which is non-trival - # since nodes are dynamic and may only support past model_transitions - gen_changes = [] - for node in self._nodes: - for criterion in node.transition_criteria: - if ( - isinstance(criterion, TrialBasedCriterion) - and criterion.criterion_class == "MaxTrials" - ): - gen_changes.append(criterion.threshold) - - # if the last node has unlimited generation, do not remeove the last - # transition point in the list - if self._nodes[-1].gen_unlimited_trials: - return [sum(gen_changes[: i + 1]) for i in range(len(gen_changes))] - return [sum(gen_changes[: i + 1]) for i in range(len(gen_changes))][:-1] + raise DeprecationWarning( + "`model_transitions` is no longer supported. Please refer to `model_key` " + "field on generator runs for similar information if needed." + ) @property def current_step(self) -> GenerationStep: diff --git a/ax/modelbridge/tests/test_generation_strategy.py b/ax/modelbridge/tests/test_generation_strategy.py index 2c799e83ad5..47cc140ce11 100644 --- a/ax/modelbridge/tests/test_generation_strategy.py +++ b/ax/modelbridge/tests/test_generation_strategy.py @@ -370,7 +370,6 @@ def test_do_not_enforce_min_observations(self) -> None: def test_sobol_GPEI_strategy(self) -> None: exp = get_branin_experiment() self.assertEqual(self.sobol_GPEI_GS.name, "Sobol+GPEI") - self.assertEqual(self.sobol_GPEI_GS.model_transitions, [5]) for i in range(7): g = self.sobol_GPEI_GS.gen(exp) exp.new_trial(generator_run=g).run() @@ -443,7 +442,6 @@ def test_sobol_GPEI_strategy_keep_generating(self) -> None: ] ) self.assertEqual(sobol_GPEI_generation_strategy.name, "Sobol+GPEI") - self.assertEqual(sobol_GPEI_generation_strategy.model_transitions, [5]) exp.new_trial(generator_run=sobol_GPEI_generation_strategy.gen(exp)).run() for i in range(1, 15): g = sobol_GPEI_generation_strategy.gen(exp) @@ -491,7 +489,6 @@ def test_factorial_thompson_strategy(self, _: MagicMock) -> None: self.assertEqual( factorial_thompson_generation_strategy.name, "Factorial+Thompson" ) - self.assertEqual(factorial_thompson_generation_strategy.model_transitions, [1]) mock_model_bridge = self.mock_discrete_model_bridge.return_value # Initial factorial batch. @@ -551,7 +548,6 @@ def test_sobol_GPEI_strategy_batches(self) -> None: ], ) self.assertEqual(sobol_GPEI_generation_strategy.name, "Sobol+GPEI") - self.assertEqual(sobol_GPEI_generation_strategy.model_transitions, [1]) gr = sobol_GPEI_generation_strategy.gen(exp, n=2) exp.new_batch_trial(generator_run=gr).run() for i in range(1, 8): diff --git a/ax/modelbridge/tests/test_rembo_strategy.py b/ax/modelbridge/tests/test_rembo_strategy.py index 5116ba92180..53d570cd63e 100644 --- a/ax/modelbridge/tests/test_rembo_strategy.py +++ b/ax/modelbridge/tests/test_rembo_strategy.py @@ -97,7 +97,6 @@ def test_REMBOStrategy(self) -> None: if i >= 2: self.assertFalse(any(len(x) < 4 for x in gs.arms_by_proj.values())) - self.assertTrue(len(gs.model_transitions) > 0) gs2 = gs.clone_reset() self.assertEqual(gs2.D, 20) self.assertEqual(gs2.d, 6) diff --git a/ax/plot/trace.py b/ax/plot/trace.py index fb97cdc4328..a15548a3922 100644 --- a/ax/plot/trace.py +++ b/ax/plot/trace.py @@ -345,40 +345,6 @@ def optimum_objective_scatter( ) -def model_transitions_scatter( - model_transitions: List[int], - y_range: List[float], - generator_change_color: Tuple[int] = COLORS.TEAL.value, -) -> List[go.Scatter]: - """Creates a graph object for the line(s) representing generator changes. - - Args: - model_transitions: iterations, before which generators - changed - y_range: upper and lower values of the y-range of the plot - generator_change_color: tuple of 3 int values representing - an RGB color. Defaults to orange. - - Returns: - go.Scatter: plotly graph objects for the lines representing generator - changes - """ - if len(y_range) != 2: - raise ValueError("y_range should have two values, lower and upper.") - data: List[go.Scatter] = [] - for change in model_transitions: - data.append( - go.Scatter( - x=[change] * 2, - y=y_range, - mode="lines", - line={"dash": "dash", "color": rgba(generator_change_color)}, - name="model change", - ) - ) - return data - - def optimization_trace_single_method_plotly( y: np.ndarray, optimum: Optional[float] = None, diff --git a/ax/service/ax_client.py b/ax/service/ax_client.py index 0a3e5a2dbeb..527a5bf7bc4 100644 --- a/ax/service/ax_client.py +++ b/ax/service/ax_client.py @@ -957,10 +957,9 @@ def _constrained_trial_objective_mean(trial: BaseTrial) -> float: else np.maximum.accumulate(best_objectives, axis=1) ), optimum=objective_optimum, - title="Model performance vs. # of iterations", + title="Best objective found vs. # of iterations", ylabel=objective_name.capitalize(), hover_labels=hover_labels, - model_transitions=model_transitions, ) def get_contour_plot( diff --git a/ax/service/tests/test_ax_client.py b/ax/service/tests/test_ax_client.py index 219f9b74db7..d8902c52703 100644 --- a/ax/service/tests/test_ax_client.py +++ b/ax/service/tests/test_ax_client.py @@ -2938,8 +2938,6 @@ def test_with_node_based_gs(self) -> None: ) as mock_plot: ax_client.get_optimization_trace() mock_plot.assert_called_once() - call_kwargs = mock_plot.call_args.kwargs - self.assertIsNone(call_kwargs["model_transitions"]) # Utility functions for testing get_model_predictions without calling diff --git a/ax/service/utils/report_utils.py b/ax/service/utils/report_utils.py index ea604932118..5d3c9d5ccff 100644 --- a/ax/service/utils/report_utils.py +++ b/ax/service/utils/report_utils.py @@ -110,7 +110,6 @@ def _get_cross_validation_plots(model: ModelBridge) -> List[go.Figure]: def _get_objective_trace_plot( experiment: Experiment, data: Data, - model_transitions: List[int], true_objective_metric_name: Optional[str] = None, ) -> Iterable[go.Figure]: if experiment.is_moo_problem: @@ -295,7 +294,6 @@ def get_standard_plots( experiment: Experiment, model: Optional[ModelBridge], data: Optional[Data] = None, - model_transitions: Optional[List[int]] = None, true_objective_metric_name: Optional[str] = None, early_stopping_strategy: Optional[BaseEarlyStoppingStrategy] = None, limit_points_per_plot: Optional[int] = None, @@ -311,9 +309,6 @@ def get_standard_plots( Args: - experiment: The ``Experiment`` from which to obtain standard plots. - model: The ``ModelBridge`` used to suggest trial parameters. - - data: If specified, data, to which to fit the model before generating plots. - - model_transitions: The arm numbers at which shifts in generation_strategy - occur. - true_objective_metric_name: Name of the metric to use as the true objective. - early_stopping_strategy: Early stopping strategy used throughout the experiment; used for visualizing when curves are stopped. @@ -365,9 +360,6 @@ def get_standard_plots( _get_objective_trace_plot( experiment=experiment, data=data, - model_transitions=( - model_transitions if model_transitions is not None else [] - ), true_objective_metric_name=true_objective_metric_name, ) )