Skip to content

Commit

Permalink
only infer reference point in global stopping if there is data
Browse files Browse the repository at this point in the history
Summary: see title

Reviewed By: Balandat

Differential Revision: D55924472
  • Loading branch information
sdaulton authored and facebook-github-bot committed Apr 9, 2024
1 parent a161618 commit 1ad750f
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 6 deletions.
5 changes: 3 additions & 2 deletions ax/plot/pareto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,7 +579,7 @@ def _build_new_optimization_config(


def infer_reference_point_from_experiment(
experiment: Experiment,
experiment: Experiment, data: Data
) -> List[ObjectiveThreshold]:
"""This functions is a wrapper around ``infer_reference_point`` to find the nadir
point from the pareto front of an experiment. Aside from converting experiment
Expand All @@ -600,7 +600,8 @@ def infer_reference_point_from_experiment(

# Reading experiment data.
mb_reference = get_tensor_converter_model(
experiment=experiment, data=experiment.fetch_data()
experiment=experiment,
data=data,
)
obs_feats, obs_data, _ = _get_modelbridge_training_data(modelbridge=mb_reference)

Expand Down
14 changes: 10 additions & 4 deletions ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,10 +450,16 @@ def completion_criterion(self) -> Tuple[bool, str]:
and len(self.experiment.trials_by_status[TrialStatus.COMPLETED])
>= gss.min_trials
):
# We infer the nadir reference point to be used by the GSS.
self.__inferred_reference_point = infer_reference_point_from_experiment(
self.experiment
)
# only infer reference point if there is data on the experiment.
data = self.experiment.fetch_data()
if not data.df.empty:
# We infer the nadir reference point to be used by the GSS.
self.__inferred_reference_point = (
infer_reference_point_from_experiment(
self.experiment,
data=data,
)
)

stop_optimization, global_stopping_msg = gss.should_stop_optimization(
experiment=self.experiment,
Expand Down
34 changes: 34 additions & 0 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -820,6 +820,40 @@ def test_inferring_reference_point(self) -> None:
scheduler.run_n_trials(max_trials=10)
mock_infer_rp.assert_called_once()

def test_inferring_reference_point_no_data(self) -> None:
init_test_engine_and_session_factory(force_init=True)
experiment = get_branin_experiment_with_multi_objective()
experiment.runner = self.runner
gs = self._get_generation_strategy_strategy_for_test(
experiment=experiment,
generation_strategy=self.sobol_GS_no_parallelism,
)

scheduler = Scheduler(
experiment=experiment,
generation_strategy=gs,
options=SchedulerOptions(
# Stops the optimization after 5 trials.
global_stopping_strategy=DummyGlobalStoppingStrategy(
min_trials=0,
trial_to_stop=5,
),
),
db_settings=self.db_settings,
)
empty_data = Data(
df=pd.DataFrame(
columns=["metric_name", "arm_name", "trial_index", "mean", "sem"]
)
)
with patch(
"ax.service.scheduler.infer_reference_point_from_experiment"
) as mock_infer_rp, patch.object(
scheduler.experiment, "fetch_data", return_value=empty_data
):
scheduler.run_n_trials(max_trials=1)
mock_infer_rp.assert_not_called()

def test_global_stopping(self) -> None:
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
Expand Down

0 comments on commit 1ad750f

Please sign in to comment.