Skip to content

Commit

Permalink
Retain original data timestamp in Experiment.clone_with (#2269)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2269

Previously, this would attach the data using `Experiment.attach_data` which would create a new timestamp for the attached data. Directly assigning `Expeirment._data_by_trial` allows us to retain the original timestamp that the data was attached at.

Reviewed By: Cesar-Cardoso

Differential Revision: D54876890

fbshipit-source-id: c17fe56cac3f9664c01009acc60ec147c46b6407
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Mar 14, 2024
1 parent 62193e4 commit cbb17c3
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 19 deletions.
37 changes: 19 additions & 18 deletions ax/core/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,6 +1542,10 @@ def clone_with(
r"""
Return a copy of this experiment with some attributes replaced.
NOTE: This method only retains the latest data attached to the experiment.
This is the same data that would be accessed using common APIs such as
``Experiment.lookup_data()``.
Args:
search_space: New search space. If None, it uses the cloned search space
of the original experiment.
Expand All @@ -1561,7 +1565,7 @@ def clone_with(
trial_indices: If specified, only clones the specified trials. If None,
clones all trials.
data: If specified, attach this data to the cloned experiment. If None,
clones the data attached to the original experiment if
clones the latest data attached to the original experiment if
the experiment has any data.
"""
search_space = (
Expand Down Expand Up @@ -1611,39 +1615,36 @@ def clone_with(
default_data_type=self._default_data_type,
)

datas = []
# clone only the specified trials
# Clone only the specified trials.
original_trial_indices = self.trials.keys()
# pyre-fixme[9]: trial_indices has type `Optional[List[int]]`; used as
# `Set[int]`.
trial_indices = (
trial_indices_to_keep = (
set(original_trial_indices) if trial_indices is None else set(trial_indices)
)
if (
# pyre-fixme[16]: `Optional` has no attribute `difference`.
len(trial_indices_diff := trial_indices.difference(original_trial_indices))
> 0
if trial_indices_diff := trial_indices_to_keep.difference(
original_trial_indices
):
warnings.warn(
f"Trials indexed with {trial_indices_diff} are not a part "
"of the original experiment. ",
stacklevel=2,
)
# pyre-fixme[16]: `Optional` has no attribute `intersection`.
for trial_index in trial_indices.intersection(original_trial_indices):

data_by_trial = {}
for trial_index in trial_indices_to_keep.intersection(original_trial_indices):
trial = self.trials[trial_index]
if isinstance(trial, BatchTrial) or isinstance(trial, Trial):
trial.clone_to(cloned_experiment)
trial_data, storage_time = self.lookup_data_for_trial(trial_index)
if (trial_data is not None) and (storage_time is not None):
datas.append(trial_data)
trial_data, timestamp = self.lookup_data_for_trial(trial_index)
if timestamp != -1:
data_by_trial[trial_index] = OrderedDict([(timestamp, trial_data)])
else:
raise NotImplementedError(f"Cloning of {type(trial)} is not supported.")

if (data is None) and (len(datas) > 0):
data = self.default_data_constructor.from_multiple_data(datas)
if data is not None:
# If user passed in data, use it.
cloned_experiment.attach_data(data)
else:
# Otherwise, attach the data extracted from the original experiment.
cloned_experiment._data_by_trial = data_by_trial

return cloned_experiment

Expand Down
11 changes: 10 additions & 1 deletion ax/core/tests/test_experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
# pyre-strict

import logging
from collections import OrderedDict
from typing import Dict, List, Type
from unittest.mock import MagicMock, patch

import numpy as np

import pandas as pd
import torch
from ax.core import BatchTrial, Trial
Expand Down Expand Up @@ -1041,6 +1041,7 @@ def test_clone_with(self) -> None:
search_space=larger_search_space,
status_quo=new_status_quo,
)
self.assertEqual(cloned_experiment._data_by_trial, experiment._data_by_trial)
self.assertEqual(len(cloned_experiment.trials), 2)
x1 = checked_cast(
RangeParameter, cloned_experiment.search_space.parameters["x1"]
Expand Down Expand Up @@ -1105,7 +1106,15 @@ def test_clone_with(self) -> None:
status_quo=new_status_quo,
)
new_data = cloned_experiment.lookup_data()
self.assertNotEqual(cloned_experiment._data_by_trial, experiment._data_by_trial)
self.assertIsInstance(new_data, MapData)
expected_data_by_trial = {}
for trial_index in experiment.trials:
if original_trial_data := experiment._data_by_trial.get(trial_index, None):
expected_data_by_trial[trial_index] = OrderedDict(
list(original_trial_data.items())[-1:]
)
self.assertEqual(cloned_experiment.data_by_trial, expected_data_by_trial)

experiment = get_experiment()
cloned_experiment = experiment.clone_with()
Expand Down

0 comments on commit cbb17c3

Please sign in to comment.