Skip to content

Commit

Permalink
outcome transform to outcome classes + options (#1880)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1880

- Make Surrogate and SurrogateSpec take outcome_transform_classes and outcome_transform_options as inputs instead of BoTorch outcome_transform
- Construct BoTorch outcome transforms using outcome_transform_classes and outcome_transform_options plus other input available to Surrogate during Surrogate.fit(). Currently only support Standardize.
- Updated the storage code to be able to encode/decode outcome_transform_classes and outcome_transform_options, and to be backward compatible.

Reviewed By: saitcakmak

Differential Revision: D49609383

fbshipit-source-id: 085f39acd68bf1017cccdf6fbd023a393333b9ad
  • Loading branch information
Susan Xia authored and facebook-github-bot committed Sep 28, 2023
1 parent af543fe commit bf74e8e
Show file tree
Hide file tree
Showing 13 changed files with 372 additions and 28 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.


from __future__ import annotations

from typing import Any, Dict, Optional, Type

from ax.utils.common.typeutils import _argparse_type_encoder
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.utils.datasets import SupervisedDataset
from botorch.utils.dispatcher import Dispatcher

outcome_transform_argparse = Dispatcher(
name="outcome_transform_argparse", encoder=_argparse_type_encoder
)


@outcome_transform_argparse.register(OutcomeTransform)
def _outcome_transform_argparse_base(
outcome_transform_class: Type[OutcomeTransform],
dataset: Optional[SupervisedDataset] = None,
outcome_transform_options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Extract the outcome transform kwargs from the given arguments.
Args:
outcome_transform_class: Outcome transform class.
dataset: Dataset containing feature matrix and the response.
outcome_transform_options: An optional dictionary of outcome transform options.
This may include overrides for the above options. For example, when
`outcome_transform_class` is Standardize this dictionary might include
{
"m": 1, # the output dimension
}
See `botorch/models/transforms/outcome.py` for more options.
Returns:
A dictionary with outcome transform kwargs.
"""
return outcome_transform_options or {}


@outcome_transform_argparse.register(Standardize)
def _outcome_transform_argparse_standardize(
outcome_transform_class: Type[Standardize],
dataset: SupervisedDataset,
outcome_transform_options: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""Extract the outcome transform kwargs form the given arguments.
Args:
outcome_transform_class: Outcome transform class, which is Standardize in this
case.
dataset: Dataset containing feature matrix and the response.
outcome_transform_options: Outcome transform kwargs.
See botorch.models.transforms.outcome.Standardize for all available options
Returns:
A dictionary with outcome transform kwargs.
"""

outcome_transform_options = outcome_transform_options or {}
m = dataset.Y.shape[-1]
outcome_transform_options.setdefault("m", m)

return outcome_transform_options
7 changes: 5 additions & 2 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ class SurrogateSpec:

input_transform_classes: Optional[List[Type[InputTransform]]] = None
input_transform_options: Optional[Dict[str, Dict[str, Any]]] = None
outcome_transform: Optional[OutcomeTransform] = None

outcome_transform_classes: Optional[List[Type[OutcomeTransform]]] = None
outcome_transform_options: Optional[Dict[str, Dict[str, Any]]] = None

allow_batched_models: bool = True

Expand Down Expand Up @@ -304,7 +306,8 @@ def fit(
likelihood_options=spec.likelihood_kwargs,
input_transform_classes=spec.input_transform_classes,
input_transform_options=spec.input_transform_options,
outcome_transform=spec.outcome_transform,
outcome_transform_classes=spec.outcome_transform_classes,
outcome_transform_options=spec.outcome_transform_options,
allow_batched_models=spec.allow_batched_models,
)
for label, spec in self.surrogate_specs.items()
Expand Down
94 changes: 83 additions & 11 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@
from ax.models.torch.botorch_modular.input_constructors.input_transforms import (
input_transform_argparse,
)
from ax.models.torch.botorch_modular.input_constructors.outcome_transform import (
outcome_transform_argparse,
)

from ax.models.torch.botorch_modular.utils import (
choose_model_class,
convert_to_block_design,
Expand All @@ -50,7 +54,8 @@
InputPerturbation,
InputTransform,
)
from botorch.models.transforms.outcome import OutcomeTransform
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform

from botorch.utils.datasets import RankingDataset, SupervisedDataset
from gpytorch.kernels import Kernel
from gpytorch.likelihoods.likelihood import Likelihood
Expand Down Expand Up @@ -84,9 +89,18 @@ class Surrogate(Base):
``Model`` constructed during ``Surrogate.fit``.
mll_class: ``MarginalLogLikelihood`` class to use for model-fitting.
mll_options: Dictionary of options / kwargs for the MLL.
outcome_transform: BoTorch outcome transforms. Passed down to the
BoTorch ``Model``. Multiple outcome transforms can be chained
outcome_transform_classes: List of BoTorch outcome transforms classes. Passed
down to the BoTorch ``Model``. Multiple outcome transforms can be chained
together using ``ChainedOutcomeTransform``.
outcome_transform_options: Outcome transform classes kwargs. The keys are
class string names and the values are dictionaries of outcome transform
kwargs. For example,
`
outcome_transform_classes = [Standardize]
outcome_transform_options = {
"Standardize": {"m": 1},
`
For more options see `botorch/models/transforms/outcome.py`.
input_transform_classes: List of BoTorch input transforms classes.
Passed down to the BoTorch ``Model``. Multiple input transforms
will be chained together using ``ChainedInputTransform``.
Expand Down Expand Up @@ -116,7 +130,8 @@ class string names and the values are dictionaries of input transform
model_options: Dict[str, Any]
mll_class: Type[MarginalLogLikelihood]
mll_options: Dict[str, Any]
outcome_transform: Optional[OutcomeTransform] = None
outcome_transform_classes: Optional[List[Type[OutcomeTransform]]] = None
outcome_transform_options: Optional[Dict[str, Dict[str, Any]]] = None
input_transform_classes: Optional[List[Type[InputTransform]]] = None
input_transform_options: Optional[Dict[str, Dict[str, Any]]] = None
covar_module_class: Optional[Type[Kernel]] = None
Expand All @@ -139,7 +154,8 @@ def __init__(
model_options: Optional[Dict[str, Any]] = None,
mll_class: Type[MarginalLogLikelihood] = ExactMarginalLogLikelihood,
mll_options: Optional[Dict[str, Any]] = None,
outcome_transform: Optional[OutcomeTransform] = None,
outcome_transform_classes: Optional[List[Type[OutcomeTransform]]] = None,
outcome_transform_options: Optional[Dict[str, Dict[str, Any]]] = None,
input_transform_classes: Optional[List[Type[InputTransform]]] = None,
input_transform_options: Optional[Dict[str, Dict[str, Any]]] = None,
covar_module_class: Optional[Type[Kernel]] = None,
Expand All @@ -152,7 +168,8 @@ def __init__(
self.model_options = model_options or {}
self.mll_class = mll_class
self.mll_options = mll_options or {}
self.outcome_transform = outcome_transform
self.outcome_transform_classes = outcome_transform_classes
self.outcome_transform_options = outcome_transform_options or {}
self.input_transform_classes = input_transform_classes
self.input_transform_options = input_transform_options or {}
self.covar_module_class = covar_module_class
Expand All @@ -166,7 +183,7 @@ def __repr__(self) -> str:
f"<{self.__class__.__name__}"
f" botorch_model_class={self.botorch_model_class} "
f"mll_class={self.mll_class} "
f"outcome_transform={self.outcome_transform} "
f"outcome_transform_classes={self.outcome_transform_classes} "
f"input_transform_classes={self.input_transform_classes} "
)

Expand Down Expand Up @@ -352,7 +369,12 @@ def _construct_model(
None,
],
["likelihood", self.likelihood_class, self.likelihood_options, None],
["outcome_transform", None, None, self.outcome_transform],
[
"outcome_transform",
self.outcome_transform_classes,
self.outcome_transform_options,
None,
],
[
"input_transform",
self.input_transform_classes,
Expand Down Expand Up @@ -451,9 +473,9 @@ class can be automatically selected.
],
[
"outcome_transform",
self.outcome_transform_classes,
deepcopy(self.outcome_transform_options),
None,
None,
deepcopy(self.outcome_transform),
],
[
"input_transform",
Expand Down Expand Up @@ -521,6 +543,15 @@ def _set_formatted_inputs(
search_space_digest=search_space_digest,
)

elif input_name == "outcome_transform":

formatted_model_inputs[
input_name
] = self._make_botorch_outcome_transform(
input_classes=input_class,
input_options=input_options,
dataset=dataset,
)
else:
formatted_model_inputs[input_name] = input_class(**input_options)

Expand Down Expand Up @@ -592,6 +623,46 @@ def _make_botorch_input_transform(

return input_instance

def _make_botorch_outcome_transform(
self,
input_classes: List[Type[OutcomeTransform]],
input_options: Dict[str, Dict[str, Any]],
dataset: SupervisedDataset,
) -> OutcomeTransform:
"""
Makes a BoTorch outcome transform from the provided classes and options.
"""
if not (
isinstance(input_classes, list)
and all(issubclass(c, OutcomeTransform) for c in input_classes)
):
raise UserInputError("Expected a list of outcome transforms.")

outcome_transform_kwargs = [
outcome_transform_argparse(
input_class,
outcome_transform_options=input_options.get(input_class.__name__, {}),
dataset=dataset,
)
for input_class in input_classes
]

outcome_transforms = [
input_class(**single_outcome_transform_kwargs)
for input_class, single_outcome_transform_kwargs in zip(
input_classes, outcome_transform_kwargs
)
]

outcome_transform_instance = (
ChainedOutcomeTransform(
**{f"otf{i}": otf for i, otf in enumerate(outcome_transforms)}
)
if len(outcome_transforms) > 1
else outcome_transforms[0]
)
return outcome_transform_instance

def fit(
self,
datasets: List[SupervisedDataset],
Expand Down Expand Up @@ -823,7 +894,8 @@ def _serialize_attributes_as_kwargs(self) -> Dict[str, Any]:
"model_options": self.model_options,
"mll_class": self.mll_class,
"mll_options": self.mll_options,
"outcome_transform": self.outcome_transform,
"outcome_transform_classes": self.outcome_transform_classes,
"outcome_transform_options": self.outcome_transform_options,
"input_transform_classes": self.input_transform_classes,
"input_transform_options": self.input_transform_options,
"covar_module_class": self.covar_module_class,
Expand Down
6 changes: 4 additions & 2 deletions ax/models/torch/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,8 @@ def test_surrogate_model_options_propagation(
likelihood_options=None,
input_transform_classes=None,
input_transform_options=None,
outcome_transform=None,
outcome_transform_classes=None,
outcome_transform_options=None,
allow_batched_models=True,
)

Expand All @@ -875,7 +876,8 @@ def test_surrogate_options_propagation(self, _: Mock, mock_init: Mock) -> None:
likelihood_options=None,
input_transform_classes=None,
input_transform_options=None,
outcome_transform=None,
outcome_transform_classes=None,
outcome_transform_options=None,
allow_batched_models=False,
)

Expand Down
59 changes: 59 additions & 0 deletions ax/models/torch/tests/test_outcome_transform_argparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import torch
from ax.models.torch.botorch_modular.input_constructors.outcome_transform import (
outcome_transform_argparse,
)
from ax.utils.common.testutils import TestCase
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.utils.datasets import SupervisedDataset


class DummyOutcomeTransform(OutcomeTransform):
pass


class OutcomeTransformArgparseTest(TestCase):
def setUp(self) -> None:
X = torch.randn((10, 4))
Y = torch.randn((10, 1))
self.dataset = SupervisedDataset(
X=X,
Y=Y,
feature_names=["chicken", "eggs", "pigeons", "bunnies"],
outcome_names=["farm"],
)

def test_notImplemented(self) -> None:
with self.assertRaises(NotImplementedError) as e:
outcome_transform_argparse[type(None)]
self.assertTrue("Could not find signature for" in str(e))

def test_register(self) -> None:
@outcome_transform_argparse.register(DummyOutcomeTransform)
def _argparse(outcome_transform: DummyOutcomeTransform) -> None:
pass

self.assertEqual(_argparse, outcome_transform_argparse[DummyOutcomeTransform])

def test_argparse_outcome_transform(self) -> None:
outcome_transform_kwargs_a = outcome_transform_argparse(OutcomeTransform)
outcome_transform_kwargs_b = outcome_transform_argparse(
OutcomeTransform, outcome_transform_options={"x": 5}, dataset=self.dataset
)

self.assertEqual(outcome_transform_kwargs_a, {})
self.assertEqual(outcome_transform_kwargs_b, {"x": 5})

def test_argparse_standardize(self) -> None:
outcome_transform_kwargs_a = outcome_transform_argparse(
Standardize, dataset=self.dataset
)
outcome_transform_kwargs_b = outcome_transform_argparse(
Standardize, dataset=self.dataset, outcome_transform_options={"m": 10}
)
self.assertEqual(outcome_transform_kwargs_a, {"m": 1})
self.assertEqual(outcome_transform_kwargs_b, {"m": 10})
Loading

0 comments on commit bf74e8e

Please sign in to comment.