Skip to content

Commit

Permalink
fix: Validate asof join by args in IR resolving phase (#20473)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Dec 27, 2024
1 parent 8b3afb7 commit 2685a86
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 13 deletions.
36 changes: 26 additions & 10 deletions crates/polars-plan/src/plans/conversion/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ pub fn resolve_join(
}

let owned = Arc::unwrap_or_clone;
let mut input_left = input_left.map_right(Ok).right_or_else(|input| {
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join left)))
})?;
let mut input_right = input_right.map_right(Ok).right_or_else(|input| {
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join right)))
})?;

let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);

if options.args.how.is_cross() {
polars_ensure!(left_on.len() + right_on.len() == 0, InvalidOperation: "a 'cross' join doesn't expect any join keys");
} else {
Expand All @@ -65,6 +75,21 @@ pub fn resolve_join(

options.args.validation.is_valid_join(&options.args.how)?;

#[cfg(feature = "asof_join")]
if let JoinType::AsOf(opt) = &options.args.how {
match (&opt.left_by, &opt.right_by) {
(None, None) => {},
(Some(l), Some(r)) => {
polars_ensure!(l.len() == r.len(), InvalidOperation: "expected equal number of columns in 'by_left' and 'by_right' in 'asof_join'");
validate_columns_in_input(l, &schema_left, "asof_join")?;
validate_columns_in_input(r, &schema_right, "asof_join")?;
},
_ => {
polars_bail!(InvalidOperation: "expected both 'by_left' and 'by_right' to be set in 'asof_join'")
},
}
}

polars_ensure!(
left_on.len() == right_on.len(),
InvalidOperation:
Expand All @@ -76,16 +101,6 @@ pub fn resolve_join(
);
}

let mut input_left = input_left.map_right(Ok).right_or_else(|input| {
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join left)))
})?;
let mut input_right = input_right.map_right(Ok).right_or_else(|input| {
to_alp_impl(owned(input), ctxt).map_err(|e| e.context(failed_here!(join right)))
})?;

let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);

let schema = det_join_schema(&schema_left, &schema_right, &left_on, &right_on, &options)
.map_err(|e| e.context(failed_here!(join schema resolving)))?;

Expand Down Expand Up @@ -120,6 +135,7 @@ pub fn resolve_join(
.coerce_types(ctxt.expr_arena, ctxt.lp_arena, input_right)
.map_err(|e| e.context("'join' failed".into()))?;

// Re-evaluate because of mutable borrows earlier.
let schema_left = ctxt.lp_arena.get(input_left).schema(ctxt.lp_arena);
let schema_right = ctxt.lp_arena.get(input_right).schema(ctxt.lp_arena);

Expand Down
7 changes: 4 additions & 3 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -4568,9 +4568,10 @@ def join_asof(
if by is not None:
by_left_ = [by] if isinstance(by, str) else by
by_right_ = by_left_
elif (by_left is not None) and (by_right is not None):
by_left_ = [by_left] if isinstance(by_left, str) else by_left
by_right_ = [by_right] if isinstance(by_right, str) else by_right
elif (by_left is not None) or (by_right is not None):
by_left_ = [by_left] if isinstance(by_left, str) else by_left # type: ignore[assignment]
by_right_ = [by_right] if isinstance(by_right, str) else by_right # type: ignore[assignment]

else:
# no by
by_left_ = None
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/operations/test_join_asof.py
Original file line number Diff line number Diff line change
Expand Up @@ -1196,3 +1196,21 @@ def test_asof_join_by_schema() -> None:
)

assert q.collect_schema() == q.collect().schema


def test_raise_invalid_by_arg_13020() -> None:
df1 = pl.DataFrame({"asOfDate": [date(2020, 1, 1)]})
df2 = pl.DataFrame(
{
"endityId": [date(2020, 1, 1)],
"eventDate": ["A"],
}
)
with pytest.raises(pl.exceptions.InvalidOperationError, match="expected both"):
df1.sort("asOfDate").join_asof(
df2.sort("eventDate"),
left_on="asOfDate",
right_on="eventDate",
by_left=None,
by_right=["entityId"],
)

0 comments on commit 2685a86

Please sign in to comment.