Skip to content

Commit

Permalink
refactor: Always expand horizontal_any/all (#15816)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Apr 21, 2024
1 parent 5b11f28 commit b4e43d7
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 94 deletions.
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/tests/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use super::*;
fn cached_before_root(q: LazyFrame) {
let (mut expr_arena, mut lp_arena) = get_arenas();
let lp = q.optimize(&mut lp_arena, &mut expr_arena).unwrap();
for input in lp_arena.get(lp).get_inputs() {
for input in lp_arena.get(lp).get_inputs_vec() {
assert!(matches!(lp_arena.get(input), IR::Cache { .. }));
}
}
Expand Down
40 changes: 0 additions & 40 deletions crates/polars-ops/src/series/ops/horizontal.rs
Original file line number Diff line number Diff line change
@@ -1,45 +1,5 @@
use std::ops::{BitAnd, BitOr};

use polars_core::frame::NullStrategy;
use polars_core::prelude::*;
use polars_core::POOL;
use rayon::prelude::*;

pub fn any_horizontal(s: &[Series]) -> PolarsResult<Series> {
let out = POOL
.install(|| {
s.par_iter()
.try_fold(
|| BooleanChunked::new("", &[false]),
|acc, b| {
let b = b.cast(&DataType::Boolean)?;
let b = b.bool()?;
PolarsResult::Ok((&acc).bitor(b))
},
)
.try_reduce(|| BooleanChunked::new("", [false]), |a, b| Ok(a.bitor(b)))
})?
.with_name(s[0].name());
Ok(out.into_series())
}

pub fn all_horizontal_impl(s: &[Series]) -> PolarsResult<Series> {
let out = POOL
.install(|| {
s.par_iter()
.try_fold(
|| BooleanChunked::new("", &[true]),
|acc, b| {
let b = b.cast(&DataType::Boolean)?;
let b = b.bool()?;
PolarsResult::Ok((&acc).bitand(b))
},
)
.try_reduce(|| BooleanChunked::new("", [true]), |a, b| Ok(a.bitand(b)))
})?
.with_name(s[0].name());
Ok(out.into_series())
}

pub fn max_horizontal(s: &[Series]) -> PolarsResult<Option<Series>> {
let df = unsafe { DataFrame::new_no_checks(Vec::from(s)) };
Expand Down
15 changes: 4 additions & 11 deletions crates/polars-plan/src/dsl/function_expr/boolean.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
use super::*;
use crate::map;
#[cfg(feature = "is_between")]
use crate::map_as_slice;
#[cfg(feature = "is_in")]
use crate::wrap;
use crate::{map, map_as_slice};

#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[derive(Clone, PartialEq, Debug, Eq, Hash)]
Expand Down Expand Up @@ -112,9 +114,8 @@ impl From<BooleanFunction> for SpecialEq<Arc<dyn SeriesUdf>> {
IsBetween { closed } => map_as_slice!(is_between, closed),
#[cfg(feature = "is_in")]
IsIn => wrap!(is_in),
AllHorizontal => map_as_slice!(all_horizontal),
AnyHorizontal => map_as_slice!(any_horizontal),
Not => map!(not),
AllHorizontal | AnyHorizontal => unreachable!(),
}
}
}
Expand Down Expand Up @@ -202,14 +203,6 @@ fn is_in(s: &mut [Series]) -> PolarsResult<Option<Series>> {
polars_ops::prelude::is_in(left, other).map(|ca| Some(ca.into_series()))
}

fn any_horizontal(s: &[Series]) -> PolarsResult<Series> {
polars_ops::prelude::any_horizontal(s)
}

fn all_horizontal(s: &[Series]) -> PolarsResult<Series> {
polars_ops::prelude::all_horizontal_impl(s)
}

fn not(s: &Series) -> PolarsResult<Series> {
polars_ops::series::negate_bitwise(s)
}
27 changes: 2 additions & 25 deletions crates/polars-plan/src/dsl/functions/horizontal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,26 +195,12 @@ where
pub fn all_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
let exprs = exprs.as_ref().to_vec();
polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown");

// We prefer this path as the optimizer can better deal with the binary operations.
// However if we have a single expression, we might lose information.
// E.g. `all().is_null()` would reduce to `all().is_null()` (the & is not needed as there is no rhs (yet)
// And upon expansion, it becomes
// `col(i).is_null() for i in len(df))`
// so we would miss the boolean operator.
if exprs.len() > 1 {
return Ok(exprs.into_iter().reduce(|l, r| l.logical_and(r)).unwrap());
}

// This will be reduced to `expr & expr` during conversion to IR.
Ok(Expr::Function {
input: exprs,
function: FunctionExpr::Boolean(BooleanFunction::AllHorizontal),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: false,
cast_to_supertypes: false,
allow_rename: true,
..Default::default()
},
})
Expand All @@ -226,21 +212,12 @@ pub fn all_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
pub fn any_horizontal<E: AsRef<[Expr]>>(exprs: E) -> PolarsResult<Expr> {
let exprs = exprs.as_ref().to_vec();
polars_ensure!(!exprs.is_empty(), ComputeError: "cannot return empty fold because the number of output rows is unknown");

// See comment in `all_horizontal`.
if exprs.len() > 1 {
return Ok(exprs.into_iter().reduce(|l, r| l.logical_or(r)).unwrap());
}

// This will be reduced to `expr | expr` during conversion to IR.
Ok(Expr::Function {
input: exprs,
function: FunctionExpr::Boolean(BooleanFunction::AnyHorizontal),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
input_wildcard_expansion: true,
returns_scalar: false,
cast_to_supertypes: false,
allow_rename: true,
..Default::default()
},
})
Expand Down
17 changes: 9 additions & 8 deletions crates/polars-plan/src/logical_plan/alp/inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -265,17 +265,18 @@ impl IR {
container.push_node(input)
}

pub fn get_inputs(&self) -> Vec<Node> {
let mut inputs = Vec::new();
pub fn get_inputs(&self) -> UnitVec<Node> {
let mut inputs: UnitVec<Node> = unitvec!();
self.copy_inputs(&mut inputs);
inputs
}

pub fn get_inputs_vec(&self) -> Vec<Node> {
let mut inputs = vec![];
self.copy_inputs(&mut inputs);
inputs
}
/// panics if more than one input
#[cfg(any(
all(feature = "strings", feature = "concat_str"),
feature = "streaming",
feature = "fused"
))]

pub(crate) fn get_input(&self) -> Option<Node> {
let mut inputs: UnitVec<Node> = unitvec!();
self.copy_inputs(&mut inputs);
Expand Down
15 changes: 15 additions & 0 deletions crates/polars-plan/src/logical_plan/alp/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,21 @@ impl IR {
}
}

pub fn input_schema<'a>(&'a self, arena: &'a Arena<IR>) -> Option<Cow<'a, SchemaRef>> {
use IR::*;
let schema = match self {
#[cfg(feature = "python")]
PythonScan { options, .. } => &options.schema,
DataFrameScan { schema, .. } => schema,
Scan { file_info, .. } => &file_info.schema,
node => {
let input = node.get_input()?;
return Some(arena.get(input).schema(arena));
},
};
Some(Cow::Borrowed(schema))
}

/// Get the schema of the logical plan node.
pub fn schema<'a>(&'a self, arena: &'a Arena<IR>) -> Cow<'a, SchemaRef> {
use IR::*;
Expand Down
21 changes: 21 additions & 0 deletions crates/polars-plan/src/logical_plan/conversion/expr_to_expr_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,27 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena<AExpr>, state: &mut ConversionSta
function,
options,
} => {
match function {
// Convert to binary expression as the optimizer understands those.
FunctionExpr::Boolean(BooleanFunction::AllHorizontal) => {
let expr = input
.into_iter()
.reduce(|l, r| l.logical_and(r))
.unwrap()
.cast(DataType::Boolean);
return to_aexpr_impl(expr, arena, state);
},
FunctionExpr::Boolean(BooleanFunction::AnyHorizontal) => {
let expr = input
.into_iter()
.reduce(|l, r| l.logical_or(r))
.unwrap()
.cast(DataType::Boolean);
return to_aexpr_impl(expr, arena, state);
},
_ => {},
}

let e = to_expr_irs(input, arena);

if state.output_name.is_none() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl<'a> PredicatePushDown<'a> {
expr_arena: &mut Arena<AExpr>,
has_projections: bool,
) -> PolarsResult<IR> {
let inputs = lp.get_inputs();
let inputs = lp.get_inputs_vec();
let exprs = lp.get_exprs();

if has_projections {
Expand Down
40 changes: 32 additions & 8 deletions crates/polars-plan/src/logical_plan/optimizer/simplify_expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -419,8 +419,8 @@ impl OptimizationRule for SimplifyExprRule {
&mut self,
expr_arena: &mut Arena<AExpr>,
expr_node: Node,
_lp_arena: &Arena<IR>,
_lp_node: Node,
lp_arena: &Arena<IR>,
lp_node: Node,
) -> PolarsResult<Option<AExpr>> {
let expr = expr_arena.get(expr_node).clone();

Expand All @@ -443,8 +443,8 @@ impl OptimizationRule for SimplifyExprRule {
#[cfg(all(feature = "strings", feature = "concat_str"))]
{
string_addition_to_linear_concat(
_lp_arena,
_lp_node,
lp_arena,
lp_node,
expr_arena,
*left,
*right,
Expand Down Expand Up @@ -595,19 +595,43 @@ impl OptimizationRule for SimplifyExprRule {
strict,
} => {
let input = expr_arena.get(*expr);
inline_cast(input, data_type, *strict)?
inline_or_prune_cast(input, data_type, *strict, lp_node, lp_arena, expr_arena)?
},
_ => None,
};
Ok(out)
}
}

fn inline_cast(input: &AExpr, dtype: &DataType, strict: bool) -> PolarsResult<Option<AExpr>> {
fn inline_or_prune_cast(
aexpr: &AExpr,
dtype: &DataType,
strict: bool,
lp_node: Node,
lp_arena: &Arena<IR>,
expr_arena: &Arena<AExpr>,
) -> PolarsResult<Option<AExpr>> {
if !dtype.is_known() {
return Ok(None);
}
let lv = match (input, dtype) {
let lv = match (aexpr, dtype) {
// PRUNE
(
AExpr::BinaryExpr {
op: Operator::LogicalOr | Operator::LogicalAnd,
..
},
_,
) => {
if let Some(schema) = lp_arena.get(lp_node).input_schema(lp_arena) {
let field = aexpr.to_field(&schema, Context::Default, expr_arena)?;
if field.dtype == *dtype {
return Ok(Some(aexpr.clone()));
}
}
return Ok(None);
},
// INLINE
(AExpr::Literal(lv), _) => match lv {
LiteralValue::Series(s) => {
let s = if strict {
Expand All @@ -622,7 +646,7 @@ fn inline_cast(input: &AExpr, dtype: &DataType, strict: bool) -> PolarsResult<Op
return Ok(None);
};
if dtype == &av.dtype() {
return Ok(Some(input.clone()));
return Ok(Some(aexpr.clone()));
}
match (av, dtype) {
// casting null always remains null
Expand Down

0 comments on commit b4e43d7

Please sign in to comment.