Skip to content

Commit

Permalink
fix(rust,python): Fix predicate pushdown into inequality joins (#19582)
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- authored Nov 1, 2024
1 parent 3fe10a3 commit da2ba82
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 4 deletions.
4 changes: 4 additions & 0 deletions crates/polars-ops/src/frame/join/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,10 @@ impl JoinType {
}
}

pub fn is_cross(&self) -> bool {
matches!(self, JoinType::Cross)
}

pub fn is_ie(&self) -> bool {
#[cfg(feature = "iejoin")]
{
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ pub fn resolve_join(
}

let owned = Arc::unwrap_or_clone;
if matches!(options.args.how, JoinType::Cross) {
if options.args.how.is_cross() {
polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys");
} else {
polars_ensure!(left_on.len() + right_on.len() > 0, InvalidOperation: "expected join keys/predicates");
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/optimizer/collapse_joins.rs
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ pub fn optimize(root: Node, lp_arena: &mut Arena<IR>, expr_arena: &mut Arena<AEx
left_on,
right_on,
options,
} if matches!(options.args.how, JoinType::Cross) => {
} if options.args.how.is_cross() => {
if predicates.is_empty() {
continue;
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/plans/optimizer/collect_members.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ impl MemberCollector {
match alp {
Join { .. } | Union { .. } => self.has_joins_or_unions = true,
Filter { input, .. } => {
self.has_filter_with_join_input |= matches!(lp_arena.get(*input), Join { options, .. } if options.args.how == JoinType::Cross)
self.has_filter_with_join_input |= matches!(lp_arena.get(*input), Join { options, .. } if options.args.how.is_cross())
},
Cache { .. } => self.has_cache = true,
ExtContext { .. } => self.has_ext_context = true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ pub(super) fn process_join(

for (_, predicate) in acc_predicates {
// Cross joins produce a cartesian product, so if a predicate combines columns from both tables, we should not push down.
if matches!(options.args.how, JoinType::Cross)
// Inequality joins logically produce a cartesian product, so the same logic applies.
if (options.args.how.is_cross() || options.args.how.is_ie())
&& predicate_applies_to_both_tables(
predicate.node(),
expr_arena,
Expand Down
31 changes: 31 additions & 0 deletions py-polars/tests/unit/operations/test_inequality_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,3 +624,34 @@ def test_join_partial_column_name_overlap_19119() -> None:
"a_right": [2],
"d": [0],
}


def test_join_predicate_pushdown_19580() -> None:
left = pl.LazyFrame(
{
"a": [1, 2, 3, 1],
"b": [1, 2, 3, 4],
"c": [2, 3, 4, 5],
}
)

right = pl.LazyFrame({"a": [1, 3], "c": [2, 4], "d": [6, 3]})

q = left.join_where(
right,
pl.col("b") < pl.col("c_right"),
pl.col("a") < pl.col("a_right"),
pl.col("a") < pl.col("d"),
)

expect = (
left.join(right, how="cross")
.collect()
.filter(
(pl.col("a") < pl.col("d"))
& (pl.col("b") < pl.col("c_right"))
& (pl.col("a") < pl.col("a_right"))
)
)

assert_frame_equal(expect, q.collect(), check_row_order=False)

0 comments on commit da2ba82

Please sign in to comment.