diff --git a/ax/service/scheduler.py b/ax/service/scheduler.py index 072c728a339..1128f7d1f2a 100644 --- a/ax/service/scheduler.py +++ b/ax/service/scheduler.py @@ -230,7 +230,6 @@ def __init__( self.markdown_messages = { "Generation strategy": GS_TYPE_MSG.format(gs_name=generation_strategy.name) } - self._timeout_hours = options.timeout_hours @classmethod def get_default_db_settings(cls) -> DBSettings: @@ -736,7 +735,6 @@ def run_trials_and_yield_results( raise UserInputError( f"Expected `timeout_hours` >= 0, got {timeout_hours}." ) - self._timeout_hours = timeout_hours self._latest_optimization_start_timestamp = current_timestamp_in_millis() self.__ignore_global_stopping_strategy = ignore_global_stopping_strategy @@ -755,7 +753,7 @@ def run_trials_and_yield_results( self._num_remaining_requested_trials > 0 and not self.should_consider_optimization_complete()[0] ): - if self.should_abort_optimization(): + if self.should_abort_optimization(timeout_hours=timeout_hours): yield self._abort_optimization(num_preexisting_trials=n_existing) return @@ -766,7 +764,8 @@ def run_trials_and_yield_results( self.candidate_trials ) while self._num_remaining_requested_trials > 0 and self.run( - max_new_trials=n_remaining_to_generate + max_new_trials=n_remaining_to_generate, + timeout_hours=timeout_hours, ): # Not checking `should_abort_optimization` on every trial for perf. # reasons. @@ -799,7 +798,7 @@ def run_trials_and_yield_results( ) while self.running_trials: - if self.should_abort_optimization(): + if self.should_abort_optimization(timeout_hours=timeout_hours): yield self._abort_optimization(num_preexisting_trials=n_existing) return report_results = self._check_exit_status_and_report_results( @@ -975,7 +974,7 @@ def should_consider_optimization_complete(self) -> tuple[bool, str]: self.logger.info(f"Completing the optimization: {completion_message}.") return should_complete, completion_message - def should_abort_optimization(self) -> bool: + def should_abort_optimization(self, timeout_hours: float | None = None) -> bool: """Checks whether this scheduler has reached some intertuption / abort criterion, such as an overall optimization timeout, tolerated failure rate, etc. """ @@ -985,15 +984,15 @@ def should_abort_optimization(self) -> bool: # if optimization is timed out, return True, else return False timed_out = ( - self._timeout_hours is not None + timeout_hours is not None and self._latest_optimization_start_timestamp is not None and current_timestamp_in_millis() - none_throws(self._latest_optimization_start_timestamp) - >= none_throws(self._timeout_hours) * 60 * 60 * 1000 + >= none_throws(timeout_hours) * 60 * 60 * 1000 ) if timed_out: self.logger.error( - "Optimization timed out (timeout hours: " f"{self._timeout_hours})!" + "Optimization timed out (timeout hours: " f"{timeout_hours})!" ) return timed_out @@ -1179,7 +1178,7 @@ def _check_exit_status_and_report_results( idle_callback, force_refit=True ) - def run(self, max_new_trials: int) -> bool: + def run(self, max_new_trials: int, timeout_hours: float | None = None) -> bool: """Schedules trial evaluation(s) if stopping criterion is not triggered, maximum parallelism is not currently reached, and capacity allows. Logs any failures / issues. @@ -1189,6 +1188,10 @@ def run(self, max_new_trials: int) -> bool: and run (useful when generating and running trials in batches). Note that this function might also re-deploy existing ``CANDIDATE`` trials that failed to deploy before, which will not count against this number. + timeout_hours: Maximum number of hours, for which + to run the optimization. This function will abort after running + for `timeout_hours` even if stopping criterion has not been reached. + If set to `None`, no optimization timeout will be applied. Returns: Boolean representing success status. @@ -1204,7 +1207,7 @@ def run(self, max_new_trials: int) -> bool: ) return False - if self.should_abort_optimization(): + if self.should_abort_optimization(timeout_hours=timeout_hours): self.logger.info( "`should_abort_optimization` is `True`, not running more trials." ) diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 651d4798cb0..26ace9d2bde 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -320,7 +320,7 @@ class AxSchedulerTestCase(TestCase): "min_failed_trials_for_failure_rate_check=5, log_filepath=None, " "logging_level=20, ttl_seconds_for_trials=None, init_seconds_between_" "polls=10, min_seconds_before_poll=1.0, seconds_between_polls_backoff_" - "factor=1.5, timeout_hours=None, run_trials_in_batches=False, " + "factor=1.5, run_trials_in_batches=False, " "debug_log_run_metadata=False, early_stopping_strategy=None, " "global_stopping_strategy=None, suppress_storage_errors_after_" "retries=False, wait_for_running_trials=True, fetch_kwargs={}, " diff --git a/ax/service/tests/test_scheduler.py b/ax/service/tests/test_scheduler.py index 6951295997b..b96e0f186b3 100644 --- a/ax/service/tests/test_scheduler.py +++ b/ax/service/tests/test_scheduler.py @@ -59,7 +59,7 @@ class TestAxSchedulerMultiTypeExperiment(AxSchedulerTestCase): "min_failed_trials_for_failure_rate_check=5, log_filepath=None, " "logging_level=20, ttl_seconds_for_trials=None, init_seconds_between_" "polls=10, min_seconds_before_poll=1.0, seconds_between_polls_backoff_" - "factor=1.5, timeout_hours=None, run_trials_in_batches=False, " + "factor=1.5, run_trials_in_batches=False, " "debug_log_run_metadata=False, early_stopping_strategy=None, " "global_stopping_strategy=None, suppress_storage_errors_after_" "retries=False, wait_for_running_trials=True, fetch_kwargs={}, " diff --git a/ax/service/utils/scheduler_options.py b/ax/service/utils/scheduler_options.py index aca015fa8a3..14564a4f9b9 100644 --- a/ax/service/utils/scheduler_options.py +++ b/ax/service/utils/scheduler_options.py @@ -135,7 +135,6 @@ class SchedulerOptions: init_seconds_between_polls: int | None = 1 min_seconds_before_poll: float = 1.0 seconds_between_polls_backoff_factor: float = 1.5 - timeout_hours: float | None = None run_trials_in_batches: bool = False debug_log_run_metadata: bool = False early_stopping_strategy: BaseEarlyStoppingStrategy | None = None