From ea3c3a26ff2809af156f7083c0f159609766447c Mon Sep 17 00:00:00 2001 From: ritchie Date: Fri, 29 Mar 2024 09:50:48 +0100 Subject: [PATCH] fix: Conversion of expr_ir in partition fast path --- crates/polars-lazy/src/physical_plan/planner/lp.rs | 4 ++-- .../src/physical_plan/streaming/convert_alp.rs | 10 ++++------ crates/polars-plan/src/logical_plan/expr_ir.rs | 4 ++-- py-polars/polars/testing/__init__.py | 1 + py-polars/polars/testing/_constants.py | 2 ++ .../tests/unit/streaming/test_streaming_group_by.py | 8 ++++++++ 6 files changed, 19 insertions(+), 10 deletions(-) create mode 100644 py-polars/polars/testing/_constants.py diff --git a/crates/polars-lazy/src/physical_plan/planner/lp.rs b/crates/polars-lazy/src/physical_plan/planner/lp.rs index 68cfbc09ad2a..f2ccf79706d0 100644 --- a/crates/polars-lazy/src/physical_plan/planner/lp.rs +++ b/crates/polars-lazy/src/physical_plan/planner/lp.rs @@ -431,11 +431,11 @@ pub fn create_physical_plan( let input = create_physical_plan(input, lp_arena, expr_arena)?; let keys = keys .iter() - .map(|e| node_to_expr(e.node(), expr_arena)) + .map(|e| e.to_expr(expr_arena)) .collect::>(); let aggs = aggs .iter() - .map(|e| node_to_expr(e.node(), expr_arena)) + .map(|e| e.to_expr(expr_arena)) .collect::>(); Ok(Box::new(executors::PartitionGroupByExec::new( input, diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index 8b2727dc3d45..acd293dead12 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -386,7 +386,7 @@ pub(crate) fn insert_streaming_nodes( aggs, maintain_order: false, apply: None, - schema, + schema: output_schema, options, .. } => { @@ -435,17 +435,15 @@ pub(crate) fn insert_streaming_nodes( let valid_key = || { keys.iter().all(|e| { - expr_arena - .get(e.node()) - .get_type(schema, Context::Default, expr_arena) - // ensure we don't group_by list + output_schema + .get(e.output_name()) .map(|dt| !matches!(dt, DataType::List(_))) .unwrap_or(false) }) }; let valid_types = || { - schema + output_schema .iter_dtypes() .all(|dt| allowed_dtype(dt, string_cache)) }; diff --git a/crates/polars-plan/src/logical_plan/expr_ir.rs b/crates/polars-plan/src/logical_plan/expr_ir.rs index 7a1b47d2cb57..f5f0b0ddf12a 100644 --- a/crates/polars-plan/src/logical_plan/expr_ir.rs +++ b/crates/polars-plan/src/logical_plan/expr_ir.rs @@ -94,11 +94,11 @@ impl ExprIR { self.output_name.unwrap() } - pub(crate) fn output_name(&self) -> &str { + pub fn output_name(&self) -> &str { self.output_name_arc().as_ref() } - pub(crate) fn to_expr(&self, expr_arena: &Arena) -> Expr { + pub fn to_expr(&self, expr_arena: &Arena) -> Expr { let out = node_to_expr(self.node, expr_arena); match &self.output_name { diff --git a/py-polars/polars/testing/__init__.py b/py-polars/polars/testing/__init__.py index b5962f7fba2c..06b4f6c91419 100644 --- a/py-polars/polars/testing/__init__.py +++ b/py-polars/polars/testing/__init__.py @@ -10,4 +10,5 @@ "assert_frame_not_equal", "assert_series_equal", "assert_series_not_equal", + "_constants", ] diff --git a/py-polars/polars/testing/_constants.py b/py-polars/polars/testing/_constants.py new file mode 100644 index 000000000000..8c11b6d0f176 --- /dev/null +++ b/py-polars/polars/testing/_constants.py @@ -0,0 +1,2 @@ +# On this limit Polars will start partitioning in debug builds +PARTITION_LIMIT = 15 diff --git a/py-polars/tests/unit/streaming/test_streaming_group_by.py b/py-polars/tests/unit/streaming/test_streaming_group_by.py index 506422dc38dd..e7915115b79a 100644 --- a/py-polars/tests/unit/streaming/test_streaming_group_by.py +++ b/py-polars/tests/unit/streaming/test_streaming_group_by.py @@ -8,6 +8,7 @@ import polars as pl from polars.testing import assert_frame_equal +from polars.testing._constants import PARTITION_LIMIT if TYPE_CHECKING: from pathlib import Path @@ -480,3 +481,10 @@ def test_streaming_groupby_binary_15116() -> None: "str": [b"A", b"BB", b"CCCC", b"DDDDDDDD", b"EEEEEEEEEEEEEEEE"], "count": [3, 2, 2, 2, 1], } + + +def test_streaming_group_by_convert_15380() -> None: + assert ( + pl.DataFrame({"a": [1] * PARTITION_LIMIT}).group_by(b="a").len()["len"].item() + == PARTITION_LIMIT + )