@@ -174,11 +174,16 @@ impl Dataset {
174
174
175
175
#[ derive( Debug , Clone ) ]
176
176
pub struct ColumnDescr {
177
- // Column name
177
+ /// Column name
178
178
name : String ,
179
179
180
- // Data type of this column
180
+ /// Data type of this column
181
181
column_type : DataType ,
182
+
183
+ /// The maximum number of distinct values in this column.
184
+ ///
185
+ /// See [`ColumnDescr::with_max_num_distinct`] for more information
186
+ max_num_distinct : Option < usize > ,
182
187
}
183
188
184
189
impl ColumnDescr {
@@ -187,8 +192,18 @@ impl ColumnDescr {
187
192
Self {
188
193
name : name. to_string ( ) ,
189
194
column_type,
195
+ max_num_distinct : None ,
190
196
}
191
197
}
198
+
199
+ /// set the maximum number of distinct values in this column
200
+ ///
201
+ /// If `None`, the number of distinct values is randomly selected between 1
202
+ /// and the number of rows.
203
+ pub fn with_max_num_distinct ( mut self , num_distinct : usize ) -> Self {
204
+ self . max_num_distinct = Some ( num_distinct) ;
205
+ self
206
+ }
192
207
}
193
208
194
209
/// Record batch generator
@@ -203,20 +218,15 @@ struct RecordBatchGenerator {
203
218
}
204
219
205
220
macro_rules! generate_string_array {
206
- ( $SELF: ident, $NUM_ROWS: ident, $BATCH_GEN_RNG: ident, $ARRAY_GEN_RNG: ident, $OFFSET_TYPE: ty) => { {
221
+ ( $SELF: ident, $NUM_ROWS: ident, $MAX_NUM_DISTINCT : expr , $ BATCH_GEN_RNG: ident, $ARRAY_GEN_RNG: ident, $OFFSET_TYPE: ty) => { {
207
222
let null_pct_idx = $BATCH_GEN_RNG. gen_range( 0 ..$SELF. candidate_null_pcts. len( ) ) ;
208
223
let null_pct = $SELF. candidate_null_pcts[ null_pct_idx] ;
209
224
let max_len = $BATCH_GEN_RNG. gen_range( 1 ..50 ) ;
210
- let num_distinct_strings = if $NUM_ROWS > 1 {
211
- $BATCH_GEN_RNG. gen_range( 1 ..$NUM_ROWS)
212
- } else {
213
- $NUM_ROWS
214
- } ;
215
225
216
226
let mut generator = StringArrayGenerator {
217
227
max_len,
218
228
num_strings: $NUM_ROWS,
219
- num_distinct_strings,
229
+ num_distinct_strings: $MAX_NUM_DISTINCT ,
220
230
null_pct,
221
231
rng: $ARRAY_GEN_RNG,
222
232
} ;
@@ -226,19 +236,14 @@ macro_rules! generate_string_array {
226
236
}
227
237
228
238
macro_rules! generate_primitive_array {
229
- ( $SELF: ident, $NUM_ROWS: ident, $BATCH_GEN_RNG: ident, $ARRAY_GEN_RNG: ident, $ARROW_TYPE: ident) => {
239
+ ( $SELF: ident, $NUM_ROWS: ident, $MAX_NUM_DISTINCT : expr , $ BATCH_GEN_RNG: ident, $ARRAY_GEN_RNG: ident, $ARROW_TYPE: ident) => {
230
240
paste:: paste! { {
231
241
let null_pct_idx = $BATCH_GEN_RNG. gen_range( 0 ..$SELF. candidate_null_pcts. len( ) ) ;
232
242
let null_pct = $SELF. candidate_null_pcts[ null_pct_idx] ;
233
- let num_distinct_primitives = if $NUM_ROWS > 1 {
234
- $BATCH_GEN_RNG. gen_range( 1 ..$NUM_ROWS)
235
- } else {
236
- $NUM_ROWS
237
- } ;
238
243
239
244
let mut generator = PrimitiveArrayGenerator {
240
245
num_primitives: $NUM_ROWS,
241
- num_distinct_primitives,
246
+ num_distinct_primitives: $MAX_NUM_DISTINCT ,
242
247
null_pct,
243
248
rng: $ARRAY_GEN_RNG,
244
249
} ;
@@ -268,7 +273,7 @@ impl RecordBatchGenerator {
268
273
let mut arrays = Vec :: with_capacity ( self . columns . len ( ) ) ;
269
274
for col in self . columns . iter ( ) {
270
275
let array = self . generate_array_of_type (
271
- col. column_type . clone ( ) ,
276
+ col,
272
277
num_rows,
273
278
& mut rng,
274
279
array_gen_rng. clone ( ) ,
@@ -289,16 +294,28 @@ impl RecordBatchGenerator {
289
294
290
295
fn generate_array_of_type (
291
296
& self ,
292
- data_type : DataType ,
297
+ col : & ColumnDescr ,
293
298
num_rows : usize ,
294
299
batch_gen_rng : & mut ThreadRng ,
295
300
array_gen_rng : StdRng ,
296
301
) -> ArrayRef {
297
- match data_type {
302
+ let num_distinct = if num_rows > 1 {
303
+ batch_gen_rng. gen_range ( 1 ..num_rows)
304
+ } else {
305
+ num_rows
306
+ } ;
307
+ // cap to at most the num_distinct values
308
+ let max_num_distinct = col
309
+ . max_num_distinct
310
+ . map ( |max| num_distinct. min ( max) )
311
+ . unwrap_or ( num_distinct) ;
312
+
313
+ match col. column_type {
298
314
DataType :: Int8 => {
299
315
generate_primitive_array ! (
300
316
self ,
301
317
num_rows,
318
+ max_num_distinct,
302
319
batch_gen_rng,
303
320
array_gen_rng,
304
321
Int8Type
@@ -308,6 +325,7 @@ impl RecordBatchGenerator {
308
325
generate_primitive_array ! (
309
326
self ,
310
327
num_rows,
328
+ max_num_distinct,
311
329
batch_gen_rng,
312
330
array_gen_rng,
313
331
Int16Type
@@ -317,6 +335,7 @@ impl RecordBatchGenerator {
317
335
generate_primitive_array ! (
318
336
self ,
319
337
num_rows,
338
+ max_num_distinct,
320
339
batch_gen_rng,
321
340
array_gen_rng,
322
341
Int32Type
@@ -326,6 +345,7 @@ impl RecordBatchGenerator {
326
345
generate_primitive_array ! (
327
346
self ,
328
347
num_rows,
348
+ max_num_distinct,
329
349
batch_gen_rng,
330
350
array_gen_rng,
331
351
Int64Type
@@ -335,6 +355,7 @@ impl RecordBatchGenerator {
335
355
generate_primitive_array ! (
336
356
self ,
337
357
num_rows,
358
+ max_num_distinct,
338
359
batch_gen_rng,
339
360
array_gen_rng,
340
361
UInt8Type
@@ -344,6 +365,7 @@ impl RecordBatchGenerator {
344
365
generate_primitive_array ! (
345
366
self ,
346
367
num_rows,
368
+ max_num_distinct,
347
369
batch_gen_rng,
348
370
array_gen_rng,
349
371
UInt16Type
@@ -353,6 +375,7 @@ impl RecordBatchGenerator {
353
375
generate_primitive_array ! (
354
376
self ,
355
377
num_rows,
378
+ max_num_distinct,
356
379
batch_gen_rng,
357
380
array_gen_rng,
358
381
UInt32Type
@@ -362,6 +385,7 @@ impl RecordBatchGenerator {
362
385
generate_primitive_array ! (
363
386
self ,
364
387
num_rows,
388
+ max_num_distinct,
365
389
batch_gen_rng,
366
390
array_gen_rng,
367
391
UInt64Type
@@ -371,6 +395,7 @@ impl RecordBatchGenerator {
371
395
generate_primitive_array ! (
372
396
self ,
373
397
num_rows,
398
+ max_num_distinct,
374
399
batch_gen_rng,
375
400
array_gen_rng,
376
401
Float32Type
@@ -380,6 +405,7 @@ impl RecordBatchGenerator {
380
405
generate_primitive_array ! (
381
406
self ,
382
407
num_rows,
408
+ max_num_distinct,
383
409
batch_gen_rng,
384
410
array_gen_rng,
385
411
Float64Type
@@ -389,6 +415,7 @@ impl RecordBatchGenerator {
389
415
generate_primitive_array ! (
390
416
self ,
391
417
num_rows,
418
+ max_num_distinct,
392
419
batch_gen_rng,
393
420
array_gen_rng,
394
421
Date32Type
@@ -398,19 +425,34 @@ impl RecordBatchGenerator {
398
425
generate_primitive_array ! (
399
426
self ,
400
427
num_rows,
428
+ max_num_distinct,
401
429
batch_gen_rng,
402
430
array_gen_rng,
403
431
Date64Type
404
432
)
405
433
}
406
434
DataType :: Utf8 => {
407
- generate_string_array ! ( self , num_rows, batch_gen_rng, array_gen_rng, i32 )
435
+ generate_string_array ! (
436
+ self ,
437
+ num_rows,
438
+ max_num_distinct,
439
+ batch_gen_rng,
440
+ array_gen_rng,
441
+ i32
442
+ )
408
443
}
409
444
DataType :: LargeUtf8 => {
410
- generate_string_array ! ( self , num_rows, batch_gen_rng, array_gen_rng, i64 )
445
+ generate_string_array ! (
446
+ self ,
447
+ num_rows,
448
+ max_num_distinct,
449
+ batch_gen_rng,
450
+ array_gen_rng,
451
+ i64
452
+ )
411
453
}
412
454
_ => {
413
- panic ! ( "Unsupported data generator type: {data_type}" )
455
+ panic ! ( "Unsupported data generator type: {}" , col . column_type )
414
456
}
415
457
}
416
458
}
@@ -435,14 +477,8 @@ mod test {
435
477
// - Their rows num should be same and between [16, 32]
436
478
let config = DatasetGeneratorConfig {
437
479
columns : vec ! [
438
- ColumnDescr {
439
- name: "a" . to_string( ) ,
440
- column_type: DataType :: Utf8 ,
441
- } ,
442
- ColumnDescr {
443
- name: "b" . to_string( ) ,
444
- column_type: DataType :: UInt32 ,
445
- } ,
480
+ ColumnDescr :: new( "a" , DataType :: Utf8 ) ,
481
+ ColumnDescr :: new( "b" , DataType :: UInt32 ) ,
446
482
] ,
447
483
rows_num_range : ( 16 , 32 ) ,
448
484
sort_keys_set : vec ! [ vec![ "b" . to_string( ) ] ] ,
0 commit comments