Skip to content

Commit 3c2b542

Browse files
authored
Add string aggregagte grouping fuzz test (#9190)
1 parent d7dcb12 commit 3c2b542

File tree

7 files changed

+381
-118
lines changed

7 files changed

+381
-118
lines changed

datafusion/core/src/datasource/memory.rs

Lines changed: 53 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,10 @@ use arrow::datatypes::SchemaRef;
2929
use arrow::record_batch::RecordBatch;
3030
use async_trait::async_trait;
3131
use datafusion_common::{
32-
not_impl_err, plan_err, Constraints, DataFusionError, SchemaExt,
32+
not_impl_err, plan_err, Constraints, DFSchema, DataFusionError, SchemaExt,
3333
};
3434
use datafusion_execution::TaskContext;
35+
use parking_lot::Mutex;
3536
use tokio::sync::RwLock;
3637
use tokio::task::JoinSet;
3738

@@ -44,6 +45,7 @@ use crate::physical_plan::memory::MemoryExec;
4445
use crate::physical_plan::{common, SendableRecordBatchStream};
4546
use crate::physical_plan::{repartition::RepartitionExec, Partitioning};
4647
use crate::physical_plan::{DisplayAs, DisplayFormatType, ExecutionPlan};
48+
use crate::physical_planner::create_physical_sort_expr;
4749

4850
/// Type alias for partition data
4951
pub type PartitionData = Arc<RwLock<Vec<RecordBatch>>>;
@@ -58,6 +60,9 @@ pub struct MemTable {
5860
pub(crate) batches: Vec<PartitionData>,
5961
constraints: Constraints,
6062
column_defaults: HashMap<String, Expr>,
63+
/// Optional pre-known sort order(s). Must be `SortExpr`s.
64+
/// inserting data into this table removes the order
65+
pub sort_order: Arc<Mutex<Vec<Vec<Expr>>>>,
6166
}
6267

6368
impl MemTable {
@@ -82,6 +87,7 @@ impl MemTable {
8287
.collect::<Vec<_>>(),
8388
constraints: Constraints::empty(),
8489
column_defaults: HashMap::new(),
90+
sort_order: Arc::new(Mutex::new(vec![])),
8591
})
8692
}
8793

@@ -100,6 +106,21 @@ impl MemTable {
100106
self
101107
}
102108

109+
/// Specify an optional pre-known sort order(s). Must be `SortExpr`s.
110+
///
111+
/// If the data is not sorted by this order, DataFusion may produce
112+
/// incorrect results.
113+
///
114+
/// DataFusion may take advantage of this ordering to omit sorts
115+
/// or use more efficient algorithms.
116+
///
117+
/// Note that multiple sort orders are supported, if some are known to be
118+
/// equivalent,
119+
pub fn with_sort_order(self, mut sort_order: Vec<Vec<Expr>>) -> Self {
120+
std::mem::swap(self.sort_order.lock().as_mut(), &mut sort_order);
121+
self
122+
}
123+
103124
/// Create a mem table by reading from another data source
104125
pub async fn load(
105126
t: Arc<dyn TableProvider>,
@@ -184,7 +205,7 @@ impl TableProvider for MemTable {
184205

185206
async fn scan(
186207
&self,
187-
_state: &SessionState,
208+
state: &SessionState,
188209
projection: Option<&Vec<usize>>,
189210
_filters: &[Expr],
190211
_limit: Option<usize>,
@@ -194,11 +215,33 @@ impl TableProvider for MemTable {
194215
let inner_vec = arc_inner_vec.read().await;
195216
partitions.push(inner_vec.clone())
196217
}
197-
Ok(Arc::new(MemoryExec::try_new(
198-
&partitions,
199-
self.schema(),
200-
projection.cloned(),
201-
)?))
218+
let mut exec =
219+
MemoryExec::try_new(&partitions, self.schema(), projection.cloned())?;
220+
221+
// add sort information if present
222+
let sort_order = self.sort_order.lock();
223+
if !sort_order.is_empty() {
224+
let df_schema = DFSchema::try_from(self.schema.as_ref().clone())?;
225+
226+
let file_sort_order = sort_order
227+
.iter()
228+
.map(|sort_exprs| {
229+
sort_exprs
230+
.iter()
231+
.map(|expr| {
232+
create_physical_sort_expr(
233+
expr,
234+
&df_schema,
235+
state.execution_props(),
236+
)
237+
})
238+
.collect::<Result<Vec<_>>>()
239+
})
240+
.collect::<Result<Vec<_>>>()?;
241+
exec = exec.with_sort_information(file_sort_order);
242+
}
243+
244+
Ok(Arc::new(exec))
202245
}
203246

204247
/// Returns an ExecutionPlan that inserts the execution results of a given [`ExecutionPlan`] into this [`MemTable`].
@@ -219,6 +262,9 @@ impl TableProvider for MemTable {
219262
input: Arc<dyn ExecutionPlan>,
220263
overwrite: bool,
221264
) -> Result<Arc<dyn ExecutionPlan>> {
265+
// If we are inserting into the table, any sort order may be messed up so reset it here
266+
*self.sort_order.lock() = vec![];
267+
222268
// Create a physical plan from the logical plan.
223269
// Check that the schema of the plan matches the schema of this table.
224270
if !self

datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs

Lines changed: 178 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -22,21 +22,32 @@ use arrow::compute::{concat_batches, SortOptions};
2222
use arrow::datatypes::DataType;
2323
use arrow::record_batch::RecordBatch;
2424
use arrow::util::pretty::pretty_format_batches;
25-
use datafusion::physical_plan::aggregates::{
26-
AggregateExec, AggregateMode, PhysicalGroupBy,
27-
};
25+
use arrow_array::cast::AsArray;
26+
use arrow_array::types::Int64Type;
27+
use arrow_array::Array;
28+
use hashbrown::HashMap;
2829
use rand::rngs::StdRng;
2930
use rand::{Rng, SeedableRng};
31+
use tokio::task::JoinSet;
3032

33+
use datafusion::common::Result;
34+
use datafusion::datasource::MemTable;
35+
use datafusion::physical_plan::aggregates::{
36+
AggregateExec, AggregateMode, PhysicalGroupBy,
37+
};
3138
use datafusion::physical_plan::memory::MemoryExec;
3239
use datafusion::physical_plan::{collect, displayable, ExecutionPlan};
33-
use datafusion::prelude::{SessionConfig, SessionContext};
40+
use datafusion::prelude::{DataFrame, SessionConfig, SessionContext};
41+
use datafusion_common::tree_node::{TreeNode, TreeNodeVisitor, VisitRecursion};
3442
use datafusion_physical_expr::expressions::{col, Sum};
3543
use datafusion_physical_expr::{AggregateExpr, PhysicalSortExpr};
36-
use test_utils::add_empty_batches;
44+
use datafusion_physical_plan::InputOrderMode;
45+
use test_utils::{add_empty_batches, StringBatchGenerator};
3746

38-
#[tokio::test(flavor = "multi_thread", worker_threads = 8)]
39-
async fn aggregate_test() {
47+
/// Tests that streaming aggregate and batch (non streaming) aggregate produce
48+
/// same results
49+
#[tokio::test(flavor = "multi_thread")]
50+
async fn streaming_aggregate_test() {
4051
let test_cases = vec![
4152
vec!["a"],
4253
vec!["b", "a"],
@@ -50,18 +61,18 @@ async fn aggregate_test() {
5061
let n = 300;
5162
let distincts = vec![10, 20];
5263
for distinct in distincts {
53-
let mut handles = Vec::new();
64+
let mut join_set = JoinSet::new();
5465
for i in 0..n {
5566
let test_idx = i % test_cases.len();
5667
let group_by_columns = test_cases[test_idx].clone();
57-
let job = tokio::spawn(run_aggregate_test(
68+
join_set.spawn(run_aggregate_test(
5869
make_staggered_batches::<true>(1000, distinct, i as u64),
5970
group_by_columns,
6071
));
61-
handles.push(job);
6272
}
63-
for job in handles {
64-
job.await.unwrap();
73+
while let Some(join_handle) = join_set.join_next().await {
74+
// propagate errors
75+
join_handle.unwrap();
6576
}
6677
}
6778
}
@@ -234,3 +245,158 @@ pub(crate) fn make_staggered_batches<const STREAM: bool>(
234245
}
235246
add_empty_batches(batches, &mut rng)
236247
}
248+
249+
/// Test group by with string/large string columns
250+
#[tokio::test(flavor = "multi_thread")]
251+
async fn group_by_strings() {
252+
let mut join_set = JoinSet::new();
253+
for large in [true, false] {
254+
for sorted in [true, false] {
255+
for generator in StringBatchGenerator::interesting_cases() {
256+
join_set.spawn(group_by_string_test(generator, sorted, large));
257+
}
258+
}
259+
}
260+
while let Some(join_handle) = join_set.join_next().await {
261+
// propagate errors
262+
join_handle.unwrap();
263+
}
264+
}
265+
266+
/// Run GROUP BY <x> using SQL and ensure the results are correct
267+
///
268+
/// If sorted is true, the input batches will be sorted by the group by column
269+
/// to test the streaming group by case
270+
///
271+
/// if large is true, the input batches will be LargeStringArray
272+
async fn group_by_string_test(
273+
mut generator: StringBatchGenerator,
274+
sorted: bool,
275+
large: bool,
276+
) {
277+
let column_name = "a";
278+
let input = if sorted {
279+
generator.make_sorted_input_batches(large)
280+
} else {
281+
generator.make_input_batches()
282+
};
283+
284+
let expected = compute_counts(&input, column_name);
285+
286+
let schema = input[0].schema();
287+
let session_config = SessionConfig::new().with_batch_size(50);
288+
let ctx = SessionContext::new_with_config(session_config);
289+
290+
let provider = MemTable::try_new(schema.clone(), vec![input]).unwrap();
291+
let provider = if sorted {
292+
let sort_expr = datafusion::prelude::col("a").sort(true, true);
293+
provider.with_sort_order(vec![vec![sort_expr]])
294+
} else {
295+
provider
296+
};
297+
298+
ctx.register_table("t", Arc::new(provider)).unwrap();
299+
300+
let df = ctx
301+
.sql("SELECT a, COUNT(*) FROM t GROUP BY a")
302+
.await
303+
.unwrap();
304+
verify_ordered_aggregate(&df, sorted).await;
305+
let results = df.collect().await.unwrap();
306+
307+
// verify that the results are correct
308+
let actual = extract_result_counts(results);
309+
assert_eq!(expected, actual);
310+
}
311+
async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) {
312+
struct Visitor {
313+
expected_sort: bool,
314+
}
315+
let mut visitor = Visitor { expected_sort };
316+
317+
impl TreeNodeVisitor for Visitor {
318+
type N = Arc<dyn ExecutionPlan>;
319+
fn pre_visit(&mut self, node: &Self::N) -> Result<VisitRecursion> {
320+
if let Some(exec) = node.as_any().downcast_ref::<AggregateExec>() {
321+
if self.expected_sort {
322+
assert!(matches!(
323+
exec.input_order_mode(),
324+
InputOrderMode::PartiallySorted(_) | InputOrderMode::Sorted
325+
));
326+
} else {
327+
assert!(matches!(exec.input_order_mode(), InputOrderMode::Linear));
328+
}
329+
}
330+
Ok(VisitRecursion::Continue)
331+
}
332+
}
333+
334+
let plan = frame.clone().create_physical_plan().await.unwrap();
335+
plan.visit(&mut visitor).unwrap();
336+
}
337+
338+
/// Compute the count of each distinct value in the specified column
339+
///
340+
/// ```text
341+
/// +---------------+---------------+
342+
/// | a | b |
343+
/// +---------------+---------------+
344+
/// | 𭏷񬝜󓴻𼇪󄶛𑩁򽵐󦊟 | 󺚤𘱦𫎛񐕿 |
345+
/// | 󂌿󶴬񰶨񺹭𿑵󖺉 | 񥼧􋽐󮋋󑤐𬿪𜋃 |
346+
/// ```
347+
fn compute_counts(batches: &[RecordBatch], col: &str) -> HashMap<Option<String>, i64> {
348+
let mut output = HashMap::new();
349+
for arr in batches
350+
.iter()
351+
.map(|batch| batch.column_by_name(col).unwrap())
352+
{
353+
for value in to_str_vec(arr) {
354+
output.entry(value).and_modify(|e| *e += 1).or_insert(1);
355+
}
356+
}
357+
output
358+
}
359+
360+
fn to_str_vec(array: &ArrayRef) -> Vec<Option<String>> {
361+
match array.data_type() {
362+
DataType::Utf8 => array
363+
.as_string::<i32>()
364+
.iter()
365+
.map(|x| x.map(|x| x.to_string()))
366+
.collect(),
367+
DataType::LargeUtf8 => array
368+
.as_string::<i64>()
369+
.iter()
370+
.map(|x| x.map(|x| x.to_string()))
371+
.collect(),
372+
_ => panic!("unexpected type"),
373+
}
374+
}
375+
376+
/// extracts the value of the first column and the count of the second column
377+
/// ```text
378+
/// +----------------+----------+
379+
/// | a | COUNT(*) |
380+
/// +----------------+----------+
381+
/// | 񩢰񴠍 | 8 |
382+
/// | 󇿺򷜄򩨝񜖫𑟑񣶏󣥽𹕉 | 11 |
383+
/// ```
384+
fn extract_result_counts(results: Vec<RecordBatch>) -> HashMap<Option<String>, i64> {
385+
let group_arrays = results.iter().map(|batch| batch.column(0));
386+
387+
let count_arrays = results
388+
.iter()
389+
.map(|batch| batch.column(1).as_primitive::<Int64Type>());
390+
391+
let mut output = HashMap::new();
392+
for (group_arr, count_arr) in group_arrays.zip(count_arrays) {
393+
assert_eq!(group_arr.len(), count_arr.len());
394+
let group_values = to_str_vec(group_arr);
395+
for (group, count) in group_values.into_iter().zip(count_arr.iter()) {
396+
assert!(output.get(&group).is_none());
397+
let count = count.unwrap(); // counts can never be null
398+
output.insert(group, count);
399+
}
400+
}
401+
output
402+
}

0 commit comments

Comments
 (0)