From 4ecc8a6104345cf5a26309a1d53fd8bb4895422d Mon Sep 17 00:00:00 2001 From: nameexhaustion Date: Wed, 29 May 2024 16:15:15 +1000 Subject: [PATCH] fix: Error selecting columns after non-coalesced join (multiple join keys) (#16559) --- .../optimizer/projection_pushdown/joins.rs | 27 ++++++----- py-polars/tests/unit/test_projections.py | 47 +++++++++++++++++++ 2 files changed, 62 insertions(+), 12 deletions(-) diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs index a30b674f8e39..04fcdb2832c9 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/joins.rs @@ -97,7 +97,7 @@ pub(super) fn process_asof_join( // The join on keys can lead that columns are already added, we don't want to create // duplicates so store the names. - let mut already_added_local_to_local_projected = BTreeSet::new(); + let mut local_projected_names = BTreeSet::new(); // We need the join columns so we push the projection downwards for e in &left_on { @@ -110,7 +110,7 @@ pub(super) fn process_asof_join( true, ) .unwrap(); - already_added_local_to_local_projected.insert(local_name); + local_projected_names.insert(local_name); } // this differs from normal joins, as in `asof_joins` // both columns remain. So `add_local=true` also for the right table @@ -126,18 +126,18 @@ pub(super) fn process_asof_join( // insert the name. // if name was already added we pop the local projection // otherwise we would project duplicate columns - if !already_added_local_to_local_projected.insert(local_name) { + if !local_projected_names.insert(local_name) { local_projection.pop(); } }; } for proj in acc_projections { - let add_local = if already_added_local_to_local_projected.is_empty() { + let add_local = if local_projected_names.is_empty() { true } else { let name = column_node_to_name(proj, expr_arena); - !already_added_local_to_local_projected.contains(&name) + !local_projected_names.contains(&name) }; process_projection( @@ -242,11 +242,15 @@ pub(super) fn process_join( // The join on keys can lead that columns are already added, we don't want to create // duplicates so store the names. - let mut already_added_local_to_local_projected = BTreeSet::new(); + let mut local_projected_names = BTreeSet::new(); // We need the join columns so we push the projection downwards for e in &left_on { - let local_name = add_keys_to_accumulated_state( + if !local_projected_names.insert(e.output_name_arc().clone()) { + continue; + } + + add_keys_to_accumulated_state( e.node(), &mut pushdown_left, &mut local_projection, @@ -255,7 +259,6 @@ pub(super) fn process_join( true, ) .unwrap(); - already_added_local_to_local_projected.insert(local_name); } // In full outer joins both columns remain. So `add_local=true` also for the right table let add_local = !options.args.coalesce.coalesce(&options.args.how); @@ -263,7 +266,7 @@ pub(super) fn process_join( // In case of full outer joins we also add the columns. // But before we do that we must check if the column wasn't already added by the lhs. let add_local = if add_local { - !already_added_local_to_local_projected.contains(e.output_name()) + !local_projected_names.contains(e.output_name()) } else { false }; @@ -278,16 +281,16 @@ pub(super) fn process_join( ); if let Some(local_name) = local_name { - already_added_local_to_local_projected.insert(local_name); + local_projected_names.insert(local_name); } } for proj in acc_projections { - let add_local = if already_added_local_to_local_projected.is_empty() { + let add_local = if local_projected_names.is_empty() { true } else { let name = column_node_to_name(proj, expr_arena); - !already_added_local_to_local_projected.contains(&name) + !local_projected_names.contains(&name) }; process_projection( diff --git a/py-polars/tests/unit/test_projections.py b/py-polars/tests/unit/test_projections.py index 02b029a46a98..a0609401e8e6 100644 --- a/py-polars/tests/unit/test_projections.py +++ b/py-polars/tests/unit/test_projections.py @@ -453,3 +453,50 @@ def test_non_coalesce_join_projection_pushdown_16515( .item() == 1 ) + + +@pytest.mark.parametrize("join_type", ["inner", "left", "full"]) +def test_non_coalesce_multi_key_join_projection_pushdown_16554( + join_type: Literal["inner", "left", "full"], +) -> None: + lf1 = pl.LazyFrame( + { + "a": [1, 2, 3, 4, 5], + "b": [1, 2, 3, 4, 5], + } + ) + lf2 = pl.LazyFrame( + { + "a": [0, 2, 3, 4, 5], + "b": [1, 2, 3, 5, 6], + "c": [7, 5, 3, 5, 7], + } + ) + + expect = ( + lf1.with_columns(a2="a") + .join( + other=lf2, + how=join_type, + left_on=["a", "a2"], + right_on=["b", "c"], + coalesce=False, + ) + .select("a", "b", "c") + .sort("a") + .collect() + ) + + out = ( + lf1.join( + other=lf2, + how=join_type, + left_on=["a", "a"], + right_on=["b", "c"], + coalesce=False, + ) + .select("a", "b", "c") + .collect() + ) + + assert_frame_equal(out.sort("a"), expect)