Skip to content

Commit

Permalink
updates trial status during clone (#2290)
Browse files Browse the repository at this point in the history
Summary:

Updates cloned trial's status to match the trial from which it was cloned. This will eventually help us deduplicate some of warm_start_from_old_experiment.

Reviewed By: mpolson64

Differential Revision: D55024671
  • Loading branch information
Bernie Beckerman authored and facebook-github-bot committed Mar 26, 2024
1 parent d120362 commit 93e2e68
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 7 deletions.
19 changes: 19 additions & 0 deletions ax/core/base_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,3 +883,22 @@ def _update_trial_attrs_on_clone(
new_trial._stop_metadata = deepcopy(self._stop_metadata)
new_trial._num_arms_created = self._num_arms_created
new_trial.runner = self._runner.clone() if self._runner else None

# Set status and reason accordingly.
match self.status:
case TrialStatus.CANDIDATE:
pass
case TrialStatus.STAGED:
new_trial.mark_staged()
case running_or_later_status:
# Other statuses require the state first be set to `RUNNING`.
new_trial.mark_running(no_runner_required=True)
match running_or_later_status:
case TrialStatus.RUNNING:
pass
case TrialStatus.ABANDONED:
new_trial.mark_abandoned(reason=self.abandoned_reason)
case TrialStatus.FAILED:
new_trial.mark_failed(reason=self.failed_reason)
case terminal_status:
new_trial.mark_as(terminal_status)
1 change: 1 addition & 0 deletions ax/core/tests/test_batch_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,7 @@ def test_clone_to(self, _) -> None:
name="status_quo", parameters={"w": 0.0, "x": 1, "y": "foo", "z": True}
)
batch.set_status_quo_and_optimize_power(status_quo)
batch.mark_running(no_runner_required=True)
new_batch_trial = batch.clone_to()
self.assertEqual(new_batch_trial.index, 2)
# Set index to original trial's value for equality check.
Expand Down
4 changes: 4 additions & 0 deletions ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,6 +1043,10 @@ def test_clone_with(self) -> None:
)
self.assertEqual(cloned_experiment._data_by_trial, experiment._data_by_trial)
self.assertEqual(len(cloned_experiment.trials), 2)
for trial_index in cloned_experiment.trials.keys():
cloned_trial = cloned_experiment.trials[trial_index]
original_trial = experiment.trials[trial_index]
self.assertEqual(cloned_trial.status, original_trial.status)
x1 = checked_cast(
RangeParameter, cloned_experiment.search_space.parameters["x1"]
)
Expand Down
25 changes: 24 additions & 1 deletion ax/core/tests/test_trial.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,11 +338,12 @@ def test_update_trial_data(self) -> None:

def test_clone_to(self) -> None:
# cloned trial attached to the same experiment
self.trial.mark_running(no_runner_required=True)
new_trial = self.trial.clone_to()
self.assertIs(new_trial.experiment, self.trial.experiment)
# Test equality of all attributes except index, time_created, and experiment.
for k, v in new_trial.__dict__.items():
if k in ["_index", "_time_created", "_experiment"]:
if k in ["_index", "_time_created", "_experiment", "_time_run_started"]:
continue
self.assertEqual(v, self.trial.__dict__[k])

Expand All @@ -355,3 +356,25 @@ def test_clone_to(self) -> None:
new_trial._status = TrialStatus.COMPLETED
self.assertTrue(new_trial.status.is_completed)
self.assertFalse(self.trial.status.is_completed)

def test_update_trial_status_on_clone(self) -> None:
for status in [
TrialStatus.CANDIDATE,
TrialStatus.STAGED,
TrialStatus.RUNNING,
TrialStatus.EARLY_STOPPED,
TrialStatus.COMPLETED,
TrialStatus.FAILED,
TrialStatus.ABANDONED,
]:
self.trial._failed_reason = self.trial._abandoned_reason = None
if status != TrialStatus.CANDIDATE:
self.trial.mark_as(
status=status, unsafe=True, no_runner_required=True, reason="test"
)
test_trial = self.trial.clone_to()
# Overwrite unimportant attrs before equality check.
test_trial._index = self.trial.index
test_trial._time_created = self.trial._time_created
test_trial._time_staged = self.trial._time_staged
self.assertEqual(self.trial, test_trial)
15 changes: 10 additions & 5 deletions ax/global_stopping/tests/tests_strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import numpy as np
import pandas as pd
from ax.core.arm import Arm
from ax.core.base_trial import TrialStatus
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
Expand Down Expand Up @@ -66,9 +67,10 @@ def test_base_cases(self) -> None:
)

# Check that we properly count completed trials.
for _ in range(4):
checked_cast(BatchTrial, exp.trials[0]).clone()
exp.trials[3].mark_running(no_runner_required=True).mark_completed()
for i in range(4):
trial = checked_cast(BatchTrial, exp.trials[0]).clone_to()
if i < 3:
trial._status = TrialStatus.CANDIDATE
stop, message = gss.should_stop_optimization(experiment=exp)
self.assertFalse(stop)
self.assertEqual(
Expand All @@ -78,9 +80,12 @@ def test_base_cases(self) -> None:
)

# Should raise ValueError if trying to check an invalid trial
with self.assertRaises(ValueError):
with self.assertRaisesRegex(
ValueError,
r"trial_to_check is larger than the total number of trials \(=4\).",
):
stop, message = gss.should_stop_optimization(
experiment=exp, trial_to_check=4
experiment=exp, trial_to_check=5
)

def _get_arm(self) -> Arm:
Expand Down
2 changes: 1 addition & 1 deletion ax/service/tests/test_best_point.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_get_trace(self) -> None:
# test that there is performance metric in the trace for each
# completed/early-stopped trial
trial1 = checked_cast(BatchTrial, trial).clone_to()
trial1.mark_abandoned()
trial1.mark_abandoned(unsafe=True)
arms = get_arms_from_dict(get_arm_weights2())
trial2 = exp.new_batch_trial(GeneratorRun(arms))
trial2.mark_running(no_runner_required=True).mark_completed()
Expand Down

0 comments on commit 93e2e68

Please sign in to comment.