diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index 70d372a5f..e1ad95508 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -441,7 +441,7 @@ def __call__(self: Self, df: ArrowDataFrame) -> Sequence[ArrowSeries]: except TypeError: # `self._then_value` is a scalar and can't be converted to an expression value_series = plx._create_series_from_scalar( - self._then_value, reference_series=condition + self._then_value, reference_series=condition.alias("literal") ) condition_native, value_series_native = broadcast_series( diff --git a/narwhals/_duckdb/namespace.py b/narwhals/_duckdb/namespace.py index c91d11d3f..c516336c9 100644 --- a/narwhals/_duckdb/namespace.py +++ b/narwhals/_duckdb/namespace.py @@ -248,22 +248,27 @@ def __call__(self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]: value = parse_into_expr(self._then_value, namespace=plx)(df)[0] except TypeError: # `self._otherwise_value` is a scalar and can't be converted to an expression - value = ConstantExpression(self._then_value) + value = ConstantExpression(self._then_value).alias("literal") value = cast("duckdb.Expression", value) + value_name = get_column_name(df, value) if self._otherwise_value is None: - return [CaseExpression(condition=condition, value=value)] + return [CaseExpression(condition=condition, value=value).alias(value_name)] try: otherwise_expr = parse_into_expr(self._otherwise_value, namespace=plx) except TypeError: # `self._otherwise_value` is a scalar and can't be converted to an expression return [ - CaseExpression(condition=condition, value=value).otherwise( - ConstantExpression(self._otherwise_value) - ) + CaseExpression(condition=condition, value=value) + .otherwise(ConstantExpression(self._otherwise_value)) + .alias(value_name) ] otherwise = otherwise_expr(df)[0] - return [CaseExpression(condition=condition, value=value).otherwise(otherwise)] + return [ + CaseExpression(condition=condition, value=value) + .otherwise(otherwise) + .alias(value_name) + ] def then(self, value: DuckDBExpr | Any) -> DuckDBThen: self._then_value = value diff --git a/narwhals/_pandas_like/namespace.py b/narwhals/_pandas_like/namespace.py index 4265a3402..84efef836 100644 --- a/narwhals/_pandas_like/namespace.py +++ b/narwhals/_pandas_like/namespace.py @@ -467,13 +467,12 @@ def __call__(self, df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: except TypeError: # `self._then_value` is a scalar and can't be converted to an expression value_series = plx._create_series_from_scalar( - self._then_value, reference_series=condition + self._then_value, reference_series=condition.alias("literal") ) condition_native, value_series_native = broadcast_align_and_extract_native( condition, value_series ) - if self._otherwise_value is None: return [ value_series._from_native_series( diff --git a/narwhals/_pandas_like/utils.py b/narwhals/_pandas_like/utils.py index 159a9f1c9..28b7a9030 100644 --- a/narwhals/_pandas_like/utils.py +++ b/narwhals/_pandas_like/utils.py @@ -128,7 +128,7 @@ def broadcast_align_and_extract_native( s = rhs._native_series return ( lhs._native_series, - s.__class__(s.iloc[0], index=lhs_index, dtype=s.dtype), + s.__class__(s.iloc[0], index=lhs_index, dtype=s.dtype, name=rhs.name), ) if lhs.len() == 1: # broadcast diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 4f768db06..140626e4e 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -24,6 +24,11 @@ def test_when(constructor: Constructor) -> None: "a_when": [3, None, None], } assert_equal_data(result, expected) + result = df.select(nw.when(nw.col("a") == 1).then(value=3)) + expected = { + "literal": [3, None, None], + } + assert_equal_data(result, expected) def test_when_otherwise(constructor: Constructor) -> None: @@ -121,22 +126,14 @@ def test_otherwise_expression(constructor: Constructor) -> None: assert_equal_data(result, expected) -def test_when_then_otherwise_into_expr( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_when_then_otherwise_into_expr(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") > 1).then("c").otherwise("e")) expected = {"c": [7, 5, 6]} assert_equal_data(result, expected) -def test_when_then_otherwise_lit_str( - constructor: Constructor, request: pytest.FixtureRequest -) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) +def test_when_then_otherwise_lit_str(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") > 1).then(nw.col("b")).otherwise(nw.lit("z"))) expected = {"b": ["z", "b", "c"]}