Skip to content

Commit

Permalink
Fix missing DB update on AxClient.stop_trial_early (#2337)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2337

`AxClient.stop_trial_early` was not updating the trial status in the DB, leading to reloaded experiments showing the trial as still running.

Reviewed By: Balandat

Differential Revision: D55900104

fbshipit-source-id: c2cea6259d21e03bdcbf93e4442874c4a9562d23
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Apr 9, 2024
1 parent 1f837a3 commit a161618
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 4 deletions.
9 changes: 5 additions & 4 deletions ax/service/ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,8 +546,7 @@ def get_next_trial(
)
trial.mark_running(no_runner_required=True)
self._save_or_update_trial_in_db_if_possible(
experiment=self.experiment,
trial=trial,
experiment=self.experiment, trial=trial
)
# TODO[T79183560]: Ensure correct handling of generator run when using
# foreign keys.
Expand Down Expand Up @@ -1294,6 +1293,9 @@ def stop_trial_early(self, trial_index: int) -> None:
trial = self.get_trial(trial_index)
trial.mark_early_stopped()
logger.info(f"Early stopped trial {trial_index}.")
self._save_or_update_trial_in_db_if_possible(
experiment=self.experiment, trial=trial
)

def estimate_early_stopping_savings(self, map_key: Optional[str] = None) -> float:
"""Estimate early stopping savings using progressions of the MapMetric present
Expand Down Expand Up @@ -1625,8 +1627,7 @@ def _update_trial_with_raw_data(
trial.mark_completed()

self._save_or_update_trial_in_db_if_possible(
experiment=self.experiment,
trial=trial,
experiment=self.experiment, trial=trial
)

return update_info
Expand Down
12 changes: 12 additions & 0 deletions ax/service/tests/test_ax_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1963,6 +1963,18 @@ def test_sqa_storage(self) -> None:
# Original experiment should still be in DB and not have been overwritten.
self.assertEqual(len(ax_client.experiment.trials), 5)

# Attach an early stopped trial.
parameters, trial_index = ax_client.get_next_trial()
ax_client.stop_trial_early(trial_index=trial_index)

# Reload experiment and check that trial status is accurate.
ax_client_new = AxClient(db_settings=db_settings)
ax_client_new.load_experiment_from_database("test_experiment")
self.assertEqual(
ax_client.experiment.trials_by_status,
ax_client_new.experiment.trials_by_status,
)

def test_overwrite(self) -> None:
init_test_engine_and_session_factory(force_init=True)
ax_client = AxClient()
Expand Down

0 comments on commit a161618

Please sign in to comment.