diff --git a/datafusion/core/tests/memory_limit/mod.rs b/datafusion/core/tests/memory_limit/mod.rs index 69ef6058a2f6..1bf50ea5751b 100644 --- a/datafusion/core/tests/memory_limit/mod.rs +++ b/datafusion/core/tests/memory_limit/mod.rs @@ -141,7 +141,10 @@ async fn join_by_expression() { TestCase::new() .with_query("select t1.* from t t1 JOIN t t2 ON t1.service != t2.service") .with_expected_errors(vec![ - "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]", + "Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:", + "NestedLoopJoinLoad[0] consumed 0 bytes", + "NestedLoopJoin[0] consumed 0 bytes", + "NestedLoopJoinBuffer[0] consumed 0 bytes", ]) .with_memory_limit(1_000) .run() diff --git a/datafusion/physical-plan/src/joins/nested_loop_join.rs b/datafusion/physical-plan/src/joins/nested_loop_join.rs index 029003374acc..2175cf6d0664 100644 --- a/datafusion/physical-plan/src/joins/nested_loop_join.rs +++ b/datafusion/physical-plan/src/joins/nested_loop_join.rs @@ -20,6 +20,7 @@ //! determined by the [`JoinType`]. use std::any::Any; +use std::collections::VecDeque; use std::fmt::Formatter; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; @@ -100,6 +101,144 @@ impl JoinLeftData { } } +/// Buffer to store the output batches. The buffer is used to +/// keep the size of output batches around the target buffer size. +/// Output batches are either combined or split to ensure that the +/// batches produced are as close to the target size as possible +/// without being any bigger. +struct OutputBuffer { + /// Whether the buffer has output any record batches, + /// used to ensure that at least one record batch is + /// produced even if no joins happen. + have_output: bool, + /// Output batch + batches: VecDeque, + /// Number of rows in the buffer + total_rows: usize, + /// Batch that is ready to be output. + next_batch: Option<(RecordBatch, usize)>, + /// Schema of the output batches + schema: SchemaRef, + /// Target output batch size + target_batch_size: usize, + /// Memory reservation for tracking batch memory use + memory_reservation: MemoryReservation, +} + +impl OutputBuffer { + fn new( + schema: SchemaRef, + target_batch_size: usize, + memory_reservation: MemoryReservation, + ) -> Self { + Self { + have_output: false, + batches: VecDeque::new(), + total_rows: 0, + next_batch: None, + schema, + target_batch_size, + memory_reservation, + } + } + + /// Pash a new batch into the output buffer. + fn push(&mut self, batch: RecordBatch) -> Result<()> { + assert!(self.next_batch.is_none()); + assert!(self.total_rows < self.target_batch_size); + + self.memory_reservation + .try_grow(batch.get_array_memory_size())?; + if self.total_rows + batch.num_rows() < self.target_batch_size { + // Not enough rows to fill a batch. + self.total_rows += batch.num_rows(); + self.batches.push_back(batch); + return Ok(()); + } + let mut batches = std::mem::take(&mut self.batches); + if self.total_rows + batch.num_rows() == self.target_batch_size { + self.total_rows = 0; + batches.push_back(batch); + } else { + self.total_rows = batch.num_rows(); + self.batches.push_back(batch); + } + let mut total_array_memory_size = 0; + let batch = concat_batches( + &self.schema, + batches.iter().inspect(|batch| { + total_array_memory_size += batch.get_array_memory_size() + }), + )?; + self.next_batch = Some((batch, total_array_memory_size)); + Ok(()) + } + + /// Return a batch if there are enough rows in the buffer. + fn next(&mut self) -> Option { + if let Some((batch, total_array_memory_size)) = self.next_batch.take() { + self.have_output = true; + self.memory_reservation.shrink(total_array_memory_size); + return Some(batch); + } + if self.total_rows < self.target_batch_size { + return None; + } + if self.total_rows == self.target_batch_size { + self.total_rows = 0; + return self.batches.pop_front().inspect(|batch| { + self.memory_reservation + .shrink(batch.get_array_memory_size()) + }); + } + let batch = self.batches[0].slice(0, self.target_batch_size); + let rest = self.batches[0].slice( + self.target_batch_size, + self.total_rows - self.target_batch_size, + ); + self.have_output = true; + self.total_rows -= self.target_batch_size; + self.batches[0] = rest; + Some(batch) + } + + /// The number of rows needed to fill the next batch. + fn needed_rows(&self) -> usize { + if self.total_rows < self.target_batch_size { + self.target_batch_size - self.total_rows + } else { + 0 + } + } + + /// Flush any remaining rows into a batch even it it's not yet full. + fn flush(&mut self) -> Result> { + if self.total_rows == 0 { + return Ok(None); + } + let batch = std::mem::take(&mut self.batches); + let total_array_memory_size = + batch.iter().map(|b| b.get_array_memory_size()).sum(); + self.total_rows = 0; + self.memory_reservation.shrink(total_array_memory_size); + self.have_output = true; + Ok(Some(concat_batches(&self.schema, batch.iter())?)) + } + + /// Finish the output buffer, returning the last batch. If the + /// buffer has never output a batch then create an empty batch. + fn finish(&mut self) -> Result> { + Ok(self.flush()?.or_else(|| { + if self.have_output { + None + } else { + self.have_output = true; + Some(RecordBatch::new_empty(Arc::clone(&self.schema))) + } + })) + } +} + /// NestedLoopJoinExec is build-probe join operator, whose main task is to /// perform joins without any equijoin conditions in `ON` clause. /// @@ -334,6 +473,18 @@ impl ExecutionPlan for NestedLoopJoinExec { MemoryConsumer::new(format!("NestedLoopJoinLoad[{partition}]")) .register(context.memory_pool()); + // Initialization of reservation for processing the outer table + let memory_reservation = + MemoryConsumer::new(format!("NestedLoopJoin[{partition}]")) + .register(context.memory_pool()); + + // Initialization of reservation for buffering output + let buffer_memory_reservation = + MemoryConsumer::new(format!("NestedLoopJoinBuffer[{partition}]")) + .register(context.memory_pool()); + + let target_batch_size = context.session_config().batch_size(); + let inner_table = self.inner_table.once(|| { collect_left_input( Arc::clone(&self.left), @@ -344,7 +495,6 @@ impl ExecutionPlan for NestedLoopJoinExec { self.right().output_partitioning().partition_count(), ) }); - let outer_table = self.right.execute(partition, context)?; let indices_cache = (UInt64Array::new_null(0), UInt32Array::new_null(0)); @@ -352,6 +502,13 @@ impl ExecutionPlan for NestedLoopJoinExec { // Right side has an order and it is maintained during operation. let right_side_ordered = self.maintains_input_order()[1] && self.right.output_ordering().is_some(); + + let output_buffer = OutputBuffer::new( + Arc::clone(&self.schema), + target_batch_size, + buffer_memory_reservation, + ); + Ok(Box::pin(NestedLoopJoinStream { schema: Arc::clone(&self.schema), filter: self.filter.clone(), @@ -363,6 +520,10 @@ impl ExecutionPlan for NestedLoopJoinExec { join_metrics, indices_cache, right_side_ordered, + outer_record_batch: None, + outer_record_batch_row: 0, + memory_reservation, + output_buffer, })) } @@ -466,6 +627,14 @@ struct NestedLoopJoinStream { indices_cache: (UInt64Array, UInt32Array), /// Whether the right side is ordered right_side_ordered: bool, + /// The current record batch being processed + outer_record_batch: Option, + /// The current index of the outer record batch being processed + outer_record_batch_row: usize, + /// The memory reservation used to track memory usage + memory_reservation: MemoryReservation, + /// The buffer to store the output rows + output_buffer: OutputBuffer, } /// Creates a Cartesian product of two input batches, preserving the order of the right batch, @@ -560,91 +729,148 @@ impl NestedLoopJoinStream { // Get or initialize visited_left_side bitmap if required by join type let visited_left_side = left_data.bitmap(); - // Check is_exhausted before polling the outer_table, such that when the outer table - // does not support `FusedStream`, Self will not poll it again - if self.is_exhausted { - return Poll::Ready(None); - } + loop { + if let Some(batch) = self.output_buffer.next() { + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + return Poll::Ready(Some(Ok(batch))); + } - self.outer_table - .poll_next_unpin(cx) - .map(|maybe_batch| match maybe_batch { - Some(Ok(right_batch)) => { - // Setting up timer & updating input metrics - self.join_metrics.input_batches.add(1); - self.join_metrics.input_rows.add(right_batch.num_rows()); - let timer = self.join_metrics.join_time.timer(); - - let result = join_left_and_right_batch( - left_data.batch(), - &right_batch, - self.join_type, - self.filter.as_ref(), - &self.column_indices, - &self.schema, - visited_left_side, - &mut self.indices_cache, - self.right_side_ordered, - ); - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); + // Check is_exhausted before polling the outer_table, such that when the outer table + // does not support `FusedStream`, Self will not poll it again + if self.is_exhausted { + let batch = self.output_buffer.finish()?; + if let Some(batch) = &batch { + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + } + return Poll::Ready(Ok(batch).transpose()); + } + + if self.outer_record_batch.is_none() { + // Get the next outer record batch + match self.outer_table.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(batch))) => { + self.join_metrics.input_batches.add(1); + self.join_metrics.input_rows.add(batch.num_rows()); + self.memory_reservation + .try_grow(batch.get_array_memory_size())?; + self.outer_record_batch = Some(batch); + self.outer_record_batch_row = 0; } + Poll::Ready(Some(Err(e))) => return Poll::Ready(Some(Err(e))), + Poll::Ready(None) => { + if need_produce_result_in_final(self.join_type) { + // At this stage `visited_left_side` won't be updated, so it's + // safe to report about probe completion. + // + // Setting `is_exhausted` will prevent from multiple calls of + // `report_probe_completed()` + if !left_data.report_probe_completed() { + self.is_exhausted = true; + continue; + }; + + // Only setting up timer, input is exhausted + let timer = self.join_metrics.join_time.timer(); + // use the global left bitmap to produce the left indices and right indices + let (left_side, right_side) = + get_final_indices_from_shared_bitmap( + visited_left_side, + self.join_type, + ); + let empty_right_batch = + RecordBatch::new_empty(self.outer_table.schema()); + // use the left and right indices to produce the batch result + let result = build_batch_from_indices( + &self.schema, + left_data.batch(), + &empty_right_batch, + &left_side, + &right_side, + &self.column_indices, + JoinSide::Left, + ); + self.is_exhausted = true; - Some(result) - } - Some(err) => Some(err), - None => { - if need_produce_result_in_final(self.join_type) { - // At this stage `visited_left_side` won't be updated, so it's - // safe to report about probe completion. - // - // Setting `is_exhausted` / returning None will prevent from - // multiple calls of `report_probe_completed()` - if !left_data.report_probe_completed() { + // Recording time & updating output metrics + match result { + Ok(batch) => { + timer.done(); + self.output_buffer.push(batch)?; + continue; + } + Err(e) => return Poll::Ready(Some(Err(e))), + } + } else { self.is_exhausted = true; - return None; + continue; + } + } + Poll::Pending => { + return match self.output_buffer.flush() { + Ok(Some(batch)) => { + // If there was anything in the output buffer flush it + // so that it can be processed. + self.join_metrics.output_batches.add(1); + self.join_metrics.output_rows.add(batch.num_rows()); + Poll::Ready(Some(Ok(batch))) + } + Ok(None) => Poll::Pending, + Err(err) => Poll::Ready(Some(Err(err))), }; + } + } + } - // Only setting up timer, input is exhausted - let timer = self.join_metrics.join_time.timer(); - // use the global left bitmap to produce the left indices and right indices - let (left_side, right_side) = - get_final_indices_from_shared_bitmap( - visited_left_side, - self.join_type, - ); - let empty_right_batch = - RecordBatch::new_empty(self.outer_table.schema()); - // use the left and right indices to produce the batch result - let result = build_batch_from_indices( - &self.schema, - left_data.batch(), - &empty_right_batch, - &left_side, - &right_side, - &self.column_indices, - JoinSide::Left, - ); - self.is_exhausted = true; - - // Recording time & updating output metrics - if let Ok(batch) = &result { - timer.done(); - self.join_metrics.output_batches.add(1); - self.join_metrics.output_rows.add(batch.num_rows()); - } + debug_assert!(self.outer_record_batch.is_some()); + let right_batch = self.outer_record_batch.as_ref().unwrap(); + let num_rows = if left_data.batch().num_rows() == 0 { + self.output_buffer.needed_rows() + } else { + std::cmp::max( + 1, + self.output_buffer.needed_rows() / left_data.batch().num_rows(), + ) + }; + let num_rows = std::cmp::min( + num_rows, + right_batch.num_rows() - self.outer_record_batch_row, + ); + let sliced_right_batch = + right_batch.slice(self.outer_record_batch_row, num_rows); + + // Setting up timer + let timer = self.join_metrics.join_time.timer(); + + let result = join_left_and_right_batch( + left_data.batch(), + &sliced_right_batch, + self.join_type, + self.filter.as_ref(), + &self.column_indices, + &self.schema, + visited_left_side, + &mut self.indices_cache, + self.right_side_ordered, + ); - Some(result) - } else { - // end of the join loop - None + // Recording time & updating output metrics + match result { + Ok(batch) => { + timer.done(); + self.outer_record_batch_row += num_rows; + if self.outer_record_batch_row >= right_batch.num_rows() { + if let Some(batch) = self.outer_record_batch.take() { + self.memory_reservation + .shrink(batch.get_array_memory_size()) + } } + self.output_buffer.push(batch)?; } - }) + Err(e) => return Poll::Ready(Some(Err(e))), + } + } } } @@ -734,6 +960,7 @@ mod tests { use arrow_array::Int32Array; use arrow_schema::SortOptions; use datafusion_common::{assert_batches_sorted_eq, assert_contains, ScalarValue}; + use datafusion_execution::config::SessionConfig; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_expr::Operator; use datafusion_physical_expr::expressions::{BinaryExpr, Literal}; @@ -1172,8 +1399,10 @@ mod tests { assert_contains!( err.to_string(), - "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as: NestedLoopJoinLoad[0]" + "External error: Resources exhausted: Additional allocation failed with top memory consumers (across reservations) as:" ); + assert_contains!(err.to_string(), "NestedLoopJoinLoad[0] consumed 0 bytes"); + assert_contains!(err.to_string(), "NestedLoopJoin[0] consumed 0 bytes"); } Ok(()) @@ -1341,4 +1570,131 @@ mod tests { fn columns(schema: &Schema) -> Vec { schema.fields().iter().map(|f| f.name().clone()).collect() } + + #[tokio::test] + async fn test_batch_size() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + ("b1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + ("c1", &vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 0]), + None, + Vec::new(), + ); + let right = build_table( + ("a2", &vec![10, 11, 12, 13, 14]), + ("b2", &vec![20, 21, 22, 23, 24]), + ("c2", &vec![30, 31, 32, 33, 34]), + None, + Vec::new(), + ); + let filter = prepare_join_filter(); + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + JoinType::LeftSemi, + JoinType::LeftAnti, + JoinType::RightSemi, + JoinType::RightAnti, + ]; + + for join_type in join_types { + let config = SessionConfig::new().with_batch_size(3); + let task_ctx = TaskContext::default().with_session_config(config); + let task_ctx = Arc::new(task_ctx); + + let (_, batches) = multi_partitioned_join_collect( + Arc::clone(&left), + Arc::clone(&right), + &join_type, + Some(filter.clone()), + task_ctx, + ) + .await + .unwrap(); + + let max_rows = batches + .iter() + .map(|batch| batch.num_rows()) + .max() + .unwrap_or(0); + assert!(max_rows <= 3); + } + + Ok(()) + } + + #[tokio::test] + async fn test_issue_12633() -> Result<()> { + let schema = Arc::new(Schema::new(vec![Field::new("v", DataType::Int32, false)])); + let batches = vec![ + RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![0, 1, 2, 3, 4]))], + )?, + RecordBatch::try_new( + Arc::clone(&schema), + vec![Arc::new(Int32Array::from(vec![5, 6, 7, 8, 9]))], + )?, + ]; + let left = Arc::new(MemoryExec::try_new( + &[batches.clone()], + Arc::clone(&schema), + None, + )?) as Arc; + let right = Arc::new(MemoryExec::try_new( + &[batches.clone()], + Arc::clone(&schema), + None, + )?) as Arc; + + let column_indices = JoinFilter::build_column_indices(vec![0], vec![0]); + let intermediate_schema = Schema::new(vec![ + Field::new("v", DataType::Int32, false), + Field::new("v", DataType::Int32, false), + ]); + let filter = JoinFilter::new( + Arc::new(BinaryExpr::new( + Arc::new(Column::new("v", 0)), + Operator::NotEq, + Arc::new(Column::new("v", 1)), + )), + column_indices, + intermediate_schema, + ); + + let config = SessionConfig::new().with_batch_size(5); + let task_ctx = Arc::new(TaskContext::default().with_session_config(config)); + + let join_types = vec![ + JoinType::Inner, + JoinType::Left, + JoinType::Right, + JoinType::Full, + ]; + + for join_type in join_types { + let (_, batches) = multi_partitioned_join_collect( + Arc::clone(&left), + Arc::clone(&right), + &join_type, + Some(filter.clone()), + Arc::clone(&task_ctx), + ) + .await + .unwrap(); + + let rows = batches.iter().map(|batch| batch.num_rows()).sum::(); + assert_eq!(rows, 90); + let max_rows = batches + .iter() + .map(|batch| batch.num_rows()) + .max() + .unwrap_or(0); + assert!(max_rows <= 5); + } + Ok(()) + } } diff --git a/datafusion/sqllogictest/test_files/joins.slt b/datafusion/sqllogictest/test_files/joins.slt index 64d5f6c7b88f..36d006dbe768 100644 --- a/datafusion/sqllogictest/test_files/joins.slt +++ b/datafusion/sqllogictest/test_files/joins.slt @@ -2130,7 +2130,7 @@ physical_plan 09)------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 10)--------MemoryExec: partitions=1, partition_sizes=[1] -query II +query II rowsort SELECT join_t1.t1_id, join_t2.t2_id FROM (select t1_id from join_t1 where join_t1.t1_id > 22) as join_t1 RIGHT JOIN (select t2_id from join_t2 where join_t2.t2_id > 11) as join_t2 @@ -4187,4 +4187,4 @@ physical_plan 02)--HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(b@1, y@1)], filter=a@0 < x@1 03)----MemoryExec: partitions=1, partition_sizes=[0] 04)----SortExec: expr=[x@0 ASC NULLS LAST], preserve_partitioning=[false] -05)------MemoryExec: partitions=1, partition_sizes=[0] \ No newline at end of file +05)------MemoryExec: partitions=1, partition_sizes=[0]