diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 58fa30b4d35..39a0768b168 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -492,6 +492,7 @@ def generate_candidates( self, num_trials: int = 1, reduce_state_generator_runs: bool = False, + remove_stale_candidates: bool = False, ) -> tuple[list[BaseTrial], Exception | None]: """Fetch the latest data and generate new candidate trials. @@ -501,10 +502,22 @@ def generate_candidates( whether to save model state for every generator run (default) or to only save model state on the final generator run of each batch. + remove_stale_candidates: If true, mark any existing candidate trials + failed before trial generation because: + - they should not be treated as pending points + - they will no longer be relevant Returns: List of trials, empty if generation is not possible. """ + if remove_stale_candidates: + stale_candidate_trials = self.experiment.trials_by_status[ + TrialStatus.CANDIDATE + ] + for trial in stale_candidate_trials: + trial.mark_failed(reason="Newer candidates generated.", unsafe=True) + else: + stale_candidate_trials = [] new_trials, err = self._get_next_trials( num_trials=num_trials, n=self.options.batch_size, @@ -513,7 +526,7 @@ def generate_candidates( new_generator_runs = [gr for t in new_trials for gr in t.generator_runs] self._save_or_update_trials_and_generation_strategy_if_possible( experiment=self.experiment, - trials=new_trials, + trials=new_trials + stale_candidate_trials, generation_strategy=self.generation_strategy, new_generator_runs=new_generator_runs, reduce_state_generator_runs=reduce_state_generator_runs, diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index e603d94c500..dce5c7da894 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -2460,6 +2460,146 @@ def test_generate_candidates_works_for_sobol(self) -> None: options.batch_size, ) + def test_generate_candidates_can_remove_stale_candidates(self) -> None: + init_test_engine_and_session_factory(force_init=True) + # GIVEN a scheduler using a GS with MBM. + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) + + # this is a HITL experiment, so we don't want trials completing on their own. + if isinstance(self.branin_experiment, MultiTypeExperiment): + self.branin_experiment.update_runner("type1", InfinitePollRunner()) + else: + self.branin_experiment.runner = InfinitePollRunner() + options = SchedulerOptions( + init_seconds_between_polls=0, # No wait bw polls so test is fast. + batch_size=10, + trial_type=TrialType.BATCH_TRIAL, + **self.scheduler_options_kwargs, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, + options=options, + db_settings=self.db_settings, + ) + + # WHEN generating candidates on a new experiment twice + scheduler.generate_candidates(num_trials=1) + scheduler.generate_candidates(num_trials=1, remove_stale_candidates=True) + + # THEN the first candidate should be failed + scheduler = Scheduler.from_stored_experiment( + experiment_name=self.branin_experiment.name, + options=options, + db_settings=self.db_settings, + ) + self.assertEqual(len(scheduler.experiment.trials), 2) + self.assertEqual( + scheduler.experiment.trials[0].status, + TrialStatus.FAILED, + ) + self.assertEqual( + scheduler.experiment.trials[0].failed_reason, "Newer candidates generated." + ) + self.assertEqual( + scheduler.experiment.trials[1].status, + TrialStatus.CANDIDATE, + ) + + def test_generate_candidates_can_choose_not_to_remove_stale_candidates( + self, + ) -> None: + init_test_engine_and_session_factory(force_init=True) + # GIVEN a scheduler using a GS with MBM. + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) + + # this is a HITL experiment, so we don't want trials completing on their own. + if isinstance(self.branin_experiment, MultiTypeExperiment): + self.branin_experiment.update_runner("type1", InfinitePollRunner()) + else: + self.branin_experiment.runner = InfinitePollRunner() + options = SchedulerOptions( + init_seconds_between_polls=0, # No wait bw polls so test is fast. + batch_size=10, + trial_type=TrialType.BATCH_TRIAL, + **self.scheduler_options_kwargs, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, + options=options, + db_settings=self.db_settings, + ) + + # WHEN generating candidates on a new experiment twice + scheduler.generate_candidates(num_trials=1) + scheduler.generate_candidates(num_trials=1, remove_stale_candidates=False) + + # THEN the first candidate should be failed + scheduler = Scheduler.from_stored_experiment( + experiment_name=self.branin_experiment.name, + options=options, + db_settings=self.db_settings, + ) + self.assertEqual(len(scheduler.experiment.trials), 2) + self.assertEqual( + len(scheduler.experiment.trials_by_status[TrialStatus.CANDIDATE]), + 2, + ) + + def test_generate_candidates_does_not_fail_stale_candidates_if_fails_to_gen( + self, + ) -> None: + init_test_engine_and_session_factory(force_init=True) + # GIVEN a scheduler using a GS with MBM. + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) + + # this is a HITL experiment, so we don't want trials completing on their own. + if isinstance(self.branin_experiment, MultiTypeExperiment): + self.branin_experiment.update_runner("type1", InfinitePollRunner()) + else: + self.branin_experiment.runner = InfinitePollRunner() + options = SchedulerOptions( + init_seconds_between_polls=0, # No wait bw polls so test is fast. + batch_size=10, + trial_type=TrialType.BATCH_TRIAL, + **self.scheduler_options_kwargs, + ) + scheduler = Scheduler( + experiment=self.branin_experiment, + generation_strategy=gs, + options=options, + db_settings=self.db_settings, + ) + + # WHEN generating candidates on a new experiment twice + scheduler.generate_candidates(num_trials=1) + with patch.object( + Scheduler, "_gen_new_trials_from_generation_strategy", return_value=[] + ): + scheduler.generate_candidates(num_trials=1, remove_stale_candidates=True) + + # THEN the first candidate should be failed + scheduler = Scheduler.from_stored_experiment( + experiment_name=self.branin_experiment.name, + options=options, + db_settings=self.db_settings, + ) + self.assertEqual(len(scheduler.experiment.trials), 1) + self.assertEqual( + len(scheduler.experiment.trials_by_status[TrialStatus.CANDIDATE]), + 1, + ) + def test_generate_candidates_works_with_status_quo(self) -> None: # GIVEN a scheduler with an experiment that has a status quo self.branin_experiment.status_quo = Arm(parameters={"x1": 0.0, "x2": 0.0})