From b794a6365ddef118d61eef635ab4ea7d6da75f7e Mon Sep 17 00:00:00 2001 From: Martin Hilton Date: Wed, 25 Sep 2024 14:21:20 +0100 Subject: [PATCH 1/6] create smaller output batches in nested loop join Nested loop join creates a single output batch for each (right side) input batch. When performing an outer join the size of the output batch can be as large as number of left data rows * batch rows. If the size of the left data is large then this can produce unreasonably large output batches. Attempt to reduce the size of the output batches by only processing a subset of the input batch at a time where the output could be very large. The trade-off is that this can produce a ;arge number of very small batches instead if the left data is large but there is a highly selective filter. --- .../src/joins/nested_loop_join.rs | 118 ++++++++++++------ datafusion/sqllogictest/test_files/joins.slt | 4 +- 2 files changed, 85 insertions(+), 37 deletions(-) diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 029003374acc..3b8ddd7adc28 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -334,6 +334,13 @@ impl ExecutionPlan for NestedLoopJoinExec { MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]")) .register(context.memory_pool()); + // Initialization of reservation for processing the outer table + let memory_reservation = + MemoryConsumer::new(format!("NestedLoopJoin[{partition}]")) + .register(context.memory_pool()); + + let target_batch_size = context.session_config().batch_size(); + let inner_table = self.inner_table.once(|| { collect_left_input( Arc::clone(&self.left), @@ -344,7 +351,6 @@ impl ExecutionPlan for NestedLoopJoinExec { self.right().output_partitioning().partition_count(), ) }); - let outer_table = self.right.execute(partition, context)?; let indices_cache = (UInt64Array::new_null(0), UInt32Array::new_null(0)); @@ -363,6 +369,10 @@ impl ExecutionPlan for NestedLoopJoinExec { join_metrics, indices_cache, right_side_ordered, + target_batch_size, + outer_record_batch: None, + outer_record_batch_row: 0, + memory_reservation, })) } @@ -466,6 +476,14 @@ struct NestedLoopJoinStream { indices_cache: (UInt64Array, UInt32Array), /// Whether the right side is ordered right_side_ordered: bool, + /// Target size for output batches + target_batch_size: usize, + /// The current record batch being processed. + outer_record_batch: Option, + /// The current index of the outer record batch being processed. + outer_record_batch_row: usize, + /// The memory reservation used to track memory usage. + memory_reservation: MemoryReservation, } /// Creates a Cartesian product of two input batches, preserving the order of the right batch, @@ -566,37 +584,18 @@ impl NestedLoopJoinStream { return Poll::Ready(None); } - self.outer_table - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - Some(Ok(right_batch)) => { - // Setting up timer & updating input metrics + if self.outer_record_batch.is_none() { + // Get the next outer record batch + match ready!(self.outer_table.poll_next_unpin(cx)) { + Some(Ok(batch)) => { self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(right_batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - let result = join_left_and_right_batch( - left_data.batch(), - &right_batch, - self.join_type, - self.filter.as_ref(), - &self.column_indices, - &self.schema, - visited_left_side, - &mut self.indices_cache, - self.right_side_ordered, - ); - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } - - Some(result) + self.join_metrics.input_rows.add(batch.num_rows()); + self.memory_reservation + .try_grow(batch.get_array_memory_size())?; + self.outer_record_batch = Some(batch); + self.outer_record_batch_row = 0; } - Some(err) => Some(err), + Some(Err(e)) => return Poll::Ready(Some(Err(e))), None => { if need_produce_result_in_final(self.join_type) { // At this stage `visited_left_side` won't be updated, so it's @@ -606,7 +605,7 @@ impl NestedLoopJoinStream { // multiple calls of `report_probe_completed()` if !left_data.report_probe_completed() { self.is_exhausted = true; - return None; + return Poll::Ready(None); }; // Only setting up timer, input is exhausted @@ -638,13 +637,60 @@ impl NestedLoopJoinStream { self.join_metrics.output_rows.add(batch.num_rows()); } - Some(result) + return Poll::Ready(Some(result)); } else { // end of the join loop - None + return Poll::Ready(None); } } - }) + } + } + + debug_assert!(self.outer_record_batch.is_some()); + let right_batch = self.outer_record_batch.as_ref().unwrap(); + let num_rows = match (self.join_type, left_data.batch().num_rows()) { + // An inner join will only produce 1 output row per input row. + (JoinType::Inner, _) | (_, 0) => self.target_batch_size, + // Outer joins can produce as many rows as there are in the build input for + // each row in the batch. + (_, rows) => std::cmp::max(1, self.target_batch_size / rows), + }; + let num_rows = std::cmp::min( + num_rows, + right_batch.num_rows() - self.outer_record_batch_row, + ); + let sliced_right_batch = right_batch.slice(self.outer_record_batch_row, num_rows); + + // Setting up timer + let timer = self.join_metrics.join_time.timer(); + + let result = join_left_and_right_batch( + left_data.batch(), + &sliced_right_batch, + self.join_type, + self.filter.as_ref(), + &self.column_indices, + &self.schema, + visited_left_side, + &mut self.indices_cache, + self.right_side_ordered, + ); + + // Recording time & updating output metrics + if let Ok(batch) = &result { + timer.done(); + self.outer_record_batch_row += num_rows; + if self.outer_record_batch_row >= right_batch.num_rows() { + if let Some(batch) = self.outer_record_batch.take() { + self.memory_reservation + .shrink(batch.get_array_memory_size()) + } + } + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + + Poll::Ready(Some(result)) } } @@ -1172,8 +1218,10 @@ mod tests { assert_contains!( err.to_string(), - "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]" + "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:" ); + assert_contains!(err.to_string(), "NestedLoopJoinLoad[0] consumed 0 bytes"); + assert_contains!(err.to_string(), "NestedLoopJoin[0] consumed 0 bytes"); } Ok(()) diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 64d5f6c7b88f..36d006dbe768 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -2130,7 +2130,7 @@ physical_plan 09)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 10)--------MemoryExec: partitions=1, partition_sizes=[1] -query II +query II rowsort SELECT join_t1.t1_id, join_t2.t2_id FROM (select t1_id from join_t1 where join_t1.t1_id > 22) as join_t1 RIGHT JOIN (select t2_id from join_t2 where join_t2.t2_id > 11) as join_t2 @@ -4187,4 +4187,4 @@ physical_plan 02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(b@1, y@1)], filter=a@0 < x@1 03)----MemoryExec: partitions=1, partition_sizes=[0] 04)----SortExec: expr=[x@0 ASC NULLS LAST], preserve_partitioning=[false] -05)------MemoryExec: partitions=1, partition_sizes=[0] \ No newline at end of file +05)------MemoryExec: partitions=1, partition_sizes=[0] From b5c7e72b84360f70def4be5c4bb6fa8b05098231 Mon Sep 17 00:00:00 2001 From: Martin Hilton Date: Thu, 26 Sep 2024 10:25:52 +0100 Subject: [PATCH 2/6] buffer output batches in nested loop join Use buffering to keep the size of output batches from nested loop join around the configured batch size. Small record batches are buffered until there is enough rows available to fill a full batch at which point the small batches are combined into a single batch. Larger batches have batch sized slices taken from them until they become smaller than the configured batch size. --- datafusion/core/tests/memory_limit/mod.rs | 5 +- .../src/joins/nested_loop_join.rs | 391 +++++++++++++----- 2 files changed, 289 insertions(+), 107 deletions(-) diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 69ef6058a2f6..1bf50ea5751b 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -141,7 +141,10 @@ async fn join_by_expression() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service != t2.service") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:", + "NestedLoopJoinLoad[0] consumed 0 bytes", + "NestedLoopJoin[0] consumed 0 bytes", + "NestedLoopJoinBuffer[0] consumed 0 bytes", ]) .with_memory_limit(1_000) .run() diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 3b8ddd7adc28..220bfee48a15 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -20,6 +20,7 @@ //! determined by the [`JoinType`]. use std::any::Any; +use std::collections::VecDeque; use std::fmt::Formatter; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -100,6 +101,144 @@ impl JoinLeftData { } } +/// Buffer to store the output batches. The buffer is used to +/// keep the size of output batches around the target buffer size. +/// Output batches are either combined or split to ensure that the +/// batches produced are as close to the target size as possible +/// without being any bigger. +struct OutputBuffer { + /// Whether the buffer has output any record batches, + /// used to ensure that at least one record batch is + /// produced even if no joins happen. + have_output: bool, + /// Output batch + batches: VecDeque, + /// Number of rows in the buffer + total_rows: usize, + /// Batch that is ready to be output. + next_batch: Option<(RecordBatch, usize)>, + /// Schema of the output batches + schema: SchemaRef, + /// Target output batch size + target_batch_size: usize, + /// Memory reservation for tracking batch memory use + memory_reservation: MemoryReservation, +} + +impl OutputBuffer { + fn new( + schema: SchemaRef, + target_batch_size: usize, + memory_reservation: MemoryReservation, + ) -> Self { + Self { + have_output: false, + batches: VecDeque::new(), + total_rows: 0, + next_batch: None, + schema, + target_batch_size, + memory_reservation, + } + } + + /// Pash a new batch into the output buffer. + fn push(&mut self, batch: RecordBatch) -> Result<()> { + assert!(self.next_batch.is_none()); + assert!(self.total_rows < self.target_batch_size); + + self.memory_reservation + .try_grow(batch.get_array_memory_size())?; + if self.total_rows + batch.num_rows() < self.target_batch_size { + // Not enough rows to fill a batch. + self.total_rows += batch.num_rows(); + self.batches.push_back(batch); + return Ok(()); + } + let mut batches = std::mem::take(&mut self.batches); + if self.total_rows + batch.num_rows() == self.target_batch_size { + self.total_rows = 0; + batches.push_back(batch); + } else { + self.total_rows = batch.num_rows(); + self.batches.push_back(batch); + } + let mut total_array_memory_size = 0; + let batch = concat_batches( + &self.schema, + batches.iter().inspect(|batch| { + total_array_memory_size += batch.get_array_memory_size() + }), + )?; + self.next_batch = Some((batch, total_array_memory_size)); + Ok(()) + } + + /// Return a batch if there are enough rows in the buffer. + fn next(&mut self) -> Option { + if let Some((batch, total_array_memory_size)) = self.next_batch.take() { + self.have_output = true; + self.memory_reservation.shrink(total_array_memory_size); + return Some(batch); + } + if self.total_rows < self.target_batch_size { + return None; + } + if self.total_rows == self.target_batch_size { + self.total_rows = 0; + return self.batches.pop_front().inspect(|batch| { + self.memory_reservation + .shrink(batch.get_array_memory_size()) + }); + } + let batch = self.batches[0].slice(0, self.target_batch_size); + let rest = self.batches[0].slice( + self.target_batch_size, + self.total_rows - self.target_batch_size, + ); + self.have_output = true; + self.total_rows -= self.target_batch_size; + self.batches[0] = rest; + Some(batch) + } + + /// The number of rows needed to fill the next batch. + fn needed_rows(&self) -> usize { + if self.total_rows < self.target_batch_size { + self.target_batch_size - self.total_rows + } else { + 0 + } + } + + /// Flush any remaining rows into a batch even it it's not yet full. + fn flush(&mut self) -> Result> { + if self.total_rows == 0 { + return Ok(None); + } + let batch = std::mem::take(&mut self.batches); + let total_array_memory_size = + batch.iter().map(|b| b.get_array_memory_size()).sum(); + self.total_rows = 0; + self.memory_reservation.shrink(total_array_memory_size); + self.have_output = true; + Ok(Some(concat_batches(&self.schema, batch.iter())?)) + } + + /// Finish the output buffer, returning the last batch. If the + /// buffer has never output a batch then create an empty batch. + fn finish(&mut self) -> Result> { + Ok(self.flush()?.or_else(|| { + if self.have_output { + None + } else { + self.have_output = true; + Some(RecordBatch::new_empty(Arc::clone(&self.schema))) + } + })) + } +} + /// NestedLoopJoinExec is build-probe join operator, whose main task is to /// perform joins without any equijoin conditions in `ON` clause. /// @@ -339,6 +478,11 @@ impl ExecutionPlan for NestedLoopJoinExec { MemoryConsumer::new(format!("NestedLoopJoin[{partition}]")) .register(context.memory_pool()); + // Initialization of reservation for buffering output + let buffer_memory_reservation = + MemoryConsumer::new(format!("NestedLoopJoinBuffer[{partition}]")) + .register(context.memory_pool()); + let target_batch_size = context.session_config().batch_size(); let inner_table = self.inner_table.once(|| { @@ -358,6 +502,13 @@ impl ExecutionPlan for NestedLoopJoinExec { // Right side has an order and it is maintained during operation. let right_side_ordered = self.maintains_input_order()[1] && self.right.output_ordering().is_some(); + + let output_buffer = OutputBuffer::new( + Arc::clone(&self.schema), + target_batch_size, + buffer_memory_reservation, + ); + Ok(Box::pin(NestedLoopJoinStream { schema: Arc::clone(&self.schema), filter: self.filter.clone(), @@ -369,10 +520,10 @@ impl ExecutionPlan for NestedLoopJoinExec { join_metrics, indices_cache, right_side_ordered, - target_batch_size, outer_record_batch: None, outer_record_batch_row: 0, memory_reservation, + output_buffer, })) } @@ -476,14 +627,14 @@ struct NestedLoopJoinStream { indices_cache: (UInt64Array, UInt32Array), /// Whether the right side is ordered right_side_ordered: bool, - /// Target size for output batches - target_batch_size: usize, - /// The current record batch being processed. + /// The current record batch being processed outer_record_batch: Option, - /// The current index of the outer record batch being processed. + /// The current index of the outer record batch being processed outer_record_batch_row: usize, - /// The memory reservation used to track memory usage. + /// The memory reservation used to track memory usage memory_reservation: MemoryReservation, + /// The buffer to store the output rows + output_buffer: OutputBuffer, } /// Creates a Cartesian product of two input batches, preserving the order of the right batch, @@ -578,119 +729,147 @@ impl NestedLoopJoinStream { // Get or initialize visited_left_side bitmap if required by join type let visited_left_side = left_data.bitmap(); - // Check is_exhausted before polling the outer_table, such that when the outer table - // does not support `FusedStream`, Self will not poll it again - if self.is_exhausted { - return Poll::Ready(None); - } + loop { + if let Some(batch) = self.output_buffer.next() { + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + return Poll::Ready(Some(Ok(batch))); + } - if self.outer_record_batch.is_none() { - // Get the next outer record batch - match ready!(self.outer_table.poll_next_unpin(cx)) { - Some(Ok(batch)) => { - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(batch.num_rows()); - self.memory_reservation - .try_grow(batch.get_array_memory_size())?; - self.outer_record_batch = Some(batch); - self.outer_record_batch_row = 0; + // Check is_exhausted before polling the outer_table, such that when the outer table + // does not support `FusedStream`, Self will not poll it again + if self.is_exhausted { + let batch = self.output_buffer.finish()?; + if let Some(batch) = &batch { + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); } - Some(Err(e)) => return Poll::Ready(Some(Err(e))), - None => { - if need_produce_result_in_final(self.join_type) { - // At this stage `visited_left_side` won't be updated, so it's - // safe to report about probe completion. - // - // Setting `is_exhausted` / returning None will prevent from - // multiple calls of `report_probe_completed()` - if !left_data.report_probe_completed() { - self.is_exhausted = true; - return Poll::Ready(None); - }; + return Poll::Ready(Ok(batch).transpose()); + } - // Only setting up timer, input is exhausted - let timer = self.join_metrics.join_time.timer(); - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = - get_final_indices_from_shared_bitmap( - visited_left_side, - self.join_type, + if self.outer_record_batch.is_none() { + // Get the next outer record batch + match self.outer_table.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + self.memory_reservation + .try_grow(batch.get_array_memory_size())?; + self.outer_record_batch = Some(batch); + self.outer_record_batch_row = 0; + } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + if need_produce_result_in_final(self.join_type) { + // At this stage `visited_left_side` won't be updated, so it's + // safe to report about probe completion. + // + // Setting `is_exhausted` will prevent from multiple calls of + // `report_probe_completed()` + if !left_data.report_probe_completed() { + self.is_exhausted = true; + continue; + }; + + // Only setting up timer, input is exhausted + let timer = self.join_metrics.join_time.timer(); + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_shared_bitmap( + visited_left_side, + self.join_type, + ); + let empty_right_batch = + RecordBatch::new_empty(self.outer_table.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, ); - let empty_right_batch = - RecordBatch::new_empty(self.outer_table.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.is_exhausted = true; - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } + self.is_exhausted = true; - return Poll::Ready(Some(result)); - } else { - // end of the join loop - return Poll::Ready(None); + // Recording time & updating output metrics + match result { + Ok(batch) => { + timer.done(); + self.output_buffer.push(batch)?; + continue; + } + Err(e) => return Poll::Ready(Some(Err(e))), + } + } else { + self.is_exhausted = true; + continue; + } + } + Poll::Pending => { + return match self.output_buffer.flush() { + Ok(Some(batch)) => { + // If there was anything in the output buffer flush it + // so that it can be processed. + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + Poll::Ready(Some(Ok(batch))) + } + Ok(None) => Poll::Pending, + Err(err) => Poll::Ready(Some(Err(err))), + }; } } } - } - - debug_assert!(self.outer_record_batch.is_some()); - let right_batch = self.outer_record_batch.as_ref().unwrap(); - let num_rows = match (self.join_type, left_data.batch().num_rows()) { - // An inner join will only produce 1 output row per input row. - (JoinType::Inner, _) | (_, 0) => self.target_batch_size, - // Outer joins can produce as many rows as there are in the build input for - // each row in the batch. - (_, rows) => std::cmp::max(1, self.target_batch_size / rows), - }; - let num_rows = std::cmp::min( - num_rows, - right_batch.num_rows() - self.outer_record_batch_row, - ); - let sliced_right_batch = right_batch.slice(self.outer_record_batch_row, num_rows); - - // Setting up timer - let timer = self.join_metrics.join_time.timer(); - let result = join_left_and_right_batch( - left_data.batch(), - &sliced_right_batch, - self.join_type, - self.filter.as_ref(), - &self.column_indices, - &self.schema, - visited_left_side, - &mut self.indices_cache, - self.right_side_ordered, - ); + debug_assert!(self.outer_record_batch.is_some()); + let right_batch = self.outer_record_batch.as_ref().unwrap(); + let num_rows = match (self.join_type, left_data.batch().num_rows()) { + // An inner join will only produce 1 output row per input row. + (JoinType::Inner, _) | (_, 0) => self.output_buffer.needed_rows(), + // Outer joins can produce as many rows as there are in the build input for + // each row in the batch. + (_, rows) => std::cmp::max(1, self.output_buffer.needed_rows() / rows), + }; + let num_rows = std::cmp::min( + num_rows, + right_batch.num_rows() - self.outer_record_batch_row, + ); + let sliced_right_batch = + right_batch.slice(self.outer_record_batch_row, num_rows); + + // Setting up timer + let timer = self.join_metrics.join_time.timer(); + + let result = join_left_and_right_batch( + left_data.batch(), + &sliced_right_batch, + self.join_type, + self.filter.as_ref(), + &self.column_indices, + &self.schema, + visited_left_side, + &mut self.indices_cache, + self.right_side_ordered, + ); - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.outer_record_batch_row += num_rows; - if self.outer_record_batch_row >= right_batch.num_rows() { - if let Some(batch) = self.outer_record_batch.take() { - self.memory_reservation - .shrink(batch.get_array_memory_size()) + // Recording time & updating output metrics + match result { + Ok(batch) => { + timer.done(); + self.outer_record_batch_row += num_rows; + if self.outer_record_batch_row >= right_batch.num_rows() { + if let Some(batch) = self.outer_record_batch.take() { + self.memory_reservation + .shrink(batch.get_array_memory_size()) + } + } + self.output_buffer.push(batch)?; } + Err(e) => return Poll::Ready(Some(Err(e))), } - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); } - - Poll::Ready(Some(result)) } } From e0dadee775015d7ec40a99921dd36af441cc4223 Mon Sep 17 00:00:00 2001 From: Martin Hilton Date: Thu, 26 Sep 2024 12:11:44 +0100 Subject: [PATCH 3/6] nested loop join batch size test Add a test that the nested loop join keeps the output batches smaller than the configured batch size. --- .../src/joins/nested_loop_join.rs | 56 +++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 220bfee48a15..460dbe7915ee 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -959,6 +959,7 @@ mod tests { use arrow_array::Int32Array; use arrow_schema::SortOptions; use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; + use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; @@ -1568,4 +1569,59 @@ mod tests { fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() } + + #[tokio::test] + async fn test_batch_size() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + None, + Vec::new(), + ); + let right = build_table( + ("a2", &vec![10, 11, 12, 13, 14]), + ("b2", &vec![20, 21, 22, 23, 24]), + ("c2", &vec![30, 31, 32, 33, 34]), + None, + Vec::new(), + ); + let filter = prepare_join_filter(); + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightSemi, + JoinType::RightAnti, + ]; + + for join_type in join_types { + let config = SessionConfig::new().with_batch_size(3); + let task_ctx = TaskContext::default().with_session_config(config); + let task_ctx = Arc::new(task_ctx); + + let (_, batches) = multi_partitioned_join_collect( + Arc::clone(&left), + Arc::clone(&right), + &join_type, + Some(filter.clone()), + task_ctx, + ) + .await + .unwrap(); + + let max_rows = batches + .iter() + .map(|batch| batch.num_rows()) + .max() + .unwrap_or(0); + assert!(max_rows <= 3); + } + + Ok(()) + } } From 0fa1977ae65dfaec6cbcd5cb9b3d3fbc7469ba27 Mon Sep 17 00:00:00 2001 From: Martin Hilton Date: Fri, 27 Sep 2024 08:52:01 +0100 Subject: [PATCH 4/6] nested loop join issue 12633 test Add a test that exercises the large batch size issue described in issue #12633. This was a code review request. --- .../src/joins/nested_loop_join.rs | 54 +++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 460dbe7915ee..4047a6f282d8 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -1624,4 +1624,58 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_issue_12633() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, false)])); + let batches = vec![ + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4]))], + )?, + RecordBatch::try_new( + schema.clone(), + vec![Arc::new(Int32Array::from(vec![5, 6, 7, 8, 9]))], + )?, + ]; + let left = MemoryExec::try_new(&[batches.clone()], schema.clone(), None)?; + let right = MemoryExec::try_new(&[batches.clone()], schema.clone(), None)?; + + let (schema, column_indices) = + build_join_schema(schema.as_ref(), schema.as_ref(), &JoinType::Full); + + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("v", 0)), + Operator::NotEq, + Arc::new(Column::new("v", 1)), + )), + column_indices, + schema, + ); + + let config = SessionConfig::new().with_batch_size(5); + let task_ctx = Arc::new(TaskContext::default().with_session_config(config)); + + let (_, batches) = multi_partitioned_join_collect( + Arc::new(left), + Arc::new(right), + &JoinType::Full, + Some(filter), + task_ctx, + ) + .await + .unwrap(); + + let rows = batches.iter().map(|batch| batch.num_rows()).sum::(); + assert_eq!(rows, 90); + let max_rows = batches + .iter() + .map(|batch| batch.num_rows()) + .max() + .unwrap_or(0); + assert!(max_rows <= 5); + + Ok(()) + } } From 215315088d4dac234b7ed2770eb570ad262f0eea Mon Sep 17 00:00:00 2001 From: Martin Hilton Date: Fri, 27 Sep 2024 09:04:58 +0100 Subject: [PATCH 5/6] fix clippy errors --- datafusion/physical-plan/src/joins/nested_loop_join.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 4047a6f282d8..a6b347b033e0 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -1630,16 +1630,16 @@ mod tests { let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, false)])); let batches = vec![ RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4]))], )?, RecordBatch::try_new( - schema.clone(), + Arc::clone(&schema), vec![Arc::new(Int32Array::from(vec![5, 6, 7, 8, 9]))], )?, ]; - let left = MemoryExec::try_new(&[batches.clone()], schema.clone(), None)?; - let right = MemoryExec::try_new(&[batches.clone()], schema.clone(), None)?; + let left = MemoryExec::try_new(&[batches.clone()], Arc::clone(&schema), None)?; + let right = MemoryExec::try_new(&[batches.clone()], Arc::clone(&schema), None)?; let (schema, column_indices) = build_join_schema(schema.as_ref(), schema.as_ref(), &JoinType::Full); From 8dd7412fa24edc96cdc939c5507471831a8baf0b Mon Sep 17 00:00:00 2001 From: Martin Hilton Date: Fri, 27 Sep 2024 10:48:57 +0100 Subject: [PATCH 6/6] inner join review suggestion Stop assuming that an INNER join cannot produce more output rows than input. Use the same row count logic for all join types. --- .../src/joins/nested_loop_join.rs | 79 ++++++++++++------- 1 file changed, 49 insertions(+), 30 deletions(-) diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index a6b347b033e0..2175cf6d0664 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -825,12 +825,13 @@ impl NestedLoopJoinStream { debug_assert!(self.outer_record_batch.is_some()); let right_batch = self.outer_record_batch.as_ref().unwrap(); - let num_rows = match (self.join_type, left_data.batch().num_rows()) { - // An inner join will only produce 1 output row per input row. - (JoinType::Inner, _) | (_, 0) => self.output_buffer.needed_rows(), - // Outer joins can produce as many rows as there are in the build input for - // each row in the batch. - (_, rows) => std::cmp::max(1, self.output_buffer.needed_rows() / rows), + let num_rows = if left_data.batch().num_rows() == 0 { + self.output_buffer.needed_rows() + } else { + std::cmp::max( + 1, + self.output_buffer.needed_rows() / left_data.batch().num_rows(), + ) }; let num_rows = std::cmp::min( num_rows, @@ -1638,12 +1639,22 @@ mod tests { vec![Arc::new(Int32Array::from(vec![5, 6, 7, 8, 9]))], )?, ]; - let left = MemoryExec::try_new(&[batches.clone()], Arc::clone(&schema), None)?; - let right = MemoryExec::try_new(&[batches.clone()], Arc::clone(&schema), None)?; - - let (schema, column_indices) = - build_join_schema(schema.as_ref(), schema.as_ref(), &JoinType::Full); + let left = Arc::new(MemoryExec::try_new( + &[batches.clone()], + Arc::clone(&schema), + None, + )?) as Arc; + let right = Arc::new(MemoryExec::try_new( + &[batches.clone()], + Arc::clone(&schema), + None, + )?) as Arc; + let column_indices = JoinFilter::build_column_indices(vec![0], vec![0]); + let intermediate_schema = Schema::new(vec![ + Field::new("v", DataType::Int32, false), + Field::new("v", DataType::Int32, false), + ]); let filter = JoinFilter::new( Arc::new(BinaryExpr::new( Arc::new(Column::new("v", 0)), @@ -1651,31 +1662,39 @@ mod tests { Arc::new(Column::new("v", 1)), )), column_indices, - schema, + intermediate_schema, ); let config = SessionConfig::new().with_batch_size(5); let task_ctx = Arc::new(TaskContext::default().with_session_config(config)); - let (_, batches) = multi_partitioned_join_collect( - Arc::new(left), - Arc::new(right), - &JoinType::Full, - Some(filter), - task_ctx, - ) - .await - .unwrap(); - - let rows = batches.iter().map(|batch| batch.num_rows()).sum::(); - assert_eq!(rows, 90); - let max_rows = batches - .iter() - .map(|batch| batch.num_rows()) - .max() - .unwrap_or(0); - assert!(max_rows <= 5); + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + ]; + + for join_type in join_types { + let (_, batches) = multi_partitioned_join_collect( + Arc::clone(&left), + Arc::clone(&right), + &join_type, + Some(filter.clone()), + Arc::clone(&task_ctx), + ) + .await + .unwrap(); + let rows = batches.iter().map(|batch| batch.num_rows()).sum::(); + assert_eq!(rows, 90); + let max_rows = batches + .iter() + .map(|batch| batch.num_rows()) + .max() + .unwrap_or(0); + assert!(max_rows <= 5); + } Ok(()) } }