diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs index 45ca341d858e..4eaaac7d9004 100644 --- a/crates/polars-mem-engine/src/planner/lp.rs +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -10,13 +10,10 @@ use crate::utils::*; fn partitionable_gb( keys: &[ExprIR], aggs: &[ExprIR], - _input_schema: &Schema, + input_schema: &Schema, expr_arena: &Arena, apply: &Option>, ) -> bool { - // We first check if we can partition the group_by on the latest moment. - let mut partitionable = true; - // checks: // 1. complex expressions in the group_by itself are also not partitionable // in this case anything more than col("foo") @@ -28,102 +25,14 @@ fn partitionable_gb( // in this case anything more than col("foo") for key in keys { if (expr_arena).iter(key.node()).count() > 1 { - partitionable = false; - break; + return false; } } - if partitionable { - for agg in aggs { - let agg = agg.node(); - let aexpr = expr_arena.get(agg); - let depth = (expr_arena).iter(agg).count(); - - // These single expressions are partitionable - if matches!(aexpr, AExpr::Len) { - continue; - } - // col() - // lit() etc. - if depth == 1 { - partitionable = false; - break; - } - - let has_aggregation = - |node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_))); - - // check if the aggregation type is partitionable - // only simple aggregation like col().sum - // that can be divided in to the aggregation of their partitions are allowed - if !((expr_arena).iter(agg).all(|(_, ae)| { - use AExpr::*; - match ae { - // struct is needed to keep both states - #[cfg(feature = "dtype-struct")] - Agg(IRAggExpr::Mean(_)) => { - // only numeric means for now. - // logical types seem to break because of casts to float. - matches!(expr_arena.get(agg).get_type(_input_schema, Context::Default, expr_arena).map(|dt| { - dt.is_numeric()}), Ok(true)) - }, - // only allowed expressions - Agg(agg_e) => { - matches!( - agg_e, - IRAggExpr::Min{..} - | IRAggExpr::Max{..} - | IRAggExpr::Sum(_) - | IRAggExpr::Last(_) - | IRAggExpr::First(_) - | IRAggExpr::Count(_, true) - ) - }, - Function {input, options, ..} => { - matches!(options.collect_groups, ApplyOptions::ElementWise) && input.len() == 1 && - !has_aggregation(input[0].node()) - } - BinaryExpr {left, right, ..} => { - !has_aggregation(*left) && !has_aggregation(*right) - } - Ternary {truthy, falsy, predicate,..} => { - !has_aggregation(*truthy) && !has_aggregation(*falsy) && !has_aggregation(*predicate) - } - Column(_) | Len | Literal(_) | Cast {..} => { - true - } - _ => { - false - }, - } - }) && - // we only allow expressions that end with an aggregation - matches!(aexpr, AExpr::Agg(_))) - { - partitionable = false; - break; - } - - #[cfg(feature = "object")] - { - for name in aexpr_to_leaf_names(agg, expr_arena) { - let dtype = _input_schema.get(&name).unwrap(); - - if let DataType::Object(_, _) = dtype { - partitionable = false; - break; - } - } - if !partitionable { - break; - } - } - } - } + can_pre_agg_exprs(aggs, expr_arena, input_schema) } else { - partitionable = false; + false } - partitionable } struct ConversionState { diff --git a/crates/polars-plan/src/plans/aexpr/utils.rs b/crates/polars-plan/src/plans/aexpr/utils.rs index c26af7cb0d3b..eb0d8b62c7e2 100644 --- a/crates/polars-plan/src/plans/aexpr/utils.rs +++ b/crates/polars-plan/src/plans/aexpr/utils.rs @@ -157,3 +157,96 @@ pub fn permits_filter_pushdown_rec<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena< true } + +pub fn can_pre_agg_exprs( + exprs: &[ExprIR], + expr_arena: &Arena, + _input_schema: &Schema, +) -> bool { + exprs + .iter() + .all(|e| can_pre_agg(e.node(), expr_arena, _input_schema)) +} + +/// Checks whether an expression can be pre-aggregated in a group-by. Note that this also must be +/// implemented physically, so this isn't a complete list. +pub fn can_pre_agg(agg: Node, expr_arena: &Arena, _input_schema: &Schema) -> bool { + let aexpr = expr_arena.get(agg); + + match aexpr { + AExpr::Len => true, + AExpr::Column(_) | AExpr::Literal(_) => false, + // We only allow expressions that end with an aggregation. + AExpr::Agg(_) => { + let has_aggregation = + |node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_))); + + // check if the aggregation type is partitionable + // only simple aggregation like col().sum + // that can be divided in to the aggregation of their partitions are allowed + let can_partition = (expr_arena).iter(agg).all(|(_, ae)| { + use AExpr::*; + match ae { + // struct is needed to keep both states + #[cfg(feature = "dtype-struct")] + Agg(IRAggExpr::Mean(_)) => { + // only numeric means for now. + // logical types seem to break because of casts to float. + matches!( + expr_arena + .get(agg) + .get_type(_input_schema, Context::Default, expr_arena) + .map(|dt| { dt.is_numeric() }), + Ok(true) + ) + }, + // only allowed expressions + Agg(agg_e) => { + matches!( + agg_e, + IRAggExpr::Min { .. } + | IRAggExpr::Max { .. } + | IRAggExpr::Sum(_) + | IRAggExpr::Last(_) + | IRAggExpr::First(_) + | IRAggExpr::Count(_, true) + ) + }, + Function { input, options, .. } => { + matches!(options.collect_groups, ApplyOptions::ElementWise) + && input.len() == 1 + && !has_aggregation(input[0].node()) + }, + BinaryExpr { left, right, .. } => { + !has_aggregation(*left) && !has_aggregation(*right) + }, + Ternary { + truthy, + falsy, + predicate, + .. + } => { + !has_aggregation(*truthy) + && !has_aggregation(*falsy) + && !has_aggregation(*predicate) + }, + Column(_) | Len | Literal(_) | Cast { .. } => true, + _ => false, + } + }); + + #[cfg(feature = "object")] + { + for name in aexpr_to_leaf_names(agg, expr_arena) { + let dtype = _input_schema.get(&name).unwrap(); + + if let DataType::Object(_, _) = dtype { + return false; + } + } + } + can_partition + }, + _ => false, + } +} diff --git a/crates/polars-plan/src/plans/ir/format.rs b/crates/polars-plan/src/plans/ir/format.rs index 469d8064f271..69e42e06322c 100644 --- a/crates/polars-plan/src/plans/ir/format.rs +++ b/crates/polars-plan/src/plans/ir/format.rs @@ -1,6 +1,5 @@ use std::borrow::Cow; -use std::fmt; -use std::fmt::{Display, Formatter}; +use std::fmt::{self, Display, Formatter}; use polars_core::datatypes::AnyValue; use polars_core::schema::Schema; @@ -299,13 +298,21 @@ impl<'a> IRDisplay<'a> { self.with_root(*input)._format(f, sub_indent) }, GroupBy { - input, keys, aggs, .. + input, + keys, + aggs, + apply, + .. } => { - let aggs = self.display_expr_slice(aggs); let keys = self.display_expr_slice(keys); write!(f, "{:indent$}AGGREGATE", "")?; - write!(f, "\n{:indent$}\t{aggs} BY {keys} FROM", "")?; + if apply.is_some() { + write!(f, "\n{:indent$}\tMAP_GROUPS BY {keys} FROM", "")?; + } else { + let aggs = self.display_expr_slice(aggs); + write!(f, "\n{:indent$}\t{aggs} BY {keys} FROM", "")?; + } self.with_root(*input)._format(f, sub_indent) }, Join {