Skip to content

Commit

Permalink
refactor(python): Minor updates to assertion utils and docstrings (#1…
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Oct 17, 2023
1 parent a507d67 commit 003ca4d
Show file tree
Hide file tree
Showing 2 changed files with 135 additions and 65 deletions.
188 changes: 123 additions & 65 deletions py-polars/polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,43 +34,63 @@ def assert_frame_equal(
categorical_as_str: bool = False,
) -> None:
"""
Raise detailed AssertionError if `left` does NOT equal `right`.
Assert that the left and right frame are equal.
Raises a detailed ``AssertionError`` if the frames differ.
This function is intended for use in unit tests.
Parameters
----------
left
the DataFrame to compare.
The first DataFrame or LazyFrame to compare.
right
the DataFrame to compare with.
The second DataFrame or LazyFrame to compare.
check_row_order
if False, frames will compare equal if the required rows are present,
irrespective of the order in which they appear; as this requires
sorting, you cannot set on frames that contain unsortable columns.
Require row order to match.
.. note::
Setting this to ``False`` requires sorting the data, which will fail on
frames that contain unsortable columns.
check_column_order
if False, frames will compare equal if the required columns are present,
irrespective of the order in which they appear.
Require column order to match.
check_dtype
if True, data types need to match exactly.
Require data types to match.
check_exact
if False, test if values are within tolerance of each other
(see `rtol` & `atol`).
Require data values to match exactly. If set to ``False``, values are considered
equal when within tolerance of each other (see ``rtol`` and ``atol``).
rtol
relative tolerance for inexact checking. Fraction of values in `right`.
Relative tolerance for inexact checking. Fraction of values in ``right``.
atol
absolute tolerance for inexact checking.
Absolute tolerance for inexact checking.
nans_compare_equal
if your assert/test requires float NaN != NaN, set this to False.
Consider NaN values to be equal.
categorical_as_str
Cast categorical columns to string before comparing. Enabling this helps
compare DataFrames that do not share the same string cache.
compare columns that do not share the same string cache.
See Also
--------
assert_series_equal
assert_frame_not_equal
Examples
--------
>>> from polars.testing import assert_frame_equal
>>> df1 = pl.DataFrame({"a": [1, 2, 3]})
>>> df2 = pl.DataFrame({"a": [2, 3, 4]})
>>> df2 = pl.DataFrame({"a": [1, 5, 3]})
>>> assert_frame_equal(df1, df2) # doctest: +SKIP
AssertionError: Values for column 'a' are different.
Traceback (most recent call last):
...
AssertionError: Series are different (value mismatch)
[left]: [1, 2, 3]
[right]: [1, 5, 3]
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
...
AssertionError: values for column 'a' are different
"""
collect_input_frames = isinstance(left, LazyFrame) and isinstance(right, LazyFrame)
if collect_input_frames:
Expand All @@ -79,23 +99,23 @@ def assert_frame_equal(
objs = "DataFrames"
else:
_raise_assertion_error(
"Inputs", "unexpected input types", type(left), type(right)
"Inputs",
"unexpected input types",
type(left).__name__,
type(right).__name__,
)

if left_not_right := [c for c in left.columns if c not in right.columns]:
raise AssertionError(
f"columns {left_not_right!r} in left frame, but not in right"
)
msg = f"columns {left_not_right!r} in left frame, but not in right"
raise AssertionError(msg)

if right_not_left := [c for c in right.columns if c not in left.columns]:
raise AssertionError(
f"columns {right_not_left!r} in right frame, but not in left"
)
msg = f"columns {right_not_left!r} in right frame, but not in left"
raise AssertionError(msg)

if check_column_order and left.columns != right.columns:
raise AssertionError(
f"columns are not in the same order:\n{left.columns!r}\n{right.columns!r}"
)
msg = f"columns are not in the same order:\n{left.columns!r}\n{right.columns!r}"
raise AssertionError(msg)

if collect_input_frames:
if check_dtype: # check this _before_ we collect
Expand All @@ -114,9 +134,8 @@ def assert_frame_equal(
left = left.sort(by=left.columns)
right = right.sort(by=left.columns)
except ComputeError as exc:
raise InvalidAssert(
"cannot set `check_row_order=False` on frame with unsortable columns"
) from exc
msg = "cannot set `check_row_order=False` on frame with unsortable columns"
raise InvalidAssert(msg) from exc

# note: does not assume a particular column order
for c in left.columns:
Expand Down Expand Up @@ -150,42 +169,53 @@ def assert_frame_not_equal(
categorical_as_str: bool = False,
) -> None:
"""
Raise AssertionError if `left` DOES equal `right`.
Assert that the left and right frame are **not** equal.
This function is intended for use in unit tests.
Parameters
----------
left
the DataFrame to compare.
The first DataFrame or LazyFrame to compare.
right
the DataFrame to compare with.
The second DataFrame or LazyFrame to compare.
check_row_order
if False, frames will compare equal if the required rows are present,
irrespective of the order in which they appear; as this requires
sorting, you cannot set on frames that contain unsortable columns.
Require row order to match.
.. note::
Setting this to ``False`` requires sorting the data, which will fail on
frames that contain unsortable columns.
check_column_order
if False, frames will compare equal if the required columns are present,
irrespective of the order in which they appear.
Require column order to match.
check_dtype
if True, data types need to match exactly.
Require data types to match.
check_exact
if False, test if values are within tolerance of each other
(see `rtol` & `atol`).
Require data values to match exactly. If set to ``False``, values are considered
equal when within tolerance of each other (see ``rtol`` and ``atol``).
rtol
relative tolerance for inexact checking. Fraction of values in `right`.
Relative tolerance for inexact checking. Fraction of values in ``right``.
atol
absolute tolerance for inexact checking.
Absolute tolerance for inexact checking.
nans_compare_equal
if your assert/test requires float NaN != NaN, set this to False.
Consider NaN values to be equal.
categorical_as_str
Cast categorical columns to string before comparing. Enabling this helps
compare DataFrames that do not share the same string cache.
compare columns that do not share the same string cache.
See Also
--------
assert_frame_equal
assert_series_not_equal
Examples
--------
>>> from polars.testing import assert_frame_not_equal
>>> df1 = pl.DataFrame({"a": [1, 2, 3]})
>>> df2 = pl.DataFrame({"a": [2, 3, 4]})
>>> assert_frame_not_equal(df1, df2)
>>> df2 = pl.DataFrame({"a": [1, 2, 3]})
>>> assert_frame_not_equal(df1, df2) # doctest: +SKIP
Traceback (most recent call last):
...
AssertionError: frames are equal
"""
try:
Expand All @@ -204,7 +234,8 @@ def assert_frame_not_equal(
except AssertionError:
return
else:
raise AssertionError("expected the input frames to be unequal")
msg = "frames are equal"
raise AssertionError(msg)


def assert_series_equal(
Expand All @@ -220,42 +251,58 @@ def assert_series_equal(
categorical_as_str: bool = False,
) -> None:
"""
Raise detailed AssertionError if `left` does NOT equal `right`.
Assert that the left and right Series are equal.
Raises a detailed ``AssertionError`` if the Series differ.
This function is intended for use in unit tests.
Parameters
----------
left
the series to compare.
The first Series to compare.
right
the series to compare with.
The second Series to compare.
check_dtype
if True, data types need to match exactly.
Require data types to match.
check_names
if True, names need to match.
Require names to match.
check_exact
if False, test if values are within tolerance of each other
(see `rtol` & `atol`).
Require data values to match exactly. If set to ``False``, values are considered
equal when within tolerance of each other (see ``rtol`` and ``atol``).
rtol
relative tolerance for inexact checking. Fraction of values in `right`.
Relative tolerance for inexact checking. Fraction of values in ``right``.
atol
absolute tolerance for inexact checking.
Absolute tolerance for inexact checking.
nans_compare_equal
if your assert/test requires float NaN != NaN, set this to False.
Consider NaN values to be equal.
categorical_as_str
Cast categorical columns to string before comparing. Enabling this helps
compare DataFrames that do not share the same string cache.
compare columns that do not share the same string cache.
See Also
--------
assert_frame_equal
assert_series_not_equal
Examples
--------
>>> from polars.testing import assert_series_equal
>>> s1 = pl.Series([1, 2, 3])
>>> s2 = pl.Series([2, 3, 4])
>>> s2 = pl.Series([1, 5, 3])
>>> assert_series_equal(s1, s2) # doctest: +SKIP
Traceback (most recent call last):
...
AssertionError: Series are different (value mismatch)
[left]: [1, 2, 3]
[right]: [1, 5, 3]
"""
if not (isinstance(left, Series) and isinstance(right, Series)): # type: ignore[redundant-expr]
_raise_assertion_error(
"Inputs", "unexpected input types", type(left), type(right)
"Inputs",
"unexpected input types",
type(left).__name__,
type(right).__name__,
)

if len(left) != len(right):
Expand Down Expand Up @@ -289,7 +336,9 @@ def assert_series_not_equal(
categorical_as_str: bool = False,
) -> None:
"""
Raise AssertionError if `left` DOES equal `right`.
Assert that the left and right Series are **not** equal.
This function is intended for use in unit tests.
Parameters
----------
Expand All @@ -314,12 +363,20 @@ def assert_series_not_equal(
Cast categorical columns to string before comparing. Enabling this helps
compare DataFrames that do not share the same string cache.
See Also
--------
assert_series_equal
assert_frame_not_equal
Examples
--------
>>> from polars.testing import assert_series_not_equal
>>> s1 = pl.Series([1, 2, 3])
>>> s2 = pl.Series([2, 3, 4])
>>> assert_series_not_equal(s1, s2)
>>> s2 = pl.Series([1, 2, 3])
>>> assert_series_not_equal(s1, s2) # doctest: +SKIP
Traceback (most recent call last):
...
AssertionError: Series are equal
"""
try:
Expand All @@ -337,7 +394,8 @@ def assert_series_not_equal(
except AssertionError:
return
else:
raise AssertionError("expected the input Series to be unequal")
msg = "Series are equal"
raise AssertionError(msg)


def _assert_series_inner(
Expand Down
12 changes: 12 additions & 0 deletions py-polars/tests/unit/testing/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,3 +1044,15 @@ def test_assert_series_equal_full_series() -> None:
)
with pytest.raises(AssertionError, match=msg):
assert_series_equal(s1, s2)


def test_assert_frame_not_equal() -> None:
df = pl.DataFrame({"a": [1, 2]})
with pytest.raises(AssertionError, match="frames are equal"):
assert_frame_not_equal(df, df)


def test_assert_series_not_equal() -> None:
s = pl.Series("a", [1, 2])
with pytest.raises(AssertionError, match="Series are equal"):
assert_series_not_equal(s, s)

0 comments on commit 003ca4d

Please sign in to comment.