Skip to content

Commit

Permalink
Add Experiment.trial_indices_expecting_data (facebook#2879)
Browse files Browse the repository at this point in the history
Summary:


I often find myself wishing for a performant version of this property, I think it was time to just add it

Reviewed By: Balandat

Differential Revision: D63999347
  • Loading branch information
Lena Kashtelyan authored and facebook-github-bot committed Oct 14, 2024
1 parent a7ecfef commit 0ef1de0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
19 changes: 18 additions & 1 deletion ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
import pandas as pd
from ax.core.arm import Arm
from ax.core.auxiliary import AuxiliaryExperiment, AuxiliaryExperimentPurpose
from ax.core.base_trial import BaseTrial, DEFAULT_STATUSES_TO_WARM_START, TrialStatus
from ax.core.base_trial import (
BaseTrial,
DEFAULT_STATUSES_TO_WARM_START,
STATUSES_EXPECTING_DATA,
TrialStatus,
)
from ax.core.batch_trial import BatchTrial, LifecycleStage
from ax.core.data import Data
from ax.core.formatting_utils import DATA_TYPE_LOOKUP, DataType
Expand Down Expand Up @@ -1056,6 +1061,18 @@ def running_trial_indices(self) -> set[int]:
"""Indices of running trials, associated with the experiment."""
return self._trial_indices_by_status[TrialStatus.RUNNING]

@property
def trial_indices_expecting_data(self) -> set[int]:
"""Set of indices of trials, statuses of which indicate that we expect
these trials to have data, either already or in the future.
"""
return set.union(
*(
self.trial_indices_by_status[status]
for status in STATUSES_EXPECTING_DATA
)
)

@property
def default_data_type(self) -> DataType:
return self._default_data_type
Expand Down
19 changes: 18 additions & 1 deletion ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,7 +769,6 @@ def test_AttachBatchTrialWithArmNames(self) -> None:
3,
)
self.assertEqual(type(self.experiment.trials[trial_index]), BatchTrial)
print({arm.name for arm in self.experiment.trials[trial_index].arms})
self.assertEqual(
{"arm1", "arm2", "arm3"},
set(self.experiment.trials[trial_index].arms_by_name) - {"status_quo"},
Expand Down Expand Up @@ -1267,6 +1266,24 @@ def test_arms_by_signature_for_deduplication(self) -> None:
experiment.arms_by_signature_for_deduplication, expected_with_other
)

def test_trial_indices(self) -> None:
experiment = self.experiment
for _ in range(6):
experiment.new_trial()
self.assertEqual(experiment.trial_indices_expecting_data, set())
experiment.trials[0].mark_staged()
experiment.trials[1].mark_running(no_runner_required=True)
experiment.trials[2].mark_running(no_runner_required=True).mark_completed()
self.assertEqual(experiment.trial_indices_expecting_data, {1, 2})
experiment.trials[1].mark_abandoned()
self.assertEqual(experiment.trial_indices_expecting_data, {2})
experiment.trials[4].mark_running(no_runner_required=True)
self.assertEqual(experiment.trial_indices_expecting_data, {2, 4})
experiment.trials[4].mark_failed()
self.assertEqual(experiment.trial_indices_expecting_data, {2})
experiment.trials[5].mark_running(no_runner_required=True).mark_early_stopped()
self.assertEqual(experiment.trial_indices_expecting_data, {2, 5})


class ExperimentWithMapDataTest(TestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit 0ef1de0

Please sign in to comment.