Skip to content

Commit

Permalink
inner join review suggestion
Browse files Browse the repository at this point in the history
Stop assuming that an INNER join cannot produce more output rows
than input. Use the same row count logic for all join types.
  • Loading branch information
mhilton committed Sep 27, 2024
1 parent 2153150 commit 8dd7412
Showing 1 changed file with 49 additions and 30 deletions.
79 changes: 49 additions & 30 deletions datafusion/physical-plan/src/joins/nested_loop_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -825,12 +825,13 @@ impl NestedLoopJoinStream {

debug_assert!(self.outer_record_batch.is_some());
let right_batch = self.outer_record_batch.as_ref().unwrap();
let num_rows = match (self.join_type, left_data.batch().num_rows()) {
// An inner join will only produce 1 output row per input row.
(JoinType::Inner, _) | (_, 0) => self.output_buffer.needed_rows(),
// Outer joins can produce as many rows as there are in the build input for
// each row in the batch.
(_, rows) => std::cmp::max(1, self.output_buffer.needed_rows() / rows),
let num_rows = if left_data.batch().num_rows() == 0 {
self.output_buffer.needed_rows()
} else {
std::cmp::max(
1,
self.output_buffer.needed_rows() / left_data.batch().num_rows(),
)
};
let num_rows = std::cmp::min(
num_rows,
Expand Down Expand Up @@ -1638,44 +1639,62 @@ mod tests {
vec![Arc::new(Int32Array::from(vec![5, 6, 7, 8, 9]))],
)?,
];
let left = MemoryExec::try_new(&[batches.clone()], Arc::clone(&schema), None)?;
let right = MemoryExec::try_new(&[batches.clone()], Arc::clone(&schema), None)?;

let (schema, column_indices) =
build_join_schema(schema.as_ref(), schema.as_ref(), &JoinType::Full);
let left = Arc::new(MemoryExec::try_new(
&[batches.clone()],
Arc::clone(&schema),
None,
)?) as Arc<dyn ExecutionPlan>;
let right = Arc::new(MemoryExec::try_new(
&[batches.clone()],
Arc::clone(&schema),
None,
)?) as Arc<dyn ExecutionPlan>;

let column_indices = JoinFilter::build_column_indices(vec![0], vec![0]);
let intermediate_schema = Schema::new(vec![
Field::new("v", DataType::Int32, false),
Field::new("v", DataType::Int32, false),
]);
let filter = JoinFilter::new(
Arc::new(BinaryExpr::new(
Arc::new(Column::new("v", 0)),
Operator::NotEq,
Arc::new(Column::new("v", 1)),
)),
column_indices,
schema,
intermediate_schema,
);

let config = SessionConfig::new().with_batch_size(5);
let task_ctx = Arc::new(TaskContext::default().with_session_config(config));

let (_, batches) = multi_partitioned_join_collect(
Arc::new(left),
Arc::new(right),
&JoinType::Full,
Some(filter),
task_ctx,
)
.await
.unwrap();

let rows = batches.iter().map(|batch| batch.num_rows()).sum::<usize>();
assert_eq!(rows, 90);
let max_rows = batches
.iter()
.map(|batch| batch.num_rows())
.max()
.unwrap_or(0);
assert!(max_rows <= 5);
let join_types = vec![
JoinType::Inner,
JoinType::Left,
JoinType::Right,
JoinType::Full,
];

for join_type in join_types {
let (_, batches) = multi_partitioned_join_collect(
Arc::clone(&left),
Arc::clone(&right),
&join_type,
Some(filter.clone()),
Arc::clone(&task_ctx),
)
.await
.unwrap();

let rows = batches.iter().map(|batch| batch.num_rows()).sum::<usize>();
assert_eq!(rows, 90);
let max_rows = batches
.iter()
.map(|batch| batch.num_rows())
.max()
.unwrap_or(0);
assert!(max_rows <= 5);
}
Ok(())
}
}

0 comments on commit 8dd7412

Please sign in to comment.