From 823450ad5ef8e8edf5ec195e4415e000e7ae75f8 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 15 Apr 2025 22:41:51 +0300 Subject: [PATCH 1/9] test: add fuzz test for doing aggregation with larger than memory groups --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 131 +++++++++++++++++- datafusion/core/tests/fuzz_cases/mod.rs | 1 + .../core/tests/fuzz_cases/stream_exec.rs | 104 ++++++++++++++ 3 files changed, 231 insertions(+), 5 deletions(-) create mode 100644 datafusion/core/tests/fuzz_cases/stream_exec.rs diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index ff3b66986ced..606080492b05 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use crate::fuzz_cases::aggregation_fuzzer::{ AggregationFuzzerBuilder, DatasetGeneratorConfig, QueryBuilder, }; +use std::sync::Arc; use arrow::array::{ types::Int64Type, Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, - StringArray, + StringArray, UInt64Array, }; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::DataType; @@ -47,17 +46,21 @@ use datafusion_physical_expr::expressions::{col, lit, Column}; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::InputOrderMode; +use futures::StreamExt; use test_utils::{add_empty_batches, StringBatchGenerator}; +use super::record_batch_generator::get_supported_types_columns; +use crate::fuzz_cases::stream_exec::StreamExec; +use datafusion_execution::memory_pool::units::MB; use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_physical_plan::metrics::MetricValue; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use rand::rngs::StdRng; use rand::{random, thread_rng, Rng, SeedableRng}; -use super::record_batch_generator::get_supported_types_columns; - // ======================================================================== // The new aggregation fuzz tests based on [`AggregationFuzzer`] // ======================================================================== @@ -640,6 +643,8 @@ fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc 0 { panic!("Expected no spill but found SpillCount metric with value greater than 0."); } + + println!("SpillCount = {}", spill_count); } else { panic!("No metrics returned from the operator; cannot verify spilling."); } @@ -753,3 +758,119 @@ async fn test_single_mode_aggregate_with_spill() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory() -> Result<()> { + let record_batch_size = 8192; + let two_mb = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(two_mb)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + run_test_high_cardinality(task_ctx, 100).await +} + +async fn run_test_high_cardinality( + task_ctx: TaskContext, + number_of_record_batches: usize, +) -> Result<()> { + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let group_by = PhysicalGroupBy::new_single(vec![( + Arc::new(Column::new("col_0", 0)), + "col_0".to_string(), + )]); + + let aggregate_expressions = vec![Arc::new( + AggregateExprBuilder::new( + array_agg_udaf(), + vec![col("col_1", &scan_schema).unwrap()], + ) + .schema(Arc::clone(&scan_schema)) + .alias("array_agg(col_1)") + .build()?, + )]; + + let record_batch_size = task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc = + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + // Simulate a large record batch 3 times in a stream + let string_array = + if index % (number_of_record_batches as u64 / 3) == 0 { + Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "b".repeat(64)), + )) + } else { + Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(8)), + )) + }; + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + // Grouping key + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + // Grouping value + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + plan, + Arc::clone(&scan_schema), + )?); + let aggregate_final = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + aggregate_exec, + Arc::clone(&scan_schema), + )?); + + let task_ctx = Arc::new(task_ctx); + + let mut result = aggregate_final.execute(0, task_ctx)?; + + let mut number_of_groups = 0; + + while let Some(batch) = result.next().await { + let batch = batch?; + number_of_groups += batch.num_rows(); + } + + assert_eq!( + number_of_groups, + number_of_record_batches * record_batch_size as usize + ); + + assert_spill_count_metric(true, aggregate_final); + + Ok(()) +} diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 8ccc2a5bc131..6430859e32b9 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -33,3 +33,4 @@ mod window_fuzz; // Utility modules mod record_batch_generator; +mod stream_exec; diff --git a/datafusion/core/tests/fuzz_cases/stream_exec.rs b/datafusion/core/tests/fuzz_cases/stream_exec.rs new file mode 100644 index 000000000000..a9b75d7547ca --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/stream_exec.rs @@ -0,0 +1,104 @@ +use arrow_schema::SchemaRef; +use datafusion_common::DataFusionError; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::sync::{Arc, Mutex}; + +/// Execution plan that return the stream on the call to `execute`. +pub struct StreamExec { + /// the results to send back + stream: Mutex>, + cache: PlanProperties, +} + +impl Debug for StreamExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "StreamExec") + } +} + +impl StreamExec { + /// Create a new `MockExec` with a single partition that returns + /// the specified `Results`s. + /// + /// By default, the batches are not produced immediately (the + /// caller has to actually yield and another task must run) to + /// ensure any poll loops are correct. This behavior can be + /// changed with `with_use_task` + pub fn new(stream: SendableRecordBatchStream) -> Self { + let cache = Self::compute_properties(stream.schema()); + Self { + stream: Mutex::new(Some(stream)), + cache, + } + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(schema: SchemaRef) -> PlanProperties { + PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + } +} + +impl DisplayAs for StreamExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "StreamExec:") + } + DisplayFormatType::TreeRender => { + write!(f, "") + } + } + } +} + +impl ExecutionPlan for StreamExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion_common::Result> { + unimplemented!() + } + + /// Returns a stream which yields data + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> datafusion_common::Result { + assert_eq!(partition, 0); + + let stream = self.stream.lock().unwrap().take(); + + stream.ok_or(DataFusionError::Internal( + "Stream already consumed".to_string(), + )) + } +} From da369dbd2307877943b3744ffee3db83d6df3473 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 15 Apr 2025 22:59:03 +0300 Subject: [PATCH 2/9] add more tests --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 80 ++++++++++++++----- 1 file changed, 60 insertions(+), 20 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 606080492b05..9c38ea7715a2 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -25,7 +25,7 @@ use arrow::array::{ StringArray, UInt64Array, }; use arrow::compute::{concat_batches, SortOptions}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, UInt64Type}; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{Field, Schema, SchemaRef}; use datafusion::common::Result; @@ -51,7 +51,7 @@ use test_utils::{add_empty_batches, StringBatchGenerator}; use super::record_batch_generator::get_supported_types_columns; use crate::fuzz_cases::stream_exec::StreamExec; -use datafusion_execution::memory_pool::units::MB; +use datafusion_execution::memory_pool::units::{KB, MB}; use datafusion_execution::memory_pool::FairSpillPool; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; @@ -766,19 +766,63 @@ async fn test_high_cardinality_with_limited_memory() -> Result<()> { let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(two_mb)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_high_cardinality(task_ctx, 100, |_| (16 * KB) as usize).await +} + + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch() -> Result<()> { + let record_batch_size = 8192; + let two_mb = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(two_mb)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_high_cardinality(task_ctx, 100, |i| { + if i % 25 == 0 { + (64 * KB) as usize + } else { + (16 * KB) as usize + } + }).await +} + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_large_record_batch() -> Result<()> { + let record_batch_size = 8192; + let two_mb = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(two_mb)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; - run_test_high_cardinality(task_ctx, 100).await + run_test_high_cardinality(task_ctx, 100, |_| two_mb / 5).await } async fn run_test_high_cardinality( task_ctx: TaskContext, number_of_record_batches: usize, + get_size_of_record_batch_to_generate: impl Fn(usize) -> usize, ) -> Result<()> { let scan_schema = Arc::new(Schema::new(vec![ Field::new("col_0", DataType::UInt64, true), @@ -808,17 +852,13 @@ async fn run_test_high_cardinality( Arc::clone(&schema), futures::stream::iter((0..number_of_record_batches as u64).map( move |index| { - // Simulate a large record batch 3 times in a stream - let string_array = - if index % (number_of_record_batches as u64 / 3) == 0 { - Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "b".repeat(64)), - )) - } else { - Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "a".repeat(8)), - )) - }; + let mut record_batch_memory_size = get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size.saturating_sub(size_of::() * record_batch_memory_size); + + let string_item_size = record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); RecordBatch::try_new( Arc::clone(&schema), From 5c34e59d500fb17614ff90b9a0b83e9bd45b2677 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 15 Apr 2025 22:59:29 +0300 Subject: [PATCH 3/9] format --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 58 ++++++++++--------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 9c38ea7715a2..f466cd44b791 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -766,31 +766,31 @@ async fn test_high_cardinality_with_limited_memory() -> Result<()> { let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(two_mb)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; run_test_high_cardinality(task_ctx, 100, |_| (16 * KB) as usize).await } - #[tokio::test] -async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch() -> Result<()> { +async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch( +) -> Result<()> { let record_batch_size = 8192; let two_mb = 2 * MB as usize; let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(two_mb)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; run_test_high_cardinality(task_ctx, 100, |i| { @@ -799,22 +799,24 @@ async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record } else { (16 * KB) as usize } - }).await + }) + .await } #[tokio::test] -async fn test_high_cardinality_with_limited_memory_and_large_record_batch() -> Result<()> { +async fn test_high_cardinality_with_limited_memory_and_large_record_batch() -> Result<()> +{ let record_batch_size = 8192; let two_mb = 2 * MB as usize; let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(two_mb)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; run_test_high_cardinality(task_ctx, 100, |_| two_mb / 5).await } @@ -852,10 +854,14 @@ async fn run_test_high_cardinality( Arc::clone(&schema), futures::stream::iter((0..number_of_record_batches as u64).map( move |index| { - let mut record_batch_memory_size = get_size_of_record_batch_to_generate(index as usize); - record_batch_memory_size = record_batch_memory_size.saturating_sub(size_of::() * record_batch_memory_size); - - let string_item_size = record_batch_memory_size / record_batch_size as usize; + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size.saturating_sub( + size_of::() * record_batch_memory_size, + ); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; let string_array = Arc::new(StringArray::from_iter_values( (0..record_batch_size).map(|_| "a".repeat(string_item_size)), )); From c51a850dcb2dc14b87fe7d741aee07d34d54fc56 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 15 Apr 2025 23:25:04 +0300 Subject: [PATCH 4/9] add more tests --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 79 +++++++++++++------ 1 file changed, 54 insertions(+), 25 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index f466cd44b791..550b04129d3d 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -18,6 +18,7 @@ use crate::fuzz_cases::aggregation_fuzzer::{ AggregationFuzzerBuilder, DatasetGeneratorConfig, QueryBuilder, }; +use std::pin::Pin; use std::sync::Arc; use arrow::array::{ @@ -626,7 +627,10 @@ fn extract_result_counts(results: Vec) -> HashMap, i output } -fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc) { +fn assert_spill_count_metric( + expect_spill: bool, + single_aggregate: Arc, +) -> usize { if let Some(metrics_set) = single_aggregate.metrics() { let mut spill_count = 0; @@ -644,7 +648,7 @@ fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc Result<()> { #[tokio::test] async fn test_high_cardinality_with_limited_memory() -> Result<()> { let record_batch_size = 8192; - let two_mb = 2 * MB as usize; + let pool_size = 2 * MB as usize; let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(two_mb)); + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) .with_runtime(Arc::new( @@ -774,16 +778,30 @@ async fn test_high_cardinality_with_limited_memory() -> Result<()> { )) }; - run_test_high_cardinality(task_ctx, 100, |_| (16 * KB) as usize).await + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = + run_test_high_cardinality(task_ctx, 100, Box::pin(|_| (16 * KB) as usize)) + .await?; + + let total_spill_files_size = spill_count * 16 * KB as usize; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {} should be greater than pool size {}", + total_spill_files_size, + pool_size + ); + + Ok(()) } #[tokio::test] async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch( ) -> Result<()> { let record_batch_size = 8192; - let two_mb = 2 * MB as usize; + let pool_size = 2 * MB as usize; let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(two_mb)); + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) .with_runtime(Arc::new( @@ -793,23 +811,29 @@ async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record )) }; - run_test_high_cardinality(task_ctx, 100, |i| { - if i % 25 == 0 { - (64 * KB) as usize - } else { - (16 * KB) as usize - } - }) - .await + run_test_high_cardinality( + task_ctx, + 100, + Box::pin(|i| { + if i + 1 % 25 == 0 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + ) + .await?; + + Ok(()) } #[tokio::test] async fn test_high_cardinality_with_limited_memory_and_large_record_batch() -> Result<()> { let record_batch_size = 8192; - let two_mb = 2 * MB as usize; + let pool_size = 2 * MB as usize; let task_ctx = { - let memory_pool = Arc::new(FairSpillPool::new(two_mb)); + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) .with_runtime(Arc::new( @@ -818,14 +842,20 @@ async fn test_high_cardinality_with_limited_memory_and_large_record_batch() -> R .build()?, )) }; - run_test_high_cardinality(task_ctx, 100, |_| two_mb / 5).await + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_test_high_cardinality(task_ctx, 100, Box::pin(move |_| pool_size / 4)).await?; + + Ok(()) } async fn run_test_high_cardinality( task_ctx: TaskContext, number_of_record_batches: usize, - get_size_of_record_batch_to_generate: impl Fn(usize) -> usize, -) -> Result<()> { + get_size_of_record_batch_to_generate: Pin< + Box usize + Send + 'static>, + >, +) -> Result { let scan_schema = Arc::new(Schema::new(vec![ Field::new("col_0", DataType::UInt64, true), Field::new("col_1", DataType::Utf8, true), @@ -856,9 +886,8 @@ async fn run_test_high_cardinality( move |index| { let mut record_batch_memory_size = get_size_of_record_batch_to_generate(index as usize); - record_batch_memory_size = record_batch_memory_size.saturating_sub( - size_of::() * record_batch_memory_size, - ); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); let string_item_size = record_batch_memory_size / record_batch_size as usize; @@ -916,7 +945,7 @@ async fn run_test_high_cardinality( number_of_record_batches * record_batch_size as usize ); - assert_spill_count_metric(true, aggregate_final); + let spill_count = assert_spill_count_metric(true, aggregate_final); - Ok(()) + Ok(spill_count) } From dd0917d0d436de2b73fe5bb9c17a347d87b39564 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 15 Apr 2025 23:30:19 +0300 Subject: [PATCH 5/9] update tests --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 12 +++++---- .../core/tests/fuzz_cases/stream_exec.rs | 27 +++++++++++++------ 2 files changed, 26 insertions(+), 13 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 550b04129d3d..8fdbf21c0e3e 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -26,7 +26,7 @@ use arrow::array::{ StringArray, UInt64Array, }; use arrow::compute::{concat_batches, SortOptions}; -use arrow::datatypes::{DataType, UInt64Type}; +use arrow::datatypes::DataType; use arrow::util::pretty::pretty_format_batches; use arrow_schema::{Field, Schema, SchemaRef}; use datafusion::common::Result; @@ -778,13 +778,15 @@ async fn test_high_cardinality_with_limited_memory() -> Result<()> { )) }; + let record_batch_size = pool_size / 16; + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch // from each spill file is too much memory let spill_count = - run_test_high_cardinality(task_ctx, 100, Box::pin(|_| (16 * KB) as usize)) + run_test_high_cardinality(task_ctx, 100, Box::pin(move |_| record_batch_size)) .await?; - let total_spill_files_size = spill_count * 16 * KB as usize; + let total_spill_files_size = spill_count * record_batch_size; assert!( total_spill_files_size > pool_size, "Total spill files size {} should be greater than pool size {}", @@ -814,8 +816,8 @@ async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record run_test_high_cardinality( task_ctx, 100, - Box::pin(|i| { - if i + 1 % 25 == 0 { + Box::pin(move |i| { + if i % 25 == 1 { pool_size / 4 } else { (16 * KB) as usize diff --git a/datafusion/core/tests/fuzz_cases/stream_exec.rs b/datafusion/core/tests/fuzz_cases/stream_exec.rs index a9b75d7547ca..6e71b9988d79 100644 --- a/datafusion/core/tests/fuzz_cases/stream_exec.rs +++ b/datafusion/core/tests/fuzz_cases/stream_exec.rs @@ -1,3 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + use arrow_schema::SchemaRef; use datafusion_common::DataFusionError; use datafusion_execution::{SendableRecordBatchStream, TaskContext}; @@ -10,7 +27,8 @@ use std::any::Any; use std::fmt::{Debug, Formatter}; use std::sync::{Arc, Mutex}; -/// Execution plan that return the stream on the call to `execute`. +/// Execution plan that return the stream on the call to `execute`. further calls to `execute` will +/// return an error pub struct StreamExec { /// the results to send back stream: Mutex>, @@ -24,13 +42,6 @@ impl Debug for StreamExec { } impl StreamExec { - /// Create a new `MockExec` with a single partition that returns - /// the specified `Results`s. - /// - /// By default, the batches are not produced immediately (the - /// caller has to actually yield and another task must run) to - /// ensure any poll loops are correct. This behavior can be - /// changed with `with_use_task` pub fn new(stream: SendableRecordBatchStream) -> Self { let cache = Self::compute_properties(stream.schema()); Self { From 67b268f1e5c1531e38e0c78220bcbbb069b1efe9 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 15 Apr 2025 23:38:56 +0300 Subject: [PATCH 6/9] update lock --- Cargo.lock | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2b3eeecf5d9b..ee6dda88a0e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3966,9 +3966,9 @@ dependencies = [ [[package]] name = "libz-rs-sys" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "902bc563b5d65ad9bba616b490842ef0651066a1a1dc3ce1087113ffcb873c8d" +checksum = "6489ca9bd760fe9642d7644e827b0c9add07df89857b0416ee15c1cc1a3b8c5a" dependencies = [ "zlib-rs", ] @@ -7512,9 +7512,9 @@ dependencies = [ [[package]] name = "zlib-rs" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b20717f0917c908dc63de2e44e97f1e6b126ca58d0e391cee86d504eb8fbd05" +checksum = "868b928d7949e09af2f6086dfc1e01936064cc7a819253bce650d4e2a2d63ba8" [[package]] name = "zstd" From e4d777867664f0d622e024fa4bdb5e6e822cd20b Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Tue, 15 Apr 2025 23:45:35 +0300 Subject: [PATCH 7/9] added to sort fuzz --- datafusion/core/tests/fuzz_cases/sort_fuzz.rs | 184 +++++++++++++++++- 1 file changed, 182 insertions(+), 2 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index 0b0f0aa2f105..ae4c6478550f 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -17,6 +17,7 @@ //! Fuzz Test for various corner cases sorting RecordBatches exceeds available memory and should spill +use std::pin::Pin; use std::sync::Arc; use arrow::{ @@ -24,6 +25,9 @@ use arrow::{ compute::SortOptions, record_batch::RecordBatch, }; +use arrow::array::UInt64Array; +use arrow_schema::{DataType, Field, Schema}; +use futures::StreamExt; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::expressions::PhysicalSortExpr; @@ -31,12 +35,19 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::cast::as_int32_array; -use datafusion_execution::memory_pool::GreedyMemoryPool; -use datafusion_physical_expr::expressions::col; +use datafusion_execution::memory_pool::{FairSpillPool, GreedyMemoryPool}; +use datafusion_physical_expr::expressions::{col, Column}; use datafusion_physical_expr_common::sort_expr::LexOrdering; use rand::Rng; +use datafusion_execution::memory_pool::units::MB; +use datafusion_execution::TaskContext; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use test_utils::{batches_to_vec, partitions_to_sorted_vec}; +use crate::fuzz_cases::stream_exec::StreamExec; const KB: usize = 1 << 10; #[tokio::test] @@ -379,3 +390,172 @@ fn make_staggered_i32_utf8_batches(len: usize) -> Vec { batches } + +#[tokio::test] +async fn test_with_limited_memory() -> datafusion_common::Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = + run_memory_test_for_limited_memory(task_ctx, 100, Box::pin(move |_| record_batch_size)) + .await?; + + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {} should be greater than pool size {}", + total_spill_files_size, + pool_size + ); + + Ok(()) +} + +#[tokio::test] +async fn test_with_limited_memory_and_different_sizes_of_record_batch( +) -> datafusion_common::Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_memory_test_for_limited_memory( + task_ctx, + 100, + Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * datafusion_execution::memory_pool::units::KB) as usize + } + }), + ) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_with_limited_memory_and_large_record_batch() -> datafusion_common::Result<()> +{ + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_memory_test_for_limited_memory(task_ctx, 100, Box::pin(move |_| pool_size / 4)).await?; + + Ok(()) +} + +async fn run_memory_test_for_limited_memory( + task_ctx: TaskContext, + number_of_record_batches: usize, + get_size_of_record_batch_to_generate: Pin< + Box usize + Send + 'static>, + >, +) -> datafusion_common::Result { + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let record_batch_size = task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc = + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + + + let sort_exec = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: col("col_0", &scan_schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]), + plan, + )); + + let task_ctx = Arc::new(task_ctx); + + let mut result = sort_exec.execute(0, task_ctx)?; + + let mut number_of_rows = 0; + + while let Some(batch) = result.next().await { + let batch = batch?; + number_of_rows += batch.num_rows(); + } + + assert_eq!( + number_of_rows, + number_of_record_batches * record_batch_size as usize + ); + + let spill_count = sort_exec.metrics().unwrap().spill_count().unwrap(); + + assert!(spill_count > 0, "Expected spill, but did not: {number_of_record_batches:?}"); + + Ok(spill_count) +} From bd0a402ebb749ed649460827c1d707088f466cef Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 16 Apr 2025 00:13:35 +0300 Subject: [PATCH 8/9] added memory tests for the aggregate --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 174 ++++++++++++++++-- datafusion/core/tests/fuzz_cases/sort_fuzz.rs | 124 +++++++------ 2 files changed, 225 insertions(+), 73 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 8fdbf21c0e3e..eaa7c624c5ee 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -53,7 +53,9 @@ use test_utils::{add_empty_batches, StringBatchGenerator}; use super::record_batch_generator::get_supported_types_columns; use crate::fuzz_cases::stream_exec::StreamExec; use datafusion_execution::memory_pool::units::{KB, MB}; -use datafusion_execution::memory_pool::FairSpillPool; +use datafusion_execution::memory_pool::{ + FairSpillPool, MemoryConsumer, MemoryReservation, +}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; use datafusion_functions_aggregate::array_agg::array_agg_udaf; @@ -782,9 +784,14 @@ async fn test_high_cardinality_with_limited_memory() -> Result<()> { // Basic test with a lot of groups that cannot all fit in memory and 1 record batch // from each spill file is too much memory - let spill_count = - run_test_high_cardinality(task_ctx, 100, Box::pin(move |_| record_batch_size)) - .await?; + let spill_count = run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; let total_spill_files_size = spill_count * record_batch_size; assert!( @@ -813,17 +820,87 @@ async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record )) }; - run_test_high_cardinality( + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, task_ctx, - 100, - Box::pin(move |i| { + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { pool_size / 4 } else { (16 * KB) as usize } }), - ) + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) .await?; Ok(()) @@ -846,18 +923,43 @@ async fn test_high_cardinality_with_limited_memory_and_large_record_batch() -> R }; // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory - run_test_high_cardinality(task_ctx, 100, Box::pin(move |_| pool_size / 4)).await?; + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), + memory_behavior: Default::default(), + }) + .await?; Ok(()) } -async fn run_test_high_cardinality( +struct RunTestHighCardinalityArgs { + pool_size: usize, task_ctx: TaskContext, number_of_record_batches: usize, - get_size_of_record_batch_to_generate: Pin< - Box usize + Send + 'static>, - >, -) -> Result { + get_size_of_record_batch_to_generate: + Pin usize + Send + 'static>>, + memory_behavior: MemoryBehavior, +} + +#[derive(Default)] +enum MemoryBehavior { + #[default] + AsIs, + TakeAllMemoryAtTheBeginning, + TakeAllMemoryAndReleaseEveryNthBatch(usize), +} + +async fn run_test_high_cardinality(args: RunTestHighCardinalityArgs) -> Result { + let RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches, + get_size_of_record_batch_to_generate, + memory_behavior, + } = args; let scan_schema = Arc::new(Schema::new(vec![ Field::new("col_0", DataType::UInt64, true), Field::new("col_1", DataType::Utf8, true), @@ -933,13 +1035,43 @@ async fn run_test_high_cardinality( let task_ctx = Arc::new(task_ctx); - let mut result = aggregate_final.execute(0, task_ctx)?; + let mut result = aggregate_final.execute(0, Arc::clone(&task_ctx))?; let mut number_of_groups = 0; + let memory_pool = task_ctx.memory_pool(); + let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); + let mut memory_reservation = memory_consumer.register(memory_pool); + + let mut index = 0; + let mut memory_took = false; + while let Some(batch) = result.next().await { + match memory_behavior { + MemoryBehavior::AsIs => { + // Do nothing + } + MemoryBehavior::TakeAllMemoryAtTheBeginning => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(10, &mut memory_reservation)?; + } + } + MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(pool_size, &mut memory_reservation)?; + } else if index % n == 0 { + // release memory + memory_reservation.free(); + } + } + } + let batch = batch?; number_of_groups += batch.num_rows(); + + index += 1; } assert_eq!( @@ -951,3 +1083,15 @@ async fn run_test_high_cardinality( Ok(spill_count) } + +fn grow_memory_as_much_as_possible( + memory_step: usize, + memory_reservation: &mut MemoryReservation, +) -> Result { + let mut was_able_to_grow = false; + while memory_reservation.try_grow(memory_step).is_ok() { + was_able_to_grow = true; + } + + Ok(was_able_to_grow) +} diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index ae4c6478550f..2c87d9524778 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -20,14 +20,13 @@ use std::pin::Pin; use std::sync::Arc; +use arrow::array::UInt64Array; use arrow::{ array::{as_string_array, ArrayRef, Int32Array, StringArray}, compute::SortOptions, record_batch::RecordBatch, }; -use arrow::array::UInt64Array; use arrow_schema::{DataType, Field, Schema}; -use futures::StreamExt; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::expressions::PhysicalSortExpr; @@ -38,16 +37,19 @@ use datafusion_common::cast::as_int32_array; use datafusion_execution::memory_pool::{FairSpillPool, GreedyMemoryPool}; use datafusion_physical_expr::expressions::{col, Column}; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use futures::StreamExt; -use rand::Rng; +use crate::fuzz_cases::stream_exec::StreamExec; use datafusion_execution::memory_pool::units::MB; use datafusion_execution::TaskContext; use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_physical_expr::aggregate::AggregateExprBuilder; -use datafusion_physical_plan::aggregates::{AggregateExec, AggregateMode, PhysicalGroupBy}; +use datafusion_physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; +use rand::Rng; use test_utils::{batches_to_vec, partitions_to_sorted_vec}; -use crate::fuzz_cases::stream_exec::StreamExec; const KB: usize = 1 << 10; #[tokio::test] @@ -398,21 +400,24 @@ async fn test_with_limited_memory() -> datafusion_common::Result<()> { let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; let record_batch_size = pool_size / 16; // Basic test with a lot of groups that cannot all fit in memory and 1 record batch // from each spill file is too much memory - let spill_count = - run_memory_test_for_limited_memory(task_ctx, 100, Box::pin(move |_| record_batch_size)) - .await?; + let spill_count = run_memory_test_for_limited_memory( + task_ctx, + 100, + Box::pin(move |_| record_batch_size), + ) + .await?; let total_spill_files_size = spill_count * record_batch_size; assert!( @@ -433,12 +438,12 @@ async fn test_with_limited_memory_and_different_sizes_of_record_batch( let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; run_memory_test_for_limited_memory( @@ -452,7 +457,7 @@ async fn test_with_limited_memory_and_different_sizes_of_record_batch( } }), ) - .await?; + .await?; Ok(()) } @@ -465,16 +470,17 @@ async fn test_with_limited_memory_and_large_record_batch() -> datafusion_common: let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) - .with_runtime(Arc::new( - RuntimeEnvBuilder::new() - .with_memory_pool(memory_pool) - .build()?, - )) + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) }; // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory - run_memory_test_for_limited_memory(task_ctx, 100, Box::pin(move |_| pool_size / 4)).await?; + run_memory_test_for_limited_memory(task_ctx, 100, Box::pin(move |_| pool_size / 4)) + .await?; Ok(()) } @@ -495,36 +501,35 @@ async fn run_memory_test_for_limited_memory( let schema = Arc::clone(&scan_schema); let plan: Arc = - Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( - Arc::clone(&schema), - futures::stream::iter((0..number_of_record_batches as u64).map( - move |index| { - let mut record_batch_memory_size = - get_size_of_record_batch_to_generate(index as usize); - record_batch_memory_size = record_batch_memory_size - .saturating_sub(size_of::() * record_batch_size as usize); - - let string_item_size = - record_batch_memory_size / record_batch_size as usize; - let string_array = Arc::new(StringArray::from_iter_values( - (0..record_batch_size).map(|_| "a".repeat(string_item_size)), - )); - - RecordBatch::try_new( - Arc::clone(&schema), - vec![ - Arc::new(UInt64Array::from_iter_values( - (index * record_batch_size) - ..(index * record_batch_size) + record_batch_size, - )), - string_array, - ], - ) + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + string_array, + ], + ) .map_err(|err| err.into()) - }, - )), - )))); - + }, + )), + )))); let sort_exec = Arc::new(SortExec::new( LexOrdering::new(vec![PhysicalSortExpr { @@ -555,7 +560,10 @@ async fn run_memory_test_for_limited_memory( let spill_count = sort_exec.metrics().unwrap().spill_count().unwrap(); - assert!(spill_count > 0, "Expected spill, but did not: {number_of_record_batches:?}"); + assert!( + spill_count > 0, + "Expected spill, but did not: {number_of_record_batches:?}" + ); Ok(spill_count) } From 12dc8a871b0d8fc89d1236c738869178a203b4a7 Mon Sep 17 00:00:00 2001 From: Raz Luvaton <16746759+rluvaton@users.noreply.github.com> Date: Wed, 16 Apr 2025 00:27:18 +0300 Subject: [PATCH 9/9] add sort fuzz tests --- datafusion/core/tests/fuzz_cases/sort_fuzz.rs | 228 +++++++++++++++--- 1 file changed, 192 insertions(+), 36 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index 2c87d9524778..7a5ebf4b439d 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -27,6 +27,7 @@ use arrow::{ record_batch::RecordBatch, }; use arrow_schema::{DataType, Field, Schema}; +use datafusion::common::Result; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::expressions::PhysicalSortExpr; @@ -34,19 +35,16 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::cast::as_int32_array; -use datafusion_execution::memory_pool::{FairSpillPool, GreedyMemoryPool}; -use datafusion_physical_expr::expressions::{col, Column}; +use datafusion_execution::memory_pool::{ + FairSpillPool, GreedyMemoryPool, MemoryConsumer, MemoryReservation, +}; +use datafusion_physical_expr::expressions::col; use datafusion_physical_expr_common::sort_expr::LexOrdering; use futures::StreamExt; use crate::fuzz_cases::stream_exec::StreamExec; use datafusion_execution::memory_pool::units::MB; use datafusion_execution::TaskContext; -use datafusion_functions_aggregate::array_agg::array_agg_udaf; -use datafusion_physical_expr::aggregate::AggregateExprBuilder; -use datafusion_physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, -}; use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use rand::Rng; use test_utils::{batches_to_vec, partitions_to_sorted_vec}; @@ -394,13 +392,17 @@ fn make_staggered_i32_utf8_batches(len: usize) -> Vec { } #[tokio::test] -async fn test_with_limited_memory() -> datafusion_common::Result<()> { +async fn test_sort_with_limited_memory() -> Result<()> { let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) .with_runtime(Arc::new( RuntimeEnvBuilder::new() .with_memory_pool(memory_pool) @@ -412,12 +414,15 @@ async fn test_with_limited_memory() -> datafusion_common::Result<()> { // Basic test with a lot of groups that cannot all fit in memory and 1 record batch // from each spill file is too much memory - let spill_count = run_memory_test_for_limited_memory( - task_ctx, - 100, - Box::pin(move |_| record_batch_size), - ) - .await?; + let spill_count = + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; let total_spill_files_size = spill_count * record_batch_size; assert!( @@ -431,14 +436,18 @@ async fn test_with_limited_memory() -> datafusion_common::Result<()> { } #[tokio::test] -async fn test_with_limited_memory_and_different_sizes_of_record_batch( -) -> datafusion_common::Result<()> { +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() -> Result<()> +{ let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) .with_runtime(Arc::new( RuntimeEnvBuilder::new() .with_memory_pool(memory_pool) @@ -446,31 +455,112 @@ async fn test_with_limited_memory_and_different_sizes_of_record_batch( )) }; - run_memory_test_for_limited_memory( + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, task_ctx, - 100, - Box::pin(move |i| { + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { if i % 25 == 1 { pool_size / 4 } else { - (16 * datafusion_execution::memory_pool::units::KB) as usize + 16 * KB } }), - ) + memory_behavior: Default::default(), + }) .await?; Ok(()) } #[tokio::test] -async fn test_with_limited_memory_and_large_record_batch() -> datafusion_common::Result<()> -{ +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( +) -> Result<()> { let record_batch_size = 8192; let pool_size = 2 * MB as usize; let task_ctx = { let memory_pool = Arc::new(FairSpillPool::new(pool_size)); TaskContext::default() - .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + 16 * KB + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + 16 * KB + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_large_record_batch() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) .with_runtime(Arc::new( RuntimeEnvBuilder::new() .with_memory_pool(memory_pool) @@ -479,19 +569,45 @@ async fn test_with_limited_memory_and_large_record_batch() -> datafusion_common: }; // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory - run_memory_test_for_limited_memory(task_ctx, 100, Box::pin(move |_| pool_size / 4)) - .await?; + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), + memory_behavior: Default::default(), + }) + .await?; Ok(()) } -async fn run_memory_test_for_limited_memory( +struct RunSortTestWithLimitedMemoryArgs { + pool_size: usize, task_ctx: TaskContext, number_of_record_batches: usize, - get_size_of_record_batch_to_generate: Pin< - Box usize + Send + 'static>, - >, -) -> datafusion_common::Result { + get_size_of_record_batch_to_generate: + Pin usize + Send + 'static>>, + memory_behavior: MemoryBehavior, +} + +#[derive(Default)] +enum MemoryBehavior { + #[default] + AsIs, + TakeAllMemoryAtTheBeginning, + TakeAllMemoryAndReleaseEveryNthBatch(usize), +} + +async fn run_sort_test_with_limited_memory( + args: RunSortTestWithLimitedMemoryArgs, +) -> Result { + let RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches, + get_size_of_record_batch_to_generate, + memory_behavior, + } = args; let scan_schema = Arc::new(Schema::new(vec![ Field::new("col_0", DataType::UInt64, true), Field::new("col_1", DataType::Utf8, true), @@ -530,7 +646,6 @@ async fn run_memory_test_for_limited_memory( }, )), )))); - let sort_exec = Arc::new(SortExec::new( LexOrdering::new(vec![PhysicalSortExpr { expr: col("col_0", &scan_schema).unwrap(), @@ -544,13 +659,43 @@ async fn run_memory_test_for_limited_memory( let task_ctx = Arc::new(task_ctx); - let mut result = sort_exec.execute(0, task_ctx)?; + let mut result = sort_exec.execute(0, Arc::clone(&task_ctx))?; let mut number_of_rows = 0; + let memory_pool = task_ctx.memory_pool(); + let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); + let mut memory_reservation = memory_consumer.register(memory_pool); + + let mut index = 0; + let mut memory_took = false; + while let Some(batch) = result.next().await { + match memory_behavior { + MemoryBehavior::AsIs => { + // Do nothing + } + MemoryBehavior::TakeAllMemoryAtTheBeginning => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(10, &mut memory_reservation)?; + } + } + MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(pool_size, &mut memory_reservation)?; + } else if index % n == 0 { + // release memory + memory_reservation.free(); + } + } + } + let batch = batch?; number_of_rows += batch.num_rows(); + + index += 1; } assert_eq!( @@ -559,7 +704,6 @@ async fn run_memory_test_for_limited_memory( ); let spill_count = sort_exec.metrics().unwrap().spill_count().unwrap(); - assert!( spill_count > 0, "Expected spill, but did not: {number_of_record_batches:?}" @@ -567,3 +711,15 @@ async fn run_memory_test_for_limited_memory( Ok(spill_count) } + +fn grow_memory_as_much_as_possible( + memory_step: usize, + memory_reservation: &mut MemoryReservation, +) -> Result { + let mut was_able_to_grow = false; + while memory_reservation.try_grow(memory_step).is_ok() { + was_able_to_grow = true; + } + + Ok(was_able_to_grow) +}