Skip to content

Commit

Permalink
fix: Incorrect logic in assert_series_equal for infinities (#20763)
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley authored Jan 17, 2025
1 parent d05b942 commit 6753bb6
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 4 deletions.
6 changes: 4 additions & 2 deletions py-polars/polars/testing/asserts/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,11 @@ def _assert_series_values_within_tolerance(

difference = (left_unequal - right_unequal).abs()
tolerance = atol + rtol * right_unequal.abs()
exceeds_tolerance = difference > tolerance
within_tolerance = (difference <= tolerance) & right_unequal.is_finite() | (
left_unequal == right_unequal
)

if exceeds_tolerance.any():
if not within_tolerance.all():
raise_assertion_error(
"Series",
"value mismatch",
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/arithmetic/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -719,7 +719,7 @@ def test_list_arithmetic_div_ops_zero_denominator(

assert_series_equal(
s / pl.Series([1]).new_from_index(0, n),
pl.Series([[float("inf")], [1.0], [None], None], dtype=pl.List(pl.Float64)),
pl.Series([[0.0], [1.0], [None], None], dtype=pl.List(pl.Float64)),
)

# floordiv
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,7 +1016,7 @@ def test_diff() -> None:

def test_pct_change() -> None:
s = pl.Series("a", [1, 2, 4, 8, 16, 32, 64])
expected = pl.Series("a", [None, None, float("inf"), 3.0, 3.0, 3.0, 3.0])
expected = pl.Series("a", [None, None, 3.0, 3.0, 3.0, 3.0, 3.0])
assert_series_equal(s.pct_change(2), expected)
assert_series_equal(s.pct_change(pl.Series([2])), expected)
# negative
Expand Down
26 changes: 26 additions & 0 deletions py-polars/tests/unit/testing/test_assert_series_equal.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,3 +837,29 @@ def test_series_data_type_fail():
assert "AssertionError: Series are different (nan value mismatch)" in stdout
assert "AssertionError: Series are different (dtype mismatch)" in stdout
assert "AssertionError: inputs are different (unexpected input types)" in stdout


def test_assert_series_equal_inf() -> None:
s1 = pl.Series([1.0, float("inf")])
s2 = pl.Series([1.0, float("inf")])
assert_series_equal(s1, s2)

s1 = pl.Series([1.0, float("-inf")])
s2 = pl.Series([1.0, float("-inf")])
assert_series_equal(s1, s2)

s1 = pl.Series([1.0, float("inf")])
s2 = pl.Series([float("inf"), 1.0])
assert_series_not_equal(s1, s2)

s1 = pl.Series([1.0, float("inf")])
s2 = pl.Series([1.0, float("-inf")])
assert_series_not_equal(s1, s2)

s1 = pl.Series([1.0, float("inf")])
s2 = pl.Series([1.0, 2.0])
assert_series_not_equal(s1, s2)

s1 = pl.Series([1.0, float("inf")])
s2 = pl.Series([1.0, float("nan")])
assert_series_not_equal(s1, s2)

0 comments on commit 6753bb6

Please sign in to comment.