Skip to content

Commit 2e87dbb

Browse files
committed
Address review comments
1 parent b652cee commit 2e87dbb

File tree

2 files changed

+208
-22
lines changed

2 files changed

+208
-22
lines changed

datafusion/core/tests/fuzz_cases/sort_fuzz.rs

+197-17
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
use std::sync::Arc;
2121

2222
use arrow::{
23-
array::{ArrayRef, Int32Array},
23+
array::{as_string_array, ArrayRef, Int32Array, StringArray},
2424
compute::SortOptions,
2525
record_batch::RecordBatch,
2626
};
@@ -29,6 +29,7 @@ use datafusion::physical_plan::expressions::PhysicalSortExpr;
2929
use datafusion::physical_plan::sorts::sort::SortExec;
3030
use datafusion::physical_plan::{collect, ExecutionPlan};
3131
use datafusion::prelude::{SessionConfig, SessionContext};
32+
use datafusion_common::cast::as_int32_array;
3233
use datafusion_execution::memory_pool::GreedyMemoryPool;
3334
use datafusion_physical_expr::expressions::col;
3435
use datafusion_physical_expr_common::sort_expr::LexOrdering;
@@ -42,12 +43,17 @@ const KB: usize = 1 << 10;
4243
#[cfg_attr(tarpaulin, ignore)]
4344
async fn test_sort_10k_mem() {
4445
for (batch_size, should_spill) in [(5, false), (20000, true), (500000, true)] {
45-
SortTest::new()
46+
let (input, collected) = SortTest::new()
4647
.with_int32_batches(batch_size)
48+
.with_sort_columns(vec!["x"])
4749
.with_pool_size(10 * KB)
4850
.with_should_spill(should_spill)
4951
.run()
5052
.await;
53+
54+
let expected = partitions_to_sorted_vec(&input);
55+
let actual = batches_to_vec(&collected);
56+
assert_eq!(expected, actual, "failure in @ batch_size {batch_size:?}");
5157
}
5258
}
5359

@@ -57,29 +63,123 @@ async fn test_sort_100k_mem() {
5763
for (batch_size, should_spill) in
5864
[(5, false), (10000, false), (20000, true), (1000000, true)]
5965
{
60-
SortTest::new()
66+
let (input, collected) = SortTest::new()
6167
.with_int32_batches(batch_size)
68+
.with_sort_columns(vec!["x"])
69+
.with_pool_size(100 * KB)
70+
.with_should_spill(should_spill)
71+
.run()
72+
.await;
73+
74+
let expected = partitions_to_sorted_vec(&input);
75+
let actual = batches_to_vec(&collected);
76+
assert_eq!(expected, actual, "failure in @ batch_size {batch_size:?}");
77+
}
78+
}
79+
80+
#[tokio::test]
81+
#[cfg_attr(tarpaulin, ignore)]
82+
async fn test_sort_strings_100k_mem() {
83+
for (batch_size, should_spill) in
84+
[(5, false), (1000, false), (10000, true), (20000, true)]
85+
{
86+
let (input, collected) = SortTest::new()
87+
.with_utf8_batches(batch_size)
88+
.with_sort_columns(vec!["x"])
6289
.with_pool_size(100 * KB)
6390
.with_should_spill(should_spill)
6491
.run()
6592
.await;
93+
94+
let mut input = input
95+
.iter()
96+
.flat_map(|p| p.iter())
97+
.map(|b| {
98+
let array = b.column(0);
99+
as_string_array(array)
100+
.iter()
101+
.map(|s| s.unwrap().to_string())
102+
})
103+
.flatten()
104+
.collect::<Vec<String>>();
105+
input.sort_unstable();
106+
let actual = collected
107+
.iter()
108+
.map(|b| {
109+
let array = b.column(0);
110+
as_string_array(array)
111+
.iter()
112+
.map(|s| s.unwrap().to_string())
113+
})
114+
.flatten()
115+
.collect::<Vec<String>>();
116+
assert_eq!(input, actual);
117+
}
118+
}
119+
120+
#[tokio::test]
121+
#[cfg_attr(tarpaulin, ignore)]
122+
async fn test_sort_multi_columns_100k_mem() {
123+
for (batch_size, should_spill) in
124+
[(5, false), (1000, false), (10000, true), (20000, true)]
125+
{
126+
let (input, collected) = SortTest::new()
127+
.with_int32_utf8_batches(batch_size)
128+
.with_sort_columns(vec!["x", "y"])
129+
.with_pool_size(100 * KB)
130+
.with_should_spill(should_spill)
131+
.run()
132+
.await;
133+
134+
fn record_batch_to_vec(b: &RecordBatch) -> Vec<(i32, String)> {
135+
let mut rows: Vec<_> = Vec::new();
136+
let i32_array = as_int32_array(b.column(0)).unwrap();
137+
let string_array = as_string_array(b.column(1));
138+
for i in 0..b.num_rows() {
139+
let str = string_array.value(i).to_string();
140+
let i32 = i32_array.value(i);
141+
rows.push((i32, str));
142+
}
143+
rows
144+
}
145+
let mut input = input
146+
.iter()
147+
.flat_map(|p| p.iter())
148+
.map(record_batch_to_vec)
149+
.flatten()
150+
.collect::<Vec<(i32, String)>>();
151+
input.sort_unstable();
152+
let actual = collected
153+
.iter()
154+
.map(record_batch_to_vec)
155+
.flatten()
156+
.collect::<Vec<(i32, String)>>();
157+
assert_eq!(input, actual);
66158
}
67159
}
68160

69161
#[tokio::test]
70162
async fn test_sort_unlimited_mem() {
71163
for (batch_size, should_spill) in [(5, false), (20000, false), (1000000, false)] {
72-
SortTest::new()
164+
let (input, collected) = SortTest::new()
73165
.with_int32_batches(batch_size)
166+
.with_sort_columns(vec!["x"])
74167
.with_pool_size(usize::MAX)
75168
.with_should_spill(should_spill)
76169
.run()
77170
.await;
171+
172+
let expected = partitions_to_sorted_vec(&input);
173+
let actual = batches_to_vec(&collected);
174+
assert_eq!(expected, actual, "failure in @ batch_size {batch_size:?}");
78175
}
79176
}
177+
80178
#[derive(Debug, Default)]
81179
struct SortTest {
82180
input: Vec<Vec<RecordBatch>>,
181+
/// The names of the columns to sort by
182+
sort_columns: Vec<String>,
83183
/// GreedyMemoryPool size, if specified
84184
pool_size: Option<usize>,
85185
/// If true, expect the sort to spill
@@ -91,12 +191,29 @@ impl SortTest {
91191
Default::default()
92192
}
93193

194+
fn with_sort_columns(mut self, sort_columns: Vec<&str>) -> Self {
195+
self.sort_columns = sort_columns.iter().map(|s| s.to_string()).collect();
196+
self
197+
}
198+
94199
/// Create batches of int32 values of rows
95200
fn with_int32_batches(mut self, rows: usize) -> Self {
96201
self.input = vec![make_staggered_i32_batches(rows)];
97202
self
98203
}
99204

205+
/// Create batches of utf8 values of rows
206+
fn with_utf8_batches(mut self, rows: usize) -> Self {
207+
self.input = vec![make_staggered_utf8_batches(rows)];
208+
self
209+
}
210+
211+
/// Create batches of int32 and utf8 values of rows
212+
fn with_int32_utf8_batches(mut self, rows: usize) -> Self {
213+
self.input = vec![make_staggered_i32_utf8_batches(rows)];
214+
self
215+
}
216+
100217
/// specify that this test should use a memory pool of the specified size
101218
fn with_pool_size(mut self, pool_size: usize) -> Self {
102219
self.pool_size = Some(pool_size);
@@ -110,7 +227,7 @@ impl SortTest {
110227

111228
/// Sort the input using SortExec and ensure the results are
112229
/// correct according to `Vec::sort` both with and without spilling
113-
async fn run(&self) {
230+
async fn run(&self) -> (Vec<Vec<RecordBatch>>, Vec<RecordBatch>) {
114231
let input = self.input.clone();
115232
let first_batch = input
116233
.iter()
@@ -119,16 +236,21 @@ impl SortTest {
119236
.expect("at least one batch");
120237
let schema = first_batch.schema();
121238

122-
let sort = LexOrdering::new(vec![PhysicalSortExpr {
123-
expr: col("x", &schema).unwrap(),
124-
options: SortOptions {
125-
descending: false,
126-
nulls_first: true,
127-
},
128-
}]);
239+
let sort_ordering = LexOrdering::new(
240+
self.sort_columns
241+
.iter()
242+
.map(|c| PhysicalSortExpr {
243+
expr: col(c, &schema).unwrap(),
244+
options: SortOptions {
245+
descending: false,
246+
nulls_first: true,
247+
},
248+
})
249+
.collect(),
250+
);
129251

130252
let exec = MemorySourceConfig::try_new_exec(&input, schema, None).unwrap();
131-
let sort = Arc::new(SortExec::new(sort, exec));
253+
let sort = Arc::new(SortExec::new(sort_ordering, exec));
132254

133255
let session_config = SessionConfig::new();
134256
let session_ctx = if let Some(pool_size) = self.pool_size {
@@ -153,9 +275,6 @@ impl SortTest {
153275
let task_ctx = session_ctx.task_ctx();
154276
let collected = collect(sort.clone(), task_ctx).await.unwrap();
155277

156-
let expected = partitions_to_sorted_vec(&input);
157-
let actual = batches_to_vec(&collected);
158-
159278
if self.should_spill {
160279
assert_ne!(
161280
sort.metrics().unwrap().spill_count().unwrap(),
@@ -175,7 +294,8 @@ impl SortTest {
175294
0,
176295
"The sort should have returned all memory used back to the memory pool"
177296
);
178-
assert_eq!(expected, actual, "failure in @ pool_size {self:?}");
297+
298+
(input, collected)
179299
}
180300
}
181301

@@ -203,3 +323,63 @@ fn make_staggered_i32_batches(len: usize) -> Vec<RecordBatch> {
203323
}
204324
batches
205325
}
326+
327+
/// Return randomly sized record batches in a field named 'x' of type `Utf8`
328+
/// with randomized content
329+
fn make_staggered_utf8_batches(len: usize) -> Vec<RecordBatch> {
330+
let mut rng = rand::thread_rng();
331+
let max_batch = 1024;
332+
333+
let mut batches = vec![];
334+
let mut remaining = len;
335+
while remaining != 0 {
336+
let to_read = rng.gen_range(0..=remaining.min(max_batch));
337+
remaining -= to_read;
338+
339+
batches.push(
340+
RecordBatch::try_from_iter(vec![(
341+
"x",
342+
Arc::new(StringArray::from_iter_values(
343+
(0..to_read).map(|_| format!("test_string_{}", rng.gen::<u32>())),
344+
)) as ArrayRef,
345+
)])
346+
.unwrap(),
347+
)
348+
}
349+
batches
350+
}
351+
352+
/// Return randomly sized record batches in a field named 'x' of type `Int32`
353+
/// with randomized i32 content and a field named 'y' of type `Utf8`
354+
/// with randomized content
355+
fn make_staggered_i32_utf8_batches(len: usize) -> Vec<RecordBatch> {
356+
let mut rng = rand::thread_rng();
357+
let max_batch = 1024;
358+
359+
let mut batches = vec![];
360+
let mut remaining = len;
361+
while remaining != 0 {
362+
let to_read = rng.gen_range(0..=remaining.min(max_batch));
363+
remaining -= to_read;
364+
365+
batches.push(
366+
RecordBatch::try_from_iter(vec![
367+
(
368+
"x",
369+
Arc::new(Int32Array::from_iter_values(
370+
(0..to_read).map(|_| rng.gen()),
371+
)) as ArrayRef,
372+
),
373+
(
374+
"y",
375+
Arc::new(StringArray::from_iter_values(
376+
(0..to_read).map(|_| format!("test_string_{}", rng.gen::<u32>())),
377+
)) as ArrayRef,
378+
),
379+
])
380+
.unwrap(),
381+
)
382+
}
383+
384+
batches
385+
}

datafusion/physical-plan/src/sorts/sort.rs

+11-5
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ struct ExternalSorter {
225225
// ========================================================================
226226
/// Potentially unsorted in memory buffer
227227
in_mem_batches: Vec<RecordBatch>,
228+
/// if `Self::in_mem_batches` are sorted
229+
in_mem_batches_sorted: bool,
228230

229231
/// If data has previously been spilled, the locations of the
230232
/// spill files (in Arrow IPC format)
@@ -277,6 +279,7 @@ impl ExternalSorter {
277279
Self {
278280
schema,
279281
in_mem_batches: vec![],
282+
in_mem_batches_sorted: false,
280283
spills: vec![],
281284
expr: expr.into(),
282285
metrics,
@@ -309,6 +312,7 @@ impl ExternalSorter {
309312
}
310313

311314
self.in_mem_batches.push(input);
315+
self.in_mem_batches_sorted = false;
312316
Ok(())
313317
}
314318

@@ -423,7 +427,8 @@ impl ExternalSorter {
423427
async fn sort_or_spill_in_mem_batches(&mut self) -> Result<()> {
424428
// Release the memory reserved for merge back to the pool so
425429
// there is some left when `in_mem_sort_stream` requests an
426-
// allocation.
430+
// allocation. At the end of this function, memory will be
431+
// reserved again for the next spill.
427432
self.merge_reservation.free();
428433

429434
let before = self.reservation.size();
@@ -458,6 +463,7 @@ impl ExternalSorter {
458463
self.spills.push(spill_file);
459464
} else {
460465
self.in_mem_batches.push(batch);
466+
self.in_mem_batches_sorted = true;
461467
}
462468
}
463469
Some(writer) => {
@@ -662,10 +668,10 @@ impl ExternalSorter {
662668
/// Estimate how much memory is needed to sort a `RecordBatch`.
663669
///
664670
/// This is used to pre-reserve memory for the sort/merge. The sort/merge process involves
665-
/// creating sorted copies of sorted columns in record batches, the sorted copies could be
666-
/// in either row format or array format. Please refer to cursor.rs and stream.rs for more
667-
/// details. No matter what format the sorted copies are, they will use more memory than
668-
/// the original record batch.
671+
/// creating sorted copies of sorted columns in record batches for speeding up comparison
672+
/// in sorting and merging. The sorted copies are in either row format or array format.
673+
/// Please refer to cursor.rs and stream.rs for more details. No matter what format the
674+
/// sorted copies are, they will use more memory than the original record batch.
669675
fn get_reserved_byte_for_record_batch(batch: &RecordBatch) -> usize {
670676
// 2x may not be enough for some cases, but it's a good start.
671677
// If 2x is not enough, user can set a larger value for `sort_spill_reservation_bytes`

0 commit comments

Comments
 (0)