Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into fix-align-single-row
Browse files Browse the repository at this point in the history
  • Loading branch information
lukemanley committed Dec 27, 2024
2 parents 4b521fc + 0c290d6 commit 6d7f28f
Show file tree
Hide file tree
Showing 3 changed files with 109 additions and 100 deletions.
99 changes: 4 additions & 95 deletions crates/polars-mem-engine/src/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,10 @@ use crate::utils::*;
fn partitionable_gb(
keys: &[ExprIR],
aggs: &[ExprIR],
_input_schema: &Schema,
input_schema: &Schema,
expr_arena: &Arena<AExpr>,
apply: &Option<Arc<dyn DataFrameUdf>>,
) -> bool {
// We first check if we can partition the group_by on the latest moment.
let mut partitionable = true;

// checks:
// 1. complex expressions in the group_by itself are also not partitionable
// in this case anything more than col("foo")
Expand All @@ -28,102 +25,14 @@ fn partitionable_gb(
// in this case anything more than col("foo")
for key in keys {
if (expr_arena).iter(key.node()).count() > 1 {
partitionable = false;
break;
return false;
}
}

if partitionable {
for agg in aggs {
let agg = agg.node();
let aexpr = expr_arena.get(agg);
let depth = (expr_arena).iter(agg).count();

// These single expressions are partitionable
if matches!(aexpr, AExpr::Len) {
continue;
}
// col()
// lit() etc.
if depth == 1 {
partitionable = false;
break;
}

let has_aggregation =
|node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_)));

// check if the aggregation type is partitionable
// only simple aggregation like col().sum
// that can be divided in to the aggregation of their partitions are allowed
if !((expr_arena).iter(agg).all(|(_, ae)| {
use AExpr::*;
match ae {
// struct is needed to keep both states
#[cfg(feature = "dtype-struct")]
Agg(IRAggExpr::Mean(_)) => {
// only numeric means for now.
// logical types seem to break because of casts to float.
matches!(expr_arena.get(agg).get_type(_input_schema, Context::Default, expr_arena).map(|dt| {
dt.is_numeric()}), Ok(true))
},
// only allowed expressions
Agg(agg_e) => {
matches!(
agg_e,
IRAggExpr::Min{..}
| IRAggExpr::Max{..}
| IRAggExpr::Sum(_)
| IRAggExpr::Last(_)
| IRAggExpr::First(_)
| IRAggExpr::Count(_, true)
)
},
Function {input, options, ..} => {
matches!(options.collect_groups, ApplyOptions::ElementWise) && input.len() == 1 &&
!has_aggregation(input[0].node())
}
BinaryExpr {left, right, ..} => {
!has_aggregation(*left) && !has_aggregation(*right)
}
Ternary {truthy, falsy, predicate,..} => {
!has_aggregation(*truthy) && !has_aggregation(*falsy) && !has_aggregation(*predicate)
}
Column(_) | Len | Literal(_) | Cast {..} => {
true
}
_ => {
false
},
}
}) &&
// we only allow expressions that end with an aggregation
matches!(aexpr, AExpr::Agg(_)))
{
partitionable = false;
break;
}

#[cfg(feature = "object")]
{
for name in aexpr_to_leaf_names(agg, expr_arena) {
let dtype = _input_schema.get(&name).unwrap();

if let DataType::Object(_, _) = dtype {
partitionable = false;
break;
}
}
if !partitionable {
break;
}
}
}
}
can_pre_agg_exprs(aggs, expr_arena, input_schema)
} else {
partitionable = false;
false
}
partitionable
}

struct ConversionState {
Expand Down
93 changes: 93 additions & 0 deletions crates/polars-plan/src/plans/aexpr/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,96 @@ pub fn permits_filter_pushdown_rec<'a>(mut ae: &'a AExpr, expr_arena: &'a Arena<

true
}

pub fn can_pre_agg_exprs(
exprs: &[ExprIR],
expr_arena: &Arena<AExpr>,
_input_schema: &Schema,
) -> bool {
exprs
.iter()
.all(|e| can_pre_agg(e.node(), expr_arena, _input_schema))
}

/// Checks whether an expression can be pre-aggregated in a group-by. Note that this also must be
/// implemented physically, so this isn't a complete list.
pub fn can_pre_agg(agg: Node, expr_arena: &Arena<AExpr>, _input_schema: &Schema) -> bool {
let aexpr = expr_arena.get(agg);

match aexpr {
AExpr::Len => true,
AExpr::Column(_) | AExpr::Literal(_) => false,
// We only allow expressions that end with an aggregation.
AExpr::Agg(_) => {
let has_aggregation =
|node: Node| has_aexpr(node, expr_arena, |ae| matches!(ae, AExpr::Agg(_)));

// check if the aggregation type is partitionable
// only simple aggregation like col().sum
// that can be divided in to the aggregation of their partitions are allowed
let can_partition = (expr_arena).iter(agg).all(|(_, ae)| {
use AExpr::*;
match ae {
// struct is needed to keep both states
#[cfg(feature = "dtype-struct")]
Agg(IRAggExpr::Mean(_)) => {
// only numeric means for now.
// logical types seem to break because of casts to float.
matches!(
expr_arena
.get(agg)
.get_type(_input_schema, Context::Default, expr_arena)
.map(|dt| { dt.is_numeric() }),
Ok(true)
)
},
// only allowed expressions
Agg(agg_e) => {
matches!(
agg_e,
IRAggExpr::Min { .. }
| IRAggExpr::Max { .. }
| IRAggExpr::Sum(_)
| IRAggExpr::Last(_)
| IRAggExpr::First(_)
| IRAggExpr::Count(_, true)
)
},
Function { input, options, .. } => {
matches!(options.collect_groups, ApplyOptions::ElementWise)
&& input.len() == 1
&& !has_aggregation(input[0].node())
},
BinaryExpr { left, right, .. } => {
!has_aggregation(*left) && !has_aggregation(*right)
},
Ternary {
truthy,
falsy,
predicate,
..
} => {
!has_aggregation(*truthy)
&& !has_aggregation(*falsy)
&& !has_aggregation(*predicate)
},
Column(_) | Len | Literal(_) | Cast { .. } => true,
_ => false,
}
});

#[cfg(feature = "object")]
{
for name in aexpr_to_leaf_names(agg, expr_arena) {
let dtype = _input_schema.get(&name).unwrap();

if let DataType::Object(_, _) = dtype {
return false;
}
}
}
can_partition
},
_ => false,
}
}
17 changes: 12 additions & 5 deletions crates/polars-plan/src/plans/ir/format.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::borrow::Cow;
use std::fmt;
use std::fmt::{Display, Formatter};
use std::fmt::{self, Display, Formatter};

use polars_core::datatypes::AnyValue;
use polars_core::schema::Schema;
Expand Down Expand Up @@ -299,13 +298,21 @@ impl<'a> IRDisplay<'a> {
self.with_root(*input)._format(f, sub_indent)
},
GroupBy {
input, keys, aggs, ..
input,
keys,
aggs,
apply,
..
} => {
let aggs = self.display_expr_slice(aggs);
let keys = self.display_expr_slice(keys);

write!(f, "{:indent$}AGGREGATE", "")?;
write!(f, "\n{:indent$}\t{aggs} BY {keys} FROM", "")?;
if apply.is_some() {
write!(f, "\n{:indent$}\tMAP_GROUPS BY {keys} FROM", "")?;
} else {
let aggs = self.display_expr_slice(aggs);
write!(f, "\n{:indent$}\t{aggs} BY {keys} FROM", "")?;
}
self.with_root(*input)._format(f, sub_indent)
},
Join {
Expand Down

0 comments on commit 6d7f28f

Please sign in to comment.