From 35ec4cdfb10091b4680ebd59192786bb0838534a Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 11 Apr 2023 12:23:41 -0700 Subject: [PATCH 1/2] finfo_test: improve min/max test --- ml_dtypes/tests/finfo_test.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/ml_dtypes/tests/finfo_test.py b/ml_dtypes/tests/finfo_test.py index a3cd262b..12429057 100644 --- a/ml_dtypes/tests/finfo_test.py +++ b/ml_dtypes/tests/finfo_test.py @@ -12,11 +12,22 @@ # See the License for the specific language governing permissions and # limitations under the License. +import contextlib +import warnings + from absl.testing import absltest from absl.testing import parameterized import ml_dtypes import numpy as np + +@contextlib.contextmanager +def ignore_warning(**kw): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", **kw) + yield + + ALL_DTYPES = [ ml_dtypes.bfloat16, ml_dtypes.float8_e4m3b11, @@ -26,6 +37,13 @@ ml_dtypes.float8_e5m2fnuz, ] +DTYPES_WITH_NO_INFINITY = [ + ml_dtypes.float8_e4m3b11, + ml_dtypes.float8_e4m3fn, + ml_dtypes.float8_e4m3fnuz, + ml_dtypes.float8_e5m2fnuz, +] + UINT_TYPES = { 8: np.uint8, 16: np.uint16, @@ -58,7 +76,11 @@ def assert_representable(val): self.assertEqual(make_val(val).item(), val) def assert_infinite(val): - self.assertNanEqual(make_val(val), make_val(np.inf)) + typed_val = make_val(val) + if typed_val.dtype in DTYPES_WITH_NO_INFINITY: + self.assertTrue(np.isnan(typed_val)) + else: + self.assertTrue(np.isposinf(typed_val)) def assert_zero(val): self.assertEqual(make_val(val), make_val(0)) @@ -71,6 +93,13 @@ def assert_zero(val): assert_representable(info.tiny) assert_representable(info.max) + assert_representable(info.min) + + msg = "overflow encountered in nextafter" + with ignore_warning(category=RuntimeWarning, message=msg): + assert_infinite(np.nextafter(info.max, make_val(np.inf))) + assert_infinite(-np.nextafter(info.min, make_val(-np.inf))) + assert_representable(2.0 ** (info.maxexp - 1)) assert_infinite(2.0**info.maxexp) From c555b4021774a39bb8eaa2a5809c40cac43967e6 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 11 Apr 2023 12:34:27 -0700 Subject: [PATCH 2/2] finfo_test: improve min/max test --- ml_dtypes/tests/finfo_test.py | 28 ++++++++-------------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/ml_dtypes/tests/finfo_test.py b/ml_dtypes/tests/finfo_test.py index 12429057..5372a9d8 100644 --- a/ml_dtypes/tests/finfo_test.py +++ b/ml_dtypes/tests/finfo_test.py @@ -12,22 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -import contextlib -import warnings - from absl.testing import absltest from absl.testing import parameterized import ml_dtypes import numpy as np - -@contextlib.contextmanager -def ignore_warning(**kw): - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", **kw) - yield - - ALL_DTYPES = [ ml_dtypes.bfloat16, ml_dtypes.float8_e4m3b11, @@ -76,11 +65,11 @@ def assert_representable(val): self.assertEqual(make_val(val).item(), val) def assert_infinite(val): - typed_val = make_val(val) - if typed_val.dtype in DTYPES_WITH_NO_INFINITY: - self.assertTrue(np.isnan(typed_val)) + val = make_val(val) + if dtype in DTYPES_WITH_NO_INFINITY: + self.assertTrue(np.isnan(val), f"expected NaN, got {val}") else: - self.assertTrue(np.isposinf(typed_val)) + self.assertTrue(np.isposinf(val), f"expected inf, got {val}") def assert_zero(val): self.assertEqual(make_val(val), make_val(0)) @@ -92,13 +81,12 @@ def assert_zero(val): self.assertEqual(info.nmant + info.nexp + 1, info.bits) assert_representable(info.tiny) + assert_representable(info.max) - assert_representable(info.min) + assert_infinite(np.spacing(info.max)) - msg = "overflow encountered in nextafter" - with ignore_warning(category=RuntimeWarning, message=msg): - assert_infinite(np.nextafter(info.max, make_val(np.inf))) - assert_infinite(-np.nextafter(info.min, make_val(-np.inf))) + assert_representable(info.min) + assert_infinite(-np.spacing(info.min)) assert_representable(2.0 ** (info.maxexp - 1)) assert_infinite(2.0**info.maxexp)