Skip to content

Commit

Permalink
introduce trial_indices argument to SupervisedDataset (facebook#2960)
Browse files Browse the repository at this point in the history
Summary:

X-link: pytorch/botorch#2595

Adds optional `trial_indices` to SupervisedDataset, whose dimensionality should correspond 1:1 with the first few dimensions of X and Y tensors, as validated in `_validate` ([pointer](https://www.internalfb.com/diff/D64764019?permalink=1739375523489084)).

Reviewed By: Balandat

Differential Revision: D64764019
  • Loading branch information
Bernie Beckerman authored and facebook-github-bot committed Oct 25, 2024
1 parent 3033b14 commit b87c949
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 5 deletions.
15 changes: 13 additions & 2 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,15 @@ def _filter_data_on_status(


def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]:
"""Get the columns used to identify and group observations from a Data object.
Args:
data: the Data object from which to extract the feature columns.
is_map_data: If True, the Data object's map_keys will be included.
Returns:
A list of column names to be used to group observations.
"""
feature_cols = OBS_COLS.intersection(data.df.columns)
# note we use this check, rather than isinstance, since
# only some Modelbridges (e.g. MapTorchModelBridge)
Expand All @@ -432,8 +441,10 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]:
stacklevel=5,
)
feature_cols.discard(column)

return list(feature_cols)
# NOTE: This ensures the order of feature_cols is deterministic so that the order
# of lists of observations are deterministic, to avoid nondeterministic tests.
# Necessary for test_TorchModelBridge.
return sorted(feature_cols)


def observations_from_data(
Expand Down
10 changes: 10 additions & 0 deletions ax/modelbridge/map_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
from ax.core.base_trial import TrialStatus
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.map_data import MapData
Expand Down Expand Up @@ -118,6 +119,9 @@ def __init__(
raise ValueError(
"`MapTorchModelBridge expects `MapData` instead of `Data`."
)

if any(isinstance(t, BatchTrial) for t in experiment.trials.values()):
raise ValueError("MapTorchModelBridge does not support batch trials.")
# pyre-fixme[4]: Attribute must be annotated.
self._map_key_features = data.map_keys
self._map_data_limit_rows_per_metric = map_data_limit_rows_per_metric
Expand Down Expand Up @@ -146,6 +150,12 @@ def statuses_to_fit_map_metric(self) -> set[TrialStatus]:

@property
def parameters_with_map_keys(self) -> list[str]:
"""The parameters used for fitting the model, including map_keys."""
# NOTE: This list determines the order of feature columns in the training data.
# Learning-curve-based modeling methods assume that the last columns are
# map_keys, so we place self._map_key_features on the end.
# TODO: Plumb down the `map_key` feature indices to the model, so that we don't
# have to make the assumption in the above note.
return self.parameters + self._map_key_features

def _predict(
Expand Down
4 changes: 4 additions & 0 deletions ax/modelbridge/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def _convert_observations(
Yvars,
candidate_metadata_dict,
any_candidate_metadata_is_not_none,
trial_indices,
) = self._extract_observation_data(
observation_data, observation_features, parameters
)
Expand All @@ -69,6 +70,9 @@ def _convert_observations(
Y=Y,
feature_names=parameters,
outcome_names=[outcome],
trial_indices=torch.tensor(trial_indices[outcome])
if trial_indices
else None,
)

datasets.append(dataset)
Expand Down
11 changes: 10 additions & 1 deletion ax/modelbridge/tests/test_map_torch_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,24 @@ def test_TorchModelBridge(self) -> None:
experiment.trials[i].mark_as(status=TrialStatus.COMPLETED)

experiment.attach_data(data=experiment.fetch_data())
model = mock.MagicMock(TorchModel, autospec=True, instance=True)
modelbridge = MapTorchModelBridge(
experiment=experiment,
search_space=experiment.search_space,
data=experiment.lookup_data(),
model=TorchModel(),
model=model,
transforms=[],
fit_out_of_design=True,
default_model_gen_options={"target_map_values": {"timestamp": 4.0}},
)
# Check that indices are set correctly.
datasets_arg = model.fit.mock_calls[0][2]["datasets"]
t1 = datasets_arg[0].trial_indices
t2 = torch.tensor([0, 1, 2])
self.assertTrue(torch.equal(t1, t2), msg=f"{t1} != {t2}")
t1 = datasets_arg[1].trial_indices
t2 = torch.tensor([0, 0, 1, 1, 2, 2])
self.assertTrue(torch.equal(t1, t2), msg=f"{t1} != {t2}")
# Check map data is converted to observations, that we get one Observation
# per row of MapData
# pyre-fixme[16]: `Data` has no attribute `map_df`.
Expand Down
25 changes: 23 additions & 2 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def _convert_observations(
Yvars,
candidate_metadata_dict,
any_candidate_metadata_is_not_none,
trial_indices,
) = self._extract_observation_data(
observation_data, observation_features, parameters
)
Expand All @@ -369,6 +370,9 @@ def _convert_observations(
Yvar=Yvar,
feature_names=parameters,
outcome_names=[outcome],
trial_indices=torch.tensor(trial_indices[outcome])
if trial_indices
else None,
)
datasets.append(dataset)
candidate_metadata.append(candidate_metadata_dict[outcome])
Expand Down Expand Up @@ -656,7 +660,7 @@ def _fit(
)
# Fit
self.model = model
self.model.fit(
not_none(self.model).fit(
datasets=datasets,
search_space_digest=search_space_digest,
candidate_metadata=candidate_metadata,
Expand Down Expand Up @@ -970,6 +974,7 @@ def _extract_observation_data(
dict[str, list[Tensor]],
dict[str, list[TCandidateMetadata]],
bool,
dict[str, list[int]] | None,
]:
"""Extract observation features & data into tensors and metadata.
Expand All @@ -991,14 +996,19 @@ def _extract_observation_data(
observation tensors `Yvar`.
- A dictionary mapping metric names to lists of corresponding metadata.
- A boolean denoting whether any candidate metadata is not none.
- A dictionary mapping metric names to lists of corresponding trial indices,
or None if trial indices are not provided. This is used to support
learning-curve-based modeling.
"""
Xs: dict[str, list[Tensor]] = defaultdict(list)
Ys: dict[str, list[Tensor]] = defaultdict(list)
Yvars: dict[str, list[Tensor]] = defaultdict(list)
candidate_metadata_dict: dict[str, list[TCandidateMetadata]] = defaultdict(list)
any_candidate_metadata_is_not_none = False
trial_indices: dict[str, list[int]] = defaultdict(list)

for obsd, obsf in zip(observation_data, observation_features):
at_least_one_trial_index_is_none = False
for obsd, obsf in zip(observation_data, observation_features, strict=True):
try:
x = torch.tensor(
[obsf.parameters[p] for p in parameters],
Expand All @@ -1016,13 +1026,24 @@ def _extract_observation_data(
if obsf.metadata is not None:
any_candidate_metadata_is_not_none = True
candidate_metadata_dict[metric_name].append(obsf.metadata)
trial_index = obsf.trial_index
if trial_index is not None:
trial_indices[metric_name].append(trial_index)
else:
at_least_one_trial_index_is_none = True
if len(trial_indices) > 0 and at_least_one_trial_index_is_none:
raise ValueError(
"Trial indices must be provided for all observation_features "
"or none of them."
)

return (
Xs,
Ys,
Yvars,
candidate_metadata_dict,
any_candidate_metadata_is_not_none,
trial_indices if len(trial_indices) > 0 else None,
)


Expand Down

0 comments on commit b87c949

Please sign in to comment.