Skip to content

Commit

Permalink
Mark candidate trials failed before generation (#2981)
Browse files Browse the repository at this point in the history
Summary:

This is done for 2 reasons:
- We don't want a huge queue of stale candidates
- We don't want candidates that we're overwriting used as pending points

Reviewed By: Cesar-Cardoso

Differential Revision: D65092785
  • Loading branch information
Daniel Cohen authored and facebook-github-bot committed Oct 28, 2024
1 parent 418fa6c commit 3f29cfd
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 1 deletion.
15 changes: 14 additions & 1 deletion ax/service/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,7 @@ def generate_candidates(
self,
num_trials: int = 1,
reduce_state_generator_runs: bool = False,
remove_stale_candidates: bool = False,
) -> tuple[list[BaseTrial], Exception | None]:
"""Fetch the latest data and generate new candidate trials.
Expand All @@ -501,10 +502,22 @@ def generate_candidates(
whether to save model state for every generator run (default)
or to only save model state on the final generator run of each
batch.
remove_stale_candidates: If true, mark any existing candidate trials
failed before trial generation because:
- they should not be treated as pending points
- they will no longer be relevant
Returns:
List of trials, empty if generation is not possible.
"""
if remove_stale_candidates:
stale_candidate_trials = self.experiment.trials_by_status[
TrialStatus.CANDIDATE
]
for trial in stale_candidate_trials:
trial.mark_failed(reason="Newer candidates generated.", unsafe=True)
else:
stale_candidate_trials = []
new_trials, err = self._get_next_trials(
num_trials=num_trials,
n=self.options.batch_size,
Expand All @@ -513,7 +526,7 @@ def generate_candidates(
new_generator_runs = [gr for t in new_trials for gr in t.generator_runs]
self._save_or_update_trials_and_generation_strategy_if_possible(
experiment=self.experiment,
trials=new_trials,
trials=new_trials + stale_candidate_trials,
generation_strategy=self.generation_strategy,
new_generator_runs=new_generator_runs,
reduce_state_generator_runs=reduce_state_generator_runs,
Expand Down
140 changes: 140 additions & 0 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2460,6 +2460,146 @@ def test_generate_candidates_works_for_sobol(self) -> None:
options.batch_size,
)

def test_generate_candidates_can_remove_stale_candidates(self) -> None:
init_test_engine_and_session_factory(force_init=True)
# GIVEN a scheduler using a GS with MBM.
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.two_sobol_steps_GS,
)

# this is a HITL experiment, so we don't want trials completing on their own.
if isinstance(self.branin_experiment, MultiTypeExperiment):
self.branin_experiment.update_runner("type1", InfinitePollRunner())
else:
self.branin_experiment.runner = InfinitePollRunner()
options = SchedulerOptions(
init_seconds_between_polls=0, # No wait bw polls so test is fast.
batch_size=10,
trial_type=TrialType.BATCH_TRIAL,
**self.scheduler_options_kwargs,
)
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=gs,
options=options,
db_settings=self.db_settings,
)

# WHEN generating candidates on a new experiment twice
scheduler.generate_candidates(num_trials=1)
scheduler.generate_candidates(num_trials=1, remove_stale_candidates=True)

# THEN the first candidate should be failed
scheduler = Scheduler.from_stored_experiment(
experiment_name=self.branin_experiment.name,
options=options,
db_settings=self.db_settings,
)
self.assertEqual(len(scheduler.experiment.trials), 2)
self.assertEqual(
scheduler.experiment.trials[0].status,
TrialStatus.FAILED,
)
self.assertEqual(
scheduler.experiment.trials[0].failed_reason, "Newer candidates generated."
)
self.assertEqual(
scheduler.experiment.trials[1].status,
TrialStatus.CANDIDATE,
)

def test_generate_candidates_can_choose_not_to_remove_stale_candidates(
self,
) -> None:
init_test_engine_and_session_factory(force_init=True)
# GIVEN a scheduler using a GS with MBM.
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.two_sobol_steps_GS,
)

# this is a HITL experiment, so we don't want trials completing on their own.
if isinstance(self.branin_experiment, MultiTypeExperiment):
self.branin_experiment.update_runner("type1", InfinitePollRunner())
else:
self.branin_experiment.runner = InfinitePollRunner()
options = SchedulerOptions(
init_seconds_between_polls=0, # No wait bw polls so test is fast.
batch_size=10,
trial_type=TrialType.BATCH_TRIAL,
**self.scheduler_options_kwargs,
)
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=gs,
options=options,
db_settings=self.db_settings,
)

# WHEN generating candidates on a new experiment twice
scheduler.generate_candidates(num_trials=1)
scheduler.generate_candidates(num_trials=1, remove_stale_candidates=False)

# THEN the first candidate should be failed
scheduler = Scheduler.from_stored_experiment(
experiment_name=self.branin_experiment.name,
options=options,
db_settings=self.db_settings,
)
self.assertEqual(len(scheduler.experiment.trials), 2)
self.assertEqual(
len(scheduler.experiment.trials_by_status[TrialStatus.CANDIDATE]),
2,
)

def test_generate_candidates_does_not_fail_stale_candidates_if_fails_to_gen(
self,
) -> None:
init_test_engine_and_session_factory(force_init=True)
# GIVEN a scheduler using a GS with MBM.
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_experiment,
generation_strategy=self.two_sobol_steps_GS,
)

# this is a HITL experiment, so we don't want trials completing on their own.
if isinstance(self.branin_experiment, MultiTypeExperiment):
self.branin_experiment.update_runner("type1", InfinitePollRunner())
else:
self.branin_experiment.runner = InfinitePollRunner()
options = SchedulerOptions(
init_seconds_between_polls=0, # No wait bw polls so test is fast.
batch_size=10,
trial_type=TrialType.BATCH_TRIAL,
**self.scheduler_options_kwargs,
)
scheduler = Scheduler(
experiment=self.branin_experiment,
generation_strategy=gs,
options=options,
db_settings=self.db_settings,
)

# WHEN generating candidates on a new experiment twice
scheduler.generate_candidates(num_trials=1)
with patch.object(
Scheduler, "_gen_new_trials_from_generation_strategy", return_value=[]
):
scheduler.generate_candidates(num_trials=1, remove_stale_candidates=True)

# THEN the first candidate should be failed
scheduler = Scheduler.from_stored_experiment(
experiment_name=self.branin_experiment.name,
options=options,
db_settings=self.db_settings,
)
self.assertEqual(len(scheduler.experiment.trials), 1)
self.assertEqual(
len(scheduler.experiment.trials_by_status[TrialStatus.CANDIDATE]),
1,
)

def test_generate_candidates_works_with_status_quo(self) -> None:
# GIVEN a scheduler with an experiment that has a status quo
self.branin_experiment.status_quo = Arm(parameters={"x1": 0.0, "x2": 0.0})
Expand Down

0 comments on commit 3f29cfd

Please sign in to comment.