Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Surrogate.from_botorch #1927

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 7 additions & 44 deletions ax/models/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import torch
from ax.core.search_space import SearchSpaceDigest
from ax.core.types import TCandidateMetadata
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.exceptions.core import UserInputError
from ax.models.model_utils import best_in_sample_point
from ax.models.torch.botorch_modular.input_constructors.covar_modules import (
covar_module_argparse,
Expand Down Expand Up @@ -159,10 +159,6 @@ def __init__(
self._training_data: Optional[List[SupervisedDataset]] = None
self._outcomes: Optional[List[str]] = None
self._model: Optional[Model] = None
# Special setting for surrogates instantiated via `Surrogate.from_botorch`,
# to avoid re-constructing the underlying BoTorch model on `Surrogate.fit`
# when set to `False`.
self._constructed_manually: bool = False

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -212,22 +208,6 @@ def dtype(self) -> torch.dtype:
def device(self) -> torch.device:
return self.training_data[0].X.device

@classmethod
def from_botorch(
cls,
model: Model,
mll_class: Type[MarginalLogLikelihood] = ExactMarginalLogLikelihood,
) -> Surrogate:
"""Instantiate a `Surrogate` from a pre-instantiated Botorch `Model`."""
surrogate = cls(botorch_model_class=model.__class__, mll_class=mll_class)
surrogate._model = model
# Temporarily disallowing `update` for surrogates instantiated from
# pre-made BoTorch `Model` instances to avoid reconstructing models
# that were likely pre-constructed for a reason (e.g. if this setup
# doesn't fully allow to constuct them).
surrogate._constructed_manually = True
return surrogate

def clone_reset(self) -> Surrogate:
return self.__class__(**self._serialize_attributes_as_kwargs())

Expand All @@ -247,9 +227,6 @@ def construct(
search_space_digest: Information about the search space used for
inferring suitable botorch model class.
"""
if self._constructed_manually:
logger.warning("Reconstructing a manually constructed `Model`.")

# To determine whether to use ModelList under the hood, we need to check for
# the batched multi-output case, so we first see which model would be chosen
# given the Yvars and the properties of data.
Expand Down Expand Up @@ -535,19 +512,12 @@ def fit(
state_dict: Optional state dict to load.
refit: Whether to re-optimize model parameters.
"""
if self._constructed_manually:
logger.debug(
"For manually constructed surrogates (via `Surrogate.from_botorch`), "
"`fit` skips setting the training data on model and only reoptimizes "
"its parameters if `refit=True`."
)
else:
self.construct(
datasets=datasets,
metric_names=metric_names,
search_space_digest=search_space_digest,
)
self._outcomes = metric_names
self.construct(
datasets=datasets,
metric_names=metric_names,
search_space_digest=search_space_digest,
)
self._outcomes = metric_names

if state_dict:
self.model.load_state_dict(not_none(state_dict))
Expand Down Expand Up @@ -662,13 +632,6 @@ def _serialize_attributes_as_kwargs(self) -> Dict[str, Any]:
"""Serialize attributes of this surrogate, to be passed back to it
as kwargs on reinstantiation.
"""
if self._constructed_manually:
raise UnsupportedError(
"Surrogates constructed manually (ie Surrogate.from_botorch) may not "
"be serialized. If serialization is necessary please initialize from "
"the constructor."
)

return {
"botorch_model_class": self.botorch_model_class,
"model_options": self.model_options,
Expand Down
19 changes: 1 addition & 18 deletions ax/models/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import numpy as np
import torch
from ax.core.search_space import RobustSearchSpaceDigest, SearchSpaceDigest
from ax.exceptions.core import UnsupportedError, UserInputError
from ax.exceptions.core import UserInputError
from ax.models.torch.botorch_modular.acquisition import Acquisition
from ax.models.torch.botorch_modular.surrogate import Surrogate
from ax.models.torch.botorch_modular.utils import choose_model_class, fit_botorch_model
Expand Down Expand Up @@ -304,15 +304,6 @@ def test_device_property(self) -> None:
)
self.assertEqual(self.device, surrogate.device)

def test_from_botorch(self) -> None:
for botorch_model_class in [SaasFullyBayesianSingleTaskGP, SingleTaskGP]:
surrogate_kwargs = botorch_model_class.construct_inputs(
self.training_data[0]
)
surrogate = Surrogate.from_botorch(botorch_model_class(**surrogate_kwargs))
self.assertIsInstance(surrogate.model, botorch_model_class)
self.assertTrue(surrogate._constructed_manually)

@patch(f"{CURRENT_PATH}.SaasFullyBayesianSingleTaskGP.__init__", return_value=None)
@patch(f"{CURRENT_PATH}.SingleTaskGP.__init__", return_value=None)
def test_construct(self, mock_GP: Mock, mock_SAAS: Mock) -> None:
Expand All @@ -338,7 +329,6 @@ def test_construct(self, mock_GP: Mock, mock_SAAS: Mock) -> None:
call_kwargs = mock_GPs[i].call_args[1]
self.assertTrue(torch.equal(call_kwargs["train_X"], self.Xs[0]))
self.assertTrue(torch.equal(call_kwargs["train_Y"], self.Ys[0]))
self.assertFalse(surrogate._constructed_manually)

# Check that `model_options` passed to the `Surrogate` constructor are
# properly propagated.
Expand Down Expand Up @@ -591,13 +581,6 @@ def test_serialize_attributes_as_kwargs(self) -> None:
}
self.assertEqual(surrogate._serialize_attributes_as_kwargs(), expected)

with self.assertRaisesRegex(
UnsupportedError, "Surrogates constructed manually"
):
surrogate, _ = self._get_surrogate(botorch_model_class=SingleTaskGP)
surrogate._constructed_manually = True
surrogate._serialize_attributes_as_kwargs()

def test_w_robust_digest(self) -> None:
surrogate = Surrogate(
botorch_model_class=SingleTaskGP,
Expand Down
23 changes: 0 additions & 23 deletions tutorials/Setup_and_Usage_of_BoTorch_Models_in_Ax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -299,29 +299,6 @@
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"originalKey": "f3c267de-f9b2-4524-852b-156fc47d1745"
},
"source": [
"Alternatively, for BoTorch `Model`-s that require complex instantiation procedures, leverage the `from_BoTorch` instantiation method of `Surrogate`:"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"originalKey": "905ba2bc-37e4-4a12-8442-1c772ffd15d0"
},
"outputs": [],
"source": [
"surrogate_from_botorch_model = Surrogate.from_BoTorch(\n",
" model=..., # BoTorch `Model` instance, with training data already set\n",
" mll_class=ExactMarginalLogLikelihood, # Optional, MLL class with which to optimize model parameters\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
Expand Down
36 changes: 2 additions & 34 deletions tutorials/modular_botax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@
"1. **`BoTorchModel` = `Surrogate` + `Acquisition` (overview)**\n",
" 1. Example with minimal options that uses the defaults\n",
" 2. Example showing all possible options\n",
" 3. Using a pre-constructed BoTorch Model (e.g. in research or development)\n",
" 4. Surrogate and Acquisition Q&A\n",
" 3. Surrogate and Acquisition Q&A\n",
"2. **I know which Botorch Model and AcquisitionFunction I'd like to combine in Ax. How do set this up?**\n",
" 1. Making a `Surrogate` from BoTorch `Model`\n",
" 2. Using an arbitrary BoTorch `AcquisitionFunction` in Ax\n",
Expand Down Expand Up @@ -269,45 +268,14 @@
")"
]
},
{
"cell_type": "markdown",
"id": "critical-receptor",
"metadata": {
"originalKey": "5b15f6d8-27a0-410e-95ff-4a304bf35498"
},
"source": [
"## 2C. `Surrogate` from pre-instantiated BoTorch `Model`\n",
"\n",
"Alternatively, for BoTorch `Model`-s that require complex instantiation procedures (or is in development stage), leverage the `from_botorch` instantiation method of Surrogate:"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "fourth-moore",
"metadata": {
"originalKey": "ce873686-cedd-4f3a-9476-b6e78f1c3650"
},
"outputs": [],
"source": [
"from_botorch_model = BoTorchModel(\n",
" surrogate=Surrogate.from_botorch(\n",
" # Pre-constructed BoTorch `Model` instance, with training data already set\n",
" model=...,\n",
" # Optional, MLL class with which to optimize model parameters\n",
" mll_class=ExactMarginalLogLikelihood,\n",
" )\n",
")"
]
},
{
"cell_type": "markdown",
"id": "fourth-material",
"metadata": {
"originalKey": "db0feafe-8af9-40a3-9f67-72c7d1fd808e"
},
"source": [
"## 2D. `Surrogate` and `Acquisition` Q&A\n",
"## 2C. `Surrogate` and `Acquisition` Q&A\n",
"\n",
"**Why is the `surrogate` argument expected to be an instance, but `botorch_acqf_class` –– a class?** Because a BoTorch `AcquisitionFunction` object (and therefore its Ax wrapper, `Acquisition`) is ephemeral: it is constructed, immediately used, and destroyed during `BoTorchModel.gen`, so there is no reason to keep around an `Acquisition` instance. A `Surrogate`, on another hand, is kept in memory as long as its parent `BoTorchModel` is.\n",
"\n",
Expand Down