@@ -87,7 +87,7 @@ def test_check_min_max_scaling(self):
87
87
X = 0.1 + 0.8 * torch .rand (4 , 2 , 3 )
88
88
with warnings .catch_warnings (record = True ) as ws :
89
89
check_min_max_scaling (X = X )
90
- self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
90
+ self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
91
91
check_min_max_scaling (X = X , raise_on_fail = True )
92
92
with self .assertWarnsRegex (
93
93
expected_warning = InputDataWarning , expected_regex = "not scaled"
@@ -100,30 +100,34 @@ def test_check_min_max_scaling(self):
100
100
Xstd = (X - Xmin ) / (Xmax - Xmin )
101
101
with warnings .catch_warnings (record = True ) as ws :
102
102
check_min_max_scaling (X = Xstd )
103
- self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
103
+ self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
104
104
check_min_max_scaling (X = Xstd , raise_on_fail = True )
105
105
with warnings .catch_warnings (record = True ) as ws :
106
106
check_min_max_scaling (X = Xstd , strict = True )
107
- self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
107
+ self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
108
108
check_min_max_scaling (X = Xstd , strict = True , raise_on_fail = True )
109
109
# check violation
110
110
X [0 , 0 , 0 ] = 2
111
111
with warnings .catch_warnings (record = True ) as ws :
112
112
check_min_max_scaling (X = X )
113
- self .assertTrue (any (issubclass (w .category , InputDataWarning ) for w in ws ))
114
- self .assertTrue (any ("not contained" in str (w .message ) for w in ws ))
113
+ self .assertTrue (any (issubclass (w .category , InputDataWarning ) for w in ws ))
114
+ self .assertTrue (any ("not contained" in str (w .message ) for w in ws ))
115
115
with self .assertRaises (InputDataError ):
116
116
check_min_max_scaling (X = X , raise_on_fail = True )
117
117
with warnings .catch_warnings (record = True ) as ws :
118
118
check_min_max_scaling (X = X , strict = True )
119
- self .assertTrue (any (issubclass (w .category , InputDataWarning ) for w in ws ))
120
- self .assertTrue (any ("not contained" in str (w .message ) for w in ws ))
119
+ self .assertTrue (any (issubclass (w .category , InputDataWarning ) for w in ws ))
120
+ self .assertTrue (any ("not contained" in str (w .message ) for w in ws ))
121
121
with self .assertRaises (InputDataError ):
122
122
check_min_max_scaling (X = X , strict = True , raise_on_fail = True )
123
123
# check ignore_dims
124
124
with warnings .catch_warnings (record = True ) as ws :
125
125
check_min_max_scaling (X = X , ignore_dims = [0 ])
126
- self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
126
+ self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
127
+ # all dims ignored
128
+ with warnings .catch_warnings (record = True ) as ws :
129
+ check_min_max_scaling (X = X , ignore_dims = [0 , 1 , 2 ])
130
+ self .assertFalse (any (issubclass (w .category , InputDataWarning ) for w in ws ))
127
131
128
132
def test_check_standardization (self ):
129
133
# Ensure that it is not filtered out.
@@ -181,6 +185,11 @@ def test_validate_input_scaling(self):
181
185
# check that errors are raised when requested
182
186
with self .assertRaises (InputDataError ):
183
187
validate_input_scaling (train_X = train_X , train_Y = train_Y , raise_on_fail = True )
188
+ # check that normalization & standardization checks & errors are skipped when
189
+ # check_nans_only is True
190
+ validate_input_scaling (
191
+ train_X = train_X , train_Y = train_Y , raise_on_fail = True , check_nans_only = True
192
+ )
184
193
# check that no errors are being raised if everything is standardized
185
194
train_X_min = train_X .min (dim = - 1 , keepdim = True )[0 ]
186
195
train_X_max = train_X .max (dim = - 1 , keepdim = True )[0 ]
@@ -202,6 +211,11 @@ def test_validate_input_scaling(self):
202
211
train_X_std [0 , 0 , 0 ] = float ("nan" )
203
212
with self .assertRaises (InputDataError ):
204
213
validate_input_scaling (train_X = train_X_std , train_Y = train_Y_std )
214
+ # NaNs still raise errors when check_nans_only is True
215
+ with self .assertRaises (InputDataError ):
216
+ validate_input_scaling (
217
+ train_X = train_X_std , train_Y = train_Y_std , check_nans_only = True
218
+ )
205
219
206
220
207
221
class TestGPTPosteriorSettings (BotorchTestCase ):
0 commit comments