Skip to content

Commit

Permalink
add OptimizationGradientError
Browse files Browse the repository at this point in the history
  • Loading branch information
jduerholt committed Sep 14, 2024
1 parent c895a8d commit d14baee
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 2 deletions.
14 changes: 14 additions & 0 deletions botorch/exceptions/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,17 @@ def __init__(
super().__init__(*args, **kwargs)
self.current_x = current_x
self.runtime = runtime


class OptimizationGradientError(BotorchError):
r"""Exception raised when gradient array `gradf` containts NaNs."""

def __init__(self, /, *args: Any, current_x: np.ndarray, **kwargs: Any) -> None:
r"""
Args:
*args: Standard args to `BoTorchError`.
current_x: A numpy array representing the current iterate.
**kwargs: Standard kwargs to `BoTorchError`.
"""
super().__init__(*args, **kwargs)
self.current_x = current_x
3 changes: 2 additions & 1 deletion botorch/generation/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import numpy as np
import torch
from botorch.acquisition import AcquisitionFunction
from botorch.exceptions.errors import OptimizationGradientError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.generation.utils import (
_convert_nonlinear_inequality_constraints,
Expand Down Expand Up @@ -215,7 +216,7 @@ def f_np_wrapper(x: np.ndarray, f: Callable):
)
if initial_conditions.dtype != torch.double:
msg += " Consider using `dtype=torch.double`."
raise RuntimeError(msg)
raise OptimizationGradientError(msg, current_x=x)
fval = loss.item()
return fval, gradf

Expand Down
7 changes: 7 additions & 0 deletions test/exceptions/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
CandidateGenerationError,
DeprecationError,
InputDataError,
OptimizationGradientError,
OptimizationTimeoutError,
UnsupportedError,
)
Expand Down Expand Up @@ -49,3 +50,9 @@ def test_OptimizationTimeoutError(self):
self.assertTrue(np.array_equal(error.current_x, np.array([1.0])))
with self.assertRaises(OptimizationTimeoutError):
raise error

def test_OptimizationGradientError(self):
error = OptimizationGradientError("message", current_x=np.array([1.0]))
self.assertTrue(np.array_equal(error.current_x, np.array([1.0])))
with self.assertRaises(OptimizationGradientError):
raise error
3 changes: 2 additions & 1 deletion test/generation/test_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import torch
from botorch.acquisition import qExpectedImprovement, qKnowledgeGradient
from botorch.exceptions.errors import OptimizationGradientError
from botorch.exceptions.warnings import OptimizationWarning
from botorch.fit import fit_gpytorch_mll
from botorch.generation.gen import (
Expand Down Expand Up @@ -318,7 +319,7 @@ def test_gen_candidates_scipy_nan_handling(self):
test_grad = torch.tensor([0.5, 0.2, float("nan")], **ckwargs)
# test NaN in grad
with mock.patch("torch.autograd.grad", return_value=[test_grad]):
with self.assertRaisesRegex(RuntimeError, expected_regex):
with self.assertRaisesRegex(OptimizationGradientError, expected_regex):
gen_candidates_scipy(
initial_conditions=test_ics,
acquisition_function=mock.Mock(return_value=test_ics),
Expand Down

0 comments on commit d14baee

Please sign in to comment.