@@ -99,11 +99,7 @@ impl DatasetGenerator {
99
99
let base_batch = self . batch_generator . generate ( ) ;
100
100
let total_rows_num = base_batch. num_rows ( ) ;
101
101
let batches = stagger_batch ( base_batch. clone ( ) ) ;
102
- let dataset = Dataset {
103
- batches,
104
- total_rows_num,
105
- sort_keys : Vec :: new ( ) ,
106
- } ;
102
+ let dataset = Dataset :: new ( batches, Vec :: new ( ) ) ;
107
103
datasets. push ( dataset) ;
108
104
109
105
// Generate the related sorted batches
@@ -123,11 +119,7 @@ impl DatasetGenerator {
123
119
. expect ( "sort batch should not fail" ) ;
124
120
125
121
let batches = stagger_batch ( sorted_batch) ;
126
- let dataset = Dataset {
127
- batches,
128
- total_rows_num,
129
- sort_keys : Vec :: new ( ) ,
130
- } ;
122
+ let dataset = Dataset :: new ( batches, sort_keys) ;
131
123
datasets. push ( dataset) ;
132
124
}
133
125
@@ -143,6 +135,18 @@ pub struct Dataset {
143
135
pub sort_keys : Vec < String > ,
144
136
}
145
137
138
+ impl Dataset {
139
+ pub fn new ( batches : Vec < RecordBatch > , sort_keys : Vec < String > ) -> Self {
140
+ let total_rows_num = batches. iter ( ) . map ( |batch| batch. num_rows ( ) ) . sum :: < usize > ( ) ;
141
+
142
+ Self {
143
+ batches,
144
+ total_rows_num,
145
+ sort_keys,
146
+ }
147
+ }
148
+ }
149
+
146
150
#[ derive( Debug , Clone ) ]
147
151
pub struct ColumnDescr {
148
152
// Column name
@@ -362,6 +366,8 @@ mod test {
362
366
use arrow:: util:: pretty:: pretty_format_batches;
363
367
use arrow_array:: UInt32Array ;
364
368
369
+ use crate :: fuzz_cases:: aggregation_fuzzer:: check_equality_of_batches;
370
+
365
371
use super :: * ;
366
372
367
373
#[ test]
@@ -428,20 +434,8 @@ mod test {
428
434
}
429
435
430
436
// Two batches should be same after sorting
431
- let formatted_batches0 = pretty_format_batches ( & datasets[ 0 ] . batches )
432
- . unwrap ( )
433
- . to_string ( ) ;
434
- let mut formatted_batches0_sorted: Vec < & str > =
435
- formatted_batches0. trim ( ) . lines ( ) . collect ( ) ;
436
- formatted_batches0_sorted. sort_unstable ( ) ;
437
- let formatted_batches1 = pretty_format_batches ( & datasets[ 1 ] . batches )
438
- . unwrap ( )
439
- . to_string ( ) ;
440
- let mut formatted_batches1_sorted: Vec < & str > =
441
- formatted_batches1. trim ( ) . lines ( ) . collect ( ) ;
442
- formatted_batches1_sorted. sort_unstable ( ) ;
443
- assert_eq ! ( formatted_batches0_sorted, formatted_batches1_sorted) ;
444
-
437
+ check_equality_of_batches ( & datasets[ 0 ] . batches , & datasets[ 1 ] . batches ) ;
438
+
445
439
// Rows num should between [16, 32]
446
440
let rows_num0 = datasets[ 0 ]
447
441
. batches
0 commit comments