Skip to content

Commit

Permalink
chore: validate predicates in nw.when one level higher (#1756)
Browse files Browse the repository at this point in the history
* chore: validate predicates in `nw.when` one level higher

* sort out fail
  • Loading branch information
MarcoGorelli authored Jan 7, 2025
1 parent a6d76e1 commit 3672e86
Show file tree
Hide file tree
Showing 5 changed files with 7 additions and 23 deletions.
7 changes: 1 addition & 6 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,12 +359,7 @@ def when(
*predicates: IntoArrowExpr,
) -> ArrowWhen:
plx = self.__class__(backend_version=self._backend_version, version=self._version)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)

condition = plx.all_horizontal(*predicates)
return ArrowWhen(condition, self._backend_version, version=self._version)

def concat_str(
Expand Down
7 changes: 1 addition & 6 deletions narwhals/_dask/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,12 +310,7 @@ def when(
*predicates: IntoDaskExpr,
) -> DaskWhen:
plx = self.__class__(backend_version=self._backend_version, version=self._version)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)

condition = plx.all_horizontal(*predicates)
return DaskWhen(
condition, self._backend_version, returns_scalar=False, version=self._version
)
Expand Down
7 changes: 1 addition & 6 deletions narwhals/_pandas_like/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,12 +371,7 @@ def when(
plx = self.__class__(
self._implementation, self._backend_version, version=self._version
)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)

condition = plx.all_horizontal(*predicates)
return PandasWhen(
condition, self._implementation, self._backend_version, version=self._version
)
Expand Down
3 changes: 3 additions & 0 deletions narwhals/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -7643,6 +7643,9 @@ def max_horizontal(*exprs: IntoExpr | Iterable[IntoExpr]) -> Expr:
class When:
def __init__(self, *predicates: IntoExpr | Iterable[IntoExpr]) -> None:
self._predicates = flatten([predicates])
if not self._predicates:
msg = "At least one predicate needs to be provided to `narwhals.when`."
raise TypeError(msg)

def _extract_predicates(self, plx: Any) -> Any:
return [extract_compliant(plx, v) for v in self._predicates]
Expand Down
6 changes: 1 addition & 5 deletions tests/expr_and_series/when_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,7 @@ def test_multiple_conditions(
assert_equal_data(result, expected)


def test_no_arg_when_fail(
constructor: Constructor, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
def test_no_arg_when_fail(constructor: Constructor) -> None:
df = nw.from_native(constructor(data))
with pytest.raises((TypeError, ValueError)):
df.select(nw.when().then(value=3).alias("a_when"))
Expand Down

0 comments on commit 3672e86

Please sign in to comment.