diff --git a/ml_dtypes/tests/finfo_test.py b/ml_dtypes/tests/finfo_test.py index 12429057..5372a9d8 100644 --- a/ml_dtypes/tests/finfo_test.py +++ b/ml_dtypes/tests/finfo_test.py @@ -12,22 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import warnings - from absl.testing import absltest from absl.testing import parameterized import ml_dtypes import numpy as np - -@contextlib.contextmanager -def ignore_warning(**kw): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", **kw) - yield - - ALL_DTYPES = [ ml_dtypes.bfloat16, ml_dtypes.float8_e4m3b11, @@ -76,11 +65,11 @@ def assert_representable(val): self.assertEqual(make_val(val).item(), val) def assert_infinite(val): - typed_val = make_val(val) - if typed_val.dtype in DTYPES_WITH_NO_INFINITY: - self.assertTrue(np.isnan(typed_val)) + val = make_val(val) + if dtype in DTYPES_WITH_NO_INFINITY: + self.assertTrue(np.isnan(val), f"expected NaN, got {val}") else: - self.assertTrue(np.isposinf(typed_val)) + self.assertTrue(np.isposinf(val), f"expected inf, got {val}") def assert_zero(val): self.assertEqual(make_val(val), make_val(0)) @@ -92,13 +81,12 @@ def assert_zero(val): self.assertEqual(info.nmant + info.nexp + 1, info.bits) assert_representable(info.tiny) + assert_representable(info.max) - assert_representable(info.min) + assert_infinite(np.spacing(info.max)) - msg = "overflow encountered in nextafter" - with ignore_warning(category=RuntimeWarning, message=msg): - assert_infinite(np.nextafter(info.max, make_val(np.inf))) - assert_infinite(-np.nextafter(info.min, make_val(-np.inf))) + assert_representable(info.min) + assert_infinite(-np.spacing(info.min)) assert_representable(2.0 ** (info.maxexp - 1)) assert_infinite(2.0**info.maxexp)