Skip to content

Commit

Permalink
fix: Conversion of expr_ir in partition fast path
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Mar 29, 2024
1 parent f61594a commit ea3c3a2
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 10 deletions.
4 changes: 2 additions & 2 deletions crates/polars-lazy/src/physical_plan/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -431,11 +431,11 @@ pub fn create_physical_plan(
let input = create_physical_plan(input, lp_arena, expr_arena)?;
let keys = keys
.iter()
.map(|e| node_to_expr(e.node(), expr_arena))
.map(|e| e.to_expr(expr_arena))
.collect::<Vec<_>>();
let aggs = aggs
.iter()
.map(|e| node_to_expr(e.node(), expr_arena))
.map(|e| e.to_expr(expr_arena))
.collect::<Vec<_>>();
Ok(Box::new(executors::PartitionGroupByExec::new(
input,
Expand Down
10 changes: 4 additions & 6 deletions crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,7 @@ pub(crate) fn insert_streaming_nodes(
aggs,
maintain_order: false,
apply: None,
schema,
schema: output_schema,
options,
..
} => {
Expand Down Expand Up @@ -435,17 +435,15 @@ pub(crate) fn insert_streaming_nodes(

let valid_key = || {
keys.iter().all(|e| {
expr_arena
.get(e.node())
.get_type(schema, Context::Default, expr_arena)
// ensure we don't group_by list
output_schema
.get(e.output_name())
.map(|dt| !matches!(dt, DataType::List(_)))
.unwrap_or(false)
})
};

let valid_types = || {
schema
output_schema
.iter_dtypes()
.all(|dt| allowed_dtype(dt, string_cache))
};
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/logical_plan/expr_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ impl ExprIR {
self.output_name.unwrap()
}

pub(crate) fn output_name(&self) -> &str {
pub fn output_name(&self) -> &str {
self.output_name_arc().as_ref()
}

pub(crate) fn to_expr(&self, expr_arena: &Arena<AExpr>) -> Expr {
pub fn to_expr(&self, expr_arena: &Arena<AExpr>) -> Expr {
let out = node_to_expr(self.node, expr_arena);

match &self.output_name {
Expand Down
1 change: 1 addition & 0 deletions py-polars/polars/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@
"assert_frame_not_equal",
"assert_series_equal",
"assert_series_not_equal",
"_constants",
]
2 changes: 2 additions & 0 deletions py-polars/polars/testing/_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# On this limit Polars will start partitioning in debug builds
PARTITION_LIMIT = 15
8 changes: 8 additions & 0 deletions py-polars/tests/unit/streaming/test_streaming_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import polars as pl
from polars.testing import assert_frame_equal
from polars.testing._constants import PARTITION_LIMIT

if TYPE_CHECKING:
from pathlib import Path
Expand Down Expand Up @@ -480,3 +481,10 @@ def test_streaming_groupby_binary_15116() -> None:
"str": [b"A", b"BB", b"CCCC", b"DDDDDDDD", b"EEEEEEEEEEEEEEEE"],
"count": [3, 2, 2, 2, 1],
}


def test_streaming_group_by_convert_15380() -> None:
assert (
pl.DataFrame({"a": [1] * PARTITION_LIMIT}).group_by(b="a").len()["len"].item()
== PARTITION_LIMIT
)

0 comments on commit ea3c3a2

Please sign in to comment.