Skip to content

Commit 8ef0d83

Browse files
committed
Extract CoalesceBatchesStream to a struct
1 parent 77311a5 commit 8ef0d83

File tree

1 file changed

+173
-117
lines changed

1 file changed

+173
-117
lines changed

datafusion/physical-plan/src/coalesce_batches.rs

Lines changed: 173 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
//! CoalesceBatchesExec combines small batches into larger batches for more efficient use of
19-
//! vectorized processing by upstream operators.
18+
//! [`CoalesceBatchesExec`] combines small batches into larger batches.
2019
2120
use std::any::Any;
2221
use std::pin::Pin;
2322
use std::sync::Arc;
24-
use std::task::{Context, Poll};
23+
use std::task::{ready, Context, Poll};
2524

2625
use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet};
2726
use super::{DisplayAs, ExecutionPlanProperties, PlanProperties, Statistics};
@@ -38,8 +37,35 @@ use datafusion_execution::TaskContext;
3837
use futures::stream::{Stream, StreamExt};
3938
use log::trace;
4039

41-
/// CoalesceBatchesExec combines small batches into larger batches for more efficient use of
42-
/// vectorized processing by upstream operators.
40+
/// `CoalesceBatchesExec` combines small batches into larger batches for more
41+
/// efficient use of vectorized processing by upstream operators.
42+
///
43+
/// Generally speaking, larger RecordBatches are more efficient to process than
44+
/// smaller record batches (until the CPU cache is exceeded) because there is
45+
/// fixed processing overhead per batch. This code concatenates multiple small
46+
/// record batches into larger ones to amortize this overhead.
47+
///
48+
/// ```text
49+
/// ┌────────────────────┐
50+
/// │ RecordBatch │
51+
/// │ num_rows = 23 │
52+
/// └────────────────────┘ ┌────────────────────┐
53+
/// │ │
54+
/// ┌────────────────────┐ Coalesce │ │
55+
/// │ │ Batches │ │
56+
/// │ RecordBatch │ │ │
57+
/// │ num_rows = 50 │ ─ ─ ─ ─ ─ ─ ▶ │ │
58+
/// │ │ │ RecordBatch │
59+
/// │ │ │ num_rows = 106 │
60+
/// └────────────────────┘ │ │
61+
/// │ │
62+
/// ┌────────────────────┐ │ │
63+
/// │ │ │ │
64+
/// │ RecordBatch │ │ │
65+
/// │ num_rows = 33 │ └────────────────────┘
66+
/// │ │
67+
/// └────────────────────┘
68+
/// ```
4369
#[derive(Debug)]
4470
pub struct CoalesceBatchesExec {
4571
/// The input plan
@@ -146,10 +172,7 @@ impl ExecutionPlan for CoalesceBatchesExec {
146172
) -> Result<SendableRecordBatchStream> {
147173
Ok(Box::pin(CoalesceBatchesStream {
148174
input: self.input.execute(partition, context)?,
149-
schema: self.input.schema(),
150-
target_batch_size: self.target_batch_size,
151-
buffer: Vec::new(),
152-
buffered_rows: 0,
175+
coalescer: BatchCoalescer::new(self.input.schema(), self.target_batch_size),
153176
is_closed: false,
154177
baseline_metrics: BaselineMetrics::new(&self.metrics, partition),
155178
}))
@@ -164,17 +187,12 @@ impl ExecutionPlan for CoalesceBatchesExec {
164187
}
165188
}
166189

190+
/// Stream for [`CoalesceBatchesExec`]. See [`CoalesceBatchesExec`] for more details.
167191
struct CoalesceBatchesStream {
168192
/// The input plan
169193
input: SendableRecordBatchStream,
170-
/// The input schema
171-
schema: SchemaRef,
172-
/// Minimum number of rows for coalesces batches
173-
target_batch_size: usize,
174-
/// Buffered batches
175-
buffer: Vec<RecordBatch>,
176-
/// Buffered row count
177-
buffered_rows: usize,
194+
/// Buffer for combining batches
195+
coalescer: BatchCoalescer,
178196
/// Whether the stream has finished returning all of its data or not
179197
is_closed: bool,
180198
/// Execution metrics
@@ -213,66 +231,35 @@ impl CoalesceBatchesStream {
213231
let input_batch = self.input.poll_next_unpin(cx);
214232
// records time on drop
215233
let _timer = cloned_time.timer();
216-
match input_batch {
217-
Poll::Ready(x) => match x {
218-
Some(Ok(batch)) => {
219-
if batch.num_rows() >= self.target_batch_size
220-
&& self.buffer.is_empty()
221-
{
222-
return Poll::Ready(Some(Ok(batch)));
223-
} else if batch.num_rows() == 0 {
224-
// discard empty batches
225-
} else {
226-
// add to the buffered batches
227-
self.buffered_rows += batch.num_rows();
228-
self.buffer.push(batch);
229-
// check to see if we have enough batches yet
230-
if self.buffered_rows >= self.target_batch_size {
231-
// combine the batches and return
232-
let batch = concat_batches(
233-
&self.schema,
234-
&self.buffer,
235-
self.buffered_rows,
236-
)?;
237-
// reset buffer state
238-
self.buffer.clear();
239-
self.buffered_rows = 0;
240-
// return batch
241-
return Poll::Ready(Some(Ok(batch)));
242-
}
243-
}
244-
}
245-
None => {
246-
self.is_closed = true;
247-
// we have reached the end of the input stream but there could still
248-
// be buffered batches
249-
if self.buffer.is_empty() {
250-
return Poll::Ready(None);
251-
} else {
252-
// combine the batches and return
253-
let batch = concat_batches(
254-
&self.schema,
255-
&self.buffer,
256-
self.buffered_rows,
257-
)?;
258-
// reset buffer state
259-
self.buffer.clear();
260-
self.buffered_rows = 0;
261-
// return batch
262-
return Poll::Ready(Some(Ok(batch)));
263-
}
234+
match ready!(input_batch) {
235+
Some(result) => {
236+
let Ok(input_batch) = result else {
237+
return Poll::Ready(Some(result)); // pass back error
238+
};
239+
// Buffer the batch and either get more input if not enough
240+
// rows yet or output
241+
match self.coalescer.push_batch(input_batch) {
242+
Ok(None) => continue,
243+
res => return Poll::Ready(res.transpose()),
264244
}
265-
other => return Poll::Ready(other),
266-
},
267-
Poll::Pending => return Poll::Pending,
245+
}
246+
None => {
247+
self.is_closed = true;
248+
// we have reached the end of the input stream but there could still
249+
// be buffered batches
250+
return match self.coalescer.finish() {
251+
Ok(None) => Poll::Ready(None),
252+
res => Poll::Ready(res.transpose()),
253+
};
254+
}
268255
}
269256
}
270257
}
271258
}
272259

273260
impl RecordBatchStream for CoalesceBatchesStream {
274261
fn schema(&self) -> SchemaRef {
275-
Arc::clone(&self.schema)
262+
self.coalescer.schema()
276263
}
277264
}
278265

@@ -290,26 +277,106 @@ pub fn concat_batches(
290277
arrow::compute::concat_batches(schema, batches)
291278
}
292279

280+
/// Concatenate multiple record batches into larger batches
281+
///
282+
/// See [`CoalesceBatchesExec`] for more details.
283+
///
284+
/// Notes:
285+
///
286+
/// 1. The output rows is the same order as the input rows
287+
///
288+
/// 2. The output is a sequence of batches, with all but the last being at least
289+
/// `target_batch_size` rows.
290+
///
291+
/// 3. Eventually this may also be able to handle other optimizations such as a
292+
/// combined filter/coalesce operation.
293+
#[derive(Debug)]
294+
struct BatchCoalescer {
295+
/// The input schema
296+
schema: SchemaRef,
297+
/// Minimum number of rows for coalesces batches
298+
target_batch_size: usize,
299+
/// Buffered batches
300+
buffer: Vec<RecordBatch>,
301+
/// Buffered row count
302+
buffered_rows: usize,
303+
}
304+
305+
impl BatchCoalescer {
306+
/// Create a new BatchCoalescer that produces batches of at least `target_batch_size` rows
307+
fn new(schema: SchemaRef, target_batch_size: usize) -> Self {
308+
Self {
309+
schema,
310+
target_batch_size,
311+
buffer: vec![],
312+
buffered_rows: 0,
313+
}
314+
}
315+
316+
/// Return the schema of the output batches
317+
fn schema(&self) -> SchemaRef {
318+
Arc::clone(&self.schema)
319+
}
320+
321+
/// Add a batch to the coalescer, returning a batch if the target batch size is reached
322+
fn push_batch(&mut self, batch: RecordBatch) -> Result<Option<RecordBatch>> {
323+
if batch.num_rows() >= self.target_batch_size && self.buffer.is_empty() {
324+
return Ok(Some(batch));
325+
}
326+
// discard empty batches
327+
if batch.num_rows() == 0 {
328+
return Ok(None);
329+
}
330+
// add to the buffered batches
331+
self.buffered_rows += batch.num_rows();
332+
self.buffer.push(batch);
333+
// check to see if we have enough batches yet
334+
let batch = if self.buffered_rows >= self.target_batch_size {
335+
// combine the batches and return
336+
let batch = concat_batches(&self.schema, &self.buffer, self.buffered_rows)?;
337+
// reset buffer state
338+
self.buffer.clear();
339+
self.buffered_rows = 0;
340+
// return batch
341+
Some(batch)
342+
} else {
343+
None
344+
};
345+
Ok(batch)
346+
}
347+
348+
/// Finish the coalescing process, returning all buffered data as a final,
349+
/// single batch, if any
350+
fn finish(&mut self) -> Result<Option<RecordBatch>> {
351+
if self.buffer.is_empty() {
352+
Ok(None)
353+
} else {
354+
// combine the batches and return
355+
let batch = concat_batches(&self.schema, &self.buffer, self.buffered_rows)?;
356+
// reset buffer state
357+
self.buffer.clear();
358+
self.buffered_rows = 0;
359+
// return batch
360+
Ok(Some(batch))
361+
}
362+
}
363+
}
364+
293365
#[cfg(test)]
294366
mod tests {
295367
use super::*;
296-
use crate::{memory::MemoryExec, repartition::RepartitionExec, Partitioning};
297-
298368
use arrow::datatypes::{DataType, Field, Schema};
299369
use arrow_array::UInt32Array;
300370

301371
#[tokio::test(flavor = "multi_thread")]
302372
async fn test_concat_batches() -> Result<()> {
303-
let schema = test_schema();
304-
let partition = create_vec_batches(&schema, 10);
305-
let partitions = vec![partition];
306-
307-
let output_partitions = coalesce_batches(&schema, partitions, 21).await?;
308-
assert_eq!(1, output_partitions.len());
373+
let Scenario { schema, batch } = uint32_scenario();
309374

310375
// input is 10 batches x 8 rows (80 rows)
376+
let input = std::iter::repeat(batch).take(10);
377+
311378
// expected output is batches of at least 20 rows (except for the final batch)
312-
let batches = &output_partitions[0];
379+
let batches = do_coalesce_batches(&schema, input, 21);
313380
assert_eq!(4, batches.len());
314381
assert_eq!(24, batches[0].num_rows());
315382
assert_eq!(24, batches[1].num_rows());
@@ -319,54 +386,43 @@ mod tests {
319386
Ok(())
320387
}
321388

322-
fn test_schema() -> Arc<Schema> {
323-
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]))
324-
}
325-
326-
async fn coalesce_batches(
389+
// Coalesce the batches with a BatchCoalescer function with the given input
390+
// and target batch size returning the resulting batches
391+
fn do_coalesce_batches(
327392
schema: &SchemaRef,
328-
input_partitions: Vec<Vec<RecordBatch>>,
393+
input: impl IntoIterator<Item = RecordBatch>,
329394
target_batch_size: usize,
330-
) -> Result<Vec<Vec<RecordBatch>>> {
395+
) -> Vec<RecordBatch> {
331396
// create physical plan
332-
let exec = MemoryExec::try_new(&input_partitions, Arc::clone(schema), None)?;
333-
let exec =
334-
RepartitionExec::try_new(Arc::new(exec), Partitioning::RoundRobinBatch(1))?;
335-
let exec: Arc<dyn ExecutionPlan> =
336-
Arc::new(CoalesceBatchesExec::new(Arc::new(exec), target_batch_size));
337-
338-
// execute and collect results
339-
let output_partition_count = exec.output_partitioning().partition_count();
340-
let mut output_partitions = Vec::with_capacity(output_partition_count);
341-
for i in 0..output_partition_count {
342-
// execute this *output* partition and collect all batches
343-
let task_ctx = Arc::new(TaskContext::default());
344-
let mut stream = exec.execute(i, Arc::clone(&task_ctx))?;
345-
let mut batches = vec![];
346-
while let Some(result) = stream.next().await {
347-
batches.push(result?);
348-
}
349-
output_partitions.push(batches);
397+
let mut coalescer = BatchCoalescer::new(Arc::clone(schema), target_batch_size);
398+
let mut output_batches: Vec<_> = input
399+
.into_iter()
400+
.filter_map(|batch| coalescer.push_batch(batch).unwrap())
401+
.collect();
402+
if let Some(batch) = coalescer.finish().unwrap() {
403+
output_batches.push(batch);
350404
}
351-
Ok(output_partitions)
405+
output_batches
352406
}
353407

354-
/// Create vector batches
355-
fn create_vec_batches(schema: &Schema, n: usize) -> Vec<RecordBatch> {
356-
let batch = create_batch(schema);
357-
let mut vec = Vec::with_capacity(n);
358-
for _ in 0..n {
359-
vec.push(batch.clone());
360-
}
361-
vec
408+
/// Test scenario
409+
#[derive(Debug)]
410+
struct Scenario {
411+
schema: Arc<Schema>,
412+
batch: RecordBatch,
362413
}
363414

364-
/// Create batch
365-
fn create_batch(schema: &Schema) -> RecordBatch {
366-
RecordBatch::try_new(
367-
Arc::new(schema.clone()),
415+
/// a batch of 8 rows of UInt32
416+
fn uint32_scenario() -> Scenario {
417+
let schema =
418+
Arc::new(Schema::new(vec![Field::new("c0", DataType::UInt32, false)]));
419+
420+
let batch = RecordBatch::try_new(
421+
Arc::clone(&schema),
368422
vec![Arc::new(UInt32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8]))],
369423
)
370-
.unwrap()
424+
.unwrap();
425+
426+
Scenario { schema, batch }
371427
}
372428
}

0 commit comments

Comments
 (0)