Skip to content

Commit e1a8e9d

Browse files
authored
refactor: Use SpillManager for all spilling scenarios (#15405)
* Use SpillManager in all spilling scenarios * resolve conflict * fix ci format
1 parent 51b6a65 commit e1a8e9d

File tree

3 files changed

+76
-109
lines changed

3 files changed

+76
-109
lines changed

datafusion/physical-plan/src/aggregates/row_hash.rs

Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ use crate::aggregates::{
3030
use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
3131
use crate::sorts::sort::sort_batch;
3232
use crate::sorts::streaming_merge::StreamingMergeBuilder;
33-
use crate::spill::{read_spill_as_stream, spill_record_batch_by_size};
33+
use crate::spill::spill_manager::SpillManager;
3434
use crate::stream::RecordBatchStreamAdapter;
3535
use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr};
3636
use crate::{RecordBatchStream, SendableRecordBatchStream};
@@ -42,7 +42,6 @@ use datafusion_common::{internal_err, DataFusionError, Result};
4242
use datafusion_execution::disk_manager::RefCountedTempFile;
4343
use datafusion_execution::memory_pool::proxy::VecAllocExt;
4444
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
45-
use datafusion_execution::runtime_env::RuntimeEnv;
4645
use datafusion_execution::TaskContext;
4746
use datafusion_expr::{EmitTo, GroupsAccumulator};
4847
use datafusion_physical_expr::expressions::Column;
@@ -91,6 +90,9 @@ struct SpillState {
9190
/// GROUP BY expressions for merging spilled data
9291
merging_group_by: PhysicalGroupBy,
9392

93+
/// Manages the process of spilling and reading back intermediate data
94+
spill_manager: SpillManager,
95+
9496
// ========================================================================
9597
// STATES:
9698
// Fields changes during execution. Can be buffer, or state flags that
@@ -109,12 +111,7 @@ struct SpillState {
109111
/// Peak memory used for buffered data.
110112
/// Calculated as sum of peak memory values across partitions
111113
peak_mem_used: metrics::Gauge,
112-
/// count of spill files during the execution of the operator
113-
spill_count: metrics::Count,
114-
/// total spilled bytes during the execution of the operator
115-
spilled_bytes: metrics::Count,
116-
/// total spilled rows during the execution of the operator
117-
spilled_rows: metrics::Count,
114+
// Metrics related to spilling are managed inside `spill_manager`
118115
}
119116

120117
/// Tracks if the aggregate should skip partial aggregations
@@ -435,9 +432,6 @@ pub(crate) struct GroupedHashAggregateStream {
435432

436433
/// Execution metrics
437434
baseline_metrics: BaselineMetrics,
438-
439-
/// The [`RuntimeEnv`] associated with the [`TaskContext`] argument
440-
runtime: Arc<RuntimeEnv>,
441435
}
442436

443437
impl GroupedHashAggregateStream {
@@ -544,6 +538,12 @@ impl GroupedHashAggregateStream {
544538

545539
let exec_state = ExecutionState::ReadingInput;
546540

541+
let spill_manager = SpillManager::new(
542+
context.runtime_env(),
543+
metrics::SpillMetrics::new(&agg.metrics, partition),
544+
Arc::clone(&partial_agg_schema),
545+
);
546+
547547
let spill_state = SpillState {
548548
spills: vec![],
549549
spill_expr,
@@ -553,9 +553,7 @@ impl GroupedHashAggregateStream {
553553
merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()),
554554
peak_mem_used: MetricBuilder::new(&agg.metrics)
555555
.gauge("peak_mem_used", partition),
556-
spill_count: MetricBuilder::new(&agg.metrics).spill_count(partition),
557-
spilled_bytes: MetricBuilder::new(&agg.metrics).spilled_bytes(partition),
558-
spilled_rows: MetricBuilder::new(&agg.metrics).spilled_rows(partition),
556+
spill_manager,
559557
};
560558

561559
// Skip aggregation is supported if:
@@ -604,7 +602,6 @@ impl GroupedHashAggregateStream {
604602
batch_size,
605603
group_ordering,
606604
input_done: false,
607-
runtime: context.runtime_env(),
608605
spill_state,
609606
group_values_soft_limit: agg.limit,
610607
skip_aggregation_probe,
@@ -981,28 +978,30 @@ impl GroupedHashAggregateStream {
981978
Ok(())
982979
}
983980

984-
/// Emit all rows, sort them, and store them on disk.
981+
/// Emit all intermediate aggregation states, sort them, and store them on disk.
982+
/// This process helps in reducing memory pressure by allowing the data to be
983+
/// read back with streaming merge.
985984
fn spill(&mut self) -> Result<()> {
985+
// Emit and sort intermediate aggregation state
986986
let Some(emit) = self.emit(EmitTo::All, true)? else {
987987
return Ok(());
988988
};
989989
let sorted = sort_batch(&emit, self.spill_state.spill_expr.as_ref(), None)?;
990-
let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?;
991-
// TODO: slice large `sorted` and write to multiple files in parallel
992-
spill_record_batch_by_size(
990+
991+
// Spill sorted state to disk
992+
let spillfile = self.spill_state.spill_manager.spill_record_batch_by_size(
993993
&sorted,
994-
spillfile.path().into(),
995-
sorted.schema(),
994+
"HashAggSpill",
996995
self.batch_size,
997996
)?;
998-
self.spill_state.spills.push(spillfile);
999-
1000-
// Update metrics
1001-
self.spill_state.spill_count.add(1);
1002-
self.spill_state
1003-
.spilled_bytes
1004-
.add(sorted.get_array_memory_size());
1005-
self.spill_state.spilled_rows.add(sorted.num_rows());
997+
match spillfile {
998+
Some(spillfile) => self.spill_state.spills.push(spillfile),
999+
None => {
1000+
return internal_err!(
1001+
"Calling spill with no intermediate batch to spill"
1002+
);
1003+
}
1004+
}
10061005

10071006
Ok(())
10081007
}
@@ -1058,7 +1057,7 @@ impl GroupedHashAggregateStream {
10581057
})),
10591058
)));
10601059
for spill in self.spill_state.spills.drain(..) {
1061-
let stream = read_spill_as_stream(spill, Arc::clone(&schema), 2)?;
1060+
let stream = self.spill_state.spill_manager.read_spill_as_stream(spill)?;
10621061
streams.push(stream);
10631062
}
10641063
self.spill_state.is_stream_merging = true;

datafusion/physical-plan/src/joins/sort_merge_join.rs

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@ use crate::joins::utils::{
4141
reorder_output_after_swap, symmetric_join_output_partitioning, JoinFilter, JoinOn,
4242
JoinOnRef,
4343
};
44-
use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
44+
use crate::metrics::{
45+
Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, SpillMetrics,
46+
};
4547
use crate::projection::{
4648
join_allows_pushdown, join_table_borders, new_join_children,
4749
physical_to_column_exprs, update_join_on, ProjectionExec,
4850
};
49-
use crate::spill::spill_record_batches;
51+
use crate::spill::spill_manager::SpillManager;
5052
use crate::{
5153
metrics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
5254
ExecutionPlanProperties, PhysicalExpr, PlanProperties, RecordBatchStream,
@@ -596,12 +598,8 @@ struct SortMergeJoinMetrics {
596598
/// Peak memory used for buffered data.
597599
/// Calculated as sum of peak memory values across partitions
598600
peak_mem_used: metrics::Gauge,
599-
/// count of spills during the execution of the operator
600-
spill_count: Count,
601-
/// total spilled bytes during the execution of the operator
602-
spilled_bytes: Count,
603-
/// total spilled rows during the execution of the operator
604-
spilled_rows: Count,
601+
/// Metrics related to spilling
602+
spill_metrics: SpillMetrics,
605603
}
606604

607605
impl SortMergeJoinMetrics {
@@ -615,9 +613,7 @@ impl SortMergeJoinMetrics {
615613
MetricBuilder::new(metrics).counter("output_batches", partition);
616614
let output_rows = MetricBuilder::new(metrics).output_rows(partition);
617615
let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition);
618-
let spill_count = MetricBuilder::new(metrics).spill_count(partition);
619-
let spilled_bytes = MetricBuilder::new(metrics).spilled_bytes(partition);
620-
let spilled_rows = MetricBuilder::new(metrics).spilled_rows(partition);
616+
let spill_metrics = SpillMetrics::new(metrics, partition);
621617

622618
Self {
623619
join_time,
@@ -626,9 +622,7 @@ impl SortMergeJoinMetrics {
626622
output_batches,
627623
output_rows,
628624
peak_mem_used,
629-
spill_count,
630-
spilled_bytes,
631-
spilled_rows,
625+
spill_metrics,
632626
}
633627
}
634628
}
@@ -884,6 +878,8 @@ struct SortMergeJoinStream {
884878
pub reservation: MemoryReservation,
885879
/// Runtime env
886880
pub runtime_env: Arc<RuntimeEnv>,
881+
/// Manages the process of spilling and reading back intermediate data
882+
pub spill_manager: SpillManager,
887883
/// A unique number for each batch
888884
pub streamed_batch_counter: AtomicUsize,
889885
}
@@ -1301,6 +1297,11 @@ impl SortMergeJoinStream {
13011297
) -> Result<Self> {
13021298
let streamed_schema = streamed.schema();
13031299
let buffered_schema = buffered.schema();
1300+
let spill_manager = SpillManager::new(
1301+
Arc::clone(&runtime_env),
1302+
join_metrics.spill_metrics.clone(),
1303+
Arc::clone(&buffered_schema),
1304+
);
13041305
Ok(Self {
13051306
state: SortMergeJoinState::Init,
13061307
sort_options,
@@ -1333,6 +1334,7 @@ impl SortMergeJoinStream {
13331334
join_metrics,
13341335
reservation,
13351336
runtime_env,
1337+
spill_manager,
13361338
streamed_batch_counter: AtomicUsize::new(0),
13371339
})
13381340
}
@@ -1402,27 +1404,19 @@ impl SortMergeJoinStream {
14021404
Ok(())
14031405
}
14041406
Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => {
1405-
// spill buffered batch to disk
1406-
let spill_file = self
1407-
.runtime_env
1408-
.disk_manager
1409-
.create_tmp_file("sort_merge_join_buffered_spill")?;
1410-
1407+
// Spill buffered batch to disk
14111408
if let Some(batch) = buffered_batch.batch {
1412-
spill_record_batches(
1413-
&[batch],
1414-
spill_file.path().into(),
1415-
Arc::clone(&self.buffered_schema),
1416-
)?;
1409+
let spill_file = self
1410+
.spill_manager
1411+
.spill_record_batch_and_finish(
1412+
&[batch],
1413+
"sort_merge_join_buffered_spill",
1414+
)?
1415+
.unwrap(); // Operation only return None if no batches are spilled, here we ensure that at least one batch is spilled
1416+
14171417
buffered_batch.spill_file = Some(spill_file);
14181418
buffered_batch.batch = None;
14191419

1420-
// update metrics to register spill
1421-
self.join_metrics.spill_count.add(1);
1422-
self.join_metrics
1423-
.spilled_bytes
1424-
.add(buffered_batch.size_estimation);
1425-
self.join_metrics.spilled_rows.add(buffered_batch.num_rows);
14261420
Ok(())
14271421
} else {
14281422
internal_err!("Buffered batch has empty body")

datafusion/physical-plan/src/spill/mod.rs

Lines changed: 22 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -29,55 +29,9 @@ use arrow::array::ArrayData;
2929
use arrow::datatypes::{Schema, SchemaRef};
3030
use arrow::ipc::{reader::StreamReader, writer::StreamWriter};
3131
use arrow::record_batch::RecordBatch;
32-
use log::debug;
3332
use tokio::sync::mpsc::Sender;
3433

3534
use datafusion_common::{exec_datafusion_err, HashSet, Result};
36-
use datafusion_execution::disk_manager::RefCountedTempFile;
37-
use datafusion_execution::memory_pool::human_readable_size;
38-
use datafusion_execution::SendableRecordBatchStream;
39-
40-
use crate::stream::RecordBatchReceiverStream;
41-
42-
/// Read spilled batches from the disk
43-
///
44-
/// `path` - temp file
45-
/// `schema` - batches schema, should be the same across batches
46-
/// `buffer` - internal buffer of capacity batches
47-
pub(crate) fn read_spill_as_stream(
48-
path: RefCountedTempFile,
49-
schema: SchemaRef,
50-
buffer: usize,
51-
) -> Result<SendableRecordBatchStream> {
52-
let mut builder = RecordBatchReceiverStream::builder(schema, buffer);
53-
let sender = builder.tx();
54-
55-
builder.spawn_blocking(move || read_spill(sender, path.path()));
56-
57-
Ok(builder.build())
58-
}
59-
60-
/// Spills in-memory `batches` to disk.
61-
///
62-
/// Returns total number of the rows spilled to disk.
63-
pub(crate) fn spill_record_batches(
64-
batches: &[RecordBatch],
65-
path: PathBuf,
66-
schema: SchemaRef,
67-
) -> Result<(usize, usize)> {
68-
let mut writer = IPCStreamWriter::new(path.as_ref(), schema.as_ref())?;
69-
for batch in batches {
70-
writer.write(batch)?;
71-
}
72-
writer.finish()?;
73-
debug!(
74-
"Spilled {} batches of total {} rows to disk, memory released {}",
75-
writer.num_batches,
76-
writer.num_rows,
77-
human_readable_size(writer.num_bytes),
78-
);
79-
Ok((writer.num_rows, writer.num_bytes))
80-
}
8135

8236
fn read_spill(sender: Sender<Result<RecordBatch>>, path: &Path) -> Result<()> {
8337
let file = BufReader::new(File::open(path)?);
@@ -92,6 +46,10 @@ fn read_spill(sender: Sender<Result<RecordBatch>>, path: &Path) -> Result<()> {
9246

9347
/// Spill the `RecordBatch` to disk as smaller batches
9448
/// split by `batch_size_rows`
49+
#[deprecated(
50+
since = "46.0.0",
51+
note = "This method is deprecated. Use `SpillManager::spill_record_batch_by_size` instead."
52+
)]
9553
pub fn spill_record_batch_by_size(
9654
batch: &RecordBatch,
9755
path: PathBuf,
@@ -619,12 +577,28 @@ mod tests {
619577

620578
let spill_manager =
621579
Arc::new(SpillManager::new(env, metrics, Arc::clone(&schema)));
622-
let mut in_progress_file = spill_manager.create_in_progress_file("Test")?;
623580

624-
// Attempt to finish without appending any batches
581+
// Test write empty batch with interface `InProgressSpillFile` and `append_batch()`
582+
let mut in_progress_file = spill_manager.create_in_progress_file("Test")?;
625583
let completed_file = in_progress_file.finish()?;
626584
assert!(completed_file.is_none());
627585

586+
// Test write empty batch with interface `spill_record_batch_and_finish()`
587+
let completed_file = spill_manager.spill_record_batch_and_finish(&[], "Test")?;
588+
assert!(completed_file.is_none());
589+
590+
// Test write empty batch with interface `spill_record_batch_by_size()`
591+
let empty_batch = RecordBatch::try_new(
592+
Arc::clone(&schema),
593+
vec![
594+
Arc::new(Int32Array::from(Vec::<Option<i32>>::new())),
595+
Arc::new(StringArray::from(Vec::<Option<&str>>::new())),
596+
],
597+
)?;
598+
let completed_file =
599+
spill_manager.spill_record_batch_by_size(&empty_batch, "Test", 1)?;
600+
assert!(completed_file.is_none());
601+
628602
Ok(())
629603
}
630604
}

0 commit comments

Comments
 (0)