Skip to content

Commit

Permalink
only infer reference point in global stopping if there is data (#2338)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2338

see title

Reviewed By: Balandat

Differential Revision: D55924472

fbshipit-source-id: 67168307420d7b19b26bfe3c998c9d0882ad8cda
  • Loading branch information
sdaulton authored and facebook-github-bot committed Apr 9, 2024
1 parent 6720be4 commit 99615d6
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 11 deletions.
5 changes: 3 additions & 2 deletions ax/plot/pareto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def _build_new_optimization_config(


def infer_reference_point_from_experiment(
experiment: Experiment,
experiment: Experiment, data: Data
) -> List[ObjectiveThreshold]:
"""This functions is a wrapper around ``infer_reference_point`` to find the nadir
point from the pareto front of an experiment. Aside from converting experiment
Expand All @@ -600,7 +600,8 @@ def infer_reference_point_from_experiment(

# Reading experiment data.
mb_reference = get_tensor_converter_model(
experiment=experiment, data=experiment.fetch_data()
experiment=experiment,
data=data,
)
obs_feats, obs_data, _ = _get_modelbridge_training_data(modelbridge=mb_reference)

Expand Down
16 changes: 11 additions & 5 deletions ax/plot/tests/test_pareto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ def test_infer_reference_point_from_experiment(self) -> None:
scalarized=False,
constrained=False,
)
inferred_reference_point = infer_reference_point_from_experiment(experiment)
data = experiment.fetch_data()
inferred_reference_point = infer_reference_point_from_experiment(
experiment, data=data
)
# The nadir point for this experiment is [-0.5, 0.5]. The function actually
# deducts 0.1*Y_range from each of the objectives. Since the range for each
# of the objectives is +/-1.5, the inferred reference point would
Expand All @@ -265,7 +268,7 @@ def test_infer_reference_point_from_experiment(self) -> None:
return_value=([], [], [], []),
):
with self.assertRaisesRegex(RuntimeError, "No frontier observations found"):
infer_reference_point_from_experiment(experiment)
infer_reference_point_from_experiment(experiment, data=data)

def test_constrained_infer_reference_point_from_experiment(self) -> None:
experiments = []
Expand All @@ -290,14 +293,15 @@ def test_constrained_infer_reference_point_from_experiment(self) -> None:

for experiment in experiments:
# special case logs a warning message.
data = experiment.fetch_data()
if experiment.optimization_config.outcome_constraints[0].bound == 1000.0:
with self.assertLogs(logger, "WARNING"):
inferred_reference_point = infer_reference_point_from_experiment(
experiment
experiment, data=data
)
else:
inferred_reference_point = infer_reference_point_from_experiment(
experiment
experiment, data=data
)
# The nadir point for this experiment is [-0.5, 0.5]. The function actually
# deducts 0.1*Y_range from each of the objectives. Since the range for each
Expand Down Expand Up @@ -377,7 +381,9 @@ def test_infer_reference_point_from_experiment_shuffled_metrics(self) -> None:
obj_t_shuffled,
),
):
inferred_reference_point = infer_reference_point_from_experiment(experiment)
inferred_reference_point = infer_reference_point_from_experiment(
experiment, data=experiment.fetch_data()
)

self.assertEqual(inferred_reference_point[0].op, ComparisonOp.LEQ)
self.assertEqual(inferred_reference_point[0].bound, -0.35)
Expand Down
14 changes: 10 additions & 4 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,16 @@ def completion_criterion(self) -> Tuple[bool, str]:
and len(self.experiment.trials_by_status[TrialStatus.COMPLETED])
>= gss.min_trials
):
# We infer the nadir reference point to be used by the GSS.
self.__inferred_reference_point = infer_reference_point_from_experiment(
self.experiment
)
# only infer reference point if there is data on the experiment.
data = self.experiment.fetch_data()
if not data.df.empty:
# We infer the nadir reference point to be used by the GSS.
self.__inferred_reference_point = (
infer_reference_point_from_experiment(
self.experiment,
data=data,
)
)

stop_optimization, global_stopping_msg = gss.should_stop_optimization(
experiment=self.experiment,
Expand Down
34 changes: 34 additions & 0 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,40 @@ def test_inferring_reference_point(self) -> None:
scheduler.run_n_trials(max_trials=10)
mock_infer_rp.assert_called_once()

def test_inferring_reference_point_no_data(self) -> None:
init_test_engine_and_session_factory(force_init=True)
experiment = get_branin_experiment_with_multi_objective()
experiment.runner = self.runner
gs = self._get_generation_strategy_strategy_for_test(
experiment=experiment,
generation_strategy=self.sobol_GS_no_parallelism,
)

scheduler = Scheduler(
experiment=experiment,
generation_strategy=gs,
options=SchedulerOptions(
# Stops the optimization after 5 trials.
global_stopping_strategy=DummyGlobalStoppingStrategy(
min_trials=0,
trial_to_stop=5,
),
),
db_settings=self.db_settings,
)
empty_data = Data(
df=pd.DataFrame(
columns=["metric_name", "arm_name", "trial_index", "mean", "sem"]
)
)
with patch(
"ax.service.scheduler.infer_reference_point_from_experiment"
) as mock_infer_rp, patch.object(
scheduler.experiment, "fetch_data", return_value=empty_data
):
scheduler.run_n_trials(max_trials=1)
mock_infer_rp.assert_not_called()

def test_global_stopping(self) -> None:
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
Expand Down

0 comments on commit 99615d6

Please sign in to comment.