Skip to content

Commit

Permalink
Silence input data warnings in tests
Browse files Browse the repository at this point in the history
Summary:
X-link: pytorch/botorch#2508

This logic broke since D61797434 updated the warning messages, leading to too many of these warnings in test outputs again.

Differential Revision: D62200731
  • Loading branch information
saitcakmak authored and facebook-github-bot committed Sep 4, 2024
1 parent ac3a7ec commit f76d0eb
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 4 deletions.
2 changes: 1 addition & 1 deletion ax/modelbridge/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,7 +904,7 @@ def cross_validate(
# users with this warning, we filter it out.
warnings.filterwarnings(
"ignore",
message=r"Data \(outcome observations\) not standardized",
message=r"Data \(outcome observations\) is not standardized",
category=InputDataWarning,
)
cv_predictions = self._cross_validate(
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/cross_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def cross_validate(
# To avoid confusing users with this warning, we filter it out.
warnings.filterwarnings(
"ignore",
message=r"Data \(outcome observations\) not standardized",
message=r"Data \(outcome observations\) is not standardized",
category=InputDataWarning,
)
cv_test_predictions = model._cross_validate(
Expand Down
2 changes: 1 addition & 1 deletion ax/modelbridge/tests/test_base_modelbridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def warn_and_return_mock_obs(
nonlocal called
called = True
warnings.warn(
"Data (outcome observations) not standardized",
"Data (outcome observations) is not standardized",
InputDataWarning,
stacklevel=2,
)
Expand Down
14 changes: 14 additions & 0 deletions ax/utils/common/tests/test_testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@

import io
import sys
import warnings

import torch
from ax.utils.common.base import Base
from ax.utils.common.testutils import TestCase
from botorch.exceptions.warnings import InputDataWarning
from botorch.models.gp_regression import SingleTaskGP


# pyre-fixme[3]: Return type must be annotated.
Expand Down Expand Up @@ -113,3 +117,13 @@ def decorated_test() -> None:
self.assertEqual(None, self._long_test_active_reason)
decorated_test()
self.assertEqual(None, self._long_test_active_reason)

def test_warning_filtering(self) -> None:
with warnings.catch_warnings(record=True) as ws:
# Model with unstandardized float data, which would typically raise
# multiple warnings.
SingleTaskGP(
train_X=torch.rand(5, 2, dtype=torch.float) * 10,
train_Y=torch.rand(5, 1, dtype=torch.float) * 10,
)
self.assertFalse(any(w.category == InputDataWarning for w in ws))
7 changes: 6 additions & 1 deletion ax/utils/common/testutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,12 @@ def setUp(self) -> None:
# BoTorch input standardization warnings.
warnings.filterwarnings(
"ignore",
message="Input data is not",
message=r"Data \(outcome observations\) is not standardized ",
category=InputDataWarning,
)
warnings.filterwarnings(
"ignore",
message=r"Data \(input features\) is not",
category=InputDataWarning,
)

Expand Down

0 comments on commit f76d0eb

Please sign in to comment.