From dc61aa206b0e8833fe74cf71909b4e3a747264d5 Mon Sep 17 00:00:00 2001 From: Louis Tiao Date: Wed, 28 Aug 2024 09:33:10 -0700 Subject: [PATCH] Provide more informative warning messages in `InputDataWarning` (#2489) Summary: X-link: https://github.com/facebook/Ax/pull/2713 Pull Request resolved: https://github.com/pytorch/botorch/pull/2489 Provide more informative warning messages when `InputDataWarning`s are raised to specify whether it pertains to input features or output targets. Updated unit tests accordingly to ensure coverage. Differential Revision: D61797434 --- botorch/models/utils/assorted.py | 7 ++++--- test/models/utils/test_assorted.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/botorch/models/utils/assorted.py b/botorch/models/utils/assorted.py index fa1e62f30a..fadab8da66 100644 --- a/botorch/models/utils/assorted.py +++ b/botorch/models/utils/assorted.py @@ -167,7 +167,7 @@ def check_min_max_scaling( msg = "contained" if msg is not None: msg = ( - f"Input data is not {msg} to the unit cube. " + f"Data (input features) not {msg} to the unit cube. " "Please consider min-max scaling the input data." ) if raise_on_fail: @@ -197,7 +197,7 @@ def check_standardization( if Y.shape[-2] <= 1: if mean_not_zero: msg = ( - f"Data is not standardized (mean = {Ymean}). " + f"Data (outcome observations) not standardized (mean = {Ymean}). " "Please consider scaling the input to zero mean and unit variance." ) if raise_on_fail: @@ -208,7 +208,8 @@ def check_standardization( std_not_one = torch.abs(Ystd - 1).max() > atol_std if mean_not_zero or std_not_one: msg = ( - f"Data is not standardized (std = {Ystd}, mean = {Ymean}). " + "Data (outcome observations) not standardized " + f"(std = {Ystd}, mean = {Ymean})." "Please consider scaling the input to zero mean and unit variance." ) if raise_on_fail: diff --git a/test/models/utils/test_assorted.py b/test/models/utils/test_assorted.py index 459893363e..21533f762f 100644 --- a/test/models/utils/test_assorted.py +++ b/test/models/utils/test_assorted.py @@ -158,14 +158,14 @@ def test_check_standardization(self): check_standardization(Y=y, raise_on_fail=True) # check nonzero mean for case where >= 2 observations per batch - msg_more_than_1_obs = r"Data is not standardized \(std =" + msg_more_than_1_obs = r"Data \(outcome observations\) not standardized \(std =" with self.assertWarnsRegex(InputDataWarning, msg_more_than_1_obs): check_standardization(Y=Yst + 1) with self.assertRaisesRegex(InputDataError, msg_more_than_1_obs): check_standardization(Y=Yst + 1, raise_on_fail=True) # check nonzero mean for case where < 2 observations per batch - msg_one_obs = r"Data is not standardized \(mean =" + msg_one_obs = r"Data \(outcome observations\) not standardized \(mean =" y = torch.ones((3, 1, 2), dtype=torch.float32) with self.assertWarnsRegex(InputDataWarning, msg_one_obs): check_standardization(Y=y)