From 856471175fc28b7999d7a95500a26630b1b0a759 Mon Sep 17 00:00:00 2001 From: Francesco Bruzzesi <42817048+FBruzzesi@users.noreply.github.com> Date: Thu, 16 Jan 2025 12:17:15 +0100 Subject: [PATCH] patch: fix when-then double lit case (#1810) --- narwhals/_arrow/namespace.py | 19 ++++++++----------- narwhals/_dask/namespace.py | 17 ++++++++++++++--- narwhals/_pandas_like/namespace.py | 26 +++++++++++++++----------- tests/expr_and_series/when_test.py | 10 ++++++++++ 4 files changed, 47 insertions(+), 25 deletions(-) diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index b02ad32ee..70d372a5f 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -435,19 +435,18 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]: plx = df.__narwhals_namespace__() condition = parse_into_expr(self._condition, namespace=plx)(df)[0] + try: value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0] except TypeError: - # `self._otherwise_value` is a scalar and can't be converted to an expression - value_series = condition.__class__._from_iterable( - pa.repeat(pa.scalar(self._then_value), len(condition)), - name="literal", - backend_version=self._backend_version, - version=self._version, + # `self._then_value` is a scalar and can't be converted to an expression + value_series = plx._create_series_from_scalar( + self._then_value, reference_series=condition ) - value_series_native = value_series._native_series - condition_native = condition._native_series + condition_native, value_series_native = broadcast_series( + [condition, value_series] + ) if self._otherwise_value is None: otherwise_native = pa.repeat( @@ -472,9 +471,7 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]: ] else: otherwise_series = otherwise_expr(df)[0] - condition_native, otherwise_native = broadcast_series( - [condition, otherwise_series] - ) + _, otherwise_native = broadcast_series([condition, otherwise_series]) return [ value_series._from_native_series( pc.if_else(condition_native, value_series_native, otherwise_native) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index 23805afdc..936b9abb9 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -398,13 +398,24 @@ def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: plx = df.__narwhals_namespace__() condition = parse_into_expr(self._condition, namespace=plx)(df)[0] condition = cast("dx.Series", condition) + try: - value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0] + then_expr = parse_into_expr(self._then_value, namespace=plx) except TypeError: - # `self._otherwise_value` is a scalar and can't be converted to an expression + # `self._then_value` is a scalar and can't be converted to an expression + value_sequence: Sequence[Any] = [self._then_value] + is_scalar = True + else: + is_scalar = then_expr._returns_scalar # type: ignore[attr-defined] + value_sequence = then_expr(df)[0] + + if is_scalar: _df = condition.to_frame("a") - _df["tmp"] = self._then_value + _df["tmp"] = value_sequence[0] value_series = _df["tmp"] + else: + value_series = value_sequence + value_series = cast("dx.Series", value_series) validate_comparand(condition, value_series) diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 52e56d34f..4265a3402 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -461,20 +461,17 @@ def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: plx = df.__narwhals_namespace__() condition = parse_into_expr(self._condition, namespace=plx)(df)[0] + try: value_series = parse_into_expr(self._then_value, namespace=plx)(df)[0] except TypeError: - # `self._otherwise_value` is a scalar and can't be converted to an expression - value_series = condition.__class__._from_iterable( - [self._then_value] * len(condition), - name="literal", - index=condition._native_series.index, - implementation=self._implementation, - backend_version=self._backend_version, - version=self._version, + # `self._then_value` is a scalar and can't be converted to an expression + value_series = plx._create_series_from_scalar( + self._then_value, reference_series=condition ) - value_series_native, condition_native = broadcast_align_and_extract_native( - value_series, condition + + condition_native, value_series_native = broadcast_align_and_extract_native( + condition, value_series ) if self._otherwise_value is None: @@ -494,7 +491,14 @@ def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: ] else: otherwise_series = otherwise_expr(df)[0] - return [value_series.zip_with(condition, otherwise_series)] + _, otherwise_native = broadcast_align_and_extract_native( + condition, otherwise_series + ) + return [ + value_series._from_native_series( + value_series_native.where(condition_native, otherwise_native) + ) + ] def then(self, value: PandasLikeExpr | PandasLikeSeries | Any) -> PandasThen: self._then_value = value diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 94e37aaa3..4f768db06 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -141,3 +141,13 @@ def test_when_then_otherwise_lit_str( result = df.select(nw.when(nw.col("a") > 1).then(nw.col("b")).otherwise(nw.lit("z"))) expected = {"b": ["z", "b", "c"]} assert_equal_data(result, expected) + + +def test_when_then_otherwise_both_lit(constructor: Constructor) -> None: + df = nw.from_native(constructor(data)) + result = df.select( + x1=nw.when(nw.col("a") > 1).then(nw.lit(42)).otherwise(nw.lit(-1)), + x2=nw.when(nw.col("a") > 2).then(nw.lit(42)).otherwise(nw.lit(-1)), + ) + expected = {"x1": [-1, 42, 42], "x2": [-1, -1, 42]} + assert_equal_data(result, expected)