From addf6a31c9d784e22bcc620cb0cde4e5dd203c0c Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Mon, 23 Oct 2023 22:48:14 -0700 Subject: [PATCH] Remove Surrogate.from_botorch Summary: This is buggy and unsupported. Reviewed By: Balandat Differential Revision: D50581677 --- ax/models/torch/botorch_modular/surrogate.py | 51 +++---------------- ax/models/torch/tests/test_surrogate.py | 19 +------ ...up_and_Usage_of_BoTorch_Models_in_Ax.ipynb | 23 --------- tutorials/modular_botax.ipynb | 36 +------------ 4 files changed, 10 insertions(+), 119 deletions(-) diff --git a/ax/models/torch/botorch_modular/surrogate.py b/ax/models/torch/botorch_modular/surrogate.py index 2ce0af804d8..70c6afc9037 100644 --- a/ax/models/torch/botorch_modular/surrogate.py +++ b/ax/models/torch/botorch_modular/surrogate.py @@ -14,7 +14,7 @@ import torch from ax.core.search_space import SearchSpaceDigest from ax.core.types import TCandidateMetadata -from ax.exceptions.core import UnsupportedError, UserInputError +from ax.exceptions.core import UserInputError from ax.models.model_utils import best_in_sample_point from ax.models.torch.botorch_modular.input_constructors.covar_modules import ( covar_module_argparse, @@ -159,10 +159,6 @@ def __init__( self._training_data: Optional[List[SupervisedDataset]] = None self._outcomes: Optional[List[str]] = None self._model: Optional[Model] = None - # Special setting for surrogates instantiated via `Surrogate.from_botorch`, - # to avoid re-constructing the underlying BoTorch model on `Surrogate.fit` - # when set to `False`. - self._constructed_manually: bool = False def __repr__(self) -> str: return ( @@ -212,22 +208,6 @@ def dtype(self) -> torch.dtype: def device(self) -> torch.device: return self.training_data[0].X.device - @classmethod - def from_botorch( - cls, - model: Model, - mll_class: Type[MarginalLogLikelihood] = ExactMarginalLogLikelihood, - ) -> Surrogate: - """Instantiate a `Surrogate` from a pre-instantiated Botorch `Model`.""" - surrogate = cls(botorch_model_class=model.__class__, mll_class=mll_class) - surrogate._model = model - # Temporarily disallowing `update` for surrogates instantiated from - # pre-made BoTorch `Model` instances to avoid reconstructing models - # that were likely pre-constructed for a reason (e.g. if this setup - # doesn't fully allow to constuct them). - surrogate._constructed_manually = True - return surrogate - def clone_reset(self) -> Surrogate: return self.__class__(**self._serialize_attributes_as_kwargs()) @@ -247,9 +227,6 @@ def construct( search_space_digest: Information about the search space used for inferring suitable botorch model class. """ - if self._constructed_manually: - logger.warning("Reconstructing a manually constructed `Model`.") - # To determine whether to use ModelList under the hood, we need to check for # the batched multi-output case, so we first see which model would be chosen # given the Yvars and the properties of data. @@ -535,19 +512,12 @@ def fit( state_dict: Optional state dict to load. refit: Whether to re-optimize model parameters. """ - if self._constructed_manually: - logger.debug( - "For manually constructed surrogates (via `Surrogate.from_botorch`), " - "`fit` skips setting the training data on model and only reoptimizes " - "its parameters if `refit=True`." - ) - else: - self.construct( - datasets=datasets, - metric_names=metric_names, - search_space_digest=search_space_digest, - ) - self._outcomes = metric_names + self.construct( + datasets=datasets, + metric_names=metric_names, + search_space_digest=search_space_digest, + ) + self._outcomes = metric_names if state_dict: self.model.load_state_dict(not_none(state_dict)) @@ -662,13 +632,6 @@ def _serialize_attributes_as_kwargs(self) -> Dict[str, Any]: """Serialize attributes of this surrogate, to be passed back to it as kwargs on reinstantiation. """ - if self._constructed_manually: - raise UnsupportedError( - "Surrogates constructed manually (ie Surrogate.from_botorch) may not " - "be serialized. If serialization is necessary please initialize from " - "the constructor." - ) - return { "botorch_model_class": self.botorch_model_class, "model_options": self.model_options, diff --git a/ax/models/torch/tests/test_surrogate.py b/ax/models/torch/tests/test_surrogate.py index e2b6725f9f7..4d82ec4b4bb 100644 --- a/ax/models/torch/tests/test_surrogate.py +++ b/ax/models/torch/tests/test_surrogate.py @@ -12,7 +12,7 @@ import numpy as np import torch from ax.core.search_space import RobustSearchSpaceDigest, SearchSpaceDigest -from ax.exceptions.core import UnsupportedError, UserInputError +from ax.exceptions.core import UserInputError from ax.models.torch.botorch_modular.acquisition import Acquisition from ax.models.torch.botorch_modular.surrogate import Surrogate from ax.models.torch.botorch_modular.utils import choose_model_class, fit_botorch_model @@ -304,15 +304,6 @@ def test_device_property(self) -> None: ) self.assertEqual(self.device, surrogate.device) - def test_from_botorch(self) -> None: - for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]: - surrogate_kwargs = botorch_model_class.construct_inputs( - self.training_data[0] - ) - surrogate = Surrogate.from_botorch(botorch_model_class(**surrogate_kwargs)) - self.assertIsInstance(surrogate.model, botorch_model_class) - self.assertTrue(surrogate._constructed_manually) - @patch(f"{CURRENT_PATH}.SaasFullyBayesianSingleTaskGP.__init__", return_value=None) @patch(f"{CURRENT_PATH}.SingleTaskGP.__init__", return_value=None) def test_construct(self, mock_GP: Mock, mock_SAAS: Mock) -> None: @@ -338,7 +329,6 @@ def test_construct(self, mock_GP: Mock, mock_SAAS: Mock) -> None: call_kwargs = mock_GPs[i].call_args[1] self.assertTrue(torch.equal(call_kwargs["train_X"], self.Xs[0])) self.assertTrue(torch.equal(call_kwargs["train_Y"], self.Ys[0])) - self.assertFalse(surrogate._constructed_manually) # Check that `model_options` passed to the `Surrogate` constructor are # properly propagated. @@ -591,13 +581,6 @@ def test_serialize_attributes_as_kwargs(self) -> None: } self.assertEqual(surrogate._serialize_attributes_as_kwargs(), expected) - with self.assertRaisesRegex( - UnsupportedError, "Surrogates constructed manually" - ): - surrogate, _ = self._get_surrogate(botorch_model_class=SingleTaskGP) - surrogate._constructed_manually = True - surrogate._serialize_attributes_as_kwargs() - def test_w_robust_digest(self) -> None: surrogate = Surrogate( botorch_model_class=SingleTaskGP, diff --git a/tutorials/Setup_and_Usage_of_BoTorch_Models_in_Ax.ipynb b/tutorials/Setup_and_Usage_of_BoTorch_Models_in_Ax.ipynb index 69a708da006..21ced8e2942 100644 --- a/tutorials/Setup_and_Usage_of_BoTorch_Models_in_Ax.ipynb +++ b/tutorials/Setup_and_Usage_of_BoTorch_Models_in_Ax.ipynb @@ -299,29 +299,6 @@ ")" ] }, - { - "cell_type": "markdown", - "metadata": { - "originalKey": "f3c267de-f9b2-4524-852b-156fc47d1745" - }, - "source": [ - "Alternatively, for BoTorch `Model`-s that require complex instantiation procedures, leverage the `from_BoTorch` instantiation method of `Surrogate`:" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "metadata": { - "originalKey": "905ba2bc-37e4-4a12-8442-1c772ffd15d0" - }, - "outputs": [], - "source": [ - "surrogate_from_botorch_model = Surrogate.from_BoTorch(\n", - " model=..., # BoTorch `Model` instance, with training data already set\n", - " mll_class=ExactMarginalLogLikelihood, # Optional, MLL class with which to optimize model parameters\n", - ")" - ] - }, { "cell_type": "markdown", "metadata": { diff --git a/tutorials/modular_botax.ipynb b/tutorials/modular_botax.ipynb index 1207165410d..7835915ec39 100644 --- a/tutorials/modular_botax.ipynb +++ b/tutorials/modular_botax.ipynb @@ -58,8 +58,7 @@ "1. **`BoTorchModel` = `Surrogate` + `Acquisition` (overview)**\n", " 1. Example with minimal options that uses the defaults\n", " 2. Example showing all possible options\n", - " 3. Using a pre-constructed BoTorch Model (e.g. in research or development)\n", - " 4. Surrogate and Acquisition Q&A\n", + " 3. Surrogate and Acquisition Q&A\n", "2. **I know which Botorch Model and AcquisitionFunction I'd like to combine in Ax. How do set this up?**\n", " 1. Making a `Surrogate` from BoTorch `Model`\n", " 2. Using an arbitrary BoTorch `AcquisitionFunction` in Ax\n", @@ -269,37 +268,6 @@ ")" ] }, - { - "cell_type": "markdown", - "id": "critical-receptor", - "metadata": { - "originalKey": "5b15f6d8-27a0-410e-95ff-4a304bf35498" - }, - "source": [ - "## 2C. `Surrogate` from pre-instantiated BoTorch `Model`\n", - "\n", - "Alternatively, for BoTorch `Model`-s that require complex instantiation procedures (or is in development stage), leverage the `from_botorch` instantiation method of Surrogate:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "fourth-moore", - "metadata": { - "originalKey": "ce873686-cedd-4f3a-9476-b6e78f1c3650" - }, - "outputs": [], - "source": [ - "from_botorch_model = BoTorchModel(\n", - " surrogate=Surrogate.from_botorch(\n", - " # Pre-constructed BoTorch `Model` instance, with training data already set\n", - " model=...,\n", - " # Optional, MLL class with which to optimize model parameters\n", - " mll_class=ExactMarginalLogLikelihood,\n", - " )\n", - ")" - ] - }, { "cell_type": "markdown", "id": "fourth-material", @@ -307,7 +275,7 @@ "originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e" }, "source": [ - "## 2D. `Surrogate` and `Acquisition` Q&A\n", + "## 2C. `Surrogate` and `Acquisition` Q&A\n", "\n", "**Why is the `surrogate` argument expected to be an instance, but `botorch_acqf_class` –– a class?** Because a BoTorch `AcquisitionFunction` object (and therefore its Ax wrapper, `Acquisition`) is ephemeral: it is constructed, immediately used, and destroyed during `BoTorchModel.gen`, so there is no reason to keep around an `Acquisition` instance. A `Surrogate`, on another hand, is kept in memory as long as its parent `BoTorchModel` is.\n", "\n",