From 3342cc27d0169e7d12e2bace0ebd4d73915a40bf Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Fri, 27 Sep 2024 20:48:18 +1000 Subject: [PATCH] fix: Fix `lit().shrink_dtype()` broadcasting (#18958) --- crates/polars-plan/src/plans/aexpr/scalar.rs | 2 +- py-polars/tests/unit/expr/test_exprs.py | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/crates/polars-plan/src/plans/aexpr/scalar.rs b/crates/polars-plan/src/plans/aexpr/scalar.rs index f7d681b407d4..553c8800b538 100644 --- a/crates/polars-plan/src/plans/aexpr/scalar.rs +++ b/crates/polars-plan/src/plans/aexpr/scalar.rs @@ -8,7 +8,7 @@ pub fn is_scalar_ae(node: Node, expr_arena: &Arena) -> bool { AExpr::Literal(lv) => lv.is_scalar(), AExpr::Function { options, input, .. } | AExpr::AnonymousFunction { options, input, .. } => { - if options.is_elementwise() { + if options.is_elementwise() || !options.flags.contains(FunctionFlags::CHANGES_LENGTH) { input.iter().all(|e| e.is_scalar(expr_arena)) } else { options.flags.contains(FunctionFlags::RETURNS_SCALAR) diff --git a/py-polars/tests/unit/expr/test_exprs.py b/py-polars/tests/unit/expr/test_exprs.py index b97b3c8f1288..31bf08534df7 100644 --- a/py-polars/tests/unit/expr/test_exprs.py +++ b/py-polars/tests/unit/expr/test_exprs.py @@ -645,3 +645,11 @@ def test_slice() -> None: result = df.select(pl.all().slice(1, 1)) expected = pl.DataFrame({"a": data["a"][1:2], "b": data["b"][1:2]}) assert_frame_equal(result, expected) + + +def test_function_expr_scalar_identification_18755() -> None: + # The function uses `ApplyOptions::GroupWise`, however the input is scalar. + assert_frame_equal( + pl.DataFrame({"a": [1, 2]}).with_columns(pl.lit(5).shrink_dtype().alias("b")), + pl.DataFrame({"a": [1, 2], "b": pl.Series([5, 5], dtype=pl.Int8)}), + )