Skip to content

Commit

Permalink
fix: Fixed some error/assertion types (#18811)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Sep 18, 2024
1 parent 8114b52 commit 16781f6
Show file tree
Hide file tree
Showing 7 changed files with 18 additions and 12 deletions.
2 changes: 1 addition & 1 deletion crates/polars-expr/src/expressions/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ impl PhysicalExpr for BinaryExpr {
polars_ensure!(
lhs.len() == rhs.len() || lhs.len() == 1 || rhs.len() == 1,
expr = self.expr,
ComputeError: "cannot evaluate two Series of different lengths ({} and {})",
ShapeMismatch: "cannot evaluate two Series of different lengths ({} and {})",
lhs.len(), rhs.len(),
);
apply_operator_owned(lhs, rhs, self.op)
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-expr/src/expressions/sortby.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ impl PhysicalExpr for SortByExpr {
for i in 1..s_sort_by.len() {
polars_ensure!(
s_sort_by[0].len() == s_sort_by[i].len(),
expr = self.expr, ComputeError:
expr = self.expr, ShapeMismatch:
"`sort_by` produced different length ({}) than earlier Series' length in `by` ({})",
s_sort_by[0].len(), s_sort_by[i].len()
);
Expand All @@ -254,7 +254,7 @@ impl PhysicalExpr for SortByExpr {
let (sorted_idx, series) = (sorted_idx?, series?);
polars_ensure!(
sorted_idx.len() == series.len(),
expr = self.expr, ComputeError:
expr = self.expr, ShapeMismatch:
"`sort_by` produced different length ({}) than the Series that has to be sorted ({})",
sorted_idx.len(), series.len()
);
Expand Down
10 changes: 9 additions & 1 deletion py-polars/polars/testing/asserts/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
unpack_dtypes,
)
from polars.datatypes.group import FLOAT_DTYPES
from polars.exceptions import ComputeError, InvalidOperationError
from polars.exceptions import ComputeError, InvalidOperationError, ShapeError
from polars.series import Series
from polars.testing.asserts.utils import raise_assertion_error

Expand Down Expand Up @@ -157,6 +157,14 @@ def _assert_series_values_equal(
right=right.dtype,
cause=exc,
)
except ShapeError as exc:
raise_assertion_error(
"Series",
"incompatible lengths",
left=left,
right=right,
cause=exc,
)

# Check nested dtypes in separate function
if _comparing_nested_floats(left.dtype, right.dtype):
Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/expr/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,7 @@ def test_power_by_expression() -> None:
assert out["pow_op_left"].to_list() == [2.0, 4.0, None, 16.0, None, 64.0]


@pytest.mark.may_fail_auto_streaming
def test_expression_appends() -> None:
df = pl.DataFrame({"a": [1, 1, 2]})

Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def test_sort_string_nulls() -> None:

def test_sort_by_unequal_lengths_7207() -> None:
df = pl.DataFrame({"a": [0, 1, 1, 0], "b": [3, 2, 3, 2]})
with pytest.raises(pl.exceptions.ComputeError):
with pytest.raises(pl.exceptions.ShapeError):
df.select(pl.col.a.sort_by(["a", 1]))


Expand Down
10 changes: 3 additions & 7 deletions py-polars/tests/unit/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
PanicException,
SchemaError,
SchemaFieldNotFoundError,
ShapeError,
StructFieldNotFoundError,
)
from tests.unit.conftest import TEMPORAL_DTYPES
Expand Down Expand Up @@ -293,10 +294,7 @@ def test_invalid_sort_by() -> None:
)

# `select a where b order by c desc`
with pytest.raises(
ComputeError,
match=r"`sort_by` produced different length \(5\) than the Series that has to be sorted \(3\)",
):
with pytest.raises(ShapeError):
df.select(pl.col("a").filter(pl.col("b") == "M").sort_by("c", descending=True))


Expand Down Expand Up @@ -447,9 +445,7 @@ def test_compare_different_len() -> None:
)

s = pl.Series([2, 5, 8])
with pytest.raises(
ComputeError, match=r"cannot evaluate two Series of different lengths"
):
with pytest.raises(ShapeError):
df.filter(pl.col("idx") == s)


Expand Down
1 change: 1 addition & 0 deletions py-polars/tests/unit/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import polars as pl


@pytest.mark.may_fail_auto_streaming
def test_invalid_broadcast() -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit 16781f6

Please sign in to comment.