Skip to content

Commit

Permalink
refactor: Streamline internal SQL join condition processing (#19658)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Nov 6, 2024
1 parent 55625a1 commit a186f12
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 67 deletions.
96 changes: 36 additions & 60 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1461,32 +1461,31 @@ fn process_join_on(
tbl_left: &TableInfo,
tbl_right: &TableInfo,
) -> PolarsResult<(Vec<Expr>, Vec<Expr>)> {
if let SQLExpr::BinaryOp { left, op, right } = expression {
match *op {
BinaryOperator::Eq => match (left.as_ref(), right.as_ref()) {
(SQLExpr::CompoundIdentifier(left), SQLExpr::CompoundIdentifier(right)) => {
collect_compound_identifiers(left, right, &tbl_left.name, &tbl_right.name)
},
_ => {
polars_bail!(SQLInterface: "only equi-join constraints (on identifiers) are currently supported; found lhs={:?}, rhs={:?}", left, right);
},
},
match expression {
SQLExpr::BinaryOp { left, op, right } => match op {
BinaryOperator::And => {
let (mut left_i, mut right_i) = process_join_on(left, tbl_left, tbl_right)?;
let (mut left_j, mut right_j) = process_join_on(right, tbl_left, tbl_right)?;

left_i.append(&mut left_j);
right_i.append(&mut right_j);
Ok((left_i, right_i))
},
BinaryOperator::Eq => match (left.as_ref(), right.as_ref()) {
(SQLExpr::CompoundIdentifier(left), SQLExpr::CompoundIdentifier(right)) => {
collect_compound_identifiers(left, right, &tbl_left.name, &tbl_right.name)
},
_ => {
polars_bail!(SQLInterface: "only equi-join constraints (on identifiers) are currently supported; found lhs={:?}, rhs={:?}", left, right)
},
},
_ => {
polars_bail!(SQLInterface: "only equi-join constraints (combined with 'AND') are currently supported; found op = '{:?}'", op);
polars_bail!(SQLInterface: "only equi-join constraints (combined with 'AND') are currently supported; found op = '{:?}'", op)
},
}
} else if let SQLExpr::Nested(expr) = expression {
process_join_on(expr, tbl_left, tbl_right)
} else {
polars_bail!(SQLInterface: "only equi-join constraints (combined with 'AND') are currently supported; found expression = {:?}", expression);
},
SQLExpr::Nested(expr) => process_join_on(expr, tbl_left, tbl_right),
_ => {
polars_bail!(SQLInterface: "only equi-join constraints are currently supported; found expression = {:?}", expression)
},
}
}

Expand All @@ -1495,49 +1494,26 @@ fn process_join_constraint(
tbl_left: &TableInfo,
tbl_right: &TableInfo,
) -> PolarsResult<(Vec<Expr>, Vec<Expr>)> {
if let JoinConstraint::On(SQLExpr::BinaryOp { left, op, right }) = constraint {
if op == &BinaryOperator::And {
let (mut left_on, mut right_on) = process_join_on(left, tbl_left, tbl_right)?;
let (left_on_, right_on_) = process_join_on(right, tbl_left, tbl_right)?;
left_on.extend(left_on_);
right_on.extend(right_on_);
return Ok((left_on, right_on));
}
if op != &BinaryOperator::Eq {
polars_bail!(SQLInterface:
"only equi-join constraints are currently supported; found '{:?}' op in\n{:?}", op, constraint)
}
match (left.as_ref(), right.as_ref()) {
(SQLExpr::CompoundIdentifier(left), SQLExpr::CompoundIdentifier(right)) => {
return collect_compound_identifiers(left, right, &tbl_left.name, &tbl_right.name);
},
(SQLExpr::Identifier(left), SQLExpr::Identifier(right)) => {
return Ok((
vec![col(left.value.as_str())],
vec![col(right.value.as_str())],
))
},
_ => {},
}
};
if let JoinConstraint::Using(idents) = constraint {
if !idents.is_empty() {
match constraint {
JoinConstraint::On(expr @ SQLExpr::BinaryOp { .. }) => {
process_join_on(expr, tbl_left, tbl_right)
},
JoinConstraint::Using(idents) if !idents.is_empty() => {
let using: Vec<Expr> = idents.iter().map(|id| col(id.value.as_str())).collect();
return Ok((using.clone(), using.clone()));
}
};
if let JoinConstraint::Natural = constraint {
let left_names = tbl_left.schema.iter_names().collect::<PlHashSet<_>>();
let right_names = tbl_right.schema.iter_names().collect::<PlHashSet<_>>();
let on = left_names
.intersection(&right_names)
.map(|&name| col(name.clone()))
.collect::<Vec<_>>();
if on.is_empty() {
polars_bail!(SQLInterface: "no common columns found for NATURAL JOIN")
}
Ok((on.clone(), on))
} else {
polars_bail!(SQLInterface: "unsupported SQL join constraint:\n{:?}", constraint);
Ok((using.clone(), using))
},
JoinConstraint::Natural => {
let left_names = tbl_left.schema.iter_names().collect::<PlHashSet<_>>();
let right_names = tbl_right.schema.iter_names().collect::<PlHashSet<_>>();
let on: Vec<Expr> = left_names
.intersection(&right_names)
.map(|&name| col(name.clone()))
.collect();
if on.is_empty() {
polars_bail!(SQLInterface: "no common columns found for NATURAL JOIN")
}
Ok((on.clone(), on))
},
_ => polars_bail!(SQLInterface: "unsupported SQL join constraint:\n{:?}", constraint),
}
}
9 changes: 3 additions & 6 deletions crates/polars-sql/tests/statements.rs
Original file line number Diff line number Diff line change
Expand Up @@ -603,29 +603,26 @@ fn test_join_utf8() {
);
}

#[test]
fn test_table() {}

#[test]
#[should_panic]
fn test_compound_invalid_1() {
let mut ctx = prepare_compound_join_context();
let sql = "SELECT * FROM df1 OUTER JOIN df2 ON a AND b";
ctx.execute(sql).unwrap().collect().unwrap();
let _ = ctx.execute(sql).unwrap();
}

#[test]
#[should_panic]
fn test_compound_invalid_2() {
let mut ctx = prepare_compound_join_context();
let sql = "SELECT * FROM df1 LEFT JOIN df2 ON df1.a = df2.a AND b = b";
ctx.execute(sql).unwrap().collect().unwrap();
let _ = ctx.execute(sql).unwrap();
}

#[test]
#[should_panic]
fn test_compound_invalid_3() {
let mut ctx = prepare_compound_join_context();
let sql = "SELECT * FROM df1 INNER JOIN df2 ON df1.a = df2.a AND b";
ctx.execute(sql).unwrap().collect().unwrap();
let _ = ctx.execute(sql).unwrap();
}
2 changes: 1 addition & 1 deletion py-polars/tests/unit/sql/test_joins.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def test_non_equi_joins(constraint: str) -> None:
with (
pytest.raises(
SQLInterfaceError,
match=r"only equi-join constraints are currently supported",
match=r"only equi-join constraints \(combined with 'AND'\) are currently supported",
),
pl.SQLContext({"tbl": pl.DataFrame({"a": [1, 2, 3], "b": [4, 3, 2]})}) as ctx,
):
Expand Down

0 comments on commit a186f12

Please sign in to comment.