Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce trial_indices argument to SupervisedDataset #2960

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions ax/core/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,15 @@ def _filter_data_on_status(


def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]:
"""Get the columns used to identify and group observations from a Data object.
Args:
data: the Data object from which to extract the feature columns.
is_map_data: If True, the Data object's map_keys will be included.
Returns:
A list of column names to be used to group observations.
"""
feature_cols = OBS_COLS.intersection(data.df.columns)
# note we use this check, rather than isinstance, since
# only some Modelbridges (e.g. MapTorchModelBridge)
Expand All @@ -432,8 +441,10 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]:
stacklevel=5,
)
feature_cols.discard(column)

return list(feature_cols)
# NOTE: This ensures the order of feature_cols is deterministic so that the order
# of lists of observations are deterministic, to avoid nondeterministic tests.
# Necessary for test_TorchModelBridge.
return sorted(feature_cols)


def observations_from_data(
Expand Down
10 changes: 10 additions & 0 deletions ax/modelbridge/map_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

import torch
from ax.core.base_trial import TrialStatus
from ax.core.batch_trial import BatchTrial
from ax.core.data import Data
from ax.core.experiment import Experiment
from ax.core.map_data import MapData
Expand Down Expand Up @@ -118,6 +119,9 @@ def __init__(
raise ValueError(
"`MapTorchModelBridge expects `MapData` instead of `Data`."
)

if any(isinstance(t, BatchTrial) for t in experiment.trials.values()):
raise ValueError("MapTorchModelBridge does not support batch trials.")
# pyre-fixme[4]: Attribute must be annotated.
self._map_key_features = data.map_keys
self._map_data_limit_rows_per_metric = map_data_limit_rows_per_metric
Expand Down Expand Up @@ -146,6 +150,12 @@ def statuses_to_fit_map_metric(self) -> set[TrialStatus]:

@property
def parameters_with_map_keys(self) -> list[str]:
"""The parameters used for fitting the model, including map_keys."""
# NOTE: This list determines the order of feature columns in the training data.
# Learning-curve-based modeling methods assume that the last columns are
# map_keys, so we place self._map_key_features on the end.
# TODO: Plumb down the `map_key` feature indices to the model, so that we don't
# have to make the assumption in the above note.
return self.parameters + self._map_key_features

def _predict(
Expand Down
4 changes: 4 additions & 0 deletions ax/modelbridge/pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def _convert_observations(
Yvars,
candidate_metadata_dict,
any_candidate_metadata_is_not_none,
trial_indices,
) = self._extract_observation_data(
observation_data, observation_features, parameters
)
Expand All @@ -69,6 +70,9 @@ def _convert_observations(
Y=Y,
feature_names=parameters,
outcome_names=[outcome],
trial_indices=torch.tensor(trial_indices[outcome])
if trial_indices
else None,
)

datasets.append(dataset)
Expand Down
11 changes: 10 additions & 1 deletion ax/modelbridge/tests/test_map_torch_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,24 @@ def test_TorchModelBridge(self) -> None:
experiment.trials[i].mark_as(status=TrialStatus.COMPLETED)

experiment.attach_data(data=experiment.fetch_data())
model = mock.MagicMock(TorchModel, autospec=True, instance=True)
modelbridge = MapTorchModelBridge(
experiment=experiment,
search_space=experiment.search_space,
data=experiment.lookup_data(),
model=TorchModel(),
model=model,
transforms=[],
fit_out_of_design=True,
default_model_gen_options={"target_map_values": {"timestamp": 4.0}},
)
# Check that indices are set correctly.
datasets_arg = model.fit.mock_calls[0][2]["datasets"]
t1 = datasets_arg[0].trial_indices
t2 = torch.tensor([0, 1, 2])
self.assertTrue(torch.equal(t1, t2), msg=f"{t1} != {t2}")
t1 = datasets_arg[1].trial_indices
t2 = torch.tensor([0, 0, 1, 1, 2, 2])
self.assertTrue(torch.equal(t1, t2), msg=f"{t1} != {t2}")
# Check map data is converted to observations, that we get one Observation
# per row of MapData
# pyre-fixme[16]: `Data` has no attribute `map_df`.
Expand Down
25 changes: 23 additions & 2 deletions ax/modelbridge/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def _convert_observations(
Yvars,
candidate_metadata_dict,
any_candidate_metadata_is_not_none,
trial_indices,
) = self._extract_observation_data(
observation_data, observation_features, parameters
)
Expand All @@ -369,6 +370,9 @@ def _convert_observations(
Yvar=Yvar,
feature_names=parameters,
outcome_names=[outcome],
trial_indices=torch.tensor(trial_indices[outcome])
if trial_indices
else None,
)
datasets.append(dataset)
candidate_metadata.append(candidate_metadata_dict[outcome])
Expand Down Expand Up @@ -656,7 +660,7 @@ def _fit(
)
# Fit
self.model = model
self.model.fit(
not_none(self.model).fit(
datasets=datasets,
search_space_digest=search_space_digest,
candidate_metadata=candidate_metadata,
Expand Down Expand Up @@ -970,6 +974,7 @@ def _extract_observation_data(
dict[str, list[Tensor]],
dict[str, list[TCandidateMetadata]],
bool,
dict[str, list[int]] | None,
]:
"""Extract observation features & data into tensors and metadata.
Expand All @@ -991,14 +996,19 @@ def _extract_observation_data(
observation tensors `Yvar`.
- A dictionary mapping metric names to lists of corresponding metadata.
- A boolean denoting whether any candidate metadata is not none.
- A dictionary mapping metric names to lists of corresponding trial indices,
or None if trial indices are not provided. This is used to support
learning-curve-based modeling.
"""
Xs: dict[str, list[Tensor]] = defaultdict(list)
Ys: dict[str, list[Tensor]] = defaultdict(list)
Yvars: dict[str, list[Tensor]] = defaultdict(list)
candidate_metadata_dict: dict[str, list[TCandidateMetadata]] = defaultdict(list)
any_candidate_metadata_is_not_none = False
trial_indices: dict[str, list[int]] = defaultdict(list)

for obsd, obsf in zip(observation_data, observation_features):
at_least_one_trial_index_is_none = False
for obsd, obsf in zip(observation_data, observation_features, strict=True):
try:
x = torch.tensor(
[obsf.parameters[p] for p in parameters],
Expand All @@ -1016,13 +1026,24 @@ def _extract_observation_data(
if obsf.metadata is not None:
any_candidate_metadata_is_not_none = True
candidate_metadata_dict[metric_name].append(obsf.metadata)
trial_index = obsf.trial_index
if trial_index is not None:
trial_indices[metric_name].append(trial_index)
else:
at_least_one_trial_index_is_none = True
if len(trial_indices) > 0 and at_least_one_trial_index_is_none:
raise ValueError(
"Trial indices must be provided for all observation_features "
"or none of them."
)

return (
Xs,
Ys,
Yvars,
candidate_metadata_dict,
any_candidate_metadata_is_not_none,
trial_indices if len(trial_indices) > 0 else None,
)


Expand Down
Loading