-
Notifications
You must be signed in to change notification settings - Fork 310
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
outcome transform to outcome classes + options (#1880)
Summary: Pull Request resolved: #1880 - Make Surrogate and SurrogateSpec take outcome_transform_classes and outcome_transform_options as inputs instead of BoTorch outcome_transform - Construct BoTorch outcome transforms using outcome_transform_classes and outcome_transform_options plus other input available to Surrogate during Surrogate.fit(). Currently only support Standardize. - Updated the storage code to be able to encode/decode outcome_transform_classes and outcome_transform_options, and to be backward compatible. Reviewed By: saitcakmak Differential Revision: D49609383 fbshipit-source-id: 085f39acd68bf1017cccdf6fbd023a393333b9ad
- Loading branch information
1 parent
af543fe
commit bf74e8e
Showing
13 changed files
with
372 additions
and
28 deletions.
There are no files selected for viewing
70 changes: 70 additions & 0 deletions
70
ax/models/torch/botorch_modular/input_constructors/outcome_transform.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
|
||
from __future__ import annotations | ||
|
||
from typing import Any, Dict, Optional, Type | ||
|
||
from ax.utils.common.typeutils import _argparse_type_encoder | ||
from botorch.models.transforms.outcome import OutcomeTransform, Standardize | ||
from botorch.utils.datasets import SupervisedDataset | ||
from botorch.utils.dispatcher import Dispatcher | ||
|
||
outcome_transform_argparse = Dispatcher( | ||
name="outcome_transform_argparse", encoder=_argparse_type_encoder | ||
) | ||
|
||
|
||
@outcome_transform_argparse.register(OutcomeTransform) | ||
def _outcome_transform_argparse_base( | ||
outcome_transform_class: Type[OutcomeTransform], | ||
dataset: Optional[SupervisedDataset] = None, | ||
outcome_transform_options: Optional[Dict[str, Any]] = None, | ||
) -> Dict[str, Any]: | ||
""" | ||
Extract the outcome transform kwargs from the given arguments. | ||
Args: | ||
outcome_transform_class: Outcome transform class. | ||
dataset: Dataset containing feature matrix and the response. | ||
outcome_transform_options: An optional dictionary of outcome transform options. | ||
This may include overrides for the above options. For example, when | ||
`outcome_transform_class` is Standardize this dictionary might include | ||
{ | ||
"m": 1, # the output dimension | ||
} | ||
See `botorch/models/transforms/outcome.py` for more options. | ||
Returns: | ||
A dictionary with outcome transform kwargs. | ||
""" | ||
return outcome_transform_options or {} | ||
|
||
|
||
@outcome_transform_argparse.register(Standardize) | ||
def _outcome_transform_argparse_standardize( | ||
outcome_transform_class: Type[Standardize], | ||
dataset: SupervisedDataset, | ||
outcome_transform_options: Optional[Dict[str, Any]] = None, | ||
) -> Dict[str, Any]: | ||
"""Extract the outcome transform kwargs form the given arguments. | ||
Args: | ||
outcome_transform_class: Outcome transform class, which is Standardize in this | ||
case. | ||
dataset: Dataset containing feature matrix and the response. | ||
outcome_transform_options: Outcome transform kwargs. | ||
See botorch.models.transforms.outcome.Standardize for all available options | ||
Returns: | ||
A dictionary with outcome transform kwargs. | ||
""" | ||
|
||
outcome_transform_options = outcome_transform_options or {} | ||
m = dataset.Y.shape[-1] | ||
outcome_transform_options.setdefault("m", m) | ||
|
||
return outcome_transform_options |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# | ||
# This source code is licensed under the MIT license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import torch | ||
from ax.models.torch.botorch_modular.input_constructors.outcome_transform import ( | ||
outcome_transform_argparse, | ||
) | ||
from ax.utils.common.testutils import TestCase | ||
from botorch.models.transforms.outcome import OutcomeTransform, Standardize | ||
from botorch.utils.datasets import SupervisedDataset | ||
|
||
|
||
class DummyOutcomeTransform(OutcomeTransform): | ||
pass | ||
|
||
|
||
class OutcomeTransformArgparseTest(TestCase): | ||
def setUp(self) -> None: | ||
X = torch.randn((10, 4)) | ||
Y = torch.randn((10, 1)) | ||
self.dataset = SupervisedDataset( | ||
X=X, | ||
Y=Y, | ||
feature_names=["chicken", "eggs", "pigeons", "bunnies"], | ||
outcome_names=["farm"], | ||
) | ||
|
||
def test_notImplemented(self) -> None: | ||
with self.assertRaises(NotImplementedError) as e: | ||
outcome_transform_argparse[type(None)] | ||
self.assertTrue("Could not find signature for" in str(e)) | ||
|
||
def test_register(self) -> None: | ||
@outcome_transform_argparse.register(DummyOutcomeTransform) | ||
def _argparse(outcome_transform: DummyOutcomeTransform) -> None: | ||
pass | ||
|
||
self.assertEqual(_argparse, outcome_transform_argparse[DummyOutcomeTransform]) | ||
|
||
def test_argparse_outcome_transform(self) -> None: | ||
outcome_transform_kwargs_a = outcome_transform_argparse(OutcomeTransform) | ||
outcome_transform_kwargs_b = outcome_transform_argparse( | ||
OutcomeTransform, outcome_transform_options={"x": 5}, dataset=self.dataset | ||
) | ||
|
||
self.assertEqual(outcome_transform_kwargs_a, {}) | ||
self.assertEqual(outcome_transform_kwargs_b, {"x": 5}) | ||
|
||
def test_argparse_standardize(self) -> None: | ||
outcome_transform_kwargs_a = outcome_transform_argparse( | ||
Standardize, dataset=self.dataset | ||
) | ||
outcome_transform_kwargs_b = outcome_transform_argparse( | ||
Standardize, dataset=self.dataset, outcome_transform_options={"m": 10} | ||
) | ||
self.assertEqual(outcome_transform_kwargs_a, {"m": 1}) | ||
self.assertEqual(outcome_transform_kwargs_b, {"m": 10}) |
Oops, something went wrong.