Skip to content

Commit

Permalink
fix: Fix lit().shrink_dtype() broadcasting (#18958)
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion authored Sep 27, 2024
1 parent a030634 commit 3342cc2
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/aexpr/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub fn is_scalar_ae(node: Node, expr_arena: &Arena<AExpr>) -> bool {
AExpr::Literal(lv) => lv.is_scalar(),
AExpr::Function { options, input, .. }
| AExpr::AnonymousFunction { options, input, .. } => {
if options.is_elementwise() {
if options.is_elementwise() || !options.flags.contains(FunctionFlags::CHANGES_LENGTH) {
input.iter().all(|e| e.is_scalar(expr_arena))
} else {
options.flags.contains(FunctionFlags::RETURNS_SCALAR)
Expand Down
8 changes: 8 additions & 0 deletions py-polars/tests/unit/expr/test_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -645,3 +645,11 @@ def test_slice() -> None:
result = df.select(pl.all().slice(1, 1))
expected = pl.DataFrame({"a": data["a"][1:2], "b": data["b"][1:2]})
assert_frame_equal(result, expected)


def test_function_expr_scalar_identification_18755() -> None:
# The function uses `ApplyOptions::GroupWise`, however the input is scalar.
assert_frame_equal(
pl.DataFrame({"a": [1, 2]}).with_columns(pl.lit(5).shrink_dtype().alias("b")),
pl.DataFrame({"a": [1, 2], "b": pl.Series([5, 5], dtype=pl.Int8)}),
)

0 comments on commit 3342cc2

Please sign in to comment.