Skip to content

Commit

Permalink
do not group by time cols when creating observations (#2293)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2293

See title. Add the time columns only if all metrics for a given trial-arm pair have the same start/end time. This means that if one wants time to be included, then they should ensure the metric gives the same start/end time for each all metrics for a single trial-arm pair.

This was causing an issue where status quo data for a single trial could be split across multiple observations because the metrics had different start and end times. This largely reverts D54017025.

Reviewed By: Balandat

Differential Revision: D55099571

fbshipit-source-id: e4e469f37fc63cd200eb395a7487208144462a24
  • Loading branch information
sdaulton authored and facebook-github-bot committed Mar 21, 2024
1 parent 45be1fb commit fbcad21
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 19 deletions.
36 changes: 29 additions & 7 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

import json
import warnings
from copy import deepcopy
from typing import Dict, Iterable, List, Optional, Set, Tuple

Expand All @@ -23,7 +24,7 @@
from ax.core.types import TCandidateMetadata, TParameterization
from ax.utils.common.base import Base
from ax.utils.common.constants import Keys
from ax.utils.common.typeutils import not_none
from ax.utils.common.typeutils import checked_cast, not_none


TIME_COLS = {"start_time", "end_time"}
Expand Down Expand Up @@ -315,6 +316,13 @@ def _observations_from_dataframe(
for f, val in features.items():
if f in OBS_KWARGS:
obs_kwargs[f] = val
# add start and end time of trial if the start and end time
# is the same for all metrics and arms
for col in TIME_COLS:
if col in d.columns:
times = d[col]
if times.nunique() == 1 and not times.isnull().any():
obs_kwargs[col] = times.iloc[0]
fidelities = features.get("fidelities")
if fidelities is not None:
obs_parameters.update(json.loads(fidelities))
Expand All @@ -338,12 +346,26 @@ def _observations_from_dataframe(
return observations


def get_feature_cols(data: Data) -> List[str]:
return list(OBS_COLS.intersection(data.df.columns))

def get_feature_cols(data: Data, is_map_data: bool = False) -> List[str]:
feature_cols = OBS_COLS.intersection(data.df.columns)
# note we use this check, rather than isinstance, since
# only some Modelbridges (e.g. MapTorchModelBridge)
# use observations_from_map_data, which is required
# to properly handle MapData features (e.g. fidelity).
if is_map_data:
data = checked_cast(MapData, data)
feature_cols = feature_cols.union(data.map_keys)

for column in TIME_COLS:
if column in feature_cols and len(data.df[column].unique()) > 1:
warnings.warn(
f"`{column} is not consistent and being discarded from "
"observation data",
stacklevel=5,
)
feature_cols.discard(column)

def get_feature_cols_from_map_data(map_data: MapData) -> List[str]:
return list(OBS_COLS.intersection(map_data.df.columns).union(map_data.map_keys))
return list(feature_cols)


def observations_from_data(
Expand Down Expand Up @@ -458,7 +480,7 @@ def observations_from_map_data(
limit_rows_per_group=limit_rows_per_group,
include_first_last=True,
)
feature_cols = get_feature_cols_from_map_data(map_data)
feature_cols = get_feature_cols(map_data, is_map_data=True)
observations = []
arm_name_only = len(feature_cols) == 1 # there will always be an arm name
# One DataFrame where all rows have all features.
Expand Down
111 changes: 111 additions & 0 deletions ax/core/tests/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
separate_observations,
)
from ax.core.trial import Trial
from ax.core.types import TParameterization
from ax.utils.common.testutils import TestCase
from ax.utils.common.typeutils import not_none


class ObservationsTest(TestCase):
Expand Down Expand Up @@ -662,6 +664,115 @@ def test_ObservationsFromDataWithSomeMissingTimes(self) -> None:
)
self.assertEqual(obs.arm_name, cname_truth[i])

def test_ObservationsFromDataWithDifferentTimesSingleTrial(self) -> None:
params0: TParameterization = {"x": 0, "y": "a"}
params1: TParameterization = {"x": 1, "y": "a"}
truth = [
{
"arm_name": "0_0",
"parameters": params0,
"mean": 2.0,
"sem": 2.0,
"trial_index": 0,
"metric_name": "a",
"start_time": "2024-03-20 08:45:00",
"end_time": "2024-03-20 08:47:00",
},
{
"arm_name": "0_0",
"parameters": params0,
"mean": 3.0,
"sem": 3.0,
"trial_index": 0,
"metric_name": "b",
"start_time": "2024-03-20 08:45:00",
},
{
"arm_name": "0_1",
"parameters": params1,
"mean": 4.0,
"sem": 4.0,
"trial_index": 0,
"metric_name": "a",
"start_time": "2024-03-20 08:43:00",
"end_time": "2024-03-20 08:46:00",
},
{
"arm_name": "0_1",
"parameters": params1,
"mean": 5.0,
"sem": 5.0,
"trial_index": 0,
"metric_name": "b",
"start_time": "2024-03-20 08:45:00",
"end_time": "2024-03-20 08:46:00",
},
]
arms_by_name = {
"0_0": Arm(name="0_0", parameters=params0),
"0_1": Arm(name="0_1", parameters=params1),
}
experiment = Mock()
experiment._trial_indices_by_status = {status: set() for status in TrialStatus}
trials = {
0: BatchTrial(experiment, GeneratorRun(arms=list(arms_by_name.values())))
}
type(experiment).arms_by_name = PropertyMock(return_value=arms_by_name)
type(experiment).trials = PropertyMock(return_value=trials)

df = pd.DataFrame(truth)[
[
"arm_name",
"trial_index",
"mean",
"sem",
"metric_name",
"start_time",
"end_time",
]
]
data = Data(df=df)
observations = observations_from_data(experiment, data)

self.assertEqual(len(observations), 2)
# Get them in the order we want for tests below
if observations[0].features.parameters["x"] == 1:
observations.reverse()

obs_truth = {
"arm_name": ["0_0", "0_1"],
"parameters": [{"x": 0, "y": "a"}, {"x": 1, "y": "a"}],
"metric_names": [["a", "b"], ["a", "b"]],
"means": [np.array([2.0, 3.0]), np.array([4.0, 5.0])],
"covariance": [np.diag([4.0, 9.0]), np.diag([16.0, 25.0])],
}

for i, obs in enumerate(observations):
self.assertEqual(obs.features.parameters, obs_truth["parameters"][i])
self.assertEqual(
obs.features.trial_index,
0,
)
self.assertEqual(obs.data.metric_names, obs_truth["metric_names"][i])
self.assertTrue(np.array_equal(obs.data.means, obs_truth["means"][i]))
self.assertTrue(
np.array_equal(obs.data.covariance, obs_truth["covariance"][i])
)
self.assertEqual(obs.arm_name, obs_truth["arm_name"][i])
self.assertEqual(obs.arm_name, obs_truth["arm_name"][i])
if i == 0:
self.assertEqual(
not_none(obs.features.start_time).strftime("%Y-%m-%d %X"),
"2024-03-20 08:45:00",
)
self.assertIsNone(obs.features.end_time)
else:
self.assertIsNone(obs.features.start_time)
self.assertEqual(
not_none(obs.features.end_time).strftime("%Y-%m-%d %X"),
"2024-03-20 08:46:00",
)

def test_SeparateObservations(self) -> None:
obs_arm_name = "0_0"
obs = Observation(
Expand Down
11 changes: 5 additions & 6 deletions ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,13 +431,12 @@ def _set_status_quo(

if len(sq_obs) == 0:
logger.warning(f"Status quo {status_quo_name} not present in data")
elif len(sq_obs) > 1:
logger.warning(
f"Status quo {status_quo_name} found in data with multiple "
"features. Use status_quo_features to specify which to use."
)
else:
if len(sq_obs) > 1:
logger.warning(
f"Status quo {status_quo_name} found in data with multiple "
"features. Use status_quo_features to specify which to use."
" Defaulting to the first observation."
)
self._status_quo = sq_obs[0]

elif status_quo_features is not None:
Expand Down
15 changes: 9 additions & 6 deletions ax/modelbridge/tests/test_base_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -578,14 +578,17 @@ def test_status_quo_for_non_monolithic_data(self, mock_gen):

# create data where metrics vary in start and end times
data = get_non_monolithic_branin_moo_data()
bridge = ModelBridge(
experiment=exp,
data=data,
model=Model(),
search_space=exp.search_space,
)
with warnings.catch_warnings(record=True) as ws:
bridge = ModelBridge(
experiment=exp,
data=data,
model=Model(),
search_space=exp.search_space,
)
# just testing it doesn't error
bridge.gen(5)
self.assertTrue(any("start_time" in str(w.message) for w in ws))
self.assertTrue(any("end_time" in str(w.message) for w in ws))
# pyre-fixme[16]: Optional type has no attribute `arm_name`.
self.assertEqual(bridge.status_quo.arm_name, "status_quo")

Expand Down

0 comments on commit fbcad21

Please sign in to comment.