Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix issue with SQA storage never removing Data objects + upgrade testing for scheduler with MapData intermediate results #2276

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 97 additions & 4 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

import os
from datetime import timedelta
from datetime import datetime, timedelta
from logging import WARNING
from math import ceil
from random import randint
Expand Down Expand Up @@ -76,6 +76,7 @@
SpecialGenerationStrategy,
)
from ax.utils.testing.mock import fast_botorch_optimize
from pyre_extensions import none_throws

from sqlalchemy.orm.exc import StaleDataError

Expand Down Expand Up @@ -243,6 +244,33 @@ def run_multiple(self, trials: Iterable[BaseTrial]) -> Dict[int, Dict[str, Any]]
raise RuntimeError("Failing for testing purposes.")


class RunnerToAllowMultipleMapMetricFetches(SyntheticRunnerWithStatusPolling):
"""``Runner`` that gives a trial 3 seconds to run before considering
the trial completed, which gives us some time to fetch the ``MapMetric``
a few times, if there is one on the experiment. Useful for testing behavior
with repeated ``MapMetric`` fetches.
"""

def poll_trial_status(
self, trials: Iterable[BaseTrial]
) -> Dict[TrialStatus, Set[int]]:
running_trials = next(iter(trials)).experiment.trials_by_status[
TrialStatus.RUNNING
]
completed, still_running = set(), set()
for t in running_trials:
# pyre-ignore[58]: Operand is actually supported between these
if datetime.now() - t.time_run_started > timedelta(seconds=3):
completed.add(t.index)
else:
still_running.add(t.index)

return {
TrialStatus.COMPLETED: completed,
TrialStatus.RUNNING: still_running,
}


class AxSchedulerTestCase(TestCase):
"""Tests base `Scheduler` functionality. This test case is meant to
test Scheduler using `GenerationStrategy`, but be extensible so
Expand Down Expand Up @@ -281,6 +309,9 @@ def setUp(self) -> None:
self.branin_timestamp_map_metric_experiment = (
get_branin_experiment_with_timestamp_map_metric()
)
self.branin_timestamp_map_metric_experiment.runner = (
RunnerToAllowMultipleMapMetricFetches()
)

self.runner = SyntheticRunnerWithStatusPolling()
self.branin_experiment.runner = self.runner
Expand Down Expand Up @@ -334,6 +365,7 @@ def runner_registry(self) -> Dict[Type[Runner], int]:
BrokenRunnerRuntimeError: 2006,
SyntheticRunnerWithSingleRunningTrial: 2007,
SyntheticRunnerWithPredictableStatusPolling: 2008,
RunnerToAllowMultipleMapMetricFetches: 2009,
**CORE_RUNNER_REGISTRY,
}

Expand Down Expand Up @@ -1061,6 +1093,67 @@ def test_sqa_storage_without_experiment_name(self) -> None:
db_settings=self.db_settings,
)

def test_sqa_storage_map_metric_experiment(self) -> None:
init_test_engine_and_session_factory(force_init=True)
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_timestamp_map_metric_experiment,
generation_strategy=self.two_sobol_steps_GS,
)
self.assertIsNotNone(self.branin_timestamp_map_metric_experiment)
NUM_TRIALS = 5
scheduler = Scheduler(
experiment=self.branin_timestamp_map_metric_experiment,
generation_strategy=gs,
options=SchedulerOptions(
total_trials=NUM_TRIALS,
init_seconds_between_polls=0, # No wait between polls so test is fast.
),
db_settings=self.db_settings,
)
with patch.object(
scheduler.experiment,
"attach_data",
Mock(wraps=scheduler.experiment.attach_data),
) as mock_experiment_attach_data:
# Artificial timestamp logic so we can later check that it's the
# last-timestamp data that was preserved after multiple `attach_
# data` calls.
with patch(
f"{Experiment.__module__}.current_timestamp_in_millis",
side_effect=lambda: len(
scheduler.experiment.trials_by_status[TrialStatus.COMPLETED]
)
* 1000
+ mock_experiment_attach_data.call_count,
):
scheduler.run_all_trials()
# Check that experiment and GS were saved and test reloading with reduced state.
exp, loaded_gs = scheduler._load_experiment_and_generation_strategy(
self.branin_timestamp_map_metric_experiment.name, reduced_state=True
)
exp = none_throws(exp)
self.assertEqual(len(exp.trials), NUM_TRIALS)

# There should only be one data object for each trial, since by default the
# `Scheduler` should override previous data objects when it gets new ones in
# a subsequent `fetch` call.
for _, datas in exp.data_by_trial.items():
self.assertEqual(len(datas), 1)

# We also should have attempted the fetch more times
# than there are trials because we have a `MapMetric` (many more since we are
# waiting 3 seconds for each trial).
self.assertGreater(mock_experiment_attach_data.call_count, NUM_TRIALS * 2)

# Check that it's the last-attached data that was kept, using
# expected value based on logic in mocked "current_timestamp_in_millis"
num_attach_calls = mock_experiment_attach_data.call_count
expected_ts_last_trial = len(exp.trials) * 1000 + num_attach_calls
self.assertEqual(
next(iter(exp.data_by_trial[len(exp.trials) - 1])),
expected_ts_last_trial,
)

def test_sqa_storage_with_experiment_name(self) -> None:
init_test_engine_and_session_factory(force_init=True)
gs = self._get_generation_strategy_strategy_for_test(
Expand All @@ -1083,10 +1176,11 @@ def test_sqa_storage_with_experiment_name(self) -> None:
self.branin_experiment.name
)
self.assertEqual(exp, self.branin_experiment)
exp = none_throws(exp)
self.assertEqual(
# pyre-fixme[16]: Add `_generator_runs` back to GSI interface or move
# interface to node-level from strategy-level (the latter is likely the
# better option)
# better option) TODO
len(gs._generator_runs),
len(not_none(loaded_gs)._generator_runs),
)
Expand All @@ -1095,7 +1189,7 @@ def test_sqa_storage_with_experiment_name(self) -> None:
exp, loaded_gs = scheduler._load_experiment_and_generation_strategy(
self.branin_experiment.name, reduced_state=True
)
# pyre-fixme[16]: `Optional` has no attribute `trials`.
exp = none_throws(exp)
self.assertEqual(len(exp.trials), NUM_TRIALS)
# Because of RGS, gs has queued additional unused candidates
self.assertGreaterEqual(len(gs._generator_runs), NUM_TRIALS)
Expand All @@ -1109,7 +1203,6 @@ def test_sqa_storage_with_experiment_name(self) -> None:
)
# Hack "resumed from storage timestamp" into `exp` to make sure all other fields
# are equal, since difference in resumed from storage timestamps is expected.
# pyre-fixme[16]: `Optional` has no attribute `_properties`.
exp._properties[ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS] = (
new_scheduler.experiment._properties[
ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS
Expand Down
41 changes: 28 additions & 13 deletions ax/storage/sqa_store/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import os

from logging import Logger
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union

from ax.core.base_trial import BaseTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
Expand Down Expand Up @@ -215,21 +216,19 @@ def _save_or_update_trials(
will also be added to the experiment, but existing data objects in the
database will *not* be updated or removed.
"""
experiment_id = experiment._db_id
if experiment_id is None:
raise ValueError("Must save experiment first.")
if experiment._db_id is None:
raise ValueError("Must save experiment before saving/updating its trials.")

# pyre-fixme[53]: Captured variable `experiment_id` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
def add_experiment_id(sqa: Union[SQATrial, SQAData]):
experiment_id: int = experiment._db_id

def add_experiment_id(sqa: Union[SQATrial, SQAData]) -> None:
sqa.experiment_id = experiment_id

if reduce_state_generator_runs:
latest_trial = trials[-1]
trials_to_reduce_state = trials[0:-1]

# pyre-fixme[3]: Return type must be annotated.
def trial_to_reduced_state_sqa_encoder(t: BaseTrial):
def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial:
return encoder.trial_to_sqa(t, generator_run_reduced_state=True)

_bulk_merge_into_session(
Expand Down Expand Up @@ -259,16 +258,32 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial):
batch_size=batch_size,
)

datas = []
data_encode_args = []
datas, data_encode_args, datas_to_keep, trial_idcs = [], [], [], []
data_sqa_class: Type[SQAData] = cast(
Type[SQAData], encoder.config.class_to_sqa_class[Data]
)
for trial in trials:
trial_idcs.append(trial.index)
trial_datas = experiment.data_by_trial.get(trial.index, {})
for ts, data in trial_datas.items():
if data.db_id is None:
# Only need to worry about new data, since it's not really possible
# or supported to modify or remove existing data.
# This is data we have not saved before; we should add it to the
# database. Previously saved data for this experiment can be removed.
datas.append(data)
data_encode_args.append({"trial_index": trial.index, "timestamp": ts})
else:
datas_to_keep.append(data.db_id)

# For trials, for which we saved new data, we can first remove previously
# saved data if it's no longer on the experiment.
with session_scope() as session:
session.query(data_sqa_class).filter_by(experiment_id=experiment_id).filter(
data_sqa_class.trial_index.isnot(None) # pyre-ignore[16]
).filter(
data_sqa_class.trial_index.in_(trial_idcs) # pyre-ignore[16]
).filter(
data_sqa_class.id.not_in(datas_to_keep) # pyre-ignore[16]
).delete()

_bulk_merge_into_session(
objs=datas,
Expand Down
Loading