Skip to content

Commit

Permalink
Minor clean up of MBM & helpers (#1875)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1875

Some no-op changes to Surrogate & helpers to clear out pyre & flake8 warnings.

Removes `original_metric_names` kwarg, which was introduced as a "hack", only for its only usage to disappear after a few months.

Reviewed By: esantorella

Differential Revision: D49657636

fbshipit-source-id: 3559a3171b1fa74af548d548c3efa40961e07eba
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Sep 27, 2023
1 parent f3216b5 commit 6c9cdba
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 55 deletions.
40 changes: 30 additions & 10 deletions ax/models/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,40 @@
import itertools
import warnings
from collections import defaultdict
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
from typing import Callable, Dict, List, Optional, Protocol, Set, Tuple, Union

import numpy as np
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TParamCounter
from ax.exceptions.core import SearchSpaceExhausted
from ax.models.torch_base import TorchModel
from ax.exceptions.core import SearchSpaceExhausted, UnsupportedError
from ax.models.types import TConfig
from ax.utils.common.typeutils import checked_cast
from botorch.acquisition.risk_measures import RiskMeasureMCObjective
from torch import Tensor


Tensoray = Union[torch.Tensor, np.ndarray]


class TorchModelLike(Protocol):
"""A protocol that stands in for ``TorchModel`` like objects that
have a ``predict`` method.
"""

def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
"""Predicts outcomes given an input tensor.
Args:
X: A ``n x d`` tensor of input parameters.
Returns:
Tensor: The predicted posterior mean as an ``n x o``-dim tensor.
Tensor: The predicted posterior covariance as a ``n x o x o``-dim tensor.
"""
...


DEFAULT_MAX_RS_DRAWS = 10000


Expand Down Expand Up @@ -227,7 +246,7 @@ def validate_bounds(


def best_observed_point(
model: TorchModel,
model: TorchModelLike,
bounds: List[Tuple[float, float]],
objective_weights: Optional[Tensoray],
outcome_constraints: Optional[Tuple[Tensoray, Tensoray]] = None,
Expand Down Expand Up @@ -267,7 +286,7 @@ def best_observed_point(
probability of feasibility (defaults 10k).
Args:
model: Numpy or Torch model.
model: A Torch model or Surrogate.
bounds: A list of (lower, upper) tuples for each feature.
objective_weights: The objective is to maximize a weighted sum of
the columns of f(x). These are the weights.
Expand Down Expand Up @@ -303,7 +322,7 @@ def best_observed_point(

def best_in_sample_point(
Xs: Union[List[torch.Tensor], List[np.ndarray]],
model: TorchModel,
model: TorchModelLike,
bounds: List[Tuple[float, float]],
objective_weights: Optional[Tensoray],
outcome_constraints: Optional[Tuple[Tensoray, Tensoray]] = None,
Expand Down Expand Up @@ -344,7 +363,7 @@ def best_in_sample_point(
Args:
Xs: Training data for the points, among which to select the best.
model: Numpy or Torch model.
model: A Torch model or Surrogate.
bounds: A list of (lower, upper) tuples for each feature.
objective_weights: The objective is to maximize a weighted sum of
the columns of f(x). These are the weights.
Expand Down Expand Up @@ -379,7 +398,7 @@ def best_in_sample_point(
# Get points observed for all objective and constraint outcomes
if objective_weights is None:
return None
objective_weights_np = as_array(objective_weights)
objective_weights_np = checked_cast(np.ndarray, as_array(objective_weights))
X_obs = get_observed(
Xs=Xs,
objective_weights=objective_weights,
Expand All @@ -399,7 +418,7 @@ def best_in_sample_point(
if isinstance(Xs[0], torch.Tensor):
X_obs = X_obs.detach().clone()
f, cov = as_array(model.predict(X_obs))
obj = objective_weights_np @ f.transpose() # pyre-ignore
obj = objective_weights_np @ f.transpose()
pfeas = np.ones_like(obj)
if outcome_constraints is not None:
A, b = as_array(outcome_constraints) # (m x j) and (m x 1)
Expand All @@ -418,7 +437,8 @@ def best_in_sample_point(
if B is None:
B = obj.min()
utility = (obj - B) * pfeas
# pyre-fixme[61]: `utility` may not be initialized here.
else: # pragma: no cover
raise UnsupportedError(f"Unknown best point method {method}.")
i = np.argmax(utility)
if utility[i] == -np.Inf:
return None
Expand Down
11 changes: 4 additions & 7 deletions ax/models/torch/botorch_modular/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,14 +189,11 @@ def __init__(
f"{Keys.AUTOSET_SURROGATE}, these are reserved."
)

self._surrogates = {}
self.surrogate_specs = {}
if surrogate_specs is not None:
self.surrogate_specs: Dict[str, SurrogateSpec] = {
label: spec for label, spec in surrogate_specs.items()
}
elif surrogate is not None:
self.surrogate_specs = dict((surrogate_specs or {}).items())
if surrogate is not None:
self._surrogates = {Keys.ONLY_SURROGATE: surrogate}
else:
self._surrogates = {}

self.acquisition_class = acquisition_class or Acquisition
self.acquisition_options = acquisition_options or {}
Expand Down
24 changes: 2 additions & 22 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from ax.models.torch.botorch_modular.input_constructors.input_transforms import (
input_transform_argparse,
)

from ax.models.torch.botorch_modular.utils import (
choose_model_class,
convert_to_block_design,
Expand Down Expand Up @@ -601,7 +600,6 @@ def fit(
candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None,
state_dict: Optional[Dict[str, Tensor]] = None,
refit: bool = True,
original_metric_names: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
"""Fits the underlying BoTorch ``Model`` to ``m`` outcomes.
Expand Down Expand Up @@ -633,13 +631,6 @@ def fit(
the order corresponding to the Xs.
state_dict: Optional state dict to load.
refit: Whether to re-optimize model parameters.
# TODO: we should refactor the fit() API to get rid of the metric_names
# and the concatenation hack that comes with it in BoTorchModel.fit()
# by attaching the individual metric_name to each dataset directly.
original_metric_names: sometimes the original list of metric_names
got tranformed into a different format before being passed down
into fit(). This arg preserves the original metric_names before
the transformation.
"""
if self._constructed_manually:
logger.debug(
Expand All @@ -656,11 +647,7 @@ def fit(
search_space_digest=search_space_digest,
**_kwargs,
)
self._outcomes = (
original_metric_names
if original_metric_names is not None
else metric_names
)
self._outcomes = metric_names

if state_dict:
self.model.load_state_dict(not_none(state_dict))
Expand All @@ -670,11 +657,9 @@ def fit(
)

def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]:
"""Predicts outcomes given a model and input tensor.
"""Predicts outcomes given an input tensor.
Args:
model: A botorch Model.
X: A ``n x d`` tensor of input parameters.
Returns:
Expand All @@ -698,11 +683,6 @@ def best_in_sample_point(
)
best_point_and_observed_value = best_in_sample_point(
Xs=self.Xs,
# pyre-ignore[6]: `best_in_sample_point` currently expects a `TorchModel`
# as `model` kwarg, but only uses them for `predict` function, the
# signature for which is the same on this `Surrogate`.
# TODO: When we move `botorch_modular` directory to OSS, we will extend
# the annotation for `model` kwarg to accept `Surrogate` too.
model=self,
bounds=search_space_digest.bounds,
objective_weights=torch_opt_config.objective_weights,
Expand Down
17 changes: 1 addition & 16 deletions ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ def setUp(self) -> None:
bounds=self.bounds,
target_fidelities={1: 1.0},
)
self.metric_names = ["x_y"]
self.original_metric_names = ["x", "y"]
self.fixed_features = {1: 2.0}
self.refit = True
self.objective_weights = torch.tensor(
Expand Down Expand Up @@ -475,19 +473,6 @@ def test_fit(
self.assertTrue(mock_fit.call_kwargs["jit_compile"])
mock_MLL.reset_mock()
mock_fit.reset_mock()
# Check that the optional original_metric_names arg propagates
# through surrogate._outcomes.
surrogate.fit(
datasets=self.training_data,
metric_names=self.metric_names,
search_space_digest=self.search_space_digest,
refit=self.refit,
original_metric_names=self.original_metric_names,
)
self.assertEqual(surrogate.outcomes, self.original_metric_names)
mock_state_dict.reset_mock()
mock_MLL.reset_mock()
mock_fit.reset_mock()
# Should `load_state_dict` when `state_dict` is not `None`
# and `refit` is `False`.
state_dict = {"state_attribute": torch.zeros(1)}
Expand Down Expand Up @@ -795,7 +780,7 @@ def test_fit_mixed(self) -> None:
surrogate = Surrogate(allow_batched_models=False)
surrogate.fit(
datasets=training_data,
metric_names=self.original_metric_names,
metric_names=self.metric_names,
search_space_digest=search_space_digest,
)
self.assertIsInstance(surrogate.model, ModelListGP)
Expand Down

0 comments on commit 6c9cdba

Please sign in to comment.