Skip to content

Commit

Permalink
Delete model_transitions
Browse files Browse the repository at this point in the history
Reviewed By: bernardbeckerman

Differential Revision: D54702629
  • Loading branch information
Lena Kashtelyan authored and facebook-github-bot committed Mar 22, 2024
1 parent fbcad21 commit ad2e034
Show file tree
Hide file tree
Showing 7 changed files with 6 additions and 68 deletions.
22 changes: 5 additions & 17 deletions ax/modelbridge/generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,24 +185,12 @@ def name(self, name: str) -> None:
@property
@step_based_gs_only
def model_transitions(self) -> List[int]:
"""List of trial indices where a transition happened from one model to
"""[DEPRECATED]List of trial indices where a transition happened from one model to
another."""
# TODO @mgarrard to support GenerationNodes here, which is non-trival
# since nodes are dynamic and may only support past model_transitions
gen_changes = []
for node in self._nodes:
for criterion in node.transition_criteria:
if (
isinstance(criterion, TrialBasedCriterion)
and criterion.criterion_class == "MaxTrials"
):
gen_changes.append(criterion.threshold)

# if the last node has unlimited generation, do not remeove the last
# transition point in the list
if self._nodes[-1].gen_unlimited_trials:
return [sum(gen_changes[: i + 1]) for i in range(len(gen_changes))]
return [sum(gen_changes[: i + 1]) for i in range(len(gen_changes))][:-1]
raise DeprecationWarning(
"`model_transitions` is no longer supported. Please refer to `model_key` "
"field on generator runs for similar information if needed."
)

@property
def current_step(self) -> GenerationStep:
Expand Down
4 changes: 0 additions & 4 deletions ax/modelbridge/tests/test_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,7 +370,6 @@ def test_do_not_enforce_min_observations(self) -> None:
def test_sobol_GPEI_strategy(self) -> None:
exp = get_branin_experiment()
self.assertEqual(self.sobol_GPEI_GS.name, "Sobol+GPEI")
self.assertEqual(self.sobol_GPEI_GS.model_transitions, [5])
for i in range(7):
g = self.sobol_GPEI_GS.gen(exp)
exp.new_trial(generator_run=g).run()
Expand Down Expand Up @@ -443,7 +442,6 @@ def test_sobol_GPEI_strategy_keep_generating(self) -> None:
]
)
self.assertEqual(sobol_GPEI_generation_strategy.name, "Sobol+GPEI")
self.assertEqual(sobol_GPEI_generation_strategy.model_transitions, [5])
exp.new_trial(generator_run=sobol_GPEI_generation_strategy.gen(exp)).run()
for i in range(1, 15):
g = sobol_GPEI_generation_strategy.gen(exp)
Expand Down Expand Up @@ -491,7 +489,6 @@ def test_factorial_thompson_strategy(self, _: MagicMock) -> None:
self.assertEqual(
factorial_thompson_generation_strategy.name, "Factorial+Thompson"
)
self.assertEqual(factorial_thompson_generation_strategy.model_transitions, [1])
mock_model_bridge = self.mock_discrete_model_bridge.return_value

# Initial factorial batch.
Expand Down Expand Up @@ -551,7 +548,6 @@ def test_sobol_GPEI_strategy_batches(self) -> None:
],
)
self.assertEqual(sobol_GPEI_generation_strategy.name, "Sobol+GPEI")
self.assertEqual(sobol_GPEI_generation_strategy.model_transitions, [1])
gr = sobol_GPEI_generation_strategy.gen(exp, n=2)
exp.new_batch_trial(generator_run=gr).run()
for i in range(1, 8):
Expand Down
1 change: 0 additions & 1 deletion ax/modelbridge/tests/test_rembo_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def test_REMBOStrategy(self) -> None:
if i >= 2:
self.assertFalse(any(len(x) < 4 for x in gs.arms_by_proj.values()))

self.assertTrue(len(gs.model_transitions) > 0)
gs2 = gs.clone_reset()
self.assertEqual(gs2.D, 20)
self.assertEqual(gs2.d, 6)
Expand Down
34 changes: 0 additions & 34 deletions ax/plot/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,40 +345,6 @@ def optimum_objective_scatter(
)


def model_transitions_scatter(
model_transitions: List[int],
y_range: List[float],
generator_change_color: Tuple[int] = COLORS.TEAL.value,
) -> List[go.Scatter]:
"""Creates a graph object for the line(s) representing generator changes.
Args:
model_transitions: iterations, before which generators
changed
y_range: upper and lower values of the y-range of the plot
generator_change_color: tuple of 3 int values representing
an RGB color. Defaults to orange.
Returns:
go.Scatter: plotly graph objects for the lines representing generator
changes
"""
if len(y_range) != 2:
raise ValueError("y_range should have two values, lower and upper.")
data: List[go.Scatter] = []
for change in model_transitions:
data.append(
go.Scatter(
x=[change] * 2,
y=y_range,
mode="lines",
line={"dash": "dash", "color": rgba(generator_change_color)},
name="model change",
)
)
return data


def optimization_trace_single_method_plotly(
y: np.ndarray,
optimum: Optional[float] = None,
Expand Down
3 changes: 1 addition & 2 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,10 +957,9 @@ def _constrained_trial_objective_mean(trial: BaseTrial) -> float:
else np.maximum.accumulate(best_objectives, axis=1)
),
optimum=objective_optimum,
title="Model performance vs. # of iterations",
title="Best objective found vs. # of iterations",
ylabel=objective_name.capitalize(),
hover_labels=hover_labels,
model_transitions=model_transitions,
)

def get_contour_plot(
Expand Down
2 changes: 0 additions & 2 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2938,8 +2938,6 @@ def test_with_node_based_gs(self) -> None:
) as mock_plot:
ax_client.get_optimization_trace()
mock_plot.assert_called_once()
call_kwargs = mock_plot.call_args.kwargs
self.assertIsNone(call_kwargs["model_transitions"])


# Utility functions for testing get_model_predictions without calling
Expand Down
8 changes: 0 additions & 8 deletions ax/service/utils/report_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ def _get_cross_validation_plots(model: ModelBridge) -> List[go.Figure]:
def _get_objective_trace_plot(
experiment: Experiment,
data: Data,
model_transitions: List[int],
true_objective_metric_name: Optional[str] = None,
) -> Iterable[go.Figure]:
if experiment.is_moo_problem:
Expand Down Expand Up @@ -295,7 +294,6 @@ def get_standard_plots(
experiment: Experiment,
model: Optional[ModelBridge],
data: Optional[Data] = None,
model_transitions: Optional[List[int]] = None,
true_objective_metric_name: Optional[str] = None,
early_stopping_strategy: Optional[BaseEarlyStoppingStrategy] = None,
limit_points_per_plot: Optional[int] = None,
Expand All @@ -311,9 +309,6 @@ def get_standard_plots(
Args:
- experiment: The ``Experiment`` from which to obtain standard plots.
- model: The ``ModelBridge`` used to suggest trial parameters.
- data: If specified, data, to which to fit the model before generating plots.
- model_transitions: The arm numbers at which shifts in generation_strategy
occur.
- true_objective_metric_name: Name of the metric to use as the true objective.
- early_stopping_strategy: Early stopping strategy used throughout the
experiment; used for visualizing when curves are stopped.
Expand Down Expand Up @@ -365,9 +360,6 @@ def get_standard_plots(
_get_objective_trace_plot(
experiment=experiment,
data=data,
model_transitions=(
model_transitions if model_transitions is not None else []
),
true_objective_metric_name=true_objective_metric_name,
)
)
Expand Down

0 comments on commit ad2e034

Please sign in to comment.