From b4e43d7b81ac3b482acb3fac9173a9e2ce02e717 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Sun, 21 Apr 2024 13:05:14 +0200 Subject: [PATCH] refactor: Always expand horizontal_any/all (#15816) --- crates/polars-lazy/src/tests/cse.rs | 2 +- .../polars-ops/src/series/ops/horizontal.rs | 40 ------------------- .../src/dsl/function_expr/boolean.rs | 15 ++----- .../src/dsl/functions/horizontal.rs | 27 +------------ .../src/logical_plan/alp/inputs.rs | 17 ++++---- .../src/logical_plan/alp/schema.rs | 15 +++++++ .../conversion/expr_to_expr_ir.rs | 21 ++++++++++ .../optimizer/predicate_pushdown/mod.rs | 2 +- .../logical_plan/optimizer/simplify_expr.rs | 40 +++++++++++++++---- 9 files changed, 85 insertions(+), 94 deletions(-) diff --git a/crates/polars-lazy/src/tests/cse.rs b/crates/polars-lazy/src/tests/cse.rs index 465eda559caf..b9e23427cde9 100644 --- a/crates/polars-lazy/src/tests/cse.rs +++ b/crates/polars-lazy/src/tests/cse.rs @@ -5,7 +5,7 @@ use super::*; fn cached_before_root(q: LazyFrame) { let (mut expr_arena, mut lp_arena) = get_arenas(); let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap(); - for input in lp_arena.get(lp).get_inputs() { + for input in lp_arena.get(lp).get_inputs_vec() { assert!(matches!(lp_arena.get(input), IR::Cache { .. })); } } diff --git a/crates/polars-ops/src/series/ops/horizontal.rs b/crates/polars-ops/src/series/ops/horizontal.rs index 4aad55f7d966..c8e3488aab93 100644 --- a/crates/polars-ops/src/series/ops/horizontal.rs +++ b/crates/polars-ops/src/series/ops/horizontal.rs @@ -1,45 +1,5 @@ -use std::ops::{BitAnd, BitOr}; - use polars_core::frame::NullStrategy; use polars_core::prelude::*; -use polars_core::POOL; -use rayon::prelude::*; - -pub fn any_horizontal(s: &[Series]) -> PolarsResult { - let out = POOL - .install(|| { - s.par_iter() - .try_fold( - || BooleanChunked::new("", &[false]), - |acc, b| { - let b = b.cast(&DataType::Boolean)?; - let b = b.bool()?; - PolarsResult::Ok((&acc).bitor(b)) - }, - ) - .try_reduce(|| BooleanChunked::new("", [false]), |a, b| Ok(a.bitor(b))) - })? - .with_name(s[0].name()); - Ok(out.into_series()) -} - -pub fn all_horizontal_impl(s: &[Series]) -> PolarsResult { - let out = POOL - .install(|| { - s.par_iter() - .try_fold( - || BooleanChunked::new("", &[true]), - |acc, b| { - let b = b.cast(&DataType::Boolean)?; - let b = b.bool()?; - PolarsResult::Ok((&acc).bitand(b)) - }, - ) - .try_reduce(|| BooleanChunked::new("", [true]), |a, b| Ok(a.bitand(b))) - })? - .with_name(s[0].name()); - Ok(out.into_series()) -} pub fn max_horizontal(s: &[Series]) -> PolarsResult> { let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) }; diff --git a/crates/polars-plan/src/dsl/function_expr/boolean.rs b/crates/polars-plan/src/dsl/function_expr/boolean.rs index 6f6cf1824eb8..6f1bacb0c8d5 100644 --- a/crates/polars-plan/src/dsl/function_expr/boolean.rs +++ b/crates/polars-plan/src/dsl/function_expr/boolean.rs @@ -1,7 +1,9 @@ use super::*; +use crate::map; +#[cfg(feature = "is_between")] +use crate::map_as_slice; #[cfg(feature = "is_in")] use crate::wrap; -use crate::{map, map_as_slice}; #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[derive(Clone, PartialEq, Debug, Eq, Hash)] @@ -112,9 +114,8 @@ impl From for SpecialEq> { IsBetween { closed } => map_as_slice!(is_between, closed), #[cfg(feature = "is_in")] IsIn => wrap!(is_in), - AllHorizontal => map_as_slice!(all_horizontal), - AnyHorizontal => map_as_slice!(any_horizontal), Not => map!(not), + AllHorizontal | AnyHorizontal => unreachable!(), } } } @@ -202,14 +203,6 @@ fn is_in(s: &mut [Series]) -> PolarsResult> { polars_ops::prelude::is_in(left, other).map(|ca| Some(ca.into_series())) } -fn any_horizontal(s: &[Series]) -> PolarsResult { - polars_ops::prelude::any_horizontal(s) -} - -fn all_horizontal(s: &[Series]) -> PolarsResult { - polars_ops::prelude::all_horizontal_impl(s) -} - fn not(s: &Series) -> PolarsResult { polars_ops::series::negate_bitwise(s) } diff --git a/crates/polars-plan/src/dsl/functions/horizontal.rs b/crates/polars-plan/src/dsl/functions/horizontal.rs index 13837fc59796..bab9212dc8eb 100644 --- a/crates/polars-plan/src/dsl/functions/horizontal.rs +++ b/crates/polars-plan/src/dsl/functions/horizontal.rs @@ -195,26 +195,12 @@ where pub fn all_horizontal>(exprs: E) -> PolarsResult { let exprs = exprs.as_ref().to_vec(); polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown"); - - // We prefer this path as the optimizer can better deal with the binary operations. - // However if we have a single expression, we might lose information. - // E.g. `all().is_null()` would reduce to `all().is_null()` (the & is not needed as there is no rhs (yet) - // And upon expansion, it becomes - // `col(i).is_null() for i in len(df))` - // so we would miss the boolean operator. - if exprs.len() > 1 { - return Ok(exprs.into_iter().reduce(|l, r| l.logical_and(r)).unwrap()); - } - + // This will be reduced to `expr & expr` during conversion to IR. Ok(Expr::Function { input: exprs, function: FunctionExpr::Boolean(BooleanFunction::AllHorizontal), options: FunctionOptions { - collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - returns_scalar: false, - cast_to_supertypes: false, - allow_rename: true, ..Default::default() }, }) @@ -226,21 +212,12 @@ pub fn all_horizontal>(exprs: E) -> PolarsResult { pub fn any_horizontal>(exprs: E) -> PolarsResult { let exprs = exprs.as_ref().to_vec(); polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown"); - - // See comment in `all_horizontal`. - if exprs.len() > 1 { - return Ok(exprs.into_iter().reduce(|l, r| l.logical_or(r)).unwrap()); - } - + // This will be reduced to `expr | expr` during conversion to IR. Ok(Expr::Function { input: exprs, function: FunctionExpr::Boolean(BooleanFunction::AnyHorizontal), options: FunctionOptions { - collect_groups: ApplyOptions::ElementWise, input_wildcard_expansion: true, - returns_scalar: false, - cast_to_supertypes: false, - allow_rename: true, ..Default::default() }, }) diff --git a/crates/polars-plan/src/logical_plan/alp/inputs.rs b/crates/polars-plan/src/logical_plan/alp/inputs.rs index 497c26044c3b..dd9ca8ff5638 100644 --- a/crates/polars-plan/src/logical_plan/alp/inputs.rs +++ b/crates/polars-plan/src/logical_plan/alp/inputs.rs @@ -265,17 +265,18 @@ impl IR { container.push_node(input) } - pub fn get_inputs(&self) -> Vec { - let mut inputs = Vec::new(); + pub fn get_inputs(&self) -> UnitVec { + let mut inputs: UnitVec = unitvec!(); + self.copy_inputs(&mut inputs); + inputs + } + + pub fn get_inputs_vec(&self) -> Vec { + let mut inputs = vec![]; self.copy_inputs(&mut inputs); inputs } - /// panics if more than one input - #[cfg(any( - all(feature = "strings", feature = "concat_str"), - feature = "streaming", - feature = "fused" - ))] + pub(crate) fn get_input(&self) -> Option { let mut inputs: UnitVec = unitvec!(); self.copy_inputs(&mut inputs); diff --git a/crates/polars-plan/src/logical_plan/alp/schema.rs b/crates/polars-plan/src/logical_plan/alp/schema.rs index 0ed719d9f5cc..db4d77b61b03 100644 --- a/crates/polars-plan/src/logical_plan/alp/schema.rs +++ b/crates/polars-plan/src/logical_plan/alp/schema.rs @@ -44,6 +44,21 @@ impl IR { } } + pub fn input_schema<'a>(&'a self, arena: &'a Arena) -> Option> { + use IR::*; + let schema = match self { + #[cfg(feature = "python")] + PythonScan { options, .. } => &options.schema, + DataFrameScan { schema, .. } => schema, + Scan { file_info, .. } => &file_info.schema, + node => { + let input = node.get_input()?; + return Some(arena.get(input).schema(arena)); + }, + }; + Some(Cow::Borrowed(schema)) + } + /// Get the schema of the logical plan node. pub fn schema<'a>(&'a self, arena: &'a Arena) -> Cow<'a, SchemaRef> { use IR::*; diff --git a/crates/polars-plan/src/logical_plan/conversion/expr_to_expr_ir.rs b/crates/polars-plan/src/logical_plan/conversion/expr_to_expr_ir.rs index 15f076f8064f..336ca65d6b78 100644 --- a/crates/polars-plan/src/logical_plan/conversion/expr_to_expr_ir.rs +++ b/crates/polars-plan/src/logical_plan/conversion/expr_to_expr_ir.rs @@ -233,6 +233,27 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta function, options, } => { + match function { + // Convert to binary expression as the optimizer understands those. + FunctionExpr::Boolean(BooleanFunction::AllHorizontal) => { + let expr = input + .into_iter() + .reduce(|l, r| l.logical_and(r)) + .unwrap() + .cast(DataType::Boolean); + return to_aexpr_impl(expr, arena, state); + }, + FunctionExpr::Boolean(BooleanFunction::AnyHorizontal) => { + let expr = input + .into_iter() + .reduce(|l, r| l.logical_or(r)) + .unwrap() + .cast(DataType::Boolean); + return to_aexpr_impl(expr, arena, state); + }, + _ => {}, + } + let e = to_expr_irs(input, arena); if state.output_name.is_none() { diff --git a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs index 11c374a3d356..4477e1176d64 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/predicate_pushdown/mod.rs @@ -78,7 +78,7 @@ impl<'a> PredicatePushDown<'a> { expr_arena: &mut Arena, has_projections: bool, ) -> PolarsResult { - let inputs = lp.get_inputs(); + let inputs = lp.get_inputs_vec(); let exprs = lp.get_exprs(); if has_projections { diff --git a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs index 028fef156e99..0339e3bfade1 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs @@ -419,8 +419,8 @@ impl OptimizationRule for SimplifyExprRule { &mut self, expr_arena: &mut Arena, expr_node: Node, - _lp_arena: &Arena, - _lp_node: Node, + lp_arena: &Arena, + lp_node: Node, ) -> PolarsResult> { let expr = expr_arena.get(expr_node).clone(); @@ -443,8 +443,8 @@ impl OptimizationRule for SimplifyExprRule { #[cfg(all(feature = "strings", feature = "concat_str"))] { string_addition_to_linear_concat( - _lp_arena, - _lp_node, + lp_arena, + lp_node, expr_arena, *left, *right, @@ -595,7 +595,7 @@ impl OptimizationRule for SimplifyExprRule { strict, } => { let input = expr_arena.get(*expr); - inline_cast(input, data_type, *strict)? + inline_or_prune_cast(input, data_type, *strict, lp_node, lp_arena, expr_arena)? }, _ => None, }; @@ -603,11 +603,35 @@ impl OptimizationRule for SimplifyExprRule { } } -fn inline_cast(input: &AExpr, dtype: &DataType, strict: bool) -> PolarsResult> { +fn inline_or_prune_cast( + aexpr: &AExpr, + dtype: &DataType, + strict: bool, + lp_node: Node, + lp_arena: &Arena, + expr_arena: &Arena, +) -> PolarsResult> { if !dtype.is_known() { return Ok(None); } - let lv = match (input, dtype) { + let lv = match (aexpr, dtype) { + // PRUNE + ( + AExpr::BinaryExpr { + op: Operator::LogicalOr | Operator::LogicalAnd, + .. + }, + _, + ) => { + if let Some(schema) = lp_arena.get(lp_node).input_schema(lp_arena) { + let field = aexpr.to_field(&schema, Context::Default, expr_arena)?; + if field.dtype == *dtype { + return Ok(Some(aexpr.clone())); + } + } + return Ok(None); + }, + // INLINE (AExpr::Literal(lv), _) => match lv { LiteralValue::Series(s) => { let s = if strict { @@ -622,7 +646,7 @@ fn inline_cast(input: &AExpr, dtype: &DataType, strict: bool) -> PolarsResult