From b7b3da60f670671afb91a40f6f312604d9a4b406 Mon Sep 17 00:00:00 2001 From: Ritchie Vink Date: Thu, 2 May 2024 15:19:04 +0200 Subject: [PATCH] feat: Convert concat during IR conversion (#16016) --- crates/polars-lazy/src/dsl/functions.rs | 97 +++---------------- .../logical_plan/conversion/convert_utils.rs | 44 +++++++++ .../{dsl_plan_to_ir_plan.rs => dsl_to_ir.rs} | 15 ++- .../{expr_to_expr_ir.rs => expr_to_ir.rs} | 0 .../src/logical_plan/conversion/mod.rs | 15 ++- crates/polars-plan/src/logical_plan/format.rs | 4 +- crates/polars-plan/src/logical_plan/mod.rs | 3 +- .../src/logical_plan/tree_format.rs | 7 +- 8 files changed, 93 insertions(+), 92 deletions(-) create mode 100644 crates/polars-plan/src/logical_plan/conversion/convert_utils.rs rename crates/polars-plan/src/logical_plan/conversion/{dsl_plan_to_ir_plan.rs => dsl_to_ir.rs} (98%) rename crates/polars-plan/src/logical_plan/conversion/{expr_to_expr_ir.rs => expr_to_ir.rs} (100%) diff --git a/crates/polars-lazy/src/dsl/functions.rs b/crates/polars-lazy/src/dsl/functions.rs index 188f4a78ab84..7d401ea76334 100644 --- a/crates/polars-lazy/src/dsl/functions.rs +++ b/crates/polars-lazy/src/dsl/functions.rs @@ -17,7 +17,7 @@ pub(crate) fn concat_impl>( ) -> PolarsResult { let mut inputs = inputs.as_ref().to_vec(); - let mut lf = std::mem::take( + let lf = std::mem::take( inputs .get_mut(0) .ok_or_else(|| polars_err!(NoData: "empty container given"))?, @@ -31,89 +31,24 @@ pub(crate) fn concat_impl>( ..Default::default() }; - let lf = match &mut lf.logical_plan { - // reuse the same union - DslPlan::Union { - inputs: existing_inputs, - options: opts, - } if opts == &options => { - for lf in &mut inputs[1..] { - // ensure we enable file caching if any lf has it enabled - opt_state.file_caching |= lf.opt_state.file_caching; - let lp = std::mem::take(&mut lf.logical_plan); - existing_inputs.push(lp) - } - lf - }, - _ => { - let mut lps = Vec::with_capacity(inputs.len()); - lps.push(lf.logical_plan); - - for lf in &mut inputs[1..] { - // ensure we enable file caching if any lf has it enabled - opt_state.file_caching |= lf.opt_state.file_caching; - let lp = std::mem::take(&mut lf.logical_plan); - lps.push(lp) - } + let mut lps = Vec::with_capacity(inputs.len()); + lps.push(lf.logical_plan); - let lp = DslPlan::Union { - inputs: lps, - options, - }; - let mut lf = LazyFrame::from(lp); - lf.opt_state = opt_state; + for lf in &mut inputs[1..] { + // ensure we enable file caching if any lf has it enabled + opt_state.file_caching |= lf.opt_state.file_caching; + let lp = std::mem::take(&mut lf.logical_plan); + lps.push(lp) + } - lf - }, + let lp = DslPlan::Union { + inputs: lps, + options, + convert_supertypes, }; - - if convert_supertypes { - let DslPlan::Union { - mut inputs, - options, - } = lf.logical_plan - else { - unreachable!() - }; - // TODO! Make this properly lazy. - let mut schema = inputs[0].compute_schema()?.as_ref().clone(); - - let mut changed = false; - for input in inputs[1..].iter() { - changed |= schema.to_supertype(input.compute_schema()?.as_ref())?; - } - - let mut placeholder = DslPlan::default(); - if changed { - let mut exprs = vec![]; - for input in &mut inputs { - std::mem::swap(input, &mut placeholder); - let input_schema = placeholder.compute_schema()?; - - exprs.clear(); - let to_cast = input_schema.iter().zip(schema.iter_dtypes()).flat_map( - |((left_name, left_type), st)| { - if left_type != st { - Some(col(left_name.as_ref()).cast(st.clone())) - } else { - None - } - }, - ); - exprs.extend(to_cast); - let mut lf = LazyFrame::from(placeholder); - if !exprs.is_empty() { - lf = lf.with_columns(exprs.as_slice()); - } - - placeholder = lf.logical_plan; - std::mem::swap(&mut placeholder, input); - } - } - Ok(LazyFrame::from(DslPlan::Union { inputs, options })) - } else { - Ok(lf) - } + let mut lf = LazyFrame::from(lp); + lf.opt_state = opt_state; + Ok(lf) } #[cfg(feature = "diagonal_concat")] diff --git a/crates/polars-plan/src/logical_plan/conversion/convert_utils.rs b/crates/polars-plan/src/logical_plan/conversion/convert_utils.rs new file mode 100644 index 000000000000..db7c591d16c6 --- /dev/null +++ b/crates/polars-plan/src/logical_plan/conversion/convert_utils.rs @@ -0,0 +1,44 @@ +use super::*; + +pub(super) fn convert_st_union( + inputs: &mut [Node], + lp_arena: &mut Arena, + expr_arena: &mut Arena, +) -> PolarsResult<()> { + let mut schema = (**lp_arena.get(inputs[0]).schema(lp_arena)).clone(); + + let mut changed = false; + for input in inputs[1..].iter() { + let schema_other = lp_arena.get(*input).schema(lp_arena); + changed |= schema.to_supertype(schema_other.as_ref())?; + } + + if changed { + for input in inputs { + let mut exprs = vec![]; + let input_schema = lp_arena.get(*input).schema(lp_arena); + + let to_cast = input_schema.iter().zip(schema.iter_dtypes()).flat_map( + |((left_name, left_type), st)| { + if left_type != st { + Some(col(left_name.as_ref()).cast(st.clone())) + } else { + None + } + }, + ); + exprs.extend(to_cast); + + if !exprs.is_empty() { + let expr = to_expr_irs(exprs, expr_arena); + let lp = IRBuilder::new(*input, expr_arena, lp_arena) + .with_columns(expr, Default::default()) + .build(); + + let node = lp_arena.add(lp); + *input = node + } + } + } + Ok(()) +} diff --git a/crates/polars-plan/src/logical_plan/conversion/dsl_plan_to_ir_plan.rs b/crates/polars-plan/src/logical_plan/conversion/dsl_to_ir.rs similarity index 98% rename from crates/polars-plan/src/logical_plan/conversion/dsl_plan_to_ir_plan.rs rename to crates/polars-plan/src/logical_plan/conversion/dsl_to_ir.rs index c0b2f3f3f571..69584815feb2 100644 --- a/crates/polars-plan/src/logical_plan/conversion/dsl_plan_to_ir_plan.rs +++ b/crates/polars-plan/src/logical_plan/conversion/dsl_to_ir.rs @@ -146,12 +146,21 @@ pub fn to_alp_impl( options, predicate: None, }, - DslPlan::Union { inputs, options } => { - let inputs = inputs + DslPlan::Union { + inputs, + options, + convert_supertypes, + } => { + let mut inputs = inputs .into_iter() .map(|lp| to_alp_impl(lp, expr_arena, lp_arena, convert)) - .collect::>() + .collect::>>() .map_err(|e| e.context(failed_input!(vertical concat)))?; + + if convert_supertypes { + convert_utils::convert_st_union(&mut inputs, lp_arena, expr_arena) + .map_err(|e| e.context(failed_input!(vertical concat)))?; + } IR::Union { inputs, options } }, DslPlan::HConcat { diff --git a/crates/polars-plan/src/logical_plan/conversion/expr_to_expr_ir.rs b/crates/polars-plan/src/logical_plan/conversion/expr_to_ir.rs similarity index 100% rename from crates/polars-plan/src/logical_plan/conversion/expr_to_expr_ir.rs rename to crates/polars-plan/src/logical_plan/conversion/expr_to_ir.rs diff --git a/crates/polars-plan/src/logical_plan/conversion/mod.rs b/crates/polars-plan/src/logical_plan/conversion/mod.rs index 0c451394be4d..f5480f37e78f 100644 --- a/crates/polars-plan/src/logical_plan/conversion/mod.rs +++ b/crates/polars-plan/src/logical_plan/conversion/mod.rs @@ -1,5 +1,6 @@ -mod dsl_plan_to_ir_plan; -mod expr_to_expr_ir; +mod convert_utils; +mod dsl_to_ir; +mod expr_to_ir; mod ir_to_dsl; #[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))] mod scans; @@ -7,8 +8,8 @@ mod stack_opt; use std::borrow::Cow; -pub use dsl_plan_to_ir_plan::*; -pub use expr_to_expr_ir::*; +pub use dsl_to_ir::*; +pub use expr_to_ir::*; pub use ir_to_dsl::*; use polars_core::prelude::*; use polars_utils::vec::ConvertVec; @@ -58,7 +59,11 @@ impl IR { .into_iter() .map(|node| convert_to_lp(node, lp_arena)) .collect(); - DslPlan::Union { inputs, options } + DslPlan::Union { + inputs, + options, + convert_supertypes: false, + } }, IR::HConcat { inputs, diff --git a/crates/polars-plan/src/logical_plan/format.rs b/crates/polars-plan/src/logical_plan/format.rs index 7b0930a3b8e9..4c8db461b264 100644 --- a/crates/polars-plan/src/logical_plan/format.rs +++ b/crates/polars-plan/src/logical_plan/format.rs @@ -81,7 +81,9 @@ impl DslPlan { options.n_rows, ) }, - Union { inputs, options } => { + Union { + inputs, options, .. + } => { let mut name = String::new(); let name = if let Some(slice) = options.slice { write!(name, "SLICED UNION: {slice:?}")?; diff --git a/crates/polars-plan/src/logical_plan/mod.rs b/crates/polars-plan/src/logical_plan/mod.rs index f89ff0ab1eab..2f5a8891e3eb 100644 --- a/crates/polars-plan/src/logical_plan/mod.rs +++ b/crates/polars-plan/src/logical_plan/mod.rs @@ -156,6 +156,7 @@ pub enum DslPlan { Union { inputs: Vec, options: UnionOptions, + convert_supertypes: bool, }, /// Horizontal concatenation of multiple plans HConcat { @@ -196,7 +197,7 @@ impl Clone for DslPlan { Self::Sort {input,by_column, slice, sort_options } => Self::Sort { input: input.clone(), by_column: by_column.clone(), slice: slice.clone(), sort_options: sort_options.clone() }, Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() }, Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() }, - Self::Union { inputs, options } => Self::Union { inputs: inputs.clone(), options: options.clone() }, + Self::Union { inputs, options, convert_supertypes } => Self::Union { inputs: inputs.clone(), options: options.clone(), convert_supertypes: *convert_supertypes }, Self::HConcat { inputs, schema, options } => Self::HConcat { inputs: inputs.clone(), schema: schema.clone(), options: options.clone() }, Self::ExtContext { input, contexts, } => Self::ExtContext { input: input.clone(), contexts: contexts.clone() }, Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() }, diff --git a/crates/polars-plan/src/logical_plan/tree_format.rs b/crates/polars-plan/src/logical_plan/tree_format.rs index f64c4dfc3f61..a78ece11e493 100644 --- a/crates/polars-plan/src/logical_plan/tree_format.rs +++ b/crates/polars-plan/src/logical_plan/tree_format.rs @@ -163,7 +163,12 @@ impl<'a> TreeFmtNode<'a> { vec![] }, ), - NL(h, Union { inputs, options }) => ND( + NL( + h, + Union { + inputs, options, .. + }, + ) => ND( wh( h, &(if let Some(slice) = options.slice {