Skip to content

Commit

Permalink
Provide more informative warning messages in InputDataWarning (#2489)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebook/Ax#2713

Pull Request resolved: #2489

Provide more informative warning messages when `InputDataWarning`s are raised to specify whether it pertains to input features or output targets.

Updated unit tests accordingly to ensure coverage.

Differential Revision: D61797434
  • Loading branch information
ltiao authored and facebook-github-bot committed Aug 28, 2024
1 parent 017a124 commit dc61aa2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
7 changes: 4 additions & 3 deletions botorch/models/utils/assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def check_min_max_scaling(
msg = "contained"
if msg is not None:
msg = (
f"Input data is not {msg} to the unit cube. "
f"Data (input features) not {msg} to the unit cube. "
"Please consider min-max scaling the input data."
)
if raise_on_fail:
Expand Down Expand Up @@ -197,7 +197,7 @@ def check_standardization(
if Y.shape[-2] <= 1:
if mean_not_zero:
msg = (
f"Data is not standardized (mean = {Ymean}). "
f"Data (outcome observations) not standardized (mean = {Ymean}). "
"Please consider scaling the input to zero mean and unit variance."
)
if raise_on_fail:
Expand All @@ -208,7 +208,8 @@ def check_standardization(
std_not_one = torch.abs(Ystd - 1).max() > atol_std
if mean_not_zero or std_not_one:
msg = (
f"Data is not standardized (std = {Ystd}, mean = {Ymean}). "
"Data (outcome observations) not standardized "
f"(std = {Ystd}, mean = {Ymean})."
"Please consider scaling the input to zero mean and unit variance."
)
if raise_on_fail:
Expand Down
4 changes: 2 additions & 2 deletions test/models/utils/test_assorted.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,14 @@ def test_check_standardization(self):
check_standardization(Y=y, raise_on_fail=True)

# check nonzero mean for case where >= 2 observations per batch
msg_more_than_1_obs = r"Data is not standardized \(std ="
msg_more_than_1_obs = r"Data \(outcome observations\) not standardized \(std ="
with self.assertWarnsRegex(InputDataWarning, msg_more_than_1_obs):
check_standardization(Y=Yst + 1)
with self.assertRaisesRegex(InputDataError, msg_more_than_1_obs):
check_standardization(Y=Yst + 1, raise_on_fail=True)

# check nonzero mean for case where < 2 observations per batch
msg_one_obs = r"Data is not standardized \(mean ="
msg_one_obs = r"Data \(outcome observations\) not standardized \(mean ="
y = torch.ones((3, 1, 2), dtype=torch.float32)
with self.assertWarnsRegex(InputDataWarning, msg_one_obs):
check_standardization(Y=y)
Expand Down

0 comments on commit dc61aa2

Please sign in to comment.