Skip to content

Commit

Permalink
Support picking best of multiple fit attempts in fit_gpytorch_mll (#2373
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: #2373

Adds an option to `_fit_fallback` to fit the model `max_attempt` times and return the result of the attempt that produced the largest MLL value. This has been requested by users from time to time, with the latest request being #2367.

Also ended up making some minor changes to address pyre complaints.

Reviewed By: sdaulton

Differential Revision: D58397740

fbshipit-source-id: a9da6bc8d3a612750dad28218079bf8091e0f6d2
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Jun 11, 2024
1 parent d753706 commit 1e73b30
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 40 deletions.
67 changes: 52 additions & 15 deletions botorch/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from __future__ import annotations

import logging
from copy import deepcopy
from functools import partial
from itertools import filterfalse
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Type, Union
Expand Down Expand Up @@ -118,10 +119,11 @@ def _fit_fallback(
__: Type[object],
*,
closure: Optional[Callable[[], Tuple[Tensor, Sequence[Optional[Tensor]]]]] = None,
optimizer: Optional[Callable] = fit_gpytorch_mll_scipy,
optimizer: Callable = fit_gpytorch_mll_scipy,
closure_kwargs: Optional[Dict[str, Any]] = None,
optimizer_kwargs: Optional[Dict[str, Any]] = None,
max_attempts: int = 5,
pick_best_of_all_attempts: bool = False,
warning_handler: Callable[[WarningMessage], bool] = DEFAULT_WARNING_HANDLER,
caught_exception_types: Tuple[Type[BaseException], ...] = (NotPSDError,),
**ignore: Any,
Expand All @@ -137,11 +139,20 @@ def _fit_fallback(
closure: Forward-backward closure for obtaining objective values and gradients.
Responsible for setting parameters' `grad` attributes. If no closure is
provided, one will be obtained by calling `get_loss_closure_with_grads`.
optimizer: The underlying optimization algorithm to run.
optimizer: The underlying optimization algorithm to run. Should return
an `OptimizationResult` object, whose `fval` field records the negative
MLL value. Defaults to `fit_gpytorch_mll_scipy`.
closure_kwargs: Keyword arguments passed to `closure`.
optimizer_kwargs: Keyword arguments passed to `optimizer`.
max_attempts: The maximum number of fit attempts allowed. The attempt budget
is NOT shared between calls to this method.
pick_best_of_all_attempts: If True, the model will be fit `max_attempts` times,
and the attempt that produces largest MLL value will be returned.
First attempt uses the initial hyper parameter values, the subsequent
attempts will call `sample_all_priors` to sample the initial values.
If any attempt produces an error, the resulting parameters are discarded.
If optimizer timeout is used, the `timeout_sec` will be used as is for
each attempt, and it should be manually adjusted accordingly.
warning_handler: A function used to filter warnings produced when calling
`optimizer`. Any unfiltered warnings (those for which `warning_handler`
returns `False`) will be rethrown and trigger a model fitting retry.
Expand All @@ -168,6 +179,9 @@ def _fit_fallback(
if closure_kwargs is not None:
closure = partial(closure, **closure_kwargs)

# Record best MLL & corresponding state dict.
best_mll: float = -float("inf")
best_state_dict = None
# Attempt to fit the model
for attempt in range(1, 1 + max_attempts):
# Wrap with rollback contextmanager so that each loop iteration reloads the
Expand All @@ -187,33 +201,56 @@ def _fit_fallback(
# Fit the model
with catch_warnings(record=True) as warning_list, debug(True):
simplefilter("always", category=OptimizationWarning)
optimizer(mll, closure=closure, **optimizer_kwargs)
result = optimizer(mll, closure=closure, **optimizer_kwargs)

# Resolved warnings and determine whether or not to retry
done = True
# Resolve warnings and determine whether or not to retry
success = True
for w in filterfalse(warning_handler, warning_list):
warn_explicit(str(w.message), w.category, w.filename, w.lineno)
done = False
success = False

if done:
if success and not pick_best_of_all_attempts:
# If not picking best of all attempts, return the first
# successful attempt.
ckpt.clear() # do not rollback upon exiting
return mll.eval()

# Ensure mll is in the right mode if fitting failed
elif success:
# Update best MLL and corresponding state dict.
# Optimizers minimize negative MLL, so we negate fval.
current_mll = -result.fval
if current_mll > best_mll:
best_mll = current_mll
# Deepcopy is important here, otherwise they get updated.
best_state_dict = deepcopy(mll.state_dict())
message = f"Fit attempt #{attempt}: New best MLL: {best_mll}."
else:
message = (
f"Fit attempt #{attempt}: Current MLL {current_mll} did "
f"not beat best MLL so far {best_mll}."
)
logging.log(logging.DEBUG, msg=message)

# Ensure mll is in the right mode if going for another attempt.
mll = mll if mll.training else mll.train()
logging.log(
logging.DEBUG,
f"Fit attempt #{attempt} of {max_attempts} triggered retry policy"
f"{'.' if attempt == max_attempts else '; retrying...'}",
)
if not success:
logging.log(
logging.DEBUG,
f"Fit attempt #{attempt} of {max_attempts} triggered retry "
f"policy {'.' if attempt == max_attempts else '; retrying...'}",
)

except caught_exception_types as err:
logging.log(
logging.DEBUG,
f"Fit attempt #{attempt} of {max_attempts} failed with exception: "
f"Fit attempt #{attempt} of {max_attempts} failed with exception:\n"
f"{err}",
)

# If picking best of all attempts, return MLL with best state dict.
if best_state_dict is not None:
mll.load_state_dict(best_state_dict)
return mll.eval()

msg = "All attempts to fit the model have failed."
if debug.off():
msg = msg + " For more information, try enabling botorch.settings.debug mode."
Expand Down
6 changes: 3 additions & 3 deletions botorch/models/fully_bayesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ class PyroModel:

def set_inputs(
self, train_X: Tensor, train_Y: Tensor, train_Yvar: Optional[Tensor] = None
):
) -> None:
"""Set the training data.
Args:
Expand Down Expand Up @@ -162,7 +162,7 @@ class SaasPyroModel(PyroModel):

def set_inputs(
self, train_X: Tensor, train_Y: Tensor, train_Yvar: Optional[Tensor] = None
):
) -> None:
super().set_inputs(train_X, train_Y, train_Yvar)
self.ard_num_dims = self.train_X.shape[-1]

Expand Down Expand Up @@ -394,7 +394,7 @@ def __init__(
pyro_model.set_inputs(
train_X=transformed_X, train_Y=train_Y, train_Yvar=train_Yvar
)
self.pyro_model = pyro_model
self.pyro_model: PyroModel = pyro_model
if outcome_transform is not None:
self.outcome_transform = outcome_transform
if input_transform is not None:
Expand Down
4 changes: 2 additions & 2 deletions botorch/models/fully_bayesian_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def set_inputs(
train_Yvar: Optional[Tensor],
task_feature: int,
task_rank: Optional[int] = None,
):
) -> None:
"""Set the training data.
Args:
Expand Down Expand Up @@ -276,7 +276,7 @@ def __init__(
task_feature=task_feature,
task_rank=self._rank,
)
self.pyro_model = pyro_model
self.pyro_model: MultitaskSaasPyroModel = pyro_model
if outcome_transform is not None:
self.outcome_transform = outcome_transform
if input_transform is not None:
Expand Down
2 changes: 1 addition & 1 deletion botorch/models/gp_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def __init__(
}
if train_Yvar is None:
self._subset_batch_dict["likelihood.noise_covar.raw_noise"] = -2
self.covar_module = covar_module
self.covar_module: Module = covar_module
# TODO: Allow subsetting of other covar modules
if outcome_transform is not None:
self.outcome_transform = outcome_transform
Expand Down
76 changes: 57 additions & 19 deletions test/test_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import math
from contextlib import ExitStack, nullcontext
from copy import deepcopy
from itertools import filterfalse, product
from typing import Callable, Iterable, Optional
from unittest.mock import MagicMock, patch
Expand All @@ -19,8 +20,8 @@
from botorch.models import SingleTaskGP, SingleTaskVariationalGP
from botorch.models.transforms.input import Normalize
from botorch.models.transforms.outcome import Standardize

from botorch.optim.closures import get_loss_closure_with_grads
from botorch.optim.core import OptimizationResult, OptimizationStatus
from botorch.optim.fit import fit_gpytorch_mll_scipy, fit_gpytorch_mll_torch
from botorch.optim.utils import get_data_loader
from botorch.settings import debug
Expand All @@ -45,8 +46,9 @@ def __init__(
self.warnings = warnings
self.exception = exception
self.call_count = 0
self.state_dicts = []

def __call__(self, mll, closure: Optional[Callable] = None):
def __call__(self, mll, closure: Optional[Callable] = None) -> OptimizationResult:
self.call_count += 1
for w in self.warnings:
warn(str(w.message), w.category)
Expand All @@ -60,14 +62,21 @@ def __call__(self, mll, closure: Optional[Callable] = None):
if self.exception is not None:
raise self.exception

return mll, None
self.state_dicts.append(deepcopy(mll.state_dict()))
return OptimizationResult(
fval=torch.rand(1).item(),
step=1,
status=OptimizationStatus.SUCCESS,
message="Mock Success!",
runtime=1.0,
)


class TestFitAPI(BotorchTestCase):
r"""Unit tests for general fitting API"""

def setUp(self) -> None:
super().setUp()
def setUp(self, suppress_input_warnings: bool = True) -> None:
super().setUp(suppress_input_warnings=suppress_input_warnings)
with torch.random.fork_rng():
torch.manual_seed(0)
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
Expand Down Expand Up @@ -108,35 +117,31 @@ def test_fit_gpytorch_mll(self):


class TestFitFallback(BotorchTestCase):
def setUp(self) -> None:
super().setUp()
def setUp(self, suppress_input_warnings: bool = True) -> None:
super().setUp(suppress_input_warnings=suppress_input_warnings)
with torch.random.fork_rng():
torch.manual_seed(0)
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
train_F = torch.sin(2 * math.pi * train_X)

self.mlls = {}
self.checkpoints = {}
for model_type, output_dim in product([SingleTaskGP], [1, 2]):
for fixed_noise, output_dim in product([True, False], [1, 2]):
train_Y = train_F.repeat(1, output_dim)
train_Y = train_Y + 0.1 * torch.randn_like(train_Y)
model = model_type(
model = SingleTaskGP(
train_X=train_X,
train_Y=train_Y,
train_Yvar=torch.full_like(train_Y, 0.1) if fixed_noise else None,
input_transform=Normalize(d=1),
outcome_transform=Standardize(m=output_dim),
**(
{}
if model_type is SingleTaskGP
else {"train_Yvar": torch.full_like(train_Y, 0.1)}
),
)
self.assertIsInstance(model.covar_module.base_kernel, MaternKernel)
model.covar_module.base_kernel.nu = 2.5

mll = ExactMarginalLogLikelihood(model.likelihood, model)
for dtype in (torch.float32, torch.float64):
key = model_type, output_dim
key = fixed_noise, output_dim
self.mlls[key] = mll.to(dtype=dtype)
self.checkpoints[key] = {
k: TensorCheckpoint(
Expand Down Expand Up @@ -310,10 +315,43 @@ def _test_exceptions(self, mll, ckpt):
all(v.equal(ckpt[k].values) for k, v in mll.state_dict().items())
)


class TestFitFallbackAppoximate(BotorchTestCase):
def setUp(self) -> None:
super().setUp()
def test_pick_best_of_all_attempts(self) -> None:
mll = next(iter(self.mlls.values()))
optimizer = MockOptimizer()
max_attempts = 10
with patch("botorch.fit.logging.log") as mock_log:
fit._fit_fallback(
mll,
None,
None,
max_attempts=max_attempts,
pick_best_of_all_attempts=True,
optimizer=optimizer,
)
# Check that optimizer is called 3 times.
self.assertEqual(optimizer.call_count, max_attempts)
# Check that we log after each call.
self.assertEqual(mock_log.call_count, max_attempts)
# We have an increasing sequence of best MLL values.
mll_vals = []
for call in mock_log.call_args_list:
message = call.kwargs["msg"]
mll_val = message.split(" ")[-1][:-1]
mll_vals.append(float(mll_val))
self.assertEqual(mll_vals, sorted(mll_vals))
# Check that the returned MLL is in eval mode.
self.assertFalse(mll.training)
# Check that the state dict matches the state dict of best attempt.
final_statedict = mll.state_dict()
best_idx = mll_vals.index(max(mll_vals))
best_state_dict = optimizer.state_dicts[best_idx]
for key, val in final_statedict.items():
self.assertAllClose(val, best_state_dict[key])


class TestFitFallbackApproximate(BotorchTestCase):
def setUp(self, suppress_input_warnings: bool = True) -> None:
super().setUp(suppress_input_warnings=suppress_input_warnings)
with torch.random.fork_rng():
torch.manual_seed(0)
train_X = torch.linspace(0, 1, 10).unsqueeze(-1)
Expand Down

0 comments on commit 1e73b30

Please sign in to comment.