diff --git a/py-polars/polars/testing/asserts/series.py b/py-polars/polars/testing/asserts/series.py index f7c522315d03..e73e1b60697e 100644 --- a/py-polars/polars/testing/asserts/series.py +++ b/py-polars/polars/testing/asserts/series.py @@ -319,9 +319,11 @@ def _assert_series_values_within_tolerance( difference = (left_unequal - right_unequal).abs() tolerance = atol + rtol * right_unequal.abs() - exceeds_tolerance = difference > tolerance + within_tolerance = (difference <= tolerance) & right_unequal.is_finite() | ( + left_unequal == right_unequal + ) - if exceeds_tolerance.any(): + if not within_tolerance.all(): raise_assertion_error( "Series", "value mismatch", diff --git a/py-polars/tests/unit/operations/arithmetic/test_list.py b/py-polars/tests/unit/operations/arithmetic/test_list.py index 7930d595141c..b6a0d4a4859b 100644 --- a/py-polars/tests/unit/operations/arithmetic/test_list.py +++ b/py-polars/tests/unit/operations/arithmetic/test_list.py @@ -719,7 +719,7 @@ def test_list_arithmetic_div_ops_zero_denominator( assert_series_equal( s / pl.Series([1]).new_from_index(0, n), - pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)), + pl.Series([[0.0], [1.0], [None], None], dtype=pl.List(pl.Float64)), ) # floordiv diff --git a/py-polars/tests/unit/series/test_series.py b/py-polars/tests/unit/series/test_series.py index bd8baa002cf9..c4d704e8bbbe 100644 --- a/py-polars/tests/unit/series/test_series.py +++ b/py-polars/tests/unit/series/test_series.py @@ -1016,7 +1016,7 @@ def test_diff() -> None: def test_pct_change() -> None: s = pl.Series("a", [1, 2, 4, 8, 16, 32, 64]) - expected = pl.Series("a", [None, None, float("inf"), 3.0, 3.0, 3.0, 3.0]) + expected = pl.Series("a", [None, None, 3.0, 3.0, 3.0, 3.0, 3.0]) assert_series_equal(s.pct_change(2), expected) assert_series_equal(s.pct_change(pl.Series([2])), expected) # negative diff --git a/py-polars/tests/unit/testing/test_assert_series_equal.py b/py-polars/tests/unit/testing/test_assert_series_equal.py index c523fe193a30..d576762fc288 100644 --- a/py-polars/tests/unit/testing/test_assert_series_equal.py +++ b/py-polars/tests/unit/testing/test_assert_series_equal.py @@ -837,3 +837,29 @@ def test_series_data_type_fail(): assert "AssertionError: Series are different (nan value mismatch)" in stdout assert "AssertionError: Series are different (dtype mismatch)" in stdout assert "AssertionError: inputs are different (unexpected input types)" in stdout + + +def test_assert_series_equal_inf() -> None: + s1 = pl.Series([1.0, float("inf")]) + s2 = pl.Series([1.0, float("inf")]) + assert_series_equal(s1, s2) + + s1 = pl.Series([1.0, float("-inf")]) + s2 = pl.Series([1.0, float("-inf")]) + assert_series_equal(s1, s2) + + s1 = pl.Series([1.0, float("inf")]) + s2 = pl.Series([float("inf"), 1.0]) + assert_series_not_equal(s1, s2) + + s1 = pl.Series([1.0, float("inf")]) + s2 = pl.Series([1.0, float("-inf")]) + assert_series_not_equal(s1, s2) + + s1 = pl.Series([1.0, float("inf")]) + s2 = pl.Series([1.0, 2.0]) + assert_series_not_equal(s1, s2) + + s1 = pl.Series([1.0, float("inf")]) + s2 = pl.Series([1.0, float("nan")]) + assert_series_not_equal(s1, s2)