Skip to content

Commit 35525d5

Browse files
committed
refactor PrimitiveArrayGenerator.
1 parent e61792e commit 35525d5

File tree

2 files changed

+70
-56
lines changed

2 files changed

+70
-56
lines changed

datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs

+23-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
use std::sync::Arc;
1919

20+
use arrow::datatypes::{Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type};
2021
use arrow_array::{ArrayRef, RecordBatch};
2122
use arrow_schema::{DataType, Field, Schema};
2223
use datafusion_common::{arrow_datafusion_err, DataFusionError, Result};
@@ -222,7 +223,7 @@ macro_rules! generate_string_array {
222223
}
223224

224225
macro_rules! generate_primitive_array {
225-
($SELF:ident, $NUM_ROWS:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $DATA_TYPE:ident) => {
226+
($SELF:ident, $NUM_ROWS:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $DATA_TYPE:ident, $ARROW_TYPE:ident) => {
226227
paste::paste! {{
227228
let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len());
228229
let null_pct = $SELF.candidate_null_pcts[null_pct_idx];
@@ -239,7 +240,7 @@ macro_rules! generate_primitive_array {
239240
rng: $ARRAY_GEN_RNG,
240241
};
241242

242-
generator.[< gen_data_ $DATA_TYPE >]()
243+
generator.gen_data::<$DATA_TYPE, $ARROW_TYPE>()
243244
}}}
244245
}
245246

@@ -297,7 +298,8 @@ impl RecordBatchGenerator {
297298
num_rows,
298299
batch_gen_rng,
299300
array_gen_rng,
300-
i8
301+
i8,
302+
Int8Type
301303
)
302304
}
303305
DataType::Int16 => {
@@ -306,7 +308,8 @@ impl RecordBatchGenerator {
306308
num_rows,
307309
batch_gen_rng,
308310
array_gen_rng,
309-
i16
311+
i16,
312+
Int16Type
310313
)
311314
}
312315
DataType::Int32 => {
@@ -315,7 +318,8 @@ impl RecordBatchGenerator {
315318
num_rows,
316319
batch_gen_rng,
317320
array_gen_rng,
318-
i32
321+
i32,
322+
Int32Type
319323
)
320324
}
321325
DataType::Int64 => {
@@ -324,7 +328,8 @@ impl RecordBatchGenerator {
324328
num_rows,
325329
batch_gen_rng,
326330
array_gen_rng,
327-
i64
331+
i64,
332+
Int64Type
328333
)
329334
}
330335
DataType::UInt8 => {
@@ -333,7 +338,8 @@ impl RecordBatchGenerator {
333338
num_rows,
334339
batch_gen_rng,
335340
array_gen_rng,
336-
u8
341+
u8,
342+
UInt8Type
337343
)
338344
}
339345
DataType::UInt16 => {
@@ -342,7 +348,8 @@ impl RecordBatchGenerator {
342348
num_rows,
343349
batch_gen_rng,
344350
array_gen_rng,
345-
u16
351+
u16,
352+
UInt16Type
346353
)
347354
}
348355
DataType::UInt32 => {
@@ -351,7 +358,8 @@ impl RecordBatchGenerator {
351358
num_rows,
352359
batch_gen_rng,
353360
array_gen_rng,
354-
u32
361+
u32,
362+
UInt32Type
355363
)
356364
}
357365
DataType::UInt64 => {
@@ -360,7 +368,8 @@ impl RecordBatchGenerator {
360368
num_rows,
361369
batch_gen_rng,
362370
array_gen_rng,
363-
u64
371+
u64,
372+
UInt64Type
364373
)
365374
}
366375
DataType::Float32 => {
@@ -369,7 +378,8 @@ impl RecordBatchGenerator {
369378
num_rows,
370379
batch_gen_rng,
371380
array_gen_rng,
372-
f32
381+
f32,
382+
Float32Type
373383
)
374384
}
375385
DataType::Float64 => {
@@ -378,7 +388,8 @@ impl RecordBatchGenerator {
378388
num_rows,
379389
batch_gen_rng,
380390
array_gen_rng,
381-
f64
391+
f64,
392+
Float64Type
382393
)
383394
}
384395
DataType::Utf8 => {

test-utils/src/array_gen/primitive.rs

+47-44
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,10 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{ArrayRef, PrimitiveArray, UInt32Array};
19-
use arrow::datatypes::{
20-
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
21-
UInt32Type, UInt64Type, UInt8Type,
22-
};
18+
use arrow::array::{ArrayRef, ArrowPrimitiveType, PrimitiveArray, UInt32Array};
19+
use arrow::datatypes::DataType;
20+
use rand::distributions::Standard;
21+
use rand::prelude::Distribution;
2322
use rand::rngs::StdRng;
2423
use rand::Rng;
2524

@@ -35,46 +34,50 @@ pub struct PrimitiveArrayGenerator {
3534
pub rng: StdRng,
3635
}
3736

38-
macro_rules! impl_gen_data {
39-
($NATIVE_TYPE:ty, $ARROW_TYPE:ident) => {
40-
paste::paste! {
41-
pub fn [< gen_data_ $NATIVE_TYPE >](&mut self) -> ArrayRef {
42-
// table of strings from which to draw
43-
let distinct_primitives: PrimitiveArray<$ARROW_TYPE> = (0..self.num_distinct_primitives)
44-
.map(|_| Some(self.rng.gen::<$NATIVE_TYPE>()))
45-
.collect();
37+
// TODO: support generating more primitive arrays
38+
impl PrimitiveArrayGenerator {
39+
pub fn gen_data<N, A: ArrowPrimitiveType>(&mut self) -> ArrayRef
40+
where
41+
A: ArrowPrimitiveType<Native = N>,
42+
N: std::marker::Sync + std::marker::Send,
43+
Standard: Distribution<N>
44+
{
45+
// table of primitives from which to draw
46+
let distinct_primitives: PrimitiveArray<A> = (0..self.num_distinct_primitives)
47+
.map(|_| Some(match A::DATA_TYPE {
48+
DataType::Int8
49+
| DataType::Int16
50+
| DataType::Int32
51+
| DataType::Int64
52+
| DataType::UInt8
53+
| DataType::UInt16
54+
| DataType::UInt32
55+
| DataType::UInt64
56+
| DataType::Float32
57+
| DataType::Float64 => self.rng.gen::<N>(),
4658

47-
// pick num_strings randomly from the distinct string table
48-
let indicies: UInt32Array = (0..self.num_primitives)
49-
.map(|_| {
50-
if self.rng.gen::<f64>() < self.null_pct {
51-
None
52-
} else if self.num_distinct_primitives > 1 {
53-
let range = 1..(self.num_distinct_primitives as u32);
54-
Some(self.rng.gen_range(range))
55-
} else {
56-
Some(0)
57-
}
58-
})
59-
.collect();
59+
_ => {
60+
let arrow_type = A::DATA_TYPE;
61+
panic!("Unsupported arrow data type: {arrow_type}")
62+
}
63+
}))
64+
.collect();
6065

61-
let options = None;
62-
arrow::compute::take(&distinct_primitives, &indicies, options).unwrap()
63-
}
64-
}
65-
};
66-
}
66+
// pick num_primitves randomly from the distinct string table
67+
let indicies: UInt32Array = (0..self.num_primitives)
68+
.map(|_| {
69+
if self.rng.gen::<f64>() < self.null_pct {
70+
None
71+
} else if self.num_distinct_primitives > 1 {
72+
let range = 1..(self.num_distinct_primitives as u32);
73+
Some(self.rng.gen_range(range))
74+
} else {
75+
Some(0)
76+
}
77+
})
78+
.collect();
6779

68-
// TODO: support generating more primitive arrays
69-
impl PrimitiveArrayGenerator {
70-
impl_gen_data!(i8, Int8Type);
71-
impl_gen_data!(i16, Int16Type);
72-
impl_gen_data!(i32, Int32Type);
73-
impl_gen_data!(i64, Int64Type);
74-
impl_gen_data!(u8, UInt8Type);
75-
impl_gen_data!(u16, UInt16Type);
76-
impl_gen_data!(u32, UInt32Type);
77-
impl_gen_data!(u64, UInt64Type);
78-
impl_gen_data!(f32, Float32Type);
79-
impl_gen_data!(f64, Float64Type);
80+
let options = None;
81+
arrow::compute::take(&distinct_primitives, &indicies, options).unwrap()
82+
}
8083
}

0 commit comments

Comments
 (0)