Skip to content

Commit

Permalink
fix: Fix elementwise-apply if any input is AggregatedScalar
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Apr 12, 2024
1 parent 44f1097 commit 4085dd5
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
10 changes: 6 additions & 4 deletions crates/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,10 +354,12 @@ impl PhysicalExpr for ApplyExpr {
},
ApplyOptions::GroupWise => self.apply_multiple_group_aware(acs, df),
ApplyOptions::ElementWise => {
if acs
.iter()
.any(|ac| matches!(ac.agg_state(), AggState::AggregatedList(_)))
{
if acs.iter().any(|ac| {
matches!(
ac.agg_state(),
AggState::AggregatedList(_) | AggState::AggregatedScalar(_)
)
}) {
self.apply_multiple_group_aware(acs, df)
} else {
apply_multiple_elementwise(
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/operations/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,3 +975,13 @@ def test_partitioned_group_by_14954(monkeypatch: Any) -> None:
[False, False, False, False, False, False, False, False, False, False],
],
}


def test_aggregated_scalar_elementwise_15602() -> None:
df = pl.DataFrame({"group": [1, 2, 1]})

out = df.group_by("group", maintain_order=True).agg(
foo=pl.col("group").is_between(1, pl.max("group"))
)
expected = pl.DataFrame({"group": [1, 2], "foo": [[True, True], [True]]})
assert_frame_equal(out, expected)

0 comments on commit 4085dd5

Please sign in to comment.