diff --git a/botorch/fit.py b/botorch/fit.py index aff72bd9b7..1b2ab0fc85 100644 --- a/botorch/fit.py +++ b/botorch/fit.py @@ -9,6 +9,7 @@ from __future__ import annotations import logging +from copy import deepcopy from functools import partial from itertools import filterfalse from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union @@ -118,10 +119,11 @@ def _fit_fallback( __: Type[object], *, closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None, - optimizer: Optional[Callable] = fit_gpytorch_mll_scipy, + optimizer: Callable = fit_gpytorch_mll_scipy, closure_kwargs: Optional[Dict[str, Any]] = None, optimizer_kwargs: Optional[Dict[str, Any]] = None, max_attempts: int = 5, + pick_best_of_all_attempts: bool = False, warning_handler: Callable[[WarningMessage], bool] = DEFAULT_WARNING_HANDLER, caught_exception_types: Tuple[Type[BaseException], ...] = (NotPSDError,), **ignore: Any, @@ -137,11 +139,20 @@ def _fit_fallback( closure: Forward-backward closure for obtaining objective values and gradients. Responsible for setting parameters' `grad` attributes. If no closure is provided, one will be obtained by calling `get_loss_closure_with_grads`. - optimizer: The underlying optimization algorithm to run. + optimizer: The underlying optimization algorithm to run. Should return + an `OptimizationResult` object, whose `fval` field records the negative + MLL value. Defaults to `fit_gpytorch_mll_scipy`. closure_kwargs: Keyword arguments passed to `closure`. optimizer_kwargs: Keyword arguments passed to `optimizer`. max_attempts: The maximum number of fit attempts allowed. The attempt budget is NOT shared between calls to this method. + pick_best_of_all_attempts: If True, the model will be fit `max_attempts` times, + and the attempt that produces largest MLL value will be returned. + First attempt uses the initial hyper parameter values, the subsequent + attempts will call `sample_all_priors` to sample the initial values. + If any attempt produces an error, the resulting parameters are discarded. + If optimizer timeout is used, the `timeout_sec` will be used as is for + each attempt, and it should be manually adjusted accordingly. warning_handler: A function used to filter warnings produced when calling `optimizer`. Any unfiltered warnings (those for which `warning_handler` returns `False`) will be rethrown and trigger a model fitting retry. @@ -168,6 +179,9 @@ def _fit_fallback( if closure_kwargs is not None: closure = partial(closure, **closure_kwargs) + # Record best MLL & corresponding state dict. + best_mll: float = -float("inf") + best_state_dict = None # Attempt to fit the model for attempt in range(1, 1 + max_attempts): # Wrap with rollback contextmanager so that each loop iteration reloads the @@ -187,33 +201,56 @@ def _fit_fallback( # Fit the model with catch_warnings(record=True) as warning_list, debug(True): simplefilter("always", category=OptimizationWarning) - optimizer(mll, closure=closure, **optimizer_kwargs) + result = optimizer(mll, closure=closure, **optimizer_kwargs) - # Resolved warnings and determine whether or not to retry - done = True + # Resolve warnings and determine whether or not to retry + success = True for w in filterfalse(warning_handler, warning_list): warn_explicit(str(w.message), w.category, w.filename, w.lineno) - done = False + success = False - if done: + if success and not pick_best_of_all_attempts: + # If not picking best of all attempts, return the first + # successful attempt. ckpt.clear() # do not rollback upon exiting return mll.eval() - - # Ensure mll is in the right mode if fitting failed + elif success: + # Update best MLL and corresponding state dict. + # Optimizers minimize negative MLL, so we negate fval. + current_mll = -result.fval + if current_mll > best_mll: + best_mll = current_mll + # Deepcopy is important here, otherwise they get updated. + best_state_dict = deepcopy(mll.state_dict()) + message = f"Fit attempt #{attempt}: New best MLL: {best_mll}." + else: + message = ( + f"Fit attempt #{attempt}: Current MLL {current_mll} did " + f"not beat best MLL so far {best_mll}." + ) + logging.log(logging.DEBUG, msg=message) + + # Ensure mll is in the right mode if going for another attempt. mll = mll if mll.training else mll.train() - logging.log( - logging.DEBUG, - f"Fit attempt #{attempt} of {max_attempts} triggered retry policy" - f"{'.' if attempt == max_attempts else '; retrying...'}", - ) + if not success: + logging.log( + logging.DEBUG, + f"Fit attempt #{attempt} of {max_attempts} triggered retry " + f"policy {'.' if attempt == max_attempts else '; retrying...'}", + ) except caught_exception_types as err: logging.log( logging.DEBUG, - f"Fit attempt #{attempt} of {max_attempts} failed with exception: " + f"Fit attempt #{attempt} of {max_attempts} failed with exception:\n" f"{err}", ) + # If picking best of all attempts, return MLL with best state dict. + if best_state_dict is not None: + mll.load_state_dict(best_state_dict) + return mll.eval() + msg = "All attempts to fit the model have failed." if debug.off(): msg = msg + " For more information, try enabling botorch.settings.debug mode." diff --git a/botorch/models/fully_bayesian.py b/botorch/models/fully_bayesian.py index b50458d6ae..c3de0befc4 100644 --- a/botorch/models/fully_bayesian.py +++ b/botorch/models/fully_bayesian.py @@ -114,7 +114,7 @@ class PyroModel: def set_inputs( self, train_X: Tensor, train_Y: Tensor, train_Yvar: Optional[Tensor] = None - ): + ) -> None: """Set the training data. Args: @@ -162,7 +162,7 @@ class SaasPyroModel(PyroModel): def set_inputs( self, train_X: Tensor, train_Y: Tensor, train_Yvar: Optional[Tensor] = None - ): + ) -> None: super().set_inputs(train_X, train_Y, train_Yvar) self.ard_num_dims = self.train_X.shape[-1] @@ -394,7 +394,7 @@ def __init__( pyro_model.set_inputs( train_X=transformed_X, train_Y=train_Y, train_Yvar=train_Yvar ) - self.pyro_model = pyro_model + self.pyro_model: PyroModel = pyro_model if outcome_transform is not None: self.outcome_transform = outcome_transform if input_transform is not None: diff --git a/botorch/models/fully_bayesian_multitask.py b/botorch/models/fully_bayesian_multitask.py index b8f6064721..bdc5a48723 100644 --- a/botorch/models/fully_bayesian_multitask.py +++ b/botorch/models/fully_bayesian_multitask.py @@ -48,7 +48,7 @@ def set_inputs( train_Yvar: Optional[Tensor], task_feature: int, task_rank: Optional[int] = None, - ): + ) -> None: """Set the training data. Args: @@ -276,7 +276,7 @@ def __init__( task_feature=task_feature, task_rank=self._rank, ) - self.pyro_model = pyro_model + self.pyro_model: MultitaskSaasPyroModel = pyro_model if outcome_transform is not None: self.outcome_transform = outcome_transform if input_transform is not None: diff --git a/botorch/models/gp_regression.py b/botorch/models/gp_regression.py index a942224a8d..4abf32e663 100644 --- a/botorch/models/gp_regression.py +++ b/botorch/models/gp_regression.py @@ -201,7 +201,7 @@ def __init__( } if train_Yvar is None: self._subset_batch_dict["likelihood.noise_covar.raw_noise"] = -2 - self.covar_module = covar_module + self.covar_module: Module = covar_module # TODO: Allow subsetting of other covar modules if outcome_transform is not None: self.outcome_transform = outcome_transform diff --git a/test/test_fit.py b/test/test_fit.py index 47eccbec69..9e9ac6d644 100644 --- a/test/test_fit.py +++ b/test/test_fit.py @@ -6,6 +6,7 @@ import math from contextlib import ExitStack, nullcontext +from copy import deepcopy from itertools import filterfalse, product from typing import Callable, Iterable, Optional from unittest.mock import MagicMock, patch @@ -19,8 +20,8 @@ from botorch.models import SingleTaskGP, SingleTaskVariationalGP from botorch.models.transforms.input import Normalize from botorch.models.transforms.outcome import Standardize - from botorch.optim.closures import get_loss_closure_with_grads +from botorch.optim.core import OptimizationResult, OptimizationStatus from botorch.optim.fit import fit_gpytorch_mll_scipy, fit_gpytorch_mll_torch from botorch.optim.utils import get_data_loader from botorch.settings import debug @@ -45,8 +46,9 @@ def __init__( self.warnings = warnings self.exception = exception self.call_count = 0 + self.state_dicts = [] - def __call__(self, mll, closure: Optional[Callable] = None): + def __call__(self, mll, closure: Optional[Callable] = None) -> OptimizationResult: self.call_count += 1 for w in self.warnings: warn(str(w.message), w.category) @@ -60,14 +62,21 @@ def __call__(self, mll, closure: Optional[Callable] = None): if self.exception is not None: raise self.exception - return mll, None + self.state_dicts.append(deepcopy(mll.state_dict())) + return OptimizationResult( + fval=torch.rand(1).item(), + step=1, + status=OptimizationStatus.SUCCESS, + message="Mock Success!", + runtime=1.0, + ) class TestFitAPI(BotorchTestCase): r"""Unit tests for general fitting API""" - def setUp(self) -> None: - super().setUp() + def setUp(self, suppress_input_warnings: bool = True) -> None: + super().setUp(suppress_input_warnings=suppress_input_warnings) with torch.random.fork_rng(): torch.manual_seed(0) train_X = torch.linspace(0, 1, 10).unsqueeze(-1) @@ -108,8 +117,8 @@ def test_fit_gpytorch_mll(self): class TestFitFallback(BotorchTestCase): - def setUp(self) -> None: - super().setUp() + def setUp(self, suppress_input_warnings: bool = True) -> None: + super().setUp(suppress_input_warnings=suppress_input_warnings) with torch.random.fork_rng(): torch.manual_seed(0) train_X = torch.linspace(0, 1, 10).unsqueeze(-1) @@ -117,26 +126,22 @@ def setUp(self) -> None: self.mlls = {} self.checkpoints = {} - for model_type, output_dim in product([SingleTaskGP], [1, 2]): + for fixed_noise, output_dim in product([True, False], [1, 2]): train_Y = train_F.repeat(1, output_dim) train_Y = train_Y + 0.1 * torch.randn_like(train_Y) - model = model_type( + model = SingleTaskGP( train_X=train_X, train_Y=train_Y, + train_Yvar=torch.full_like(train_Y, 0.1) if fixed_noise else None, input_transform=Normalize(d=1), outcome_transform=Standardize(m=output_dim), - **( - {} - if model_type is SingleTaskGP - else {"train_Yvar": torch.full_like(train_Y, 0.1)} - ), ) self.assertIsInstance(model.covar_module.base_kernel, MaternKernel) model.covar_module.base_kernel.nu = 2.5 mll = ExactMarginalLogLikelihood(model.likelihood, model) for dtype in (torch.float32, torch.float64): - key = model_type, output_dim + key = fixed_noise, output_dim self.mlls[key] = mll.to(dtype=dtype) self.checkpoints[key] = { k: TensorCheckpoint( @@ -310,10 +315,43 @@ def _test_exceptions(self, mll, ckpt): all(v.equal(ckpt[k].values) for k, v in mll.state_dict().items()) ) - -class TestFitFallbackAppoximate(BotorchTestCase): - def setUp(self) -> None: - super().setUp() + def test_pick_best_of_all_attempts(self) -> None: + mll = next(iter(self.mlls.values())) + optimizer = MockOptimizer() + max_attempts = 10 + with patch("botorch.fit.logging.log") as mock_log: + fit._fit_fallback( + mll, + None, + None, + max_attempts=max_attempts, + pick_best_of_all_attempts=True, + optimizer=optimizer, + ) + # Check that optimizer is called 3 times. + self.assertEqual(optimizer.call_count, max_attempts) + # Check that we log after each call. + self.assertEqual(mock_log.call_count, max_attempts) + # We have an increasing sequence of best MLL values. + mll_vals = [] + for call in mock_log.call_args_list: + message = call.kwargs["msg"] + mll_val = message.split(" ")[-1][:-1] + mll_vals.append(float(mll_val)) + self.assertEqual(mll_vals, sorted(mll_vals)) + # Check that the returned MLL is in eval mode. + self.assertFalse(mll.training) + # Check that the state dict matches the state dict of best attempt. + final_statedict = mll.state_dict() + best_idx = mll_vals.index(max(mll_vals)) + best_state_dict = optimizer.state_dicts[best_idx] + for key, val in final_statedict.items(): + self.assertAllClose(val, best_state_dict[key]) + + +class TestFitFallbackApproximate(BotorchTestCase): + def setUp(self, suppress_input_warnings: bool = True) -> None: + super().setUp(suppress_input_warnings=suppress_input_warnings) with torch.random.fork_rng(): torch.manual_seed(0) train_X = torch.linspace(0, 1, 10).unsqueeze(-1)