Skip to content

refactor: Use SpillManager for all spilling scenarios #15405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 29 additions & 30 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::aggregates::{
use crate::metrics::{BaselineMetrics, MetricBuilder, RecordOutput};
use crate::sorts::sort::sort_batch;
use crate::sorts::streaming_merge::StreamingMergeBuilder;
use crate::spill::{read_spill_as_stream, spill_record_batch_by_size};
use crate::spill::spill_manager::SpillManager;
use crate::stream::RecordBatchStreamAdapter;
use crate::{aggregates, metrics, ExecutionPlan, PhysicalExpr};
use crate::{RecordBatchStream, SendableRecordBatchStream};
Expand All @@ -42,7 +42,6 @@ use datafusion_common::{internal_err, DataFusionError, Result};
use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::memory_pool::proxy::VecAllocExt;
use datafusion_execution::memory_pool::{MemoryConsumer, MemoryReservation};
use datafusion_execution::runtime_env::RuntimeEnv;
use datafusion_execution::TaskContext;
use datafusion_expr::{EmitTo, GroupsAccumulator};
use datafusion_physical_expr::expressions::Column;
Expand Down Expand Up @@ -91,6 +90,9 @@ struct SpillState {
/// GROUP BY expressions for merging spilled data
merging_group_by: PhysicalGroupBy,

/// Manages the process of spilling and reading back intermediate data
spill_manager: SpillManager,

// ========================================================================
// STATES:
// Fields changes during execution. Can be buffer, or state flags that
Expand All @@ -109,12 +111,7 @@ struct SpillState {
/// Peak memory used for buffered data.
/// Calculated as sum of peak memory values across partitions
peak_mem_used: metrics::Gauge,
/// count of spill files during the execution of the operator
spill_count: metrics::Count,
/// total spilled bytes during the execution of the operator
spilled_bytes: metrics::Count,
/// total spilled rows during the execution of the operator
spilled_rows: metrics::Count,
// Metrics related to spilling are managed inside `spill_manager`
}

/// Tracks if the aggregate should skip partial aggregations
Expand Down Expand Up @@ -435,9 +432,6 @@ pub(crate) struct GroupedHashAggregateStream {

/// Execution metrics
baseline_metrics: BaselineMetrics,

/// The [`RuntimeEnv`] associated with the [`TaskContext`] argument
runtime: Arc<RuntimeEnv>,
}

impl GroupedHashAggregateStream {
Expand Down Expand Up @@ -544,6 +538,12 @@ impl GroupedHashAggregateStream {

let exec_state = ExecutionState::ReadingInput;

let spill_manager = SpillManager::new(
context.runtime_env(),
metrics::SpillMetrics::new(&agg.metrics, partition),
Arc::clone(&partial_agg_schema),
);

let spill_state = SpillState {
spills: vec![],
spill_expr,
Expand All @@ -553,9 +553,7 @@ impl GroupedHashAggregateStream {
merging_group_by: PhysicalGroupBy::new_single(agg_group_by.expr.clone()),
peak_mem_used: MetricBuilder::new(&agg.metrics)
.gauge("peak_mem_used", partition),
spill_count: MetricBuilder::new(&agg.metrics).spill_count(partition),
spilled_bytes: MetricBuilder::new(&agg.metrics).spilled_bytes(partition),
spilled_rows: MetricBuilder::new(&agg.metrics).spilled_rows(partition),
spill_manager,
};

// Skip aggregation is supported if:
Expand Down Expand Up @@ -604,7 +602,6 @@ impl GroupedHashAggregateStream {
batch_size,
group_ordering,
input_done: false,
runtime: context.runtime_env(),
spill_state,
group_values_soft_limit: agg.limit,
skip_aggregation_probe,
Expand Down Expand Up @@ -981,28 +978,30 @@ impl GroupedHashAggregateStream {
Ok(())
}

/// Emit all rows, sort them, and store them on disk.
/// Emit all intermediate aggregation states, sort them, and store them on disk.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

/// This process helps in reducing memory pressure by allowing the data to be
/// read back with streaming merge.
fn spill(&mut self) -> Result<()> {
// Emit and sort intermediate aggregation state
let Some(emit) = self.emit(EmitTo::All, true)? else {
return Ok(());
};
let sorted = sort_batch(&emit, self.spill_state.spill_expr.as_ref(), None)?;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

eventually it might make sense to have the spill manager handle sorting the runs too (so it could potentially merge multiple files into a single run to reduce fanout, etc

let spillfile = self.runtime.disk_manager.create_tmp_file("HashAggSpill")?;
// TODO: slice large `sorted` and write to multiple files in parallel
spill_record_batch_by_size(

// Spill sorted state to disk
let spillfile = self.spill_state.spill_manager.spill_record_batch_by_size(
&sorted,
spillfile.path().into(),
sorted.schema(),
"HashAggSpill",
self.batch_size,
)?;
self.spill_state.spills.push(spillfile);

// Update metrics
self.spill_state.spill_count.add(1);
self.spill_state
.spilled_bytes
.add(sorted.get_array_memory_size());
self.spill_state.spilled_rows.add(sorted.num_rows());
match spillfile {
Some(spillfile) => self.spill_state.spills.push(spillfile),
None => {
return internal_err!(
"Calling spill with no intermediate batch to spill"
);
}
}

Ok(())
}
Expand Down Expand Up @@ -1058,7 +1057,7 @@ impl GroupedHashAggregateStream {
})),
)));
for spill in self.spill_state.spills.drain(..) {
let stream = read_spill_as_stream(spill, Arc::clone(&schema), 2)?;
let stream = self.spill_state.spill_manager.read_spill_as_stream(spill)?;
streams.push(stream);
}
self.spill_state.is_stream_merging = true;
Expand Down
56 changes: 25 additions & 31 deletions datafusion/physical-plan/src/joins/sort_merge_join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@ use crate::joins::utils::{
reorder_output_after_swap, symmetric_join_output_partitioning, JoinFilter, JoinOn,
JoinOnRef,
};
use crate::metrics::{Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet};
use crate::metrics::{
Count, ExecutionPlanMetricsSet, MetricBuilder, MetricsSet, SpillMetrics,
};
use crate::projection::{
join_allows_pushdown, join_table_borders, new_join_children,
physical_to_column_exprs, update_join_on, ProjectionExec,
};
use crate::spill::spill_record_batches;
use crate::spill::spill_manager::SpillManager;
use crate::{
metrics, DisplayAs, DisplayFormatType, Distribution, ExecutionPlan,
ExecutionPlanProperties, PhysicalExpr, PlanProperties, RecordBatchStream,
Expand Down Expand Up @@ -596,12 +598,8 @@ struct SortMergeJoinMetrics {
/// Peak memory used for buffered data.
/// Calculated as sum of peak memory values across partitions
peak_mem_used: metrics::Gauge,
/// count of spills during the execution of the operator
spill_count: Count,
/// total spilled bytes during the execution of the operator
spilled_bytes: Count,
/// total spilled rows during the execution of the operator
spilled_rows: Count,
/// Metrics related to spilling
spill_metrics: SpillMetrics,
}

impl SortMergeJoinMetrics {
Expand All @@ -615,9 +613,7 @@ impl SortMergeJoinMetrics {
MetricBuilder::new(metrics).counter("output_batches", partition);
let output_rows = MetricBuilder::new(metrics).output_rows(partition);
let peak_mem_used = MetricBuilder::new(metrics).gauge("peak_mem_used", partition);
let spill_count = MetricBuilder::new(metrics).spill_count(partition);
let spilled_bytes = MetricBuilder::new(metrics).spilled_bytes(partition);
let spilled_rows = MetricBuilder::new(metrics).spilled_rows(partition);
let spill_metrics = SpillMetrics::new(metrics, partition);

Self {
join_time,
Expand All @@ -626,9 +622,7 @@ impl SortMergeJoinMetrics {
output_batches,
output_rows,
peak_mem_used,
spill_count,
spilled_bytes,
spilled_rows,
spill_metrics,
}
}
}
Expand Down Expand Up @@ -884,6 +878,8 @@ struct SortMergeJoinStream {
pub reservation: MemoryReservation,
/// Runtime env
pub runtime_env: Arc<RuntimeEnv>,
/// Manages the process of spilling and reading back intermediate data
pub spill_manager: SpillManager,
/// A unique number for each batch
pub streamed_batch_counter: AtomicUsize,
}
Expand Down Expand Up @@ -1301,6 +1297,11 @@ impl SortMergeJoinStream {
) -> Result<Self> {
let streamed_schema = streamed.schema();
let buffered_schema = buffered.schema();
let spill_manager = SpillManager::new(
Arc::clone(&runtime_env),
join_metrics.spill_metrics.clone(),
Arc::clone(&buffered_schema),
);
Ok(Self {
state: SortMergeJoinState::Init,
sort_options,
Expand Down Expand Up @@ -1333,6 +1334,7 @@ impl SortMergeJoinStream {
join_metrics,
reservation,
runtime_env,
spill_manager,
streamed_batch_counter: AtomicUsize::new(0),
})
}
Expand Down Expand Up @@ -1402,27 +1404,19 @@ impl SortMergeJoinStream {
Ok(())
}
Err(_) if self.runtime_env.disk_manager.tmp_files_enabled() => {
// spill buffered batch to disk
let spill_file = self
.runtime_env
.disk_manager
.create_tmp_file("sort_merge_join_buffered_spill")?;

// Spill buffered batch to disk
if let Some(batch) = buffered_batch.batch {
spill_record_batches(
&[batch],
spill_file.path().into(),
Arc::clone(&self.buffered_schema),
)?;
let spill_file = self
.spill_manager
.spill_record_batch_and_finish(
&[batch],
"sort_merge_join_buffered_spill",
)?
.unwrap(); // Operation only return None if no batches are spilled, here we ensure that at least one batch is spilled

buffered_batch.spill_file = Some(spill_file);
buffered_batch.batch = None;

// update metrics to register spill
self.join_metrics.spill_count.add(1);
self.join_metrics
.spilled_bytes
.add(buffered_batch.size_estimation);
self.join_metrics.spilled_rows.add(buffered_batch.num_rows);
Ok(())
} else {
internal_err!("Buffered batch has empty body")
Expand Down
70 changes: 22 additions & 48 deletions datafusion/physical-plan/src/spill/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,55 +29,9 @@ use arrow::array::ArrayData;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::ipc::{reader::StreamReader, writer::StreamWriter};
use arrow::record_batch::RecordBatch;
use log::debug;
use tokio::sync::mpsc::Sender;

use datafusion_common::{exec_datafusion_err, HashSet, Result};
use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::memory_pool::human_readable_size;
use datafusion_execution::SendableRecordBatchStream;

use crate::stream::RecordBatchReceiverStream;

/// Read spilled batches from the disk
///
/// `path` - temp file
/// `schema` - batches schema, should be the same across batches
/// `buffer` - internal buffer of capacity batches
pub(crate) fn read_spill_as_stream(
path: RefCountedTempFile,
schema: SchemaRef,
buffer: usize,
) -> Result<SendableRecordBatchStream> {
let mut builder = RecordBatchReceiverStream::builder(schema, buffer);
let sender = builder.tx();

builder.spawn_blocking(move || read_spill(sender, path.path()));

Ok(builder.build())
}

/// Spills in-memory `batches` to disk.
///
/// Returns total number of the rows spilled to disk.
pub(crate) fn spill_record_batches(
batches: &[RecordBatch],
path: PathBuf,
schema: SchemaRef,
) -> Result<(usize, usize)> {
let mut writer = IPCStreamWriter::new(path.as_ref(), schema.as_ref())?;
for batch in batches {
writer.write(batch)?;
}
writer.finish()?;
debug!(
"Spilled {} batches of total {} rows to disk, memory released {}",
writer.num_batches,
writer.num_rows,
human_readable_size(writer.num_bytes),
);
Ok((writer.num_rows, writer.num_bytes))
}

fn read_spill(sender: Sender<Result<RecordBatch>>, path: &Path) -> Result<()> {
let file = BufReader::new(File::open(path)?);
Expand All @@ -92,6 +46,10 @@ fn read_spill(sender: Sender<Result<RecordBatch>>, path: &Path) -> Result<()> {

/// Spill the `RecordBatch` to disk as smaller batches
/// split by `batch_size_rows`
#[deprecated(
since = "46.0.0",
note = "This method is deprecated. Use `SpillManager::spill_record_batch_by_size` instead."
)]
pub fn spill_record_batch_by_size(
batch: &RecordBatch,
path: PathBuf,
Expand Down Expand Up @@ -619,12 +577,28 @@ mod tests {

let spill_manager =
Arc::new(SpillManager::new(env, metrics, Arc::clone(&schema)));
let mut in_progress_file = spill_manager.create_in_progress_file("Test")?;

// Attempt to finish without appending any batches
// Test write empty batch with interface `InProgressSpillFile` and `append_batch()`
let mut in_progress_file = spill_manager.create_in_progress_file("Test")?;
let completed_file = in_progress_file.finish()?;
assert!(completed_file.is_none());

// Test write empty batch with interface `spill_record_batch_and_finish()`
let completed_file = spill_manager.spill_record_batch_and_finish(&[], "Test")?;
assert!(completed_file.is_none());

// Test write empty batch with interface `spill_record_batch_by_size()`
let empty_batch = RecordBatch::try_new(
Arc::clone(&schema),
vec![
Arc::new(Int32Array::from(Vec::<Option<i32>>::new())),
Arc::new(StringArray::from(Vec::<Option<&str>>::new())),
],
)?;
let completed_file =
spill_manager.spill_record_batch_by_size(&empty_batch, "Test", 1)?;
assert!(completed_file.is_none());

Ok(())
}
}