Skip to content

Commit

Permalink
fix: Fix elementwise-apply if any input is AggregatedScalar (#15606)
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa authored Apr 12, 2024
1 parent 0b84b14 commit 48615d5
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
18 changes: 13 additions & 5 deletions crates/polars-lazy/src/physical_plan/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,19 @@ impl PhysicalExpr for ApplyExpr {
},
ApplyOptions::GroupWise => self.apply_multiple_group_aware(acs, df),
ApplyOptions::ElementWise => {
if acs
.iter()
.any(|ac| matches!(ac.agg_state(), AggState::AggregatedList(_)))
{
self.apply_multiple_group_aware(acs, df)
let mut has_agg_list = false;
let mut has_agg_scalar = false;
let mut has_not_agg = false;
for ac in &acs {
match ac.state {
AggState::AggregatedList(_) => has_agg_list = true,
AggState::AggregatedScalar(_) => has_agg_scalar = true,
AggState::NotAggregated(_) => has_not_agg = true,
_ => {},
}
}
if has_agg_list || (has_agg_scalar && has_not_agg) {
return self.apply_multiple_group_aware(acs, df);
} else {
apply_multiple_elementwise(
acs,
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/operations/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,3 +975,13 @@ def test_partitioned_group_by_14954(monkeypatch: Any) -> None:
[False, False, False, False, False, False, False, False, False, False],
],
}


def test_aggregated_scalar_elementwise_15602() -> None:
df = pl.DataFrame({"group": [1, 2, 1]})

out = df.group_by("group", maintain_order=True).agg(
foo=pl.col("group").is_between(1, pl.max("group"))
)
expected = pl.DataFrame({"group": [1, 2], "foo": [[True, True], [True]]})
assert_frame_equal(out, expected)

0 comments on commit 48615d5

Please sign in to comment.