diff --git a/Cargo.lock b/Cargo.lock index 2b3eeecf5d9b..ee6dda88a0e2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3966,9 +3966,9 @@ dependencies = [ [[package]] name = "libz-rs-sys" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "902bc563b5d65ad9bba616b490842ef0651066a1a1dc3ce1087113ffcb873c8d" +checksum = "6489ca9bd760fe9642d7644e827b0c9add07df89857b0416ee15c1cc1a3b8c5a" dependencies = [ "zlib-rs", ] @@ -7512,9 +7512,9 @@ dependencies = [ [[package]] name = "zlib-rs" -version = "0.4.2" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8b20717f0917c908dc63de2e44e97f1e6b126ca58d0e391cee86d504eb8fbd05" +checksum = "868b928d7949e09af2f6086dfc1e01936064cc7a819253bce650d4e2a2d63ba8" [[package]] name = "zstd" diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index ff3b66986ced..eaa7c624c5ee 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -15,15 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::sync::Arc; - use crate::fuzz_cases::aggregation_fuzzer::{ AggregationFuzzerBuilder, DatasetGeneratorConfig, QueryBuilder, }; +use std::pin::Pin; +use std::sync::Arc; use arrow::array::{ types::Int64Type, Array, ArrayRef, AsArray, Int32Array, Int64Array, RecordBatch, - StringArray, + StringArray, UInt64Array, }; use arrow::compute::{concat_batches, SortOptions}; use arrow::datatypes::DataType; @@ -47,17 +47,23 @@ use datafusion_physical_expr::expressions::{col, lit, Column}; use datafusion_physical_expr::PhysicalSortExpr; use datafusion_physical_expr_common::sort_expr::LexOrdering; use datafusion_physical_plan::InputOrderMode; +use futures::StreamExt; use test_utils::{add_empty_batches, StringBatchGenerator}; -use datafusion_execution::memory_pool::FairSpillPool; +use super::record_batch_generator::get_supported_types_columns; +use crate::fuzz_cases::stream_exec::StreamExec; +use datafusion_execution::memory_pool::units::{KB, MB}; +use datafusion_execution::memory_pool::{ + FairSpillPool, MemoryConsumer, MemoryReservation, +}; use datafusion_execution::runtime_env::RuntimeEnvBuilder; use datafusion_execution::TaskContext; +use datafusion_functions_aggregate::array_agg::array_agg_udaf; use datafusion_physical_plan::metrics::MetricValue; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use rand::rngs::StdRng; use rand::{random, thread_rng, Rng, SeedableRng}; -use super::record_batch_generator::get_supported_types_columns; - // ======================================================================== // The new aggregation fuzz tests based on [`AggregationFuzzer`] // ======================================================================== @@ -623,7 +629,10 @@ fn extract_result_counts(results: Vec) -> HashMap, i output } -fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc) { +fn assert_spill_count_metric( + expect_spill: bool, + single_aggregate: Arc, +) -> usize { if let Some(metrics_set) = single_aggregate.metrics() { let mut spill_count = 0; @@ -640,6 +649,8 @@ fn assert_spill_count_metric(expect_spill: bool, single_aggregate: Arc 0 { panic!("Expected no spill but found SpillCount metric with value greater than 0."); } + + spill_count } else { panic!("No metrics returned from the operator; cannot verify spilling."); } @@ -753,3 +764,334 @@ async fn test_single_mode_aggregate_with_spill() -> Result<()> { Ok(()) } + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; + + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {} should be greater than pool size {}", + total_spill_files_size, + pool_size + ); + + Ok(()) +} + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + (16 * KB) as usize + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_high_cardinality_with_limited_memory_and_large_record_batch() -> Result<()> +{ + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config(SessionConfig::new().with_batch_size(record_batch_size)) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_test_high_cardinality(RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +struct RunTestHighCardinalityArgs { + pool_size: usize, + task_ctx: TaskContext, + number_of_record_batches: usize, + get_size_of_record_batch_to_generate: + Pin usize + Send + 'static>>, + memory_behavior: MemoryBehavior, +} + +#[derive(Default)] +enum MemoryBehavior { + #[default] + AsIs, + TakeAllMemoryAtTheBeginning, + TakeAllMemoryAndReleaseEveryNthBatch(usize), +} + +async fn run_test_high_cardinality(args: RunTestHighCardinalityArgs) -> Result { + let RunTestHighCardinalityArgs { + pool_size, + task_ctx, + number_of_record_batches, + get_size_of_record_batch_to_generate, + memory_behavior, + } = args; + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let group_by = PhysicalGroupBy::new_single(vec![( + Arc::new(Column::new("col_0", 0)), + "col_0".to_string(), + )]); + + let aggregate_expressions = vec![Arc::new( + AggregateExprBuilder::new( + array_agg_udaf(), + vec![col("col_1", &scan_schema).unwrap()], + ) + .schema(Arc::clone(&scan_schema)) + .alias("array_agg(col_1)") + .build()?, + )]; + + let record_batch_size = task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc = + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + // Grouping key + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + // Grouping value + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + + let aggregate_exec = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + group_by.clone(), + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + plan, + Arc::clone(&scan_schema), + )?); + let aggregate_final = Arc::new(AggregateExec::try_new( + AggregateMode::Final, + group_by, + aggregate_expressions.clone(), + vec![None; aggregate_expressions.len()], + aggregate_exec, + Arc::clone(&scan_schema), + )?); + + let task_ctx = Arc::new(task_ctx); + + let mut result = aggregate_final.execute(0, Arc::clone(&task_ctx))?; + + let mut number_of_groups = 0; + + let memory_pool = task_ctx.memory_pool(); + let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); + let mut memory_reservation = memory_consumer.register(memory_pool); + + let mut index = 0; + let mut memory_took = false; + + while let Some(batch) = result.next().await { + match memory_behavior { + MemoryBehavior::AsIs => { + // Do nothing + } + MemoryBehavior::TakeAllMemoryAtTheBeginning => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(10, &mut memory_reservation)?; + } + } + MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(pool_size, &mut memory_reservation)?; + } else if index % n == 0 { + // release memory + memory_reservation.free(); + } + } + } + + let batch = batch?; + number_of_groups += batch.num_rows(); + + index += 1; + } + + assert_eq!( + number_of_groups, + number_of_record_batches * record_batch_size as usize + ); + + let spill_count = assert_spill_count_metric(true, aggregate_final); + + Ok(spill_count) +} + +fn grow_memory_as_much_as_possible( + memory_step: usize, + memory_reservation: &mut MemoryReservation, +) -> Result { + let mut was_able_to_grow = false; + while memory_reservation.try_grow(memory_step).is_ok() { + was_able_to_grow = true; + } + + Ok(was_able_to_grow) +} diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 8ccc2a5bc131..6430859e32b9 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -33,3 +33,4 @@ mod window_fuzz; // Utility modules mod record_batch_generator; +mod stream_exec; diff --git a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs index 0b0f0aa2f105..7a5ebf4b439d 100644 --- a/datafusion/core/tests/fuzz_cases/sort_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/sort_fuzz.rs @@ -17,13 +17,17 @@ //! Fuzz Test for various corner cases sorting RecordBatches exceeds available memory and should spill +use std::pin::Pin; use std::sync::Arc; +use arrow::array::UInt64Array; use arrow::{ array::{as_string_array, ArrayRef, Int32Array, StringArray}, compute::SortOptions, record_batch::RecordBatch, }; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::common::Result; use datafusion::datasource::memory::MemorySourceConfig; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::physical_plan::expressions::PhysicalSortExpr; @@ -31,10 +35,17 @@ use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::{SessionConfig, SessionContext}; use datafusion_common::cast::as_int32_array; -use datafusion_execution::memory_pool::GreedyMemoryPool; +use datafusion_execution::memory_pool::{ + FairSpillPool, GreedyMemoryPool, MemoryConsumer, MemoryReservation, +}; use datafusion_physical_expr::expressions::col; use datafusion_physical_expr_common::sort_expr::LexOrdering; +use futures::StreamExt; +use crate::fuzz_cases::stream_exec::StreamExec; +use datafusion_execution::memory_pool::units::MB; +use datafusion_execution::TaskContext; +use datafusion_physical_plan::stream::RecordBatchStreamAdapter; use rand::Rng; use test_utils::{batches_to_vec, partitions_to_sorted_vec}; @@ -379,3 +390,336 @@ fn make_staggered_i32_utf8_batches(len: usize) -> Vec { batches } + +#[tokio::test] +async fn test_sort_with_limited_memory() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + let record_batch_size = pool_size / 16; + + // Basic test with a lot of groups that cannot all fit in memory and 1 record batch + // from each spill file is too much memory + let spill_count = + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| record_batch_size), + memory_behavior: Default::default(), + }) + .await?; + + let total_spill_files_size = spill_count * record_batch_size; + assert!( + total_spill_files_size > pool_size, + "Total spill files size {} should be greater than pool size {}", + total_spill_files_size, + pool_size + ); + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch() -> Result<()> +{ + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + 16 * KB + } + }), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_changing_memory_reservation( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + 16 * KB + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(10), + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_different_sizes_of_record_batch_and_take_all_memory( +) -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |i| { + if i % 25 == 1 { + pool_size / 4 + } else { + 16 * KB + } + }), + memory_behavior: MemoryBehavior::TakeAllMemoryAtTheBeginning, + }) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn test_sort_with_limited_memory_and_large_record_batch() -> Result<()> { + let record_batch_size = 8192; + let pool_size = 2 * MB as usize; + let task_ctx = { + let memory_pool = Arc::new(FairSpillPool::new(pool_size)); + TaskContext::default() + .with_session_config( + SessionConfig::new() + .with_batch_size(record_batch_size) + .with_sort_spill_reservation_bytes(1), + ) + .with_runtime(Arc::new( + RuntimeEnvBuilder::new() + .with_memory_pool(memory_pool) + .build()?, + )) + }; + + // Test that the merge degree of multi level merge sort cannot be fixed size when there is not enough memory + run_sort_test_with_limited_memory(RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches: 100, + get_size_of_record_batch_to_generate: Box::pin(move |_| pool_size / 4), + memory_behavior: Default::default(), + }) + .await?; + + Ok(()) +} + +struct RunSortTestWithLimitedMemoryArgs { + pool_size: usize, + task_ctx: TaskContext, + number_of_record_batches: usize, + get_size_of_record_batch_to_generate: + Pin usize + Send + 'static>>, + memory_behavior: MemoryBehavior, +} + +#[derive(Default)] +enum MemoryBehavior { + #[default] + AsIs, + TakeAllMemoryAtTheBeginning, + TakeAllMemoryAndReleaseEveryNthBatch(usize), +} + +async fn run_sort_test_with_limited_memory( + args: RunSortTestWithLimitedMemoryArgs, +) -> Result { + let RunSortTestWithLimitedMemoryArgs { + pool_size, + task_ctx, + number_of_record_batches, + get_size_of_record_batch_to_generate, + memory_behavior, + } = args; + let scan_schema = Arc::new(Schema::new(vec![ + Field::new("col_0", DataType::UInt64, true), + Field::new("col_1", DataType::Utf8, true), + ])); + + let record_batch_size = task_ctx.session_config().batch_size() as u64; + + let schema = Arc::clone(&scan_schema); + let plan: Arc = + Arc::new(StreamExec::new(Box::pin(RecordBatchStreamAdapter::new( + Arc::clone(&schema), + futures::stream::iter((0..number_of_record_batches as u64).map( + move |index| { + let mut record_batch_memory_size = + get_size_of_record_batch_to_generate(index as usize); + record_batch_memory_size = record_batch_memory_size + .saturating_sub(size_of::() * record_batch_size as usize); + + let string_item_size = + record_batch_memory_size / record_batch_size as usize; + let string_array = Arc::new(StringArray::from_iter_values( + (0..record_batch_size).map(|_| "a".repeat(string_item_size)), + )); + + RecordBatch::try_new( + Arc::clone(&schema), + vec![ + Arc::new(UInt64Array::from_iter_values( + (index * record_batch_size) + ..(index * record_batch_size) + record_batch_size, + )), + string_array, + ], + ) + .map_err(|err| err.into()) + }, + )), + )))); + let sort_exec = Arc::new(SortExec::new( + LexOrdering::new(vec![PhysicalSortExpr { + expr: col("col_0", &scan_schema).unwrap(), + options: SortOptions { + descending: false, + nulls_first: true, + }, + }]), + plan, + )); + + let task_ctx = Arc::new(task_ctx); + + let mut result = sort_exec.execute(0, Arc::clone(&task_ctx))?; + + let mut number_of_rows = 0; + + let memory_pool = task_ctx.memory_pool(); + let memory_consumer = MemoryConsumer::new("mock_memory_consumer"); + let mut memory_reservation = memory_consumer.register(memory_pool); + + let mut index = 0; + let mut memory_took = false; + + while let Some(batch) = result.next().await { + match memory_behavior { + MemoryBehavior::AsIs => { + // Do nothing + } + MemoryBehavior::TakeAllMemoryAtTheBeginning => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(10, &mut memory_reservation)?; + } + } + MemoryBehavior::TakeAllMemoryAndReleaseEveryNthBatch(n) => { + if !memory_took { + memory_took = true; + grow_memory_as_much_as_possible(pool_size, &mut memory_reservation)?; + } else if index % n == 0 { + // release memory + memory_reservation.free(); + } + } + } + + let batch = batch?; + number_of_rows += batch.num_rows(); + + index += 1; + } + + assert_eq!( + number_of_rows, + number_of_record_batches * record_batch_size as usize + ); + + let spill_count = sort_exec.metrics().unwrap().spill_count().unwrap(); + assert!( + spill_count > 0, + "Expected spill, but did not: {number_of_record_batches:?}" + ); + + Ok(spill_count) +} + +fn grow_memory_as_much_as_possible( + memory_step: usize, + memory_reservation: &mut MemoryReservation, +) -> Result { + let mut was_able_to_grow = false; + while memory_reservation.try_grow(memory_step).is_ok() { + was_able_to_grow = true; + } + + Ok(was_able_to_grow) +} diff --git a/datafusion/core/tests/fuzz_cases/stream_exec.rs b/datafusion/core/tests/fuzz_cases/stream_exec.rs new file mode 100644 index 000000000000..6e71b9988d79 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/stream_exec.rs @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::SchemaRef; +use datafusion_common::DataFusionError; +use datafusion_execution::{SendableRecordBatchStream, TaskContext}; +use datafusion_physical_expr::{EquivalenceProperties, Partitioning}; +use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType}; +use datafusion_physical_plan::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, +}; +use std::any::Any; +use std::fmt::{Debug, Formatter}; +use std::sync::{Arc, Mutex}; + +/// Execution plan that return the stream on the call to `execute`. further calls to `execute` will +/// return an error +pub struct StreamExec { + /// the results to send back + stream: Mutex>, + cache: PlanProperties, +} + +impl Debug for StreamExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "StreamExec") + } +} + +impl StreamExec { + pub fn new(stream: SendableRecordBatchStream) -> Self { + let cache = Self::compute_properties(stream.schema()); + Self { + stream: Mutex::new(Some(stream)), + cache, + } + } + + /// This function creates the cache object that stores the plan properties such as schema, equivalence properties, ordering, partitioning, etc. + fn compute_properties(schema: SchemaRef) -> PlanProperties { + PlanProperties::new( + EquivalenceProperties::new(schema), + Partitioning::UnknownPartitioning(1), + EmissionType::Incremental, + Boundedness::Bounded, + ) + } +} + +impl DisplayAs for StreamExec { + fn fmt_as(&self, t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + match t { + DisplayFormatType::Default | DisplayFormatType::Verbose => { + write!(f, "StreamExec:") + } + DisplayFormatType::TreeRender => { + write!(f, "") + } + } + } +} + +impl ExecutionPlan for StreamExec { + fn name(&self) -> &'static str { + Self::static_name() + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn properties(&self) -> &PlanProperties { + &self.cache + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _: Vec>, + ) -> datafusion_common::Result> { + unimplemented!() + } + + /// Returns a stream which yields data + fn execute( + &self, + partition: usize, + _context: Arc, + ) -> datafusion_common::Result { + assert_eq!(partition, 0); + + let stream = self.stream.lock().unwrap().take(); + + stream.ok_or(DataFusionError::Internal( + "Stream already consumed".to_string(), + )) + } +}