Skip to content

Commit

Permalink
all predicates
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Sep 6, 2024
1 parent c1da0c6 commit dbc0460
Show file tree
Hide file tree
Showing 4 changed files with 195 additions and 131 deletions.
139 changes: 37 additions & 102 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2107,106 +2107,41 @@ impl JoinBuilder {
}


// // Finish with join predicates
// pub fn join_where(self, predicates: Vec<Expr>) -> PolarsResult<LazyFrame> {
// let to_inner = Arc::unwrap_or_clone;
//
// let mut ie_left_on = vec![];
// let mut ie_right_on = vec![];
// let mut ie_op = vec![];
//
// let mut eq_left_on = vec![];
// let mut eq_right_on = vec![];
//
// let mut remaining_preds = vec![];
//
// fn to_inequality_operator(op: &Operator) -> Option<InequalityOperator> {
// match op {
// Operator::Lt => Some(InequalityOperator::Lt),
// Operator::LtEq => Some(InequalityOperator::LtEq),
// Operator::Gt => Some(InequalityOperator::Gt),
// Operator::GtEq => Some(InequalityOperator::GtEq),
// _ => None,
// }
// }
//
// for pred in predicates.into_iter() {
// let Expr::BinaryExpr {left, op, right} = pred else { polars_bail!(InvalidOperation: "can only join on binary expressions") };
// polars_ensure!(op.is_comparison(), InvalidOperation: "expected comparison in join predicate");
//
// if let Some(ie_op_) = to_inequality_operator(&op) {
// ie_left_on.push(to_inner(left));
// ie_right_on.push(to_inner(right));
// ie_op.push(ie_op_)
// } else if matches!(op, Operator::Eq) {
// eq_left_on.push(to_inner(left));
// eq_right_on.push(to_inner(right));
// } else {
//
// remaining_preds.push(pred);
// }
// }
//
//
// fn parse_ie_join_expressions(
// expressions: Vec<Expr>,
// ) -> PolarsResult<(Vec<Expr>, Vec<InequalityOperator>, Vec<Expr>)> {
//
// let mut left_on = Vec::with_capacity(2);
// let mut operators = Vec::with_capacity(2);
// let mut right_on = Vec::with_capacity(2);
//
// for expression in expressions.into_iter() {
// let (left, op, right) = parse_inequality_expression(expression)?;
// left_on.push(left);
// operators.push(op);
// right_on.push(right);
// }
//
// Ok((left_on, operators, right_on))
// }
//
// fn parse_inequality_expression(expression: Expr) -> PolarsResult<(Expr, InequalityOperator, Expr)> {
// fn to_inequality_operator(op: &Operator) -> PolarsResult<InequalityOperator> {
// match op {
// Operator::Lt => Ok(InequalityOperator::Lt),
// Operator::LtEq => Ok(InequalityOperator::LtEq),
// Operator::Gt => Ok(InequalityOperator::Gt),
// Operator::GtEq => Ok(InequalityOperator::GtEq),
// _ => Err(PyValueError::new_err(format!(
// "expected an inequality operator in join inequality, got '{}'",
// op
// ))),
// }
// }
//
// match expression.inner {
// Expr::BinaryExpr { left, op, right } => {
// let inequality_op = to_inequality_operator(&op)?;
// Ok(((*left).clone(), inequality_op, (*right).clone()))
// },
// _ => Err(PyValueError::new_err(
// "expected a binary expression for a join inequality",
// )),
// }
// }
//
// let mut opt_state = self.lf.opt_state;
// let other = self.other.expect("with not set");
//
// // If any of the nodes reads from files we must activate this plan as well.
// if other.opt_state.contains(OptFlags::FILE_CACHING) {
// opt_state |= OptFlags::FILE_CACHING;
// }
//
// let args = JoinArgs {
// how: self.how,
// validation: self.validation,
// suffix: self.suffix,
// slice: None,
// join_nulls: self.join_nulls,
// coalesce: self.coalesce,
// };
//
// }
// Finish with join predicates
pub fn join_where(self, predicates: Vec<Expr>) -> LazyFrame {
let mut opt_state = self.lf.opt_state;
let other = self.other.expect("with not set");

// If any of the nodes reads from files we must activate this plan as well.
if other.opt_state.contains(OptFlags::FILE_CACHING) {
opt_state |= OptFlags::FILE_CACHING;
}

let args = JoinArgs {
how: self.how,
validation: self.validation,
suffix: self.suffix,
slice: None,
join_nulls: self.join_nulls,
coalesce: self.coalesce,
};
let options = JoinOptions {
allow_parallel: self.allow_parallel,
force_parallel: self.force_parallel,
args,
..Default::default()
};

let lp = DslPlan::Join {
input_left: Arc::new(self.lf.logical_plan),
input_right: Arc::new(other.logical_plan),
left_on: Default::default(),
right_on: Default::default(),
predicates,
options: Arc::from(options),
};

LazyFrame::from_logical_plan(lp, opt_state)

}
}
22 changes: 11 additions & 11 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,21 +85,22 @@ pub(super) struct DslConversionContext<'a> {
pub(super) opt_flags: &'a mut OptFlags,
}

pub(super) fn run_conversion(lp: IR, ctxt: &mut DslConversionContext, name: &str) -> PolarsResult<Node> {
let lp_node = ctxt.lp_arena.add(lp);
ctxt.conversion_optimizer
.coerce_types(ctxt.expr_arena, ctxt.lp_arena, lp_node)
.map_err(|e| e.context(format!("'{name}' failed").into()))?;

Ok(lp_node)
}

/// converts LogicalPlan to IR
/// it adds expressions & lps to the respective arenas as it traverses the plan
/// finally it returns the top node of the logical plan
#[recursive]
pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult<Node> {
let owned = Arc::unwrap_or_clone;

fn run_conversion(lp: IR, ctxt: &mut DslConversionContext, name: &str) -> PolarsResult<Node> {
let lp_node = ctxt.lp_arena.add(lp);
ctxt.conversion_optimizer
.coerce_types(ctxt.expr_arena, ctxt.lp_arena, lp_node)
.map_err(|e| e.context(format!("'{name}' failed").into()))?;

Ok(lp_node)
}

let v = match lp {
DslPlan::Scan {
Expand Down Expand Up @@ -541,10 +542,9 @@ pub fn to_alp_impl(lp: DslPlan, ctxt: &mut DslConversionContext) -> PolarsResult
left_on,
right_on,
predicates,
mut options,
options,
} => {
let ir = join::resolve_join(input_left, input_right, left_on, right_on, predicates, options, ctxt)?;
return run_conversion(ir, ctxt, "join");
return join::resolve_join(input_left, input_right, left_on, right_on, predicates, options, ctxt)
},
DslPlan::HStack {
input,
Expand Down
151 changes: 142 additions & 9 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@ use crate::plans::AExpr;
use crate::prelude::FunctionOptions;
use super::*;

fn check_join_keys(keys: &[Expr]) -> PolarsResult<()> {
for e in keys {
if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) {
polars_bail!(
InvalidOperation:
"'alias' is not allowed in a join key, use 'with_columns' first",
)
}
}
Ok(())

}
pub fn resolve_join(
input_left: Arc<DslPlan>,
input_right: Arc<DslPlan>,
Expand All @@ -13,19 +25,21 @@ pub fn resolve_join(
predicates: Vec<Expr>,
mut options: Arc<JoinOptions>,
ctxt: &mut DslConversionContext
) -> PolarsResult<IR> {
) -> PolarsResult<Node> {
if !predicates.is_empty() {
debug_assert!(left_on.is_empty() && right_on.is_empty());
return resolve_join_where(input_left, input_right, predicates, options, ctxt)
}

let owned = Arc::unwrap_or_clone;
if matches!(options.args.how, JoinType::Cross) {
polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys");
} else {
check_join_keys(&left_on)?;
check_join_keys(&right_on)?;

let mut turn_off_coalesce = false;
for e in left_on.iter().chain(right_on.iter()) {
if has_expr(e, |e| matches!(e, Expr::Alias(_, _))) {
polars_bail!(
ComputeError:
"'alias' is not allowed in a join key, use 'with_columns' first",
)
}
// Any expression that is not a simple column expression will turn of coalescing.
turn_off_coalesce |= has_expr(e, |e| !matches!(e, Expr::Column(_)));
}
Expand All @@ -41,7 +55,7 @@ pub fn resolve_join(

polars_ensure!(
left_on.len() == right_on.len(),
ComputeError:
InvalidOperation:
format!(
"the number of columns given as join key (left: {}, right:{}) should be equal",
left_on.len(),
Expand Down Expand Up @@ -96,6 +110,125 @@ pub fn resolve_join(
right_on,
options,
};
Ok(lp)
run_conversion(lp, ctxt, "join")
}

impl From<InequalityOperator> for Operator {
fn from(value: InequalityOperator) -> Self {
match value {
InequalityOperator::LtEq => Operator::LtEq,
InequalityOperator::Lt => Operator::Lt,
InequalityOperator::GtEq => Operator::GtEq,
InequalityOperator::Gt => Operator::Gt,
}
}
}

fn resolve_join_where(
input_left: Arc<DslPlan>,
input_right: Arc<DslPlan>,
predicates: Vec<Expr>,
mut options: Arc<JoinOptions>,
ctxt: &mut DslConversionContext
) -> PolarsResult<Node> {
check_join_keys(&predicates)?;

let owned = |e: Arc<Expr>| (*e).clone();

let mut ie_left_on = vec![];
let mut ie_right_on = vec![];
let mut ie_op = vec![];

let mut eq_left_on = vec![];
let mut eq_right_on = vec![];

let mut remaining_preds = vec![];

fn to_inequality_operator(op: &Operator) -> Option<InequalityOperator> {
match op {
Operator::Lt => Some(InequalityOperator::Lt),
Operator::LtEq => Some(InequalityOperator::LtEq),
Operator::Gt => Some(InequalityOperator::Gt),
Operator::GtEq => Some(InequalityOperator::GtEq),
_ => None,
}
}

for pred in predicates.into_iter() {
let Expr::BinaryExpr {left, op, right} = pred.clone() else { polars_bail!(InvalidOperation: "can only join on binary expressions") };
polars_ensure!(op.is_comparison(), InvalidOperation: "expected comparison in join predicate");

if let Some(ie_op_) = to_inequality_operator(&op) {
// We already have an IEjoin or an Inner join, push to remaining
if ie_op.len() >= 2 || !eq_right_on.is_empty() {
remaining_preds.push(Expr::BinaryExpr {left, op, right})
} else {
ie_left_on.push(owned(left));
ie_right_on.push(owned(right));
ie_op.push(ie_op_)
}
} else if matches!(op, Operator::Eq) {
eq_left_on.push(owned(left));
eq_right_on.push(owned(right));
} else {
remaining_preds.push(pred);
}
}

let join_node = if !eq_left_on.is_empty() {
let join_node = resolve_join(input_left, input_right, eq_left_on, eq_right_on, vec![], options.clone(), ctxt)?;

for ((l, op), r) in ie_left_on.into_iter().zip(ie_op.into_iter()).zip(ie_right_on.into_iter()) {
remaining_preds.push(Expr::BinaryExpr {left: Arc::from(l), op: op.into(), right: Arc::from(r)})
}
join_node

} else if ie_right_on.len() == 2 {
let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::IEJoin(IEJoinOptions {
operator1: ie_op[0],
operator2: ie_op[1],
});

resolve_join(input_left, input_right, ie_left_on, ie_right_on, vec![], options.clone(), ctxt)?
} else {
let opts = Arc::make_mut(&mut options);
opts.args.how = JoinType::Cross;

resolve_join(input_left, input_right, vec![], vec![], vec![], options.clone(), ctxt)?
};

let IR::Join {input_right, ..} = ctxt.lp_arena.get(join_node) else { unreachable!()};
let schema_right = ctxt.lp_arena.get(*input_right).schema(ctxt.lp_arena).into_owned();


let suffix = options.args.suffix();

let mut last_node = join_node;

// Ensure that the predicates use the proper suffix
for e in remaining_preds {
let predicate = to_expr_ir_ignore_alias(e, ctxt.expr_arena)?;
let AExpr::BinaryExpr {left, op, mut right} = *ctxt.expr_arena.get(predicate.node()) else { unreachable!() };

let original_right = right;

for name in aexpr_to_leaf_names(right, ctxt.expr_arena) {
if !schema_right.contains(name.as_str()) {
let new_name = _join_suffix_name(name.as_str(), suffix.as_str());
polars_ensure!(schema_right.contains(new_name.as_str()), ColumnNotFound: "could not find column {name} in the right table during join operation");

right = rename_matching_aexpr_leaf_names(right, ctxt.expr_arena, name.as_str(), new_name);
}
}
ctxt.expr_arena.swap(right, original_right);

let ir = IR::Filter {
input: last_node,
predicate
};
last_node = ctxt.lp_arena.add(ir);

}
Ok(last_node)
}
14 changes: 5 additions & 9 deletions crates/polars-python/src/lazyframe/general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -969,21 +969,17 @@ impl PyLazyFrame {
.into())
}

fn join_where(&self, other: Self, on: Vec<PyExpr>, suffix: String) -> PyResult<Self> {
fn join_where(&self, other: Self, predicates: Vec<PyExpr>, suffix: String) -> PyResult<Self> {
let ldf = self.ldf.clone();
let other = other.ldf;
let (left_on, operators, right_on) = parse_ie_join_expressions(on)?;

let predicates = predicates.to_exprs();

Ok(ldf
.join_builder()
.with(other)
.left_on(left_on)
.right_on(right_on)
.how(JoinType::IEJoin(IEJoinOptions {
operator1: operators[0],
operator2: operators[1],
}))
.suffix(suffix)
.finish()
.join_where(predicates)
.into())
}

Expand Down

0 comments on commit dbc0460

Please sign in to comment.