From cbb17c3d8fad9d1354ceb30394daa94d4ab4ed6c Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 13 Mar 2024 18:51:49 -0700 Subject: [PATCH] Retain original data timestamp in Experiment.clone_with (#2269) Summary: Pull Request resolved: https://github.com/facebook/Ax/pull/2269 Previously, this would attach the data using `Experiment.attach_data` which would create a new timestamp for the attached data. Directly assigning `Expeirment._data_by_trial` allows us to retain the original timestamp that the data was attached at. Reviewed By: Cesar-Cardoso Differential Revision: D54876890 fbshipit-source-id: c17fe56cac3f9664c01009acc60ec147c46b6407 --- ax/core/experiment.py | 37 ++++++++++++++++---------------- ax/core/tests/test_experiment.py | 11 +++++++++- 2 files changed, 29 insertions(+), 19 deletions(-) diff --git a/ax/core/experiment.py b/ax/core/experiment.py index 757f3151a5f..e09ac614229 100644 --- a/ax/core/experiment.py +++ b/ax/core/experiment.py @@ -1542,6 +1542,10 @@ def clone_with( r""" Return a copy of this experiment with some attributes replaced. + NOTE: This method only retains the latest data attached to the experiment. + This is the same data that would be accessed using common APIs such as + ``Experiment.lookup_data()``. + Args: search_space: New search space. If None, it uses the cloned search space of the original experiment. @@ -1561,7 +1565,7 @@ def clone_with( trial_indices: If specified, only clones the specified trials. If None, clones all trials. data: If specified, attach this data to the cloned experiment. If None, - clones the data attached to the original experiment if + clones the latest data attached to the original experiment if the experiment has any data. """ search_space = ( @@ -1611,39 +1615,36 @@ def clone_with( default_data_type=self._default_data_type, ) - datas = [] - # clone only the specified trials + # Clone only the specified trials. original_trial_indices = self.trials.keys() - # pyre-fixme[9]: trial_indices has type `Optional[List[int]]`; used as - # `Set[int]`. - trial_indices = ( + trial_indices_to_keep = ( set(original_trial_indices) if trial_indices is None else set(trial_indices) ) - if ( - # pyre-fixme[16]: `Optional` has no attribute `difference`. - len(trial_indices_diff := trial_indices.difference(original_trial_indices)) - > 0 + if trial_indices_diff := trial_indices_to_keep.difference( + original_trial_indices ): warnings.warn( f"Trials indexed with {trial_indices_diff} are not a part " "of the original experiment. ", stacklevel=2, ) - # pyre-fixme[16]: `Optional` has no attribute `intersection`. - for trial_index in trial_indices.intersection(original_trial_indices): + + data_by_trial = {} + for trial_index in trial_indices_to_keep.intersection(original_trial_indices): trial = self.trials[trial_index] if isinstance(trial, BatchTrial) or isinstance(trial, Trial): trial.clone_to(cloned_experiment) - trial_data, storage_time = self.lookup_data_for_trial(trial_index) - if (trial_data is not None) and (storage_time is not None): - datas.append(trial_data) + trial_data, timestamp = self.lookup_data_for_trial(trial_index) + if timestamp != -1: + data_by_trial[trial_index] = OrderedDict([(timestamp, trial_data)]) else: raise NotImplementedError(f"Cloning of {type(trial)} is not supported.") - - if (data is None) and (len(datas) > 0): - data = self.default_data_constructor.from_multiple_data(datas) if data is not None: + # If user passed in data, use it. cloned_experiment.attach_data(data) + else: + # Otherwise, attach the data extracted from the original experiment. + cloned_experiment._data_by_trial = data_by_trial return cloned_experiment diff --git a/ax/core/tests/test_experiment.py b/ax/core/tests/test_experiment.py index b85643b9d73..641a213ba94 100644 --- a/ax/core/tests/test_experiment.py +++ b/ax/core/tests/test_experiment.py @@ -7,11 +7,11 @@ # pyre-strict import logging +from collections import OrderedDict from typing import Dict, List, Type from unittest.mock import MagicMock, patch import numpy as np - import pandas as pd import torch from ax.core import BatchTrial, Trial @@ -1041,6 +1041,7 @@ def test_clone_with(self) -> None: search_space=larger_search_space, status_quo=new_status_quo, ) + self.assertEqual(cloned_experiment._data_by_trial, experiment._data_by_trial) self.assertEqual(len(cloned_experiment.trials), 2) x1 = checked_cast( RangeParameter, cloned_experiment.search_space.parameters["x1"] @@ -1105,7 +1106,15 @@ def test_clone_with(self) -> None: status_quo=new_status_quo, ) new_data = cloned_experiment.lookup_data() + self.assertNotEqual(cloned_experiment._data_by_trial, experiment._data_by_trial) self.assertIsInstance(new_data, MapData) + expected_data_by_trial = {} + for trial_index in experiment.trials: + if original_trial_data := experiment._data_by_trial.get(trial_index, None): + expected_data_by_trial[trial_index] = OrderedDict( + list(original_trial_data.items())[-1:] + ) + self.assertEqual(cloned_experiment.data_by_trial, expected_data_by_trial) experiment = get_experiment() cloned_experiment = experiment.clone_with()