From 3672e86f0a2356869637848ccd13c41852ad1c28 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 7 Jan 2025 15:39:04 +0000 Subject: [PATCH] chore: validate predicates in `nw.when` one level higher (#1756) * chore: validate predicates in `nw.when` one level higher * sort out fail --- narwhals/_arrow/namespace.py | 7 +------ narwhals/_dask/namespace.py | 7 +------ narwhals/_pandas_like/namespace.py | 7 +------ narwhals/expr.py | 3 +++ tests/expr_and_series/when_test.py | 6 +----- 5 files changed, 7 insertions(+), 23 deletions(-) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 99f043ebd..b02ad32ee 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -359,12 +359,7 @@ def when( *predicates: IntoArrowExpr, ) -> ArrowWhen: plx = self.__class__(backend_version=self._backend_version, version=self._version) - if predicates: - condition = plx.all_horizontal(*predicates) - else: - msg = "at least one predicate needs to be provided" - raise TypeError(msg) - + condition = plx.all_horizontal(*predicates) return ArrowWhen(condition, self._backend_version, version=self._version) def concat_str( diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index d9a1a8ac6..9a16d7f13 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -310,12 +310,7 @@ def when( *predicates: IntoDaskExpr, ) -> DaskWhen: plx = self.__class__(backend_version=self._backend_version, version=self._version) - if predicates: - condition = plx.all_horizontal(*predicates) - else: - msg = "at least one predicate needs to be provided" - raise TypeError(msg) - + condition = plx.all_horizontal(*predicates) return DaskWhen( condition, self._backend_version, returns_scalar=False, version=self._version ) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 7885d7de0..212c9c938 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -371,12 +371,7 @@ def when( plx = self.__class__( self._implementation, self._backend_version, version=self._version ) - if predicates: - condition = plx.all_horizontal(*predicates) - else: - msg = "at least one predicate needs to be provided" - raise TypeError(msg) - + condition = plx.all_horizontal(*predicates) return PandasWhen( condition, self._implementation, self._backend_version, version=self._version ) diff --git a/narwhals/expr.py b/narwhals/expr.py index 809f76e77..653300da8 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -7643,6 +7643,9 @@ def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr: class When: def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None: self._predicates = flatten([predicates]) + if not self._predicates: + msg = "At least one predicate needs to be provided to `narwhals.when`." + raise TypeError(msg) def _extract_predicates(self, plx: Any) -> Any: return [extract_compliant(plx, v) for v in self._predicates] diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index b59dda488..739b00e2d 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -54,11 +54,7 @@ def test_multiple_conditions( assert_equal_data(result, expected) -def test_no_arg_when_fail( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_no_arg_when_fail(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) with pytest.raises((TypeError, ValueError)): df.select(nw.when().then(value=3).alias("a_when"))