Skip to content

Commit

Permalink
fix(rust): Ensure deduced join key names are unique (#16551)
Browse files Browse the repository at this point in the history
  • Loading branch information
wence- authored May 28, 2024
1 parent a2b1e9b commit 54213a2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
15 changes: 10 additions & 5 deletions crates/polars-plan/src/logical_plan/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,6 @@ pub fn to_alp_impl(
}
}

let mut joined_on = PlHashSet::new();
for (l, r) in left_on.iter().zip(right_on.iter()) {
polars_ensure!(joined_on.insert((l, r)), InvalidOperation: "joins on same keys twice; already joined on {} and {}", l, r)
}
drop(joined_on);
options.args.validation.is_valid_join(&options.args.how)?;

polars_ensure!(
Expand All @@ -356,6 +351,16 @@ pub fn to_alp_impl(

let left_on = to_expr_irs_ignore_alias(left_on, expr_arena);
let right_on = to_expr_irs_ignore_alias(right_on, expr_arena);
let mut joined_on = PlHashSet::new();
for (l, r) in left_on.iter().zip(right_on.iter()) {
polars_ensure!(
joined_on.insert((l.output_name(), r.output_name())),
InvalidOperation: "joining with repeated key names; already joined on {} and {}",
l.output_name(),
r.output_name()
)
}
drop(joined_on);

convert.fill_scratch(&left_on, expr_arena);
convert.fill_scratch(&right_on, expr_arena);
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/operations/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -1014,3 +1014,13 @@ def test_left_join_coalesce_default_deprecation_message() -> None:
right = pl.DataFrame({"a": [2, 3, 4], "c": [4, 5, 6]})
with pytest.deprecated_call():
left.join(right, on="a", how="left")


@pytest.mark.parametrize("coalesce", [False, True])
def test_join_raise_on_repeated_expression_key_names(coalesce: bool) -> None:
left = pl.DataFrame({"a": [1, 2, 3], "b": [3, 4, 5], "c": [5, 6, 7]})
right = pl.DataFrame({"a": [2, 3, 4], "c": [4, 5, 6]})
with pytest.raises(pl.InvalidOperationError, match="already joined on"):
left.join(
right, on=[pl.col("a"), pl.col("a") % 2], how="full", coalesce=coalesce
)

0 comments on commit 54213a2

Please sign in to comment.