From 3db1a0ebae91b396db9b7344c5aa133931acdcf0 Mon Sep 17 00:00:00 2001 From: Sait Cakmak Date: Wed, 4 Sep 2024 19:02:04 -0700 Subject: [PATCH] Silence input data warnings in tests (#2508) Summary: X-link: https://github.com/facebook/Ax/pull/2739 Pull Request resolved: https://github.com/pytorch/botorch/pull/2508 This logic broke since D61797434 updated the warning messages, leading to too many of these warnings in test outputs again. Reviewed By: Balandat, esantorella Differential Revision: D62200731 fbshipit-source-id: a8c802abc613e0b144c6eb817f4692857a4cb83d --- botorch/models/utils/assorted.py | 14 ++++++++++---- botorch/utils/testing.py | 4 ++-- test/models/utils/test_assorted.py | 4 ++-- test/utils/test_testing.py | 16 ++++++++++++++++ 4 files changed, 30 insertions(+), 8 deletions(-) diff --git a/botorch/models/utils/assorted.py b/botorch/models/utils/assorted.py index fadab8da66..b25a8f5e8c 100644 --- a/botorch/models/utils/assorted.py +++ b/botorch/models/utils/assorted.py @@ -166,8 +166,10 @@ def check_min_max_scaling( if torch.any(Xmin < -atol) or torch.any(Xmax > 1 + atol): msg = "contained" if msg is not None: + # NOTE: If you update this message, update the warning filters as well. + # See https://github.com/pytorch/botorch/pull/2508. msg = ( - f"Data (input features) not {msg} to the unit cube. " + f"Data (input features) is not {msg} to the unit cube. " "Please consider min-max scaling the input data." ) if raise_on_fail: @@ -196,9 +198,11 @@ def check_standardization( mean_not_zero = torch.abs(Ymean).max() > atol_mean if Y.shape[-2] <= 1: if mean_not_zero: + # NOTE: If you update this message, update the warning filters as well. + # See https://github.com/pytorch/botorch/pull/2508. msg = ( - f"Data (outcome observations) not standardized (mean = {Ymean}). " - "Please consider scaling the input to zero mean and unit variance." + f"Data (outcome observations) is not standardized (mean = {Ymean})." + " Please consider scaling the input to zero mean and unit variance." ) if raise_on_fail: raise InputDataError(msg) @@ -207,8 +211,10 @@ def check_standardization( Ystd = torch.std(Y, dim=-2) std_not_one = torch.abs(Ystd - 1).max() > atol_std if mean_not_zero or std_not_one: + # NOTE: If you update this message, update the warning filters as well. + # See https://github.com/pytorch/botorch/pull/2508. msg = ( - "Data (outcome observations) not standardized " + "Data (outcome observations) is not standardized " f"(std = {Ystd}, mean = {Ymean})." "Please consider scaling the input to zero mean and unit variance." ) diff --git a/botorch/utils/testing.py b/botorch/utils/testing.py index e0cfc6d035..aa58e85984 100644 --- a/botorch/utils/testing.py +++ b/botorch/utils/testing.py @@ -62,12 +62,12 @@ def setUp(self, suppress_input_warnings: bool = True) -> None: ) warnings.filterwarnings( "ignore", - message="Data is not standardized.", + message=r"Data \(outcome observations\) is not standardized ", category=InputDataWarning, ) warnings.filterwarnings( "ignore", - message="Input data is not contained to the unit cube.", + message=r"Data \(input features\) is not", category=InputDataWarning, ) diff --git a/test/models/utils/test_assorted.py b/test/models/utils/test_assorted.py index 21533f762f..b8db931aa8 100644 --- a/test/models/utils/test_assorted.py +++ b/test/models/utils/test_assorted.py @@ -158,14 +158,14 @@ def test_check_standardization(self): check_standardization(Y=y, raise_on_fail=True) # check nonzero mean for case where >= 2 observations per batch - msg_more_than_1_obs = r"Data \(outcome observations\) not standardized \(std =" + msg_more_than_1_obs = r"Data \(outcome observations\) is not standardized \(std" with self.assertWarnsRegex(InputDataWarning, msg_more_than_1_obs): check_standardization(Y=Yst + 1) with self.assertRaisesRegex(InputDataError, msg_more_than_1_obs): check_standardization(Y=Yst + 1, raise_on_fail=True) # check nonzero mean for case where < 2 observations per batch - msg_one_obs = r"Data \(outcome observations\) not standardized \(mean =" + msg_one_obs = r"Data \(outcome observations\) is not standardized \(mean =" y = torch.ones((3, 1, 2), dtype=torch.float32) with self.assertWarnsRegex(InputDataWarning, msg_one_obs): check_standardization(Y=y) diff --git a/test/utils/test_testing.py b/test/utils/test_testing.py index 6ea60f7ed8..4fa3b887f6 100644 --- a/test/utils/test_testing.py +++ b/test/utils/test_testing.py @@ -4,7 +4,11 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +import warnings + import torch +from botorch.exceptions.warnings import InputDataWarning +from botorch.models.gp_regression import SingleTaskGP from botorch.utils.testing import BotorchTestCase, MockModel, MockPosterior @@ -49,3 +53,15 @@ def test_basic(self) -> None: self.assertEqual(mm.num_outputs, 0) mm.state_dict() mm.load_state_dict() + + +class TestMisc(BotorchTestCase): + def test_warning_filtering(self) -> None: + with warnings.catch_warnings(record=True) as ws: + # Model with unstandardized float data, which would typically raise + # multiple warnings. + SingleTaskGP( + train_X=torch.rand(5, 2, dtype=torch.float) * 10, + train_Y=torch.rand(5, 1, dtype=torch.float) * 10, + ) + self.assertFalse(any(w.category == InputDataWarning for w in ws))