diff --git a/ax/service/tests/scheduler_test_utils.py b/ax/service/tests/scheduler_test_utils.py index 924dcb04c71..e19f16bb10b 100644 --- a/ax/service/tests/scheduler_test_utils.py +++ b/ax/service/tests/scheduler_test_utils.py @@ -7,7 +7,7 @@ # pyre-strict import os -from datetime import timedelta +from datetime import datetime, timedelta from logging import WARNING from math import ceil from random import randint @@ -76,6 +76,7 @@ SpecialGenerationStrategy, ) from ax.utils.testing.mock import fast_botorch_optimize +from pyre_extensions import none_throws from sqlalchemy.orm.exc import StaleDataError @@ -243,6 +244,33 @@ def run_multiple(self, trials: Iterable[BaseTrial]) -> Dict[int, Dict[str, Any]] raise RuntimeError("Failing for testing purposes.") +class RunnerToAllowMultipleMapMetricFetches(SyntheticRunnerWithStatusPolling): + """``Runner`` that gives a trial 3 seconds to run before considering + the trial completed, which gives us some time to fetch the ``MapMetric`` + a few times, if there is one on the experiment. Useful for testing behavior + with repeated ``MapMetric`` fetches. + """ + + def poll_trial_status( + self, trials: Iterable[BaseTrial] + ) -> Dict[TrialStatus, Set[int]]: + running_trials = next(iter(trials)).experiment.trials_by_status[ + TrialStatus.RUNNING + ] + completed, still_running = set(), set() + for t in running_trials: + # pyre-ignore[58]: Operand is actually supported between these + if datetime.now() - t.time_run_started > timedelta(seconds=3): + completed.add(t.index) + else: + still_running.add(t.index) + + return { + TrialStatus.COMPLETED: completed, + TrialStatus.RUNNING: still_running, + } + + class AxSchedulerTestCase(TestCase): """Tests base `Scheduler` functionality. This test case is meant to test Scheduler using `GenerationStrategy`, but be extensible so @@ -281,6 +309,9 @@ def setUp(self) -> None: self.branin_timestamp_map_metric_experiment = ( get_branin_experiment_with_timestamp_map_metric() ) + self.branin_timestamp_map_metric_experiment.runner = ( + RunnerToAllowMultipleMapMetricFetches() + ) self.runner = SyntheticRunnerWithStatusPolling() self.branin_experiment.runner = self.runner @@ -334,6 +365,7 @@ def runner_registry(self) -> Dict[Type[Runner], int]: BrokenRunnerRuntimeError: 2006, SyntheticRunnerWithSingleRunningTrial: 2007, SyntheticRunnerWithPredictableStatusPolling: 2008, + RunnerToAllowMultipleMapMetricFetches: 2009, **CORE_RUNNER_REGISTRY, } @@ -1061,6 +1093,67 @@ def test_sqa_storage_without_experiment_name(self) -> None: db_settings=self.db_settings, ) + def test_sqa_storage_map_metric_experiment(self) -> None: + init_test_engine_and_session_factory(force_init=True) + gs = self._get_generation_strategy_strategy_for_test( + experiment=self.branin_timestamp_map_metric_experiment, + generation_strategy=self.two_sobol_steps_GS, + ) + self.assertIsNotNone(self.branin_timestamp_map_metric_experiment) + NUM_TRIALS = 5 + scheduler = Scheduler( + experiment=self.branin_timestamp_map_metric_experiment, + generation_strategy=gs, + options=SchedulerOptions( + total_trials=NUM_TRIALS, + init_seconds_between_polls=0, # No wait between polls so test is fast. + ), + db_settings=self.db_settings, + ) + with patch.object( + scheduler.experiment, + "attach_data", + Mock(wraps=scheduler.experiment.attach_data), + ) as mock_experiment_attach_data: + # Artificial timestamp logic so we can later check that it's the + # last-timestamp data that was preserved after multiple `attach_ + # data` calls. + with patch( + f"{Experiment.__module__}.current_timestamp_in_millis", + side_effect=lambda: len( + scheduler.experiment.trials_by_status[TrialStatus.COMPLETED] + ) + * 1000 + + mock_experiment_attach_data.call_count, + ): + scheduler.run_all_trials() + # Check that experiment and GS were saved and test reloading with reduced state. + exp, loaded_gs = scheduler._load_experiment_and_generation_strategy( + self.branin_timestamp_map_metric_experiment.name, reduced_state=True + ) + exp = none_throws(exp) + self.assertEqual(len(exp.trials), NUM_TRIALS) + + # There should only be one data object for each trial, since by default the + # `Scheduler` should override previous data objects when it gets new ones in + # a subsequent `fetch` call. + for _, datas in exp.data_by_trial.items(): + self.assertEqual(len(datas), 1) + + # We also should have attempted the fetch more times + # than there are trials because we have a `MapMetric` (many more since we are + # waiting 3 seconds for each trial). + self.assertGreater(mock_experiment_attach_data.call_count, NUM_TRIALS * 2) + + # Check that it's the last-attached data that was kept, using + # expected value based on logic in mocked "current_timestamp_in_millis" + num_attach_calls = mock_experiment_attach_data.call_count + expected_ts_last_trial = len(exp.trials) * 1000 + num_attach_calls + self.assertEqual( + next(iter(exp.data_by_trial[len(exp.trials) - 1])), + expected_ts_last_trial, + ) + def test_sqa_storage_with_experiment_name(self) -> None: init_test_engine_and_session_factory(force_init=True) gs = self._get_generation_strategy_strategy_for_test( @@ -1083,10 +1176,11 @@ def test_sqa_storage_with_experiment_name(self) -> None: self.branin_experiment.name ) self.assertEqual(exp, self.branin_experiment) + exp = none_throws(exp) self.assertEqual( # pyre-fixme[16]: Add `_generator_runs` back to GSI interface or move # interface to node-level from strategy-level (the latter is likely the - # better option) + # better option) TODO len(gs._generator_runs), len(not_none(loaded_gs)._generator_runs), ) @@ -1095,7 +1189,7 @@ def test_sqa_storage_with_experiment_name(self) -> None: exp, loaded_gs = scheduler._load_experiment_and_generation_strategy( self.branin_experiment.name, reduced_state=True ) - # pyre-fixme[16]: `Optional` has no attribute `trials`. + exp = none_throws(exp) self.assertEqual(len(exp.trials), NUM_TRIALS) # Because of RGS, gs has queued additional unused candidates self.assertGreaterEqual(len(gs._generator_runs), NUM_TRIALS) @@ -1109,7 +1203,6 @@ def test_sqa_storage_with_experiment_name(self) -> None: ) # Hack "resumed from storage timestamp" into `exp` to make sure all other fields # are equal, since difference in resumed from storage timestamps is expected. - # pyre-fixme[16]: `Optional` has no attribute `_properties`. exp._properties[ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS] = ( new_scheduler.experiment._properties[ ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS diff --git a/ax/storage/sqa_store/save.py b/ax/storage/sqa_store/save.py index 6f598e5283f..dd8798497a2 100644 --- a/ax/storage/sqa_store/save.py +++ b/ax/storage/sqa_store/save.py @@ -9,9 +9,10 @@ import os from logging import Logger -from typing import Any, Callable, Dict, List, Optional, Sequence, Union +from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union from ax.core.base_trial import BaseTrial +from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.generator_run import GeneratorRun from ax.core.metric import Metric @@ -215,21 +216,19 @@ def _save_or_update_trials( will also be added to the experiment, but existing data objects in the database will *not* be updated or removed. """ - experiment_id = experiment._db_id - if experiment_id is None: - raise ValueError("Must save experiment first.") + if experiment._db_id is None: + raise ValueError("Must save experiment before saving/updating its trials.") - # pyre-fixme[53]: Captured variable `experiment_id` is not annotated. - # pyre-fixme[3]: Return type must be annotated. - def add_experiment_id(sqa: Union[SQATrial, SQAData]): + experiment_id: int = experiment._db_id + + def add_experiment_id(sqa: Union[SQATrial, SQAData]) -> None: sqa.experiment_id = experiment_id if reduce_state_generator_runs: latest_trial = trials[-1] trials_to_reduce_state = trials[0:-1] - # pyre-fixme[3]: Return type must be annotated. - def trial_to_reduced_state_sqa_encoder(t: BaseTrial): + def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial: return encoder.trial_to_sqa(t, generator_run_reduced_state=True) _bulk_merge_into_session( @@ -259,16 +258,32 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial): batch_size=batch_size, ) - datas = [] - data_encode_args = [] + datas, data_encode_args, datas_to_keep, trial_idcs = [], [], [], [] + data_sqa_class: Type[SQAData] = cast( + Type[SQAData], encoder.config.class_to_sqa_class[Data] + ) for trial in trials: + trial_idcs.append(trial.index) trial_datas = experiment.data_by_trial.get(trial.index, {}) for ts, data in trial_datas.items(): if data.db_id is None: - # Only need to worry about new data, since it's not really possible - # or supported to modify or remove existing data. + # This is data we have not saved before; we should add it to the + # database. Previously saved data for this experiment can be removed. datas.append(data) data_encode_args.append({"trial_index": trial.index, "timestamp": ts}) + else: + datas_to_keep.append(data.db_id) + + # For trials, for which we saved new data, we can first remove previously + # saved data if it's no longer on the experiment. + with session_scope() as session: + session.query(data_sqa_class).filter_by(experiment_id=experiment_id).filter( + data_sqa_class.trial_index.isnot(None) # pyre-ignore[16] + ).filter( + data_sqa_class.trial_index.in_(trial_idcs) # pyre-ignore[16] + ).filter( + data_sqa_class.id.not_in(datas_to_keep) # pyre-ignore[16] + ).delete() _bulk_merge_into_session( objs=datas,