Skip to content

Commit

Permalink
Fix issue with SQA storage never removing Data objects + upgrade te…
Browse files Browse the repository at this point in the history
…sting for scheduler with `MapData` intermediate results

Summary: As titled

Differential Revision: D54879641
  • Loading branch information
Lena Kashtelyan authored and facebook-github-bot committed Mar 14, 2024
1 parent cbb17c3 commit c2cb684
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 17 deletions.
102 changes: 98 additions & 4 deletions ax/service/tests/scheduler_test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
# pyre-strict

import os
from datetime import timedelta
from datetime import datetime, timedelta
from logging import WARNING
from math import ceil
from random import randint
Expand Down Expand Up @@ -76,6 +76,7 @@
SpecialGenerationStrategy,
)
from ax.utils.testing.mock import fast_botorch_optimize
from pyre_extensions import none_throws

from sqlalchemy.orm.exc import StaleDataError

Expand Down Expand Up @@ -243,6 +244,33 @@ def run_multiple(self, trials: Iterable[BaseTrial]) -> Dict[int, Dict[str, Any]]
raise RuntimeError("Failing for testing purposes.")


class RunnerToAllowMultipleMapMetricFetches(SyntheticRunnerWithStatusPolling):
"""``Runner`` that gives a trial 3 seconds to run before considering
the trial completed, which gives us some time to fetch the ``MapMetric``
a few times, if there is one on the experiment. Useful for testing behavior
with repeated ``MapMetric`` fetches.
"""

def poll_trial_status(
self, trials: Iterable[BaseTrial]
) -> Dict[TrialStatus, Set[int]]:
running_trials = next(iter(trials)).experiment.trials_by_status[
TrialStatus.RUNNING
]
completed, still_running = set(), set()
for t in running_trials:
# pyre-ignore[58]: Operand is actually supported between these
if datetime.now() - t.time_run_started > timedelta(seconds=3):
completed.add(t.index)
else:
still_running.add(t.index)

return {
TrialStatus.COMPLETED: completed,
TrialStatus.RUNNING: still_running,
}


class AxSchedulerTestCase(TestCase):
"""Tests base `Scheduler` functionality. This test case is meant to
test Scheduler using `GenerationStrategy`, but be extensible so
Expand Down Expand Up @@ -281,6 +309,9 @@ def setUp(self) -> None:
self.branin_timestamp_map_metric_experiment = (
get_branin_experiment_with_timestamp_map_metric()
)
self.branin_timestamp_map_metric_experiment.runner = (
RunnerToAllowMultipleMapMetricFetches()
)

self.runner = SyntheticRunnerWithStatusPolling()
self.branin_experiment.runner = self.runner
Expand Down Expand Up @@ -334,6 +365,7 @@ def runner_registry(self) -> Dict[Type[Runner], int]:
BrokenRunnerRuntimeError: 2006,
SyntheticRunnerWithSingleRunningTrial: 2007,
SyntheticRunnerWithPredictableStatusPolling: 2008,
RunnerToAllowMultipleMapMetricFetches: 2009,
**CORE_RUNNER_REGISTRY,
}

Expand Down Expand Up @@ -1061,6 +1093,67 @@ def test_sqa_storage_without_experiment_name(self) -> None:
db_settings=self.db_settings,
)

def test_sqa_storage_map_metric_experiment(self) -> None:
init_test_engine_and_session_factory(force_init=True)
gs = self._get_generation_strategy_strategy_for_test(
experiment=self.branin_timestamp_map_metric_experiment,
generation_strategy=self.two_sobol_steps_GS,
)
self.assertIsNotNone(self.branin_timestamp_map_metric_experiment)
NUM_TRIALS = 5
scheduler = Scheduler(
experiment=self.branin_timestamp_map_metric_experiment,
generation_strategy=gs,
options=SchedulerOptions(
total_trials=NUM_TRIALS,
init_seconds_between_polls=0, # No wait between polls so test is fast.
),
db_settings=self.db_settings,
)
with patch.object(
scheduler.experiment,
"attach_data",
Mock(wraps=scheduler.experiment.attach_data),
) as mock_experiment_attach_data:
# Artificial timestamp logic so we can later check that it's the
# last-timestamp data that was preserved after multiple `attach_
# data` calls.
with patch(
f"{Experiment.__module__}.current_timestamp_in_millis",
side_effect=lambda: len(
scheduler.experiment.trials_by_status[TrialStatus.COMPLETED]
)
* 1000
+ mock_experiment_attach_data.call_count,
):
scheduler.run_all_trials()
# Check that experiment and GS were saved and test reloading with reduced state.
exp, loaded_gs = scheduler._load_experiment_and_generation_strategy(
self.branin_timestamp_map_metric_experiment.name, reduced_state=True
)
exp = none_throws(exp)
self.assertEqual(len(exp.trials), NUM_TRIALS)

# There should only be one data object for each trial, since by default the
# `Scheduler` should override previous data objects when it gets new ones in
# a subsequent `fetch` call.
for _, datas in exp.data_by_trial.items():
self.assertEqual(len(datas), 1)

# We also should have attempted the fetch more times
# than there are trials because we have a `MapMetric` (many more since we are
# waiting 3 seconds for each trial).
self.assertGreater(mock_experiment_attach_data.call_count, NUM_TRIALS * 2)

# Check that it's the last-attached data that was kept, using
# expected value based on logic in mocked "current_timestamp_in_millis"
num_attach_calls = mock_experiment_attach_data.call_count
expected_ts_last_trial = len(exp.trials) * 1000 + num_attach_calls
self.assertEqual(
next(iter(exp.data_by_trial[len(exp.trials) - 1])),
expected_ts_last_trial,
)

def test_sqa_storage_with_experiment_name(self) -> None:
init_test_engine_and_session_factory(force_init=True)
gs = self._get_generation_strategy_strategy_for_test(
Expand All @@ -1083,10 +1176,12 @@ def test_sqa_storage_with_experiment_name(self) -> None:
self.branin_experiment.name
)
self.assertEqual(exp, self.branin_experiment)
exp = none_throws(exp)
self.assertEqual(exp, self.branin_timestamp_map_metric_experiment)
self.assertEqual(
# pyre-fixme[16]: Add `_generator_runs` back to GSI interface or move
# interface to node-level from strategy-level (the latter is likely the
# better option)
# better option) TODO
len(gs._generator_runs),
len(not_none(loaded_gs)._generator_runs),
)
Expand All @@ -1095,7 +1190,7 @@ def test_sqa_storage_with_experiment_name(self) -> None:
exp, loaded_gs = scheduler._load_experiment_and_generation_strategy(
self.branin_experiment.name, reduced_state=True
)
# pyre-fixme[16]: `Optional` has no attribute `trials`.
exp = none_throws(exp)
self.assertEqual(len(exp.trials), NUM_TRIALS)
# Because of RGS, gs has queued additional unused candidates
self.assertGreaterEqual(len(gs._generator_runs), NUM_TRIALS)
Expand All @@ -1109,7 +1204,6 @@ def test_sqa_storage_with_experiment_name(self) -> None:
)
# Hack "resumed from storage timestamp" into `exp` to make sure all other fields
# are equal, since difference in resumed from storage timestamps is expected.
# pyre-fixme[16]: `Optional` has no attribute `_properties`.
exp._properties[ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS] = (
new_scheduler.experiment._properties[
ExperimentStatusProperties.RESUMED_FROM_STORAGE_TIMESTAMPS
Expand Down
45 changes: 32 additions & 13 deletions ax/storage/sqa_store/save.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@
import os

from logging import Logger
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Type, Union

from ax.core.base_trial import BaseTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.generator_run import GeneratorRun
from ax.core.metric import Metric
Expand Down Expand Up @@ -215,21 +216,19 @@ def _save_or_update_trials(
will also be added to the experiment, but existing data objects in the
database will *not* be updated or removed.
"""
experiment_id = experiment._db_id
if experiment_id is None:
raise ValueError("Must save experiment first.")
if experiment._db_id is None:
raise ValueError("Must save experiment before saving/updating its trials.")

# pyre-fixme[53]: Captured variable `experiment_id` is not annotated.
# pyre-fixme[3]: Return type must be annotated.
def add_experiment_id(sqa: Union[SQATrial, SQAData]):
experiment_id: int = experiment._db_id

def add_experiment_id(sqa: Union[SQATrial, SQAData]) -> None:
sqa.experiment_id = experiment_id

if reduce_state_generator_runs:
latest_trial = trials[-1]
trials_to_reduce_state = trials[0:-1]

# pyre-fixme[3]: Return type must be annotated.
def trial_to_reduced_state_sqa_encoder(t: BaseTrial):
def trial_to_reduced_state_sqa_encoder(t: BaseTrial) -> SQATrial:
return encoder.trial_to_sqa(t, generator_run_reduced_state=True)

_bulk_merge_into_session(
Expand Down Expand Up @@ -259,16 +258,36 @@ def trial_to_reduced_state_sqa_encoder(t: BaseTrial):
batch_size=batch_size,
)

datas = []
data_encode_args = []
datas, data_encode_args, datas_to_keep, trial_idcs = [], [], [], []
data_sqa_class: Type[SQAData] = cast(
Type[SQAData], encoder.config.class_to_sqa_class[Data]
)
for trial in trials:
trial_idcs.append(trial.index)
trial_datas = experiment.data_by_trial.get(trial.index, {})
for ts, data in trial_datas.items():
if data.db_id is None:
# Only need to worry about new data, since it's not really possible
# or supported to modify or remove existing data.
# This is data we have not saved before; we should add it to the
# database. Previously saved data for this experiment can be removed.
datas.append(data)
data_encode_args.append({"trial_index": trial.index, "timestamp": ts})
else:
datas_to_keep.append(data.db_id)

# For trials, for which we saved new data, we can first remove previously
# saved data if it's no longer on the experiment.
# TODO: Consider doing this only for trials for which we saved new data
# (reason: why else would we need to remove data?), but on first glance that
# seems overthought; we should just bring the DB object in alignment with
# the new state of the experiment.
with session_scope() as session:
session.query(data_sqa_class).filter_by(experiment_id=experiment_id).filter(
data_sqa_class.trial_index.isnot(None) # pyre-ignore[16]
).filter(
data_sqa_class.trial_index.in_(trial_idcs) # pyre-ignore[16]
).filter(
data_sqa_class.id.not_in(datas_to_keep) # pyre-ignore[16]
).delete()

_bulk_merge_into_session(
objs=datas,
Expand Down

0 comments on commit c2cb684

Please sign in to comment.