Skip to content

Commit

Permalink
feat: Convert concat during IR conversion (#16016)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored May 2, 2024
1 parent f03e7e0 commit b7b3da6
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 92 deletions.
97 changes: 16 additions & 81 deletions crates/polars-lazy/src/dsl/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub(crate) fn concat_impl<L: AsRef<[LazyFrame]>>(
) -> PolarsResult<LazyFrame> {
let mut inputs = inputs.as_ref().to_vec();

let mut lf = std::mem::take(
let lf = std::mem::take(
inputs
.get_mut(0)
.ok_or_else(|| polars_err!(NoData: "empty container given"))?,
Expand All @@ -31,89 +31,24 @@ pub(crate) fn concat_impl<L: AsRef<[LazyFrame]>>(
..Default::default()
};

let lf = match &mut lf.logical_plan {
// reuse the same union
DslPlan::Union {
inputs: existing_inputs,
options: opts,
} if opts == &options => {
for lf in &mut inputs[1..] {
// ensure we enable file caching if any lf has it enabled
opt_state.file_caching |= lf.opt_state.file_caching;
let lp = std::mem::take(&mut lf.logical_plan);
existing_inputs.push(lp)
}
lf
},
_ => {
let mut lps = Vec::with_capacity(inputs.len());
lps.push(lf.logical_plan);

for lf in &mut inputs[1..] {
// ensure we enable file caching if any lf has it enabled
opt_state.file_caching |= lf.opt_state.file_caching;
let lp = std::mem::take(&mut lf.logical_plan);
lps.push(lp)
}
let mut lps = Vec::with_capacity(inputs.len());
lps.push(lf.logical_plan);

let lp = DslPlan::Union {
inputs: lps,
options,
};
let mut lf = LazyFrame::from(lp);
lf.opt_state = opt_state;
for lf in &mut inputs[1..] {
// ensure we enable file caching if any lf has it enabled
opt_state.file_caching |= lf.opt_state.file_caching;
let lp = std::mem::take(&mut lf.logical_plan);
lps.push(lp)
}

lf
},
let lp = DslPlan::Union {
inputs: lps,
options,
convert_supertypes,
};

if convert_supertypes {
let DslPlan::Union {
mut inputs,
options,
} = lf.logical_plan
else {
unreachable!()
};
// TODO! Make this properly lazy.
let mut schema = inputs[0].compute_schema()?.as_ref().clone();

let mut changed = false;
for input in inputs[1..].iter() {
changed |= schema.to_supertype(input.compute_schema()?.as_ref())?;
}

let mut placeholder = DslPlan::default();
if changed {
let mut exprs = vec![];
for input in &mut inputs {
std::mem::swap(input, &mut placeholder);
let input_schema = placeholder.compute_schema()?;

exprs.clear();
let to_cast = input_schema.iter().zip(schema.iter_dtypes()).flat_map(
|((left_name, left_type), st)| {
if left_type != st {
Some(col(left_name.as_ref()).cast(st.clone()))
} else {
None
}
},
);
exprs.extend(to_cast);
let mut lf = LazyFrame::from(placeholder);
if !exprs.is_empty() {
lf = lf.with_columns(exprs.as_slice());
}

placeholder = lf.logical_plan;
std::mem::swap(&mut placeholder, input);
}
}
Ok(LazyFrame::from(DslPlan::Union { inputs, options }))
} else {
Ok(lf)
}
let mut lf = LazyFrame::from(lp);
lf.opt_state = opt_state;
Ok(lf)
}

#[cfg(feature = "diagonal_concat")]
Expand Down
44 changes: 44 additions & 0 deletions crates/polars-plan/src/logical_plan/conversion/convert_utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use super::*;

pub(super) fn convert_st_union(
inputs: &mut [Node],
lp_arena: &mut Arena<IR>,
expr_arena: &mut Arena<AExpr>,
) -> PolarsResult<()> {
let mut schema = (**lp_arena.get(inputs[0]).schema(lp_arena)).clone();

let mut changed = false;
for input in inputs[1..].iter() {
let schema_other = lp_arena.get(*input).schema(lp_arena);
changed |= schema.to_supertype(schema_other.as_ref())?;
}

if changed {
for input in inputs {
let mut exprs = vec![];
let input_schema = lp_arena.get(*input).schema(lp_arena);

let to_cast = input_schema.iter().zip(schema.iter_dtypes()).flat_map(
|((left_name, left_type), st)| {
if left_type != st {
Some(col(left_name.as_ref()).cast(st.clone()))
} else {
None
}
},
);
exprs.extend(to_cast);

if !exprs.is_empty() {
let expr = to_expr_irs(exprs, expr_arena);
let lp = IRBuilder::new(*input, expr_arena, lp_arena)
.with_columns(expr, Default::default())
.build();

let node = lp_arena.add(lp);
*input = node
}
}
}
Ok(())
}
Original file line number Diff line number Diff line change
Expand Up @@ -146,12 +146,21 @@ pub fn to_alp_impl(
options,
predicate: None,
},
DslPlan::Union { inputs, options } => {
let inputs = inputs
DslPlan::Union {
inputs,
options,
convert_supertypes,
} => {
let mut inputs = inputs
.into_iter()
.map(|lp| to_alp_impl(lp, expr_arena, lp_arena, convert))
.collect::<PolarsResult<_>>()
.collect::<PolarsResult<Vec<_>>>()
.map_err(|e| e.context(failed_input!(vertical concat)))?;

if convert_supertypes {
convert_utils::convert_st_union(&mut inputs, lp_arena, expr_arena)
.map_err(|e| e.context(failed_input!(vertical concat)))?;
}
IR::Union { inputs, options }
},
DslPlan::HConcat {
Expand Down
15 changes: 10 additions & 5 deletions crates/polars-plan/src/logical_plan/conversion/mod.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
mod dsl_plan_to_ir_plan;
mod expr_to_expr_ir;
mod convert_utils;
mod dsl_to_ir;
mod expr_to_ir;
mod ir_to_dsl;
#[cfg(any(feature = "ipc", feature = "parquet", feature = "csv"))]
mod scans;
mod stack_opt;

use std::borrow::Cow;

pub use dsl_plan_to_ir_plan::*;
pub use expr_to_expr_ir::*;
pub use dsl_to_ir::*;
pub use expr_to_ir::*;
pub use ir_to_dsl::*;
use polars_core::prelude::*;
use polars_utils::vec::ConvertVec;
Expand Down Expand Up @@ -58,7 +59,11 @@ impl IR {
.into_iter()
.map(|node| convert_to_lp(node, lp_arena))
.collect();
DslPlan::Union { inputs, options }
DslPlan::Union {
inputs,
options,
convert_supertypes: false,
}
},
IR::HConcat {
inputs,
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-plan/src/logical_plan/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ impl DslPlan {
options.n_rows,
)
},
Union { inputs, options } => {
Union {
inputs, options, ..
} => {
let mut name = String::new();
let name = if let Some(slice) = options.slice {
write!(name, "SLICED UNION: {slice:?}")?;
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-plan/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ pub enum DslPlan {
Union {
inputs: Vec<DslPlan>,
options: UnionOptions,
convert_supertypes: bool,
},
/// Horizontal concatenation of multiple plans
HConcat {
Expand Down Expand Up @@ -196,7 +197,7 @@ impl Clone for DslPlan {
Self::Sort {input,by_column, slice, sort_options } => Self::Sort { input: input.clone(), by_column: by_column.clone(), slice: slice.clone(), sort_options: sort_options.clone() },
Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() },
Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() },
Self::Union { inputs, options } => Self::Union { inputs: inputs.clone(), options: options.clone() },
Self::Union { inputs, options, convert_supertypes } => Self::Union { inputs: inputs.clone(), options: options.clone(), convert_supertypes: *convert_supertypes },
Self::HConcat { inputs, schema, options } => Self::HConcat { inputs: inputs.clone(), schema: schema.clone(), options: options.clone() },
Self::ExtContext { input, contexts, } => Self::ExtContext { input: input.clone(), contexts: contexts.clone() },
Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() },
Expand Down
7 changes: 6 additions & 1 deletion crates/polars-plan/src/logical_plan/tree_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ impl<'a> TreeFmtNode<'a> {
vec![]
},
),
NL(h, Union { inputs, options }) => ND(
NL(
h,
Union {
inputs, options, ..
},
) => ND(
wh(
h,
&(if let Some(slice) = options.slice {
Expand Down

0 comments on commit b7b3da6

Please sign in to comment.