diff --git a/ax/models/model_utils.py b/ax/models/model_utils.py index 5d572cd621f..cc683f8bc4d 100644 --- a/ax/models/model_utils.py +++ b/ax/models/model_utils.py @@ -9,21 +9,40 @@ import itertools import warnings from collections import defaultdict -from typing import Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Callable, Dict, List, Optional, Protocol, Set, Tuple, Union import numpy as np import torch from ax.core.search_space import SearchSpaceDigest from ax.core.types import TParamCounter -from ax.exceptions.core import SearchSpaceExhausted -from ax.models.torch_base import TorchModel +from ax.exceptions.core import SearchSpaceExhausted, UnsupportedError from ax.models.types import TConfig +from ax.utils.common.typeutils import checked_cast from botorch.acquisition.risk_measures import RiskMeasureMCObjective +from torch import Tensor Tensoray = Union[torch.Tensor, np.ndarray] +class TorchModelLike(Protocol): + """A protocol that stands in for ``TorchModel`` like objects that + have a ``predict`` method. + """ + + def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]: + """Predicts outcomes given an input tensor. + + Args: + X: A ``n x d`` tensor of input parameters. + + Returns: + Tensor: The predicted posterior mean as an ``n x o``-dim tensor. + Tensor: The predicted posterior covariance as a ``n x o x o``-dim tensor. + """ + ... + + DEFAULT_MAX_RS_DRAWS = 10000 @@ -227,7 +246,7 @@ def validate_bounds( def best_observed_point( - model: TorchModel, + model: TorchModelLike, bounds: List[Tuple[float, float]], objective_weights: Optional[Tensoray], outcome_constraints: Optional[Tuple[Tensoray, Tensoray]] = None, @@ -267,7 +286,7 @@ def best_observed_point( probability of feasibility (defaults 10k). Args: - model: Numpy or Torch model. + model: A Torch model or Surrogate. bounds: A list of (lower, upper) tuples for each feature. objective_weights: The objective is to maximize a weighted sum of the columns of f(x). These are the weights. @@ -303,7 +322,7 @@ def best_observed_point( def best_in_sample_point( Xs: Union[List[torch.Tensor], List[np.ndarray]], - model: TorchModel, + model: TorchModelLike, bounds: List[Tuple[float, float]], objective_weights: Optional[Tensoray], outcome_constraints: Optional[Tuple[Tensoray, Tensoray]] = None, @@ -344,7 +363,7 @@ def best_in_sample_point( Args: Xs: Training data for the points, among which to select the best. - model: Numpy or Torch model. + model: A Torch model or Surrogate. bounds: A list of (lower, upper) tuples for each feature. objective_weights: The objective is to maximize a weighted sum of the columns of f(x). These are the weights. @@ -379,7 +398,7 @@ def best_in_sample_point( # Get points observed for all objective and constraint outcomes if objective_weights is None: return None - objective_weights_np = as_array(objective_weights) + objective_weights_np = checked_cast(np.ndarray, as_array(objective_weights)) X_obs = get_observed( Xs=Xs, objective_weights=objective_weights, @@ -399,7 +418,7 @@ def best_in_sample_point( if isinstance(Xs[0], torch.Tensor): X_obs = X_obs.detach().clone() f, cov = as_array(model.predict(X_obs)) - obj = objective_weights_np @ f.transpose() # pyre-ignore + obj = objective_weights_np @ f.transpose() pfeas = np.ones_like(obj) if outcome_constraints is not None: A, b = as_array(outcome_constraints) # (m x j) and (m x 1) @@ -418,7 +437,8 @@ def best_in_sample_point( if B is None: B = obj.min() utility = (obj - B) * pfeas - # pyre-fixme[61]: `utility` may not be initialized here. + else: # pragma: no cover + raise UnsupportedError(f"Unknown best point method {method}.") i = np.argmax(utility) if utility[i] == -np.Inf: return None diff --git a/ax/models/torch/botorch_modular/model.py b/ax/models/torch/botorch_modular/model.py index 5b17835bd27..d96434c66a3 100644 --- a/ax/models/torch/botorch_modular/model.py +++ b/ax/models/torch/botorch_modular/model.py @@ -189,14 +189,11 @@ def __init__( f"{Keys.AUTOSET_SURROGATE}, these are reserved." ) - self._surrogates = {} - self.surrogate_specs = {} - if surrogate_specs is not None: - self.surrogate_specs: Dict[str, SurrogateSpec] = { - label: spec for label, spec in surrogate_specs.items() - } - elif surrogate is not None: + self.surrogate_specs = dict((surrogate_specs or {}).items()) + if surrogate is not None: self._surrogates = {Keys.ONLY_SURROGATE: surrogate} + else: + self._surrogates = {} self.acquisition_class = acquisition_class or Acquisition self.acquisition_options = acquisition_options or {} diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 716cbc8f606..225f9e8afad 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -25,7 +25,6 @@ from ax.models.torch.botorch_modular.input_constructors.input_transforms import ( input_transform_argparse, ) - from ax.models.torch.botorch_modular.utils import ( choose_model_class, convert_to_block_design, @@ -601,7 +600,6 @@ def fit( candidate_metadata: Optional[List[List[TCandidateMetadata]]] = None, state_dict: Optional[Dict[str, Tensor]] = None, refit: bool = True, - original_metric_names: Optional[List[str]] = None, **kwargs: Any, ) -> None: """Fits the underlying BoTorch ``Model`` to ``m`` outcomes. @@ -633,13 +631,6 @@ def fit( the order corresponding to the Xs. state_dict: Optional state dict to load. refit: Whether to re-optimize model parameters. - # TODO: we should refactor the fit() API to get rid of the metric_names - # and the concatenation hack that comes with it in BoTorchModel.fit() - # by attaching the individual metric_name to each dataset directly. - original_metric_names: sometimes the original list of metric_names - got tranformed into a different format before being passed down - into fit(). This arg preserves the original metric_names before - the transformation. """ if self._constructed_manually: logger.debug( @@ -656,11 +647,7 @@ def fit( search_space_digest=search_space_digest, **_kwargs, ) - self._outcomes = ( - original_metric_names - if original_metric_names is not None - else metric_names - ) + self._outcomes = metric_names if state_dict: self.model.load_state_dict(not_none(state_dict)) @@ -670,11 +657,9 @@ def fit( ) def predict(self, X: Tensor) -> Tuple[Tensor, Tensor]: - """Predicts outcomes given a model and input tensor. - + """Predicts outcomes given an input tensor. Args: - model: A botorch Model. X: A ``n x d`` tensor of input parameters. Returns: @@ -698,11 +683,6 @@ def best_in_sample_point( ) best_point_and_observed_value = best_in_sample_point( Xs=self.Xs, - # pyre-ignore[6]: `best_in_sample_point` currently expects a `TorchModel` - # as `model` kwarg, but only uses them for `predict` function, the - # signature for which is the same on this `Surrogate`. - # TODO: When we move `botorch_modular` directory to OSS, we will extend - # the annotation for `model` kwarg to accept `Surrogate` too. model=self, bounds=search_space_digest.bounds, objective_weights=torch_opt_config.objective_weights, diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index 736f7c92491..f5bb3423243 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -94,8 +94,6 @@ def setUp(self) -> None: bounds=self.bounds, target_fidelities={1: 1.0}, ) - self.metric_names = ["x_y"] - self.original_metric_names = ["x", "y"] self.fixed_features = {1: 2.0} self.refit = True self.objective_weights = torch.tensor( @@ -475,19 +473,6 @@ def test_fit( self.assertTrue(mock_fit.call_kwargs["jit_compile"]) mock_MLL.reset_mock() mock_fit.reset_mock() - # Check that the optional original_metric_names arg propagates - # through surrogate._outcomes. - surrogate.fit( - datasets=self.training_data, - metric_names=self.metric_names, - search_space_digest=self.search_space_digest, - refit=self.refit, - original_metric_names=self.original_metric_names, - ) - self.assertEqual(surrogate.outcomes, self.original_metric_names) - mock_state_dict.reset_mock() - mock_MLL.reset_mock() - mock_fit.reset_mock() # Should `load_state_dict` when `state_dict` is not `None` # and `refit` is `False`. state_dict = {"state_attribute": torch.zeros(1)} @@ -795,7 +780,7 @@ def test_fit_mixed(self) -> None: surrogate = Surrogate(allow_batched_models=False) surrogate.fit( datasets=training_data, - metric_names=self.original_metric_names, + metric_names=self.metric_names, search_space_digest=search_space_digest, ) self.assertIsInstance(surrogate.model, ModelListGP)