Skip to content

Commit

Permalink
Silence input data warnings in tests (pytorch#2508)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebook/Ax#2739

Pull Request resolved: pytorch#2508

This logic broke since D61797434 updated the warning messages, leading to too many of these warnings in test outputs again.

Reviewed By: Balandat, esantorella

Differential Revision: D62200731
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Sep 5, 2024
1 parent a6ed603 commit 944a0f6
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 8 deletions.
14 changes: 10 additions & 4 deletions botorch/models/utils/assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,10 @@ def check_min_max_scaling(
if torch.any(Xmin < -atol) or torch.any(Xmax > 1 + atol):
msg = "contained"
if msg is not None:
# NOTE: If you update this message, update the warning filters as well.
# See https://github.com/pytorch/botorch/pull/2508.
msg = (
f"Data (input features) not {msg} to the unit cube. "
f"Data (input features) is not {msg} to the unit cube. "
"Please consider min-max scaling the input data."
)
if raise_on_fail:
Expand Down Expand Up @@ -196,9 +198,11 @@ def check_standardization(
mean_not_zero = torch.abs(Ymean).max() > atol_mean
if Y.shape[-2] <= 1:
if mean_not_zero:
# NOTE: If you update this message, update the warning filters as well.
# See https://github.com/pytorch/botorch/pull/2508.
msg = (
f"Data (outcome observations) not standardized (mean = {Ymean}). "
"Please consider scaling the input to zero mean and unit variance."
f"Data (outcome observations) is not standardized (mean = {Ymean})."
" Please consider scaling the input to zero mean and unit variance."
)
if raise_on_fail:
raise InputDataError(msg)
Expand All @@ -207,8 +211,10 @@ def check_standardization(
Ystd = torch.std(Y, dim=-2)
std_not_one = torch.abs(Ystd - 1).max() > atol_std
if mean_not_zero or std_not_one:
# NOTE: If you update this message, update the warning filters as well.
# See https://github.com/pytorch/botorch/pull/2508.
msg = (
"Data (outcome observations) not standardized "
"Data (outcome observations) is not standardized "
f"(std = {Ystd}, mean = {Ymean})."
"Please consider scaling the input to zero mean and unit variance."
)
Expand Down
4 changes: 2 additions & 2 deletions botorch/utils/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,12 @@ def setUp(self, suppress_input_warnings: bool = True) -> None:
)
warnings.filterwarnings(
"ignore",
message="Data is not standardized.",
message=r"Data \(outcome observations\) is not standardized ",
category=InputDataWarning,
)
warnings.filterwarnings(
"ignore",
message="Input data is not contained to the unit cube.",
message=r"Data \(input features\) is not",
category=InputDataWarning,
)

Expand Down
4 changes: 2 additions & 2 deletions test/models/utils/test_assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,14 @@ def test_check_standardization(self):
check_standardization(Y=y, raise_on_fail=True)

# check nonzero mean for case where >= 2 observations per batch
msg_more_than_1_obs = r"Data \(outcome observations\) not standardized \(std ="
msg_more_than_1_obs = r"Data \(outcome observations\) is not standardized \(std"
with self.assertWarnsRegex(InputDataWarning, msg_more_than_1_obs):
check_standardization(Y=Yst + 1)
with self.assertRaisesRegex(InputDataError, msg_more_than_1_obs):
check_standardization(Y=Yst + 1, raise_on_fail=True)

# check nonzero mean for case where < 2 observations per batch
msg_one_obs = r"Data \(outcome observations\) not standardized \(mean ="
msg_one_obs = r"Data \(outcome observations\) is not standardized \(mean ="
y = torch.ones((3, 1, 2), dtype=torch.float32)
with self.assertWarnsRegex(InputDataWarning, msg_one_obs):
check_standardization(Y=y)
Expand Down
16 changes: 16 additions & 0 deletions test/utils/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import warnings

import torch
from botorch.exceptions.warnings import InputDataWarning
from botorch.models.gp_regression import SingleTaskGP
from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior


Expand Down Expand Up @@ -49,3 +53,15 @@ def test_basic(self) -> None:
self.assertEqual(mm.num_outputs, 0)
mm.state_dict()
mm.load_state_dict()


class TestMisc(BotorchTestCase):
def test_warning_filtering(self) -> None:
with warnings.catch_warnings(record=True) as ws:
# Model with unstandardized float data, which would typically raise
# multiple warnings.
SingleTaskGP(
train_X=torch.rand(5, 2, dtype=torch.float) * 10,
train_Y=torch.rand(5, 1, dtype=torch.float) * 10,
)
self.assertFalse(any(w.category == InputDataWarning for w in ws))

0 comments on commit 944a0f6

Please sign in to comment.