Skip to content

Commit 68e3d1e

Browse files
authored
Merge branch 'pytorch:main' into fix/fixed_features_dimensionality
2 parents 7da91b1 + 9a7c517 commit 68e3d1e

File tree

7 files changed

+179
-115
lines changed

7 files changed

+179
-115
lines changed

botorch/models/approximate_gp.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,11 @@ def __init__(
399399
)
400400
train_Y, _ = outcome_transform(train_Y, X=transformed_X)
401401
self._validate_tensor_args(X=transformed_X, Y=train_Y)
402-
validate_input_scaling(train_X=transformed_X, train_Y=train_Y)
402+
validate_input_scaling(
403+
train_X=transformed_X,
404+
train_Y=train_Y,
405+
check_nans_only=covar_module is not None,
406+
)
403407
if train_Y.shape[-1] != num_outputs:
404408
num_outputs = train_Y.shape[-1]
405409

botorch/models/gp_regression.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,7 @@ def __init__(
171171
train_Y=train_Y,
172172
train_Yvar=train_Yvar,
173173
ignore_X_dims=ignore_X_dims,
174+
check_nans_only=covar_module is not None,
174175
)
175176
self._set_dimensions(train_X=train_X, train_Y=train_Y)
176177
train_X, train_Y, train_Yvar = self._transform_tensor_args(

botorch/models/utils/assorted.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def validate_input_scaling(
228228
train_Yvar: Tensor | None = None,
229229
raise_on_fail: bool = False,
230230
ignore_X_dims: list[int] | None = None,
231+
check_nans_only: bool = False,
231232
) -> None:
232233
r"""Helper function to validate input data to models.
233234
@@ -243,6 +244,10 @@ def validate_input_scaling(
243244
raised if NaN values are present).
244245
ignore_X_dims: For this subset of dimensions from `{1, ..., d}`, ignore the
245246
min-max scaling check.
247+
check_nans_only: If True, only check for NaN values. Skips min-max scaling
248+
and standardization checks. This is used when the model is provided
249+
with a custom covariance module, to avoid potentially irrelevant
250+
warnings.
246251
247252
This function is typically called inside the constructor of standard BoTorch
248253
models. It validates the following:
@@ -261,10 +266,11 @@ def validate_input_scaling(
261266
check_no_nans(train_Yvar)
262267
if torch.any(train_Yvar < 0):
263268
raise InputDataError("Input data contains negative variances.")
264-
check_min_max_scaling(
265-
X=train_X, raise_on_fail=raise_on_fail, ignore_dims=ignore_X_dims
266-
)
267-
check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)
269+
if not check_nans_only:
270+
check_min_max_scaling(
271+
X=train_X, raise_on_fail=raise_on_fail, ignore_dims=ignore_X_dims
272+
)
273+
check_standardization(Y=train_Y, raise_on_fail=raise_on_fail)
268274

269275

270276
def mod_batch_shape(module: Module, names: list[str], b: int) -> None:

botorch/optim/optimize.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,10 @@ def __post_init__(self) -> None:
116116
f"shape is {batch_initial_conditions_shape}."
117117
)
118118

119-
if len(batch_initial_conditions_shape) == 2:
119+
if (
120+
len(batch_initial_conditions_shape) == 2
121+
and self.raw_samples is not None
122+
):
120123
warnings.warn(
121124
"If using a 2-dim `batch_initial_conditions` botorch will "
122125
"default to old behavior of ignoring `num_restarts` and just "
@@ -132,6 +135,7 @@ def __post_init__(self) -> None:
132135
len(batch_initial_conditions_shape) == 3
133136
and batch_initial_conditions_shape[0] < self.num_restarts
134137
and batch_initial_conditions_shape[-2] != self.q
138+
and self.raw_samples is not None
135139
):
136140
warnings.warn(
137141
"If using a 3-dim `batch_initial_conditions` where the "

botorch/optim/optimize_mixed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,7 @@ def continuous_step(
533533
opt_inputs,
534534
q=1,
535535
num_restarts=1,
536+
raw_samples=None,
536537
batch_initial_conditions=current_x.unsqueeze(0),
537538
fixed_features={
538539
**dict(zip(discrete_dims.tolist(), current_x[discrete_dims])),

test/models/utils/test_assorted.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ def test_check_min_max_scaling(self):
8787
X = 0.1 + 0.8 * torch.rand(4, 2, 3)
8888
with warnings.catch_warnings(record=True) as ws:
8989
check_min_max_scaling(X=X)
90-
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
90+
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
9191
check_min_max_scaling(X=X, raise_on_fail=True)
9292
with self.assertWarnsRegex(
9393
expected_warning=InputDataWarning, expected_regex="not scaled"
@@ -100,30 +100,34 @@ def test_check_min_max_scaling(self):
100100
Xstd = (X - Xmin) / (Xmax - Xmin)
101101
with warnings.catch_warnings(record=True) as ws:
102102
check_min_max_scaling(X=Xstd)
103-
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
103+
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
104104
check_min_max_scaling(X=Xstd, raise_on_fail=True)
105105
with warnings.catch_warnings(record=True) as ws:
106106
check_min_max_scaling(X=Xstd, strict=True)
107-
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
107+
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
108108
check_min_max_scaling(X=Xstd, strict=True, raise_on_fail=True)
109109
# check violation
110110
X[0, 0, 0] = 2
111111
with warnings.catch_warnings(record=True) as ws:
112112
check_min_max_scaling(X=X)
113-
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
114-
self.assertTrue(any("not contained" in str(w.message) for w in ws))
113+
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
114+
self.assertTrue(any("not contained" in str(w.message) for w in ws))
115115
with self.assertRaises(InputDataError):
116116
check_min_max_scaling(X=X, raise_on_fail=True)
117117
with warnings.catch_warnings(record=True) as ws:
118118
check_min_max_scaling(X=X, strict=True)
119-
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
120-
self.assertTrue(any("not contained" in str(w.message) for w in ws))
119+
self.assertTrue(any(issubclass(w.category, InputDataWarning) for w in ws))
120+
self.assertTrue(any("not contained" in str(w.message) for w in ws))
121121
with self.assertRaises(InputDataError):
122122
check_min_max_scaling(X=X, strict=True, raise_on_fail=True)
123123
# check ignore_dims
124124
with warnings.catch_warnings(record=True) as ws:
125125
check_min_max_scaling(X=X, ignore_dims=[0])
126-
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
126+
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
127+
# all dims ignored
128+
with warnings.catch_warnings(record=True) as ws:
129+
check_min_max_scaling(X=X, ignore_dims=[0, 1, 2])
130+
self.assertFalse(any(issubclass(w.category, InputDataWarning) for w in ws))
127131

128132
def test_check_standardization(self):
129133
# Ensure that it is not filtered out.
@@ -181,6 +185,11 @@ def test_validate_input_scaling(self):
181185
# check that errors are raised when requested
182186
with self.assertRaises(InputDataError):
183187
validate_input_scaling(train_X=train_X, train_Y=train_Y, raise_on_fail=True)
188+
# check that normalization & standardization checks & errors are skipped when
189+
# check_nans_only is True
190+
validate_input_scaling(
191+
train_X=train_X, train_Y=train_Y, raise_on_fail=True, check_nans_only=True
192+
)
184193
# check that no errors are being raised if everything is standardized
185194
train_X_min = train_X.min(dim=-1, keepdim=True)[0]
186195
train_X_max = train_X.max(dim=-1, keepdim=True)[0]
@@ -202,6 +211,11 @@ def test_validate_input_scaling(self):
202211
train_X_std[0, 0, 0] = float("nan")
203212
with self.assertRaises(InputDataError):
204213
validate_input_scaling(train_X=train_X_std, train_Y=train_Y_std)
214+
# NaNs still raise errors when check_nans_only is True
215+
with self.assertRaises(InputDataError):
216+
validate_input_scaling(
217+
train_X=train_X_std, train_Y=train_Y_std, check_nans_only=True
218+
)
205219

206220

207221
class TestGPTPosteriorSettings(BotorchTestCase):

0 commit comments

Comments
 (0)