From da2ba82c7f3027c7e1519d7128209f188724da72 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 1 Nov 2024 12:01:03 +0000 Subject: [PATCH] fix(rust,python): Fix predicate pushdown into inequality joins (#19582) --- crates/polars-ops/src/frame/join/args.rs | 4 +++ .../polars-plan/src/plans/conversion/join.rs | 2 +- .../src/plans/optimizer/collapse_joins.rs | 2 +- .../src/plans/optimizer/collect_members.rs | 2 +- .../optimizer/predicate_pushdown/join.rs | 3 +- .../unit/operations/test_inequality_join.py | 31 +++++++++++++++++++ 6 files changed, 40 insertions(+), 4 deletions(-) diff --git a/crates/polars-ops/src/frame/join/args.rs b/crates/polars-ops/src/frame/join/args.rs index 4c845a2ba541..d34c37e7ff67 100644 --- a/crates/polars-ops/src/frame/join/args.rs +++ b/crates/polars-ops/src/frame/join/args.rs @@ -174,6 +174,10 @@ impl JoinType { } } + pub fn is_cross(&self) -> bool { + matches!(self, JoinType::Cross) + } + pub fn is_ie(&self) -> bool { #[cfg(feature = "iejoin")] { diff --git a/crates/polars-plan/src/plans/conversion/join.rs b/crates/polars-plan/src/plans/conversion/join.rs index a81c36bef1f8..5222ca73cda1 100644 --- a/crates/polars-plan/src/plans/conversion/join.rs +++ b/crates/polars-plan/src/plans/conversion/join.rs @@ -41,7 +41,7 @@ pub fn resolve_join( } let owned = Arc::unwrap_or_clone; - if matches!(options.args.how, JoinType::Cross) { + if options.args.how.is_cross() { polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys"); } else { polars_ensure!(left_on.len() + right_on.len() > 0, InvalidOperation: "expected join keys/predicates"); diff --git a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs index 778efee6aa9b..07a4a45948b8 100644 --- a/crates/polars-plan/src/plans/optimizer/collapse_joins.rs +++ b/crates/polars-plan/src/plans/optimizer/collapse_joins.rs @@ -192,7 +192,7 @@ pub fn optimize(root: Node, lp_arena: &mut Arena, expr_arena: &mut Arena { + } if options.args.how.is_cross() => { if predicates.is_empty() { continue; } diff --git a/crates/polars-plan/src/plans/optimizer/collect_members.rs b/crates/polars-plan/src/plans/optimizer/collect_members.rs index 3157821c7696..f829b3362473 100644 --- a/crates/polars-plan/src/plans/optimizer/collect_members.rs +++ b/crates/polars-plan/src/plans/optimizer/collect_members.rs @@ -48,7 +48,7 @@ impl MemberCollector { match alp { Join { .. } | Union { .. } => self.has_joins_or_unions = true, Filter { input, .. } => { - self.has_filter_with_join_input |= matches!(lp_arena.get(*input), Join { options, .. } if options.args.how == JoinType::Cross) + self.has_filter_with_join_input |= matches!(lp_arena.get(*input), Join { options, .. } if options.args.how.is_cross()) }, Cache { .. } => self.has_cache = true, ExtContext { .. } => self.has_ext_context = true, diff --git a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs index 3b23faef8e04..b79de0e2b959 100644 --- a/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs +++ b/crates/polars-plan/src/plans/optimizer/predicate_pushdown/join.rs @@ -149,7 +149,8 @@ pub(super) fn process_join( for (_, predicate) in acc_predicates { // Cross joins produce a cartesian product, so if a predicate combines columns from both tables, we should not push down. - if matches!(options.args.how, JoinType::Cross) + // Inequality joins logically produce a cartesian product, so the same logic applies. + if (options.args.how.is_cross() || options.args.how.is_ie()) && predicate_applies_to_both_tables( predicate.node(), expr_arena, diff --git a/py-polars/tests/unit/operations/test_inequality_join.py b/py-polars/tests/unit/operations/test_inequality_join.py index 89eec6bfd224..d306dd173431 100644 --- a/py-polars/tests/unit/operations/test_inequality_join.py +++ b/py-polars/tests/unit/operations/test_inequality_join.py @@ -624,3 +624,34 @@ def test_join_partial_column_name_overlap_19119() -> None: "a_right": [2], "d": [0], } + + +def test_join_predicate_pushdown_19580() -> None: + left = pl.LazyFrame( + { + "a": [1, 2, 3, 1], + "b": [1, 2, 3, 4], + "c": [2, 3, 4, 5], + } + ) + + right = pl.LazyFrame({"a": [1, 3], "c": [2, 4], "d": [6, 3]}) + + q = left.join_where( + right, + pl.col("b") < pl.col("c_right"), + pl.col("a") < pl.col("a_right"), + pl.col("a") < pl.col("d"), + ) + + expect = ( + left.join(right, how="cross") + .collect() + .filter( + (pl.col("a") < pl.col("d")) + & (pl.col("b") < pl.col("c_right")) + & (pl.col("a") < pl.col("a_right")) + ) + ) + + assert_frame_equal(expect, q.collect(), check_row_order=False)