From 48615d53ead820a3ed111e876e384301ffad2961 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Fri, 12 Apr 2024 19:58:54 +0800 Subject: [PATCH] fix: Fix elementwise-apply if any input is `AggregatedScalar` (#15606) --- .../src/physical_plan/expressions/apply.rs | 18 +++++++++++++----- .../tests/unit/operations/test_group_by.py | 10 ++++++++++ 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/crates/polars-lazy/src/physical_plan/expressions/apply.rs b/crates/polars-lazy/src/physical_plan/expressions/apply.rs index 7a447dbade11..1bfb9f12a035 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/apply.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/apply.rs @@ -354,11 +354,19 @@ impl PhysicalExpr for ApplyExpr { }, ApplyOptions::GroupWise => self.apply_multiple_group_aware(acs, df), ApplyOptions::ElementWise => { - if acs - .iter() - .any(|ac| matches!(ac.agg_state(), AggState::AggregatedList(_))) - { - self.apply_multiple_group_aware(acs, df) + let mut has_agg_list = false; + let mut has_agg_scalar = false; + let mut has_not_agg = false; + for ac in &acs { + match ac.state { + AggState::AggregatedList(_) => has_agg_list = true, + AggState::AggregatedScalar(_) => has_agg_scalar = true, + AggState::NotAggregated(_) => has_not_agg = true, + _ => {}, + } + } + if has_agg_list || (has_agg_scalar && has_not_agg) { + return self.apply_multiple_group_aware(acs, df); } else { apply_multiple_elementwise( acs, diff --git a/py-polars/tests/unit/operations/test_group_by.py b/py-polars/tests/unit/operations/test_group_by.py index 85ed79f9be40..e432d1e8b9bd 100644 --- a/py-polars/tests/unit/operations/test_group_by.py +++ b/py-polars/tests/unit/operations/test_group_by.py @@ -975,3 +975,13 @@ def test_partitioned_group_by_14954(monkeypatch: Any) -> None: [False, False, False, False, False, False, False, False, False, False], ], } + + +def test_aggregated_scalar_elementwise_15602() -> None: + df = pl.DataFrame({"group": [1, 2, 1]}) + + out = df.group_by("group", maintain_order=True).agg( + foo=pl.col("group").is_between(1, pl.max("group")) + ) + expected = pl.DataFrame({"group": [1, 2], "foo": [[True, True], [True]]}) + assert_frame_equal(out, expected)