From b87c949482310dc513f34214e92705aecd4a1238 Mon Sep 17 00:00:00 2001 From: Bernie Beckerman Date: Thu, 24 Oct 2024 19:05:30 -0700 Subject: [PATCH] introduce trial_indices argument to SupervisedDataset (#2960) Summary: X-link: https://github.com/pytorch/botorch/pull/2595 Adds optional `trial_indices` to SupervisedDataset, whose dimensionality should correspond 1:1 with the first few dimensions of X and Y tensors, as validated in `_validate` ([pointer](https://www.internalfb.com/diff/D64764019?permalink=1739375523489084)). Reviewed By: Balandat Differential Revision: D64764019 --- ax/core/observation.py | 15 +++++++++-- ax/modelbridge/map_torch.py | 10 ++++++++ ax/modelbridge/pairwise.py | 4 +++ .../tests/test_map_torch_modelbridge.py | 11 +++++++- ax/modelbridge/torch.py | 25 +++++++++++++++++-- 5 files changed, 60 insertions(+), 5 deletions(-) diff --git a/ax/core/observation.py b/ax/core/observation.py index d9e88458593..73fe0bab734 100644 --- a/ax/core/observation.py +++ b/ax/core/observation.py @@ -415,6 +415,15 @@ def _filter_data_on_status( def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]: + """Get the columns used to identify and group observations from a Data object. + + Args: + data: the Data object from which to extract the feature columns. + is_map_data: If True, the Data object's map_keys will be included. + + Returns: + A list of column names to be used to group observations. + """ feature_cols = OBS_COLS.intersection(data.df.columns) # note we use this check, rather than isinstance, since # only some Modelbridges (e.g. MapTorchModelBridge) @@ -432,8 +441,10 @@ def get_feature_cols(data: Data, is_map_data: bool = False) -> list[str]: stacklevel=5, ) feature_cols.discard(column) - - return list(feature_cols) + # NOTE: This ensures the order of feature_cols is deterministic so that the order + # of lists of observations are deterministic, to avoid nondeterministic tests. + # Necessary for test_TorchModelBridge. + return sorted(feature_cols) def observations_from_data( diff --git a/ax/modelbridge/map_torch.py b/ax/modelbridge/map_torch.py index ee5b6a4877e..3cd30ffe95d 100644 --- a/ax/modelbridge/map_torch.py +++ b/ax/modelbridge/map_torch.py @@ -11,6 +11,7 @@ import torch from ax.core.base_trial import TrialStatus +from ax.core.batch_trial import BatchTrial from ax.core.data import Data from ax.core.experiment import Experiment from ax.core.map_data import MapData @@ -118,6 +119,9 @@ def __init__( raise ValueError( "`MapTorchModelBridge expects `MapData` instead of `Data`." ) + + if any(isinstance(t, BatchTrial) for t in experiment.trials.values()): + raise ValueError("MapTorchModelBridge does not support batch trials.") # pyre-fixme[4]: Attribute must be annotated. self._map_key_features = data.map_keys self._map_data_limit_rows_per_metric = map_data_limit_rows_per_metric @@ -146,6 +150,12 @@ def statuses_to_fit_map_metric(self) -> set[TrialStatus]: @property def parameters_with_map_keys(self) -> list[str]: + """The parameters used for fitting the model, including map_keys.""" + # NOTE: This list determines the order of feature columns in the training data. + # Learning-curve-based modeling methods assume that the last columns are + # map_keys, so we place self._map_key_features on the end. + # TODO: Plumb down the `map_key` feature indices to the model, so that we don't + # have to make the assumption in the above note. return self.parameters + self._map_key_features def _predict( diff --git a/ax/modelbridge/pairwise.py b/ax/modelbridge/pairwise.py index bbe607b9a90..6cbb343c4a3 100644 --- a/ax/modelbridge/pairwise.py +++ b/ax/modelbridge/pairwise.py @@ -48,6 +48,7 @@ def _convert_observations( Yvars, candidate_metadata_dict, any_candidate_metadata_is_not_none, + trial_indices, ) = self._extract_observation_data( observation_data, observation_features, parameters ) @@ -69,6 +70,9 @@ def _convert_observations( Y=Y, feature_names=parameters, outcome_names=[outcome], + trial_indices=torch.tensor(trial_indices[outcome]) + if trial_indices + else None, ) datasets.append(dataset) diff --git a/ax/modelbridge/tests/test_map_torch_modelbridge.py b/ax/modelbridge/tests/test_map_torch_modelbridge.py index 3a14f203234..f8be3b540d8 100644 --- a/ax/modelbridge/tests/test_map_torch_modelbridge.py +++ b/ax/modelbridge/tests/test_map_torch_modelbridge.py @@ -43,15 +43,24 @@ def test_TorchModelBridge(self) -> None: experiment.trials[i].mark_as(status=TrialStatus.COMPLETED) experiment.attach_data(data=experiment.fetch_data()) + model = mock.MagicMock(TorchModel, autospec=True, instance=True) modelbridge = MapTorchModelBridge( experiment=experiment, search_space=experiment.search_space, data=experiment.lookup_data(), - model=TorchModel(), + model=model, transforms=[], fit_out_of_design=True, default_model_gen_options={"target_map_values": {"timestamp": 4.0}}, ) + # Check that indices are set correctly. + datasets_arg = model.fit.mock_calls[0][2]["datasets"] + t1 = datasets_arg[0].trial_indices + t2 = torch.tensor([0, 1, 2]) + self.assertTrue(torch.equal(t1, t2), msg=f"{t1} != {t2}") + t1 = datasets_arg[1].trial_indices + t2 = torch.tensor([0, 0, 1, 1, 2, 2]) + self.assertTrue(torch.equal(t1, t2), msg=f"{t1} != {t2}") # Check map data is converted to observations, that we get one Observation # per row of MapData # pyre-fixme[16]: `Data` has no attribute `map_df`. diff --git a/ax/modelbridge/torch.py b/ax/modelbridge/torch.py index 79112e88c80..8233e3bc770 100644 --- a/ax/modelbridge/torch.py +++ b/ax/modelbridge/torch.py @@ -343,6 +343,7 @@ def _convert_observations( Yvars, candidate_metadata_dict, any_candidate_metadata_is_not_none, + trial_indices, ) = self._extract_observation_data( observation_data, observation_features, parameters ) @@ -369,6 +370,9 @@ def _convert_observations( Yvar=Yvar, feature_names=parameters, outcome_names=[outcome], + trial_indices=torch.tensor(trial_indices[outcome]) + if trial_indices + else None, ) datasets.append(dataset) candidate_metadata.append(candidate_metadata_dict[outcome]) @@ -656,7 +660,7 @@ def _fit( ) # Fit self.model = model - self.model.fit( + not_none(self.model).fit( datasets=datasets, search_space_digest=search_space_digest, candidate_metadata=candidate_metadata, @@ -970,6 +974,7 @@ def _extract_observation_data( dict[str, list[Tensor]], dict[str, list[TCandidateMetadata]], bool, + dict[str, list[int]] | None, ]: """Extract observation features & data into tensors and metadata. @@ -991,14 +996,19 @@ def _extract_observation_data( observation tensors `Yvar`. - A dictionary mapping metric names to lists of corresponding metadata. - A boolean denoting whether any candidate metadata is not none. + - A dictionary mapping metric names to lists of corresponding trial indices, + or None if trial indices are not provided. This is used to support + learning-curve-based modeling. """ Xs: dict[str, list[Tensor]] = defaultdict(list) Ys: dict[str, list[Tensor]] = defaultdict(list) Yvars: dict[str, list[Tensor]] = defaultdict(list) candidate_metadata_dict: dict[str, list[TCandidateMetadata]] = defaultdict(list) any_candidate_metadata_is_not_none = False + trial_indices: dict[str, list[int]] = defaultdict(list) - for obsd, obsf in zip(observation_data, observation_features): + at_least_one_trial_index_is_none = False + for obsd, obsf in zip(observation_data, observation_features, strict=True): try: x = torch.tensor( [obsf.parameters[p] for p in parameters], @@ -1016,6 +1026,16 @@ def _extract_observation_data( if obsf.metadata is not None: any_candidate_metadata_is_not_none = True candidate_metadata_dict[metric_name].append(obsf.metadata) + trial_index = obsf.trial_index + if trial_index is not None: + trial_indices[metric_name].append(trial_index) + else: + at_least_one_trial_index_is_none = True + if len(trial_indices) > 0 and at_least_one_trial_index_is_none: + raise ValueError( + "Trial indices must be provided for all observation_features " + "or none of them." + ) return ( Xs, @@ -1023,6 +1043,7 @@ def _extract_observation_data( Yvars, candidate_metadata_dict, any_candidate_metadata_is_not_none, + trial_indices if len(trial_indices) > 0 else None, )