diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index d6fd3d4364cb..3536c3071332 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -145,6 +145,12 @@ impl LazyFrame { self } + /// Toggle collapse joins optimization. + pub fn with_collapse_joins(mut self, toggle: bool) -> Self { + self.opt_state.set(OptFlags::COLLAPSE_JOINS, toggle); + self + } + /// Toggle predicate pushdown optimization. pub fn with_predicate_pushdown(mut self, toggle: bool) -> Self { self.opt_state.set(OptFlags::PREDICATE_PUSHDOWN, toggle); diff --git a/crates/polars-plan/src/frame/opt_state.rs b/crates/polars-plan/src/frame/opt_state.rs index 934f42e6109f..699b4a09a0ce 100644 --- a/crates/polars-plan/src/frame/opt_state.rs +++ b/crates/polars-plan/src/frame/opt_state.rs @@ -33,6 +33,8 @@ bitflags! { const ROW_ESTIMATE = 1 << 13; /// Replace simple projections with a faster inlined projection that skips the expression engine. const FAST_PROJECTION = 1 << 14; + /// Collapse slower joins with filters into faster joins. + const COLLAPSE_JOINS = 1 << 15; } } diff --git a/crates/polars-plan/src/plans/expr_ir.rs b/crates/polars-plan/src/plans/expr_ir.rs index b70bca266ee1..8d66ceb9072f 100644 --- a/crates/polars-plan/src/plans/expr_ir.rs +++ b/crates/polars-plan/src/plans/expr_ir.rs @@ -152,6 +152,10 @@ impl ExprIR { self.output_name = OutputName::Alias(name) } + pub(crate) fn set_columnlhs(&mut self, name: PlSmallStr) { + self.output_name = OutputName::ColumnLhs(name) + } + pub fn output_name_inner(&self) -> &OutputName { &self.output_name } diff --git a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs new file mode 100644 index 000000000000..5656734969fe --- /dev/null +++ b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs @@ -0,0 +1,478 @@ +//! Optimization that collapses several a join with several filters into faster join. +//! +//! For example, `join(how='cross').filter(pl.col.l == pl.col.r)` can be collapsed to +//! `join(how='inner', left_on=pl.col.l, right_on=pl.col.r)`. + +use std::sync::Arc; + +use polars_core::schema::SchemaRef; +#[cfg(feature = "iejoin")] +use polars_ops::frame::{IEJoinOptions, InequalityOperator}; +use polars_ops::frame::{JoinCoalesce, JoinType}; +use polars_utils::arena::{Arena, Node}; +use polars_utils::pl_str::PlSmallStr; + +use super::{aexpr_to_leaf_names_iter, AExpr, JoinOptions, IR}; +use crate::dsl::Operator; +use crate::plans::{ExprIR, OutputName}; + +/// Join origin of an expression +#[derive(Debug, Clone, Copy)] +enum ExprOrigin { + /// Utilizes no columns + None, + /// Utilizes columns from the left side of the join + Left, + /// Utilizes columns from the right side of the join + Right, + /// Utilizes columns from both sides of the join + Both, +} + +fn get_origin( + root: Node, + expr_arena: &Arena, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + suffix: &str, +) -> ExprOrigin { + let mut expr_origin = ExprOrigin::None; + + for name in aexpr_to_leaf_names_iter(root, expr_arena) { + let in_left = left_schema.contains(name.as_str()); + let in_right = right_schema.contains(name.as_str()); + let has_suffix = name.as_str().ends_with(suffix); + let in_right = in_right + | (has_suffix && right_schema.contains(&name.as_str()[..name.len() - suffix.len()])); + + let name_origin = match (in_left, in_right, has_suffix) { + (true, false, _) | (true, true, false) => ExprOrigin::Left, + (false, true, _) | (true, true, true) => ExprOrigin::Right, + (false, false, _) => { + unreachable!("Invalid filter column should have been filtered before") + }, + }; + + use ExprOrigin as O; + expr_origin = match (expr_origin, name_origin) { + (O::None, other) | (other, O::None) => other, + (O::Left, O::Left) => O::Left, + (O::Right, O::Right) => O::Right, + _ => O::Both, + }; + } + + expr_origin +} + +/// Remove the join suffixes from a list of expressions +fn remove_suffix( + exprs: &mut Vec, + expr_arena: &mut Arena, + schema: &SchemaRef, + suffix: &str, +) { + let mut stack = Vec::new(); + + for expr in exprs { + if let OutputName::ColumnLhs(colname) = expr.output_name_inner() { + if colname.ends_with(suffix) && !schema.contains(colname.as_str()) { + expr.set_columnlhs(PlSmallStr::from(&colname[..colname.len() - suffix.len()])); + } + } + + stack.clear(); + stack.push(expr.node()); + while let Some(node) = stack.pop() { + let expr = expr_arena.get_mut(node); + expr.nodes(&mut stack); + + let AExpr::Column(colname) = expr else { + continue; + }; + + if !colname.ends_with(suffix) || schema.contains(colname.as_str()) { + continue; + } + + *colname = PlSmallStr::from(&colname[..colname.len() - suffix.len()]); + } + } +} + +/// An iterator over all the minterms in a boolean expression boolean. +/// +/// In other words, all the terms that can `AND` together to form this expression. +/// +/// # Example +/// +/// ``` +/// a & (b | c) & (b & (c | (a & c))) +/// ``` +/// +/// Gives terms: +/// +/// ``` +/// a +/// b | c +/// b +/// c | (a & c) +/// ``` +struct MintermIter<'a> { + stack: Vec, + expr_arena: &'a Arena, +} + +impl<'a> Iterator for MintermIter<'a> { + type Item = Node; + + fn next(&mut self) -> Option { + let mut top = self.stack.pop()?; + + while let AExpr::BinaryExpr { + left, + op: Operator::And, + right, + } = self.expr_arena.get(top) + { + self.stack.push(*right); + top = *left; + } + + Some(top) + } +} + +impl<'a> MintermIter<'a> { + fn new(root: Node, expr_arena: &'a Arena) -> Self { + Self { + stack: vec![root], + expr_arena, + } + } +} + +fn and_expr(left: Node, right: Node, expr_arena: &mut Arena) -> Node { + expr_arena.add(AExpr::BinaryExpr { + left, + op: Operator::And, + right, + }) +} + +pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &mut Arena) { + let mut predicates = Vec::with_capacity(4); + + // Partition to: + // - equality predicates + // - IEjoin supported inequality predicates + // - remaining predicates + #[cfg(feature = "iejoin")] + let mut ie_op = Vec::new(); + let mut remaining_predicates = Vec::new(); + + let mut ir_stack = Vec::with_capacity(16); + ir_stack.push(root); + + while let Some(current) = ir_stack.pop() { + let current_ir = lp_arena.get(current); + current_ir.copy_inputs(&mut ir_stack); + + match current_ir { + IR::Filter { + input: _, + predicate, + } => { + predicates.push((current, predicate.node())); + }, + IR::Join { + input_left, + input_right, + schema, + left_on, + right_on, + options, + } if matches!(options.args.how, JoinType::Cross) => { + if predicates.is_empty() { + continue; + } + + let suffix = options.args.suffix(); + + debug_assert!(left_on.is_empty()); + debug_assert!(right_on.is_empty()); + + let mut eq_left_on = Vec::new(); + let mut eq_right_on = Vec::new(); + + #[cfg(feature = "iejoin")] + let mut ie_left_on = Vec::new(); + #[cfg(feature = "iejoin")] + let mut ie_right_on = Vec::new(); + + #[cfg(feature = "iejoin")] + { + ie_op.clear(); + } + + remaining_predicates.clear(); + + #[cfg(feature = "iejoin")] + fn to_inequality_operator(op: &Operator) -> Option { + match op { + Operator::Lt => Some(InequalityOperator::Lt), + Operator::LtEq => Some(InequalityOperator::LtEq), + Operator::Gt => Some(InequalityOperator::Gt), + Operator::GtEq => Some(InequalityOperator::GtEq), + _ => None, + } + } + + let left_schema = lp_arena.get(*input_left).schema(lp_arena); + let right_schema = lp_arena.get(*input_right).schema(lp_arena); + + let left_schema = left_schema.as_ref(); + let right_schema = right_schema.as_ref(); + + for (_, predicate_node) in &predicates { + for node in MintermIter::new(*predicate_node, expr_arena) { + let AExpr::BinaryExpr { left, op, right } = expr_arena.get(node) else { + remaining_predicates.push(node); + continue; + }; + + if !op.is_comparison() { + // @NOTE: This is not a valid predicate, but we should not handle that + // here. + remaining_predicates.push(node); + continue; + } + + let mut left = *left; + let mut op = *op; + let mut right = *right; + + let left_origin = get_origin( + left, + expr_arena, + left_schema, + right_schema, + suffix.as_str(), + ); + let right_origin = get_origin( + right, + expr_arena, + left_schema, + right_schema, + suffix.as_str(), + ); + + use ExprOrigin as EO; + + // We can only join if both sides of the binary expression stem from + // different sides of the join. + match (left_origin, right_origin) { + (EO::Both, _) | (_, EO::Both) => { + // If either expression originates from the both sides, we need to + // filter it afterwards. + remaining_predicates.push(node); + continue; + }, + (EO::None, _) | (_, EO::None) => { + // @TODO: This should probably be pushed down + remaining_predicates.push(node); + continue; + }, + (EO::Left, EO::Left) | (EO::Right, EO::Right) => { + // @TODO: This can probably be pushed down in the predicate + // pushdown, but for now just take it as is. + remaining_predicates.push(node); + continue; + }, + (EO::Right, EO::Left) => { + // Swap around the expressions so they match with the left_on and + // right_on. + std::mem::swap(&mut left, &mut right); + op = op.swap_operands(); + }, + (EO::Left, EO::Right) => {}, + } + + if matches!(op, Operator::Eq) { + eq_left_on.push(ExprIR::from_node(left, expr_arena)); + eq_right_on.push(ExprIR::from_node(right, expr_arena)); + } else { + #[cfg(feature = "iejoin")] + if let Some(ie_op_) = to_inequality_operator(&op) { + // We already have an IEjoin or an Inner join, push to remaining + if ie_op.len() >= 2 || !eq_left_on.is_empty() { + remaining_predicates.push(node); + } else { + ie_left_on.push(ExprIR::from_node(left, expr_arena)); + ie_right_on.push(ExprIR::from_node(right, expr_arena)); + ie_op.push(ie_op_); + } + } else { + remaining_predicates.push(node); + } + + #[cfg(not(feature = "iejoin"))] + remaining_predicates.push(node); + } + } + } + + let mut can_simplify_join = false; + + if !eq_left_on.is_empty() { + remove_suffix(&mut eq_right_on, expr_arena, right_schema, suffix.as_str()); + can_simplify_join = true; + } else { + #[cfg(feature = "iejoin")] + if !ie_op.is_empty() { + remove_suffix(&mut ie_right_on, expr_arena, right_schema, suffix.as_str()); + can_simplify_join = true; + } + } + + if can_simplify_join { + let new_join = insert_fitting_join( + eq_left_on, + eq_right_on, + #[cfg(feature = "iejoin")] + ie_left_on, + #[cfg(feature = "iejoin")] + ie_right_on, + #[cfg(feature = "iejoin")] + &ie_op, + &remaining_predicates, + lp_arena, + expr_arena, + options.as_ref().clone(), + *input_left, + *input_right, + schema.clone(), + ); + + lp_arena.swap(predicates[0].0, new_join); + } + + predicates.clear(); + }, + _ => { + predicates.clear(); + }, + } + } +} + +#[allow(clippy::too_many_arguments)] +pub fn insert_fitting_join( + eq_left_on: Vec, + eq_right_on: Vec, + #[cfg(feature = "iejoin")] ie_left_on: Vec, + #[cfg(feature = "iejoin")] ie_right_on: Vec, + #[cfg(feature = "iejoin")] ie_op: &[InequalityOperator], + remaining_predicates: &[Node], + lp_arena: &mut Arena, + expr_arena: &mut Arena, + mut options: JoinOptions, + input_left: Node, + input_right: Node, + schema: SchemaRef, +) -> Node { + debug_assert_eq!(eq_left_on.len(), eq_right_on.len()); + #[cfg(feature = "iejoin")] + { + debug_assert_eq!(ie_op.len(), ie_left_on.len()); + debug_assert_eq!(ie_left_on.len(), ie_right_on.len()); + debug_assert!(ie_op.len() <= 2); + } + debug_assert_eq!(options.args.how, JoinType::Cross); + + let remaining_predicates = remaining_predicates + .iter() + .copied() + .reduce(|left, right| and_expr(left, right, expr_arena)); + + let (left_on, right_on, remaining_predicates) = match () { + _ if !eq_left_on.is_empty() => { + options.args.how = JoinType::Inner; + // We need to make sure not to delete any columns + options.args.coalesce = JoinCoalesce::KeepColumns; + + #[cfg(feature = "iejoin")] + let remaining_predicates = ie_left_on.into_iter().zip(ie_op).zip(ie_right_on).fold( + remaining_predicates, + |acc, ((left, op), right)| { + let e = expr_arena.add(AExpr::BinaryExpr { + left: left.node(), + op: (*op).into(), + right: right.node(), + }); + Some(acc.map_or(e, |acc| and_expr(acc, e, expr_arena))) + }, + ); + + (eq_left_on, eq_right_on, remaining_predicates) + }, + #[cfg(feature = "iejoin")] + _ if !ie_op.is_empty() => { + // We can only IE join up to 2 operators + + let operator1 = ie_op[0]; + let operator2 = ie_op.get(1).copied(); + + // Do an IEjoin. + options.args.how = JoinType::IEJoin(IEJoinOptions { + operator1, + operator2, + }); + // We need to make sure not to delete any columns + options.args.coalesce = JoinCoalesce::KeepColumns; + + (ie_left_on, ie_right_on, remaining_predicates) + }, + // If anything just fall back to a cross join. + _ => { + options.args.how = JoinType::Cross; + // We need to make sure not to delete any columns + options.args.coalesce = JoinCoalesce::KeepColumns; + + #[cfg(feature = "iejoin")] + let remaining_predicates = ie_left_on.into_iter().zip(ie_op).zip(ie_right_on).fold( + remaining_predicates, + |acc, ((left, op), right)| { + let e = expr_arena.add(AExpr::BinaryExpr { + left: left.node(), + op: (*op).into(), + right: right.node(), + }); + Some(acc.map_or(e, |acc| and_expr(acc, e, expr_arena))) + }, + ); + + (Vec::new(), Vec::new(), remaining_predicates) + }, + }; + + let join_ir = IR::Join { + input_left, + input_right, + schema, + left_on, + right_on, + options: Arc::new(options), + }; + + let join_node = lp_arena.add(join_ir); + + if let Some(predicate) = remaining_predicates { + lp_arena.add(IR::Filter { + input: join_node, + predicate: ExprIR::from_node(predicate, &*expr_arena), + }) + } else { + join_node + } +} diff --git a/crates/polars-plan/src/plans/optimizer/collect_members.rs b/crates/polars-plan/src/plans/optimizer/collect_members.rs index 3974df05b598..3157821c7696 100644 --- a/crates/polars-plan/src/plans/optimizer/collect_members.rs +++ b/crates/polars-plan/src/plans/optimizer/collect_members.rs @@ -26,6 +26,7 @@ pub(super) struct MemberCollector { pub(crate) has_joins_or_unions: bool, pub(crate) has_cache: bool, pub(crate) has_ext_context: bool, + pub(crate) has_filter_with_join_input: bool, #[cfg(feature = "cse")] scans: UniqueScans, } @@ -36,6 +37,7 @@ impl MemberCollector { has_joins_or_unions: false, has_cache: false, has_ext_context: false, + has_filter_with_join_input: false, #[cfg(feature = "cse")] scans: UniqueScans::default(), } @@ -45,6 +47,9 @@ impl MemberCollector { for (_node, alp) in lp_arena.iter(root) { match alp { Join { .. } | Union { .. } => self.has_joins_or_unions = true, + Filter { input, .. } => { + self.has_filter_with_join_input |= matches!(lp_arena.get(*input), Join { options, .. } if options.args.how == JoinType::Cross) + }, Cache { .. } => self.has_cache = true, ExtContext { .. } => self.has_ext_context = true, #[cfg(feature = "cse")] diff --git a/crates/polars-plan/src/plans/optimizer/mod.rs b/crates/polars-plan/src/plans/optimizer/mod.rs index 34dc6dca9a29..70880ca78359 100644 --- a/crates/polars-plan/src/plans/optimizer/mod.rs +++ b/crates/polars-plan/src/plans/optimizer/mod.rs @@ -7,6 +7,7 @@ mod delay_rechunk; mod cluster_with_columns; mod collapse_and_project; +mod collapse_joins; mod collect_members; mod count_star; #[cfg(feature = "cse")] @@ -82,6 +83,7 @@ pub fn optimize( // get toggle values let cluster_with_columns = opt_state.contains(OptFlags::CLUSTER_WITH_COLUMNS); + let collapse_joins = opt_state.contains(OptFlags::COLLAPSE_JOINS); let predicate_pushdown = opt_state.contains(OptFlags::PREDICATE_PUSHDOWN); let projection_pushdown = opt_state.contains(OptFlags::PROJECTION_PUSHDOWN); let simplify_expr = opt_state.contains(OptFlags::SIMPLIFY_EXPR); @@ -164,6 +166,11 @@ pub fn optimize( cluster_with_columns::optimize(lp_top, lp_arena, expr_arena) } + // Make sure it is after predicate pushdown + if collapse_joins && members.has_filter_with_join_input { + collapse_joins::optimize(lp_top, lp_arena, expr_arena) + } + // Make sure its before slice pushdown. if fast_projection { rules.push(Box::new(SimpleProjectionAndCollapse::new(eager))); diff --git a/crates/polars-python/src/lazyframe/general.rs b/crates/polars-python/src/lazyframe/general.rs index 57c31d0b05da..3d0327c8f2cd 100644 --- a/crates/polars-python/src/lazyframe/general.rs +++ b/crates/polars-python/src/lazyframe/general.rs @@ -438,6 +438,7 @@ impl PyLazyFrame { comm_subplan_elim: bool, comm_subexpr_elim: bool, cluster_with_columns: bool, + collapse_joins: bool, streaming: bool, _eager: bool, #[allow(unused_variables)] new_streaming: bool, @@ -449,6 +450,7 @@ impl PyLazyFrame { .with_simplify_expr(simplify_expression) .with_slice_pushdown(slice_pushdown) .with_cluster_with_columns(cluster_with_columns) + .with_collapse_joins(collapse_joins) ._with_eager(_eager) .with_projection_pushdown(projection_pushdown); diff --git a/crates/polars-utils/src/idx_vec.rs b/crates/polars-utils/src/idx_vec.rs index 13ecfbb89448..8bfdfafa2fd4 100644 --- a/crates/polars-utils/src/idx_vec.rs +++ b/crates/polars-utils/src/idx_vec.rs @@ -71,6 +71,11 @@ impl UnitVec { self.capacity.get() } + #[inline(always)] + pub fn clear(&mut self) { + self.len = 0; + } + #[inline(always)] pub fn push(&mut self, idx: T) { if self.len == self.capacity.get() { diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index 8ba891d70a59..c78996e07753 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -1617,6 +1617,7 @@ def collect_all( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, streaming: bool = False, ) -> list[DataFrame]: """ @@ -1646,6 +1647,8 @@ def collect_all( Common subexpressions will be cached and reused. cluster_with_columns Combine sequential independent calls to with_columns + collapse_joins + Collapse a join and filters into a faster join streaming Process the query in batches to handle larger-than-memory data. If set to `False` (default), the entire query is processed in a single @@ -1671,6 +1674,7 @@ def collect_all( comm_subplan_elim = False comm_subexpr_elim = False cluster_with_columns = False + collapse_joins = False if streaming: issue_unstable_warning("Streaming mode is considered unstable.") @@ -1688,6 +1692,7 @@ def collect_all( comm_subplan_elim, comm_subexpr_elim, cluster_with_columns, + collapse_joins, streaming, _eager=False, new_streaming=False, @@ -1716,6 +1721,7 @@ def collect_all_async( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, streaming: bool = True, ) -> _GeventDataFrameResult[list[DataFrame]]: ... @@ -1734,6 +1740,7 @@ def collect_all_async( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, streaming: bool = False, ) -> Awaitable[list[DataFrame]]: ... @@ -1752,6 +1759,7 @@ def collect_all_async( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, streaming: bool = False, ) -> Awaitable[list[DataFrame]] | _GeventDataFrameResult[list[DataFrame]]: """ @@ -1792,6 +1800,8 @@ def collect_all_async( Common subexpressions will be cached and reused. cluster_with_columns Combine sequential independent calls to with_columns + collapse_joins + Collapse a join and filters into a faster join streaming Process the query in batches to handle larger-than-memory data. If set to `False` (default), the entire query is processed in a single @@ -1829,6 +1839,7 @@ def collect_all_async( comm_subplan_elim = False comm_subexpr_elim = False cluster_with_columns = False + collapse_joins = False if streaming: issue_unstable_warning("Streaming mode is considered unstable.") @@ -1846,6 +1857,7 @@ def collect_all_async( comm_subplan_elim, comm_subexpr_elim, cluster_with_columns, + collapse_joins, streaming, _eager=False, new_streaming=False, diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index b26160bc7e75..7b434a4147df 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -1018,6 +1018,7 @@ def explain( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, streaming: bool = False, tree_format: bool | None = None, ) -> str: @@ -1051,6 +1052,8 @@ def explain( Common subexpressions will be cached and reused. cluster_with_columns Combine sequential independent calls to with_columns + collapse_joins + Collapse a join and filters into a faster join streaming Run parts of the query in a streaming fashion (this is in an alpha state) @@ -1099,6 +1102,7 @@ def explain( comm_subplan_elim, comm_subexpr_elim, cluster_with_columns, + collapse_joins, streaming, _eager=False, new_streaming=False, @@ -1129,6 +1133,7 @@ def show_graph( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, streaming: bool = False, ) -> str | None: """ @@ -1165,6 +1170,8 @@ def show_graph( Common subexpressions will be cached and reused. cluster_with_columns Combine sequential independent calls to with_columns + collapse_joins + Collapse a join and filters into a faster join streaming Run parts of the query in a streaming fashion (this is in an alpha state) @@ -1190,6 +1197,7 @@ def show_graph( comm_subplan_elim, comm_subexpr_elim, cluster_with_columns, + collapse_joins, streaming, _eager=False, new_streaming=False, @@ -1626,6 +1634,7 @@ def profile( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, show_plot: bool = False, truncate_nodes: int = 0, figsize: tuple[int, int] = (18, 8), @@ -1660,6 +1669,8 @@ def profile( Common subexpressions will be cached and reused. cluster_with_columns Combine sequential independent calls to with_columns + collapse_joins + Collapse a join and filters into a faster join show_plot Show a gantt chart of the profiling result truncate_nodes @@ -1709,6 +1720,7 @@ def profile( comm_subplan_elim = False comm_subexpr_elim = False cluster_with_columns = False + collapse_joins = False ldf = self._ldf.optimization_toggle( type_coercion, @@ -1719,6 +1731,7 @@ def profile( comm_subplan_elim, comm_subexpr_elim, cluster_with_columns, + collapse_joins, streaming, _eager=False, new_streaming=False, @@ -1777,6 +1790,7 @@ def collect( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, no_optimization: bool = False, streaming: bool = False, engine: EngineType = "cpu", @@ -1796,6 +1810,7 @@ def collect( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, no_optimization: bool = False, streaming: bool = False, engine: EngineType = "cpu", @@ -1814,6 +1829,7 @@ def collect( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, no_optimization: bool = False, streaming: bool = False, engine: EngineType = "cpu", @@ -1845,6 +1861,8 @@ def collect( Common subexpressions will be cached and reused. cluster_with_columns Combine sequential independent calls to with_columns + collapse_joins + Collapse a join and filters into a faster join no_optimization Turn off (certain) optimizations. streaming @@ -1976,6 +1994,7 @@ def collect( comm_subplan_elim = False comm_subexpr_elim = False cluster_with_columns = False + collapse_joins = False if streaming: issue_unstable_warning("Streaming mode is considered unstable.") @@ -2004,6 +2023,7 @@ def collect( comm_subplan_elim, comm_subexpr_elim, cluster_with_columns, + collapse_joins, streaming, _eager, new_streaming, @@ -2046,6 +2066,7 @@ def collect_async( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, streaming: bool = True, ) -> _GeventDataFrameResult[DataFrame]: ... @@ -2063,6 +2084,7 @@ def collect_async( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, streaming: bool = True, ) -> Awaitable[DataFrame]: ... @@ -2079,6 +2101,7 @@ def collect_async( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, streaming: bool = False, ) -> Awaitable[DataFrame] | _GeventDataFrameResult[DataFrame]: """ @@ -2117,6 +2140,8 @@ def collect_async( Common subexpressions will be cached and reused. cluster_with_columns Combine sequential independent calls to with_columns + collapse_joins + Collapse a join and filters into a faster join streaming Process the query in batches to handle larger-than-memory data. If set to `False` (default), the entire query is processed in a single @@ -2182,6 +2207,7 @@ def collect_async( comm_subplan_elim = False comm_subexpr_elim = False cluster_with_columns = False + collapse_joins = False if streaming: issue_unstable_warning("Streaming mode is considered unstable.") @@ -2195,6 +2221,7 @@ def collect_async( comm_subplan_elim, comm_subexpr_elim, cluster_with_columns, + collapse_joins, streaming, _eager=False, new_streaming=False, @@ -2252,6 +2279,7 @@ def sink_parquet( projection_pushdown: bool = True, simplify_expression: bool = True, slice_pushdown: bool = True, + collapse_joins: bool = True, no_optimization: bool = False, ) -> None: """ @@ -2316,6 +2344,8 @@ def sink_parquet( Run simplify expressions optimization. slice_pushdown Slice pushdown optimization. + collapse_joins + Collapse a join and filters into a faster join no_optimization Turn off (certain) optimizations. @@ -2334,6 +2364,7 @@ def sink_parquet( projection_pushdown=projection_pushdown, simplify_expression=simplify_expression, slice_pushdown=slice_pushdown, + collapse_joins=collapse_joins, no_optimization=no_optimization, ) @@ -2376,6 +2407,7 @@ def sink_ipc( projection_pushdown: bool = True, simplify_expression: bool = True, slice_pushdown: bool = True, + collapse_joins: bool = True, no_optimization: bool = False, ) -> None: """ @@ -2407,6 +2439,8 @@ def sink_ipc( Run simplify expressions optimization. slice_pushdown Slice pushdown optimization. + collapse_joins + Collapse a join and filters into a faster join no_optimization Turn off (certain) optimizations. @@ -2425,6 +2459,7 @@ def sink_ipc( projection_pushdown=projection_pushdown, simplify_expression=simplify_expression, slice_pushdown=slice_pushdown, + collapse_joins=collapse_joins, no_optimization=no_optimization, ) @@ -2458,6 +2493,7 @@ def sink_csv( projection_pushdown: bool = True, simplify_expression: bool = True, slice_pushdown: bool = True, + collapse_joins: bool = True, no_optimization: bool = False, ) -> None: """ @@ -2537,6 +2573,8 @@ def sink_csv( Run simplify expressions optimization. slice_pushdown Slice pushdown optimization. + collapse_joins + Collapse a join and filters into a faster join no_optimization Turn off (certain) optimizations. @@ -2562,6 +2600,7 @@ def sink_csv( projection_pushdown=projection_pushdown, simplify_expression=simplify_expression, slice_pushdown=slice_pushdown, + collapse_joins=collapse_joins, no_optimization=no_optimization, ) @@ -2594,6 +2633,7 @@ def sink_ndjson( projection_pushdown: bool = True, simplify_expression: bool = True, slice_pushdown: bool = True, + collapse_joins: bool = True, no_optimization: bool = False, ) -> None: """ @@ -2622,6 +2662,8 @@ def sink_ndjson( Run simplify expressions optimization. slice_pushdown Slice pushdown optimization. + collapse_joins + Collapse a join and filters into a faster join no_optimization Turn off (certain) optimizations. @@ -2640,6 +2682,7 @@ def sink_ndjson( projection_pushdown=projection_pushdown, simplify_expression=simplify_expression, slice_pushdown=slice_pushdown, + collapse_joins=collapse_joins, no_optimization=no_optimization, ) @@ -2653,6 +2696,7 @@ def _set_sink_optimizations( projection_pushdown: bool = True, simplify_expression: bool = True, slice_pushdown: bool = True, + collapse_joins: bool = True, no_optimization: bool = False, ) -> PyLazyFrame: if no_optimization: @@ -2669,6 +2713,7 @@ def _set_sink_optimizations( comm_subplan_elim=False, comm_subexpr_elim=False, cluster_with_columns=False, + collapse_joins=collapse_joins, streaming=True, _eager=False, new_streaming=False, @@ -2692,6 +2737,7 @@ def fetch( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, streaming: bool = False, ) -> DataFrame: """ @@ -2725,6 +2771,7 @@ def fetch( comm_subplan_elim=comm_subplan_elim, comm_subexpr_elim=comm_subexpr_elim, cluster_with_columns=cluster_with_columns, + collapse_joins=collapse_joins, streaming=streaming, ) @@ -2741,6 +2788,7 @@ def _fetch( comm_subplan_elim: bool = True, comm_subexpr_elim: bool = True, cluster_with_columns: bool = True, + collapse_joins: bool = True, streaming: bool = False, ) -> DataFrame: """ @@ -2771,6 +2819,8 @@ def _fetch( Common subexpressions will be cached and reused. cluster_with_columns Combine sequential independent calls to with_columns + collapse_joins + Collapse a join and filters into a faster join streaming Run parts of the query in a streaming fashion (this is in an alpha state) @@ -2823,6 +2873,7 @@ def _fetch( comm_subplan_elim = False comm_subexpr_elim = False cluster_with_columns = False + collapse_joins = False if streaming: issue_unstable_warning("Streaming mode is considered unstable.") @@ -2836,6 +2887,7 @@ def _fetch( comm_subplan_elim, comm_subexpr_elim, cluster_with_columns, + collapse_joins, streaming, _eager=False, new_streaming=False, diff --git a/py-polars/tests/unit/lazyframe/test_optimizations.py b/py-polars/tests/unit/lazyframe/test_optimizations.py index 648bd3123787..e2ac38bd0424 100644 --- a/py-polars/tests/unit/lazyframe/test_optimizations.py +++ b/py-polars/tests/unit/lazyframe/test_optimizations.py @@ -1,3 +1,7 @@ +import itertools + +import pytest + import polars as pl from polars.testing import assert_frame_equal @@ -204,3 +208,122 @@ def test_drop_nulls_followed_by_count() -> None: ) assert "null_count" not in non_optimized_result_plan assert "drop_nulls" in non_optimized_result_plan + + +def test_collapse_joins() -> None: + a = pl.LazyFrame({"a": [1, 2, 3], "b": [2, 2, 2]}) + b = pl.LazyFrame({"x": [7, 1, 2]}) + + cross = a.join(b, how="cross") + + inner_join = cross.filter(pl.col.a == pl.col.x) + e = inner_join.explain() + assert "INNER JOIN" in e + assert "FILTER" not in e + assert_frame_equal(inner_join.collect(collapse_joins=False), inner_join.collect()) + + inner_join = cross.filter(pl.col.x == pl.col.a) + e = inner_join.explain() + assert "INNER JOIN" in e + assert "FILTER" not in e + assert_frame_equal( + inner_join.collect(collapse_joins=False), + inner_join.collect(), + check_row_order=False, + ) + + double_inner_join = cross.filter(pl.col.x == pl.col.a).filter(pl.col.x == pl.col.b) + e = double_inner_join.explain() + assert "INNER JOIN" in e + assert "FILTER" not in e + assert_frame_equal( + double_inner_join.collect(collapse_joins=False), + double_inner_join.collect(), + check_row_order=False, + ) + + dont_mix = cross.filter(pl.col.x + pl.col.a != 0) + e = dont_mix.explain() + assert "CROSS JOIN" in e + assert "FILTER" in e + assert_frame_equal( + dont_mix.collect(collapse_joins=False), + dont_mix.collect(), + check_row_order=False, + ) + + no_literals = cross.filter(pl.col.x == 2) + e = no_literals.explain() + assert "CROSS JOIN" in e + assert_frame_equal( + no_literals.collect(collapse_joins=False), + no_literals.collect(), + check_row_order=False, + ) + + iejoin = cross.filter(pl.col.x >= pl.col.a) + e = iejoin.explain() + assert "IEJOIN" in e + assert "CROSS JOIN" not in e + assert "FILTER" not in e + assert_frame_equal( + iejoin.collect(collapse_joins=False), + iejoin.collect(), + check_row_order=False, + ) + + iejoin = cross.filter(pl.col.x >= pl.col.a).filter(pl.col.x <= pl.col.b) + e = iejoin.explain() + assert "IEJOIN" in e + assert "CROSS JOIN" not in e + assert "FILTER" not in e + assert_frame_equal( + iejoin.collect(collapse_joins=False), iejoin.collect(), check_row_order=False + ) + + +@pytest.mark.slow +def test_collapse_joins_combinations() -> None: + # This just tests all possible combinations for expressions on a cross join. + + a = pl.LazyFrame({"a": [1, 2, 3], "x": [7, 2, 1]}) + b = pl.LazyFrame({"b": [2, 2, 2], "x": [7, 1, 3]}) + + cross = a.join(b, how="cross") + + exprs = [] + + for lhs in [pl.col.a, pl.col.b, pl.col.x, pl.lit(1), pl.col.a + pl.col.b]: + for rhs in [pl.col.a, pl.col.b, pl.col.x, pl.lit(1), pl.col.a * pl.col.x]: + for cmp in ["__eq__", "__ge__", "__lt__"]: + e = (getattr(lhs, cmp))(rhs) + exprs.append(e) + + for amount in range(3): + for merge in itertools.product(["__and__", "__or__"] * (amount - 1)): + for es in itertools.product(*([exprs] * amount)): + e = es[0] + for i in range(amount - 1): + e = (getattr(e, merge[i]))(es[i + 1]) + + # NOTE: We need to sort because the order of the cross-join & + # IE-join is unspecified. Therefore, this might not necessarily + # create the exact same dataframe. + optimized = cross.filter(e).sort(pl.all()).collect() + unoptimized = cross.filter(e).collect(collapse_joins=False) + + try: + assert_frame_equal(optimized, unoptimized, check_row_order=False) + except: + print(e) + print() + print("Optimized") + print(cross.filter(e).explain()) + print(optimized) + print() + print("Unoptimized") + print(cross.filter(e).explain(collapse_joins=False)) + print(unoptimized) + print() + + raise