@@ -279,8 +279,88 @@ pub fn get_data_dir(
279
279
}
280
280
}
281
281
282
+ #[ macro_export]
283
+ macro_rules! create_array {
284
+ ( Boolean , $values: expr) => {
285
+ std:: sync:: Arc :: new( arrow:: array:: BooleanArray :: from( $values) )
286
+ } ;
287
+ ( Int8 , $values: expr) => {
288
+ std:: sync:: Arc :: new( arrow:: array:: Int8Array :: from( $values) )
289
+ } ;
290
+ ( Int16 , $values: expr) => {
291
+ std:: sync:: Arc :: new( arrow:: array:: Int16Array :: from( $values) )
292
+ } ;
293
+ ( Int32 , $values: expr) => {
294
+ std:: sync:: Arc :: new( arrow:: array:: Int32Array :: from( $values) )
295
+ } ;
296
+ ( Int64 , $values: expr) => {
297
+ std:: sync:: Arc :: new( arrow:: array:: Int64Array :: from( $values) )
298
+ } ;
299
+ ( UInt8 , $values: expr) => {
300
+ std:: sync:: Arc :: new( arrow:: array:: UInt8Array :: from( $values) )
301
+ } ;
302
+ ( UInt16 , $values: expr) => {
303
+ std:: sync:: Arc :: new( arrow:: array:: UInt16Array :: from( $values) )
304
+ } ;
305
+ ( UInt32 , $values: expr) => {
306
+ std:: sync:: Arc :: new( arrow:: array:: UInt32Array :: from( $values) )
307
+ } ;
308
+ ( UInt64 , $values: expr) => {
309
+ std:: sync:: Arc :: new( arrow:: array:: UInt64Array :: from( $values) )
310
+ } ;
311
+ ( Float16 , $values: expr) => {
312
+ std:: sync:: Arc :: new( arrow:: array:: Float16Array :: from( $values) )
313
+ } ;
314
+ ( Float32 , $values: expr) => {
315
+ std:: sync:: Arc :: new( arrow:: array:: Float32Array :: from( $values) )
316
+ } ;
317
+ ( Float64 , $values: expr) => {
318
+ std:: sync:: Arc :: new( arrow:: array:: Float64Array :: from( $values) )
319
+ } ;
320
+ ( Utf8 , $values: expr) => {
321
+ std:: sync:: Arc :: new( arrow:: array:: StringArray :: from( $values) )
322
+ } ;
323
+ }
324
+
325
+ /// Creates a record batch from literal slice of values, suitable for rapid
326
+ /// testing and development.
327
+ ///
328
+ /// Example:
329
+ /// ```
330
+ /// use datafusion_common::{record_batch, create_array};
331
+ /// let batch = record_batch!(
332
+ /// ("a", Int32, vec![1, 2, 3]),
333
+ /// ("b", Float64, vec![Some(4.0), None, Some(5.0)]),
334
+ /// ("c", Utf8, vec!["alpha", "beta", "gamma"])
335
+ /// );
336
+ /// ```
337
+ #[ macro_export]
338
+ macro_rules! record_batch {
339
+ ( $( ( $name: expr, $type: ident, $values: expr) ) ,* ) => {
340
+ {
341
+ let schema = std:: sync:: Arc :: new( arrow_schema:: Schema :: new( vec![
342
+ $(
343
+ arrow_schema:: Field :: new( $name, arrow_schema:: DataType :: $type, true ) ,
344
+ ) *
345
+ ] ) ) ;
346
+
347
+ let batch = arrow_array:: RecordBatch :: try_new(
348
+ schema,
349
+ vec![ $(
350
+ create_array!( $type, $values) ,
351
+ ) * ]
352
+ ) ;
353
+
354
+ batch
355
+ }
356
+ }
357
+ }
358
+
282
359
#[ cfg( test) ]
283
360
mod tests {
361
+ use crate :: cast:: { as_float64_array, as_int32_array, as_string_array} ;
362
+ use crate :: error:: Result ;
363
+
284
364
use super :: * ;
285
365
use std:: env;
286
366
@@ -333,4 +413,44 @@ mod tests {
333
413
let res = parquet_test_data ( ) ;
334
414
assert ! ( PathBuf :: from( res) . is_dir( ) ) ;
335
415
}
416
+
417
+ #[ test]
418
+ fn test_create_record_batch ( ) -> Result < ( ) > {
419
+ use arrow_array:: Array ;
420
+
421
+ let batch = record_batch ! (
422
+ ( "a" , Int32 , vec![ 1 , 2 , 3 , 4 ] ) ,
423
+ ( "b" , Float64 , vec![ Some ( 4.0 ) , None , Some ( 5.0 ) , None ] ) ,
424
+ ( "c" , Utf8 , vec![ "alpha" , "beta" , "gamma" , "delta" ] )
425
+ ) ?;
426
+
427
+ assert_eq ! ( 3 , batch. num_columns( ) ) ;
428
+ assert_eq ! ( 4 , batch. num_rows( ) ) ;
429
+
430
+ let values: Vec < _ > = as_int32_array ( batch. column ( 0 ) ) ?
431
+ . values ( )
432
+ . iter ( )
433
+ . map ( |v| v. to_owned ( ) )
434
+ . collect ( ) ;
435
+ assert_eq ! ( values, vec![ 1 , 2 , 3 , 4 ] ) ;
436
+
437
+ let values: Vec < _ > = as_float64_array ( batch. column ( 1 ) ) ?
438
+ . values ( )
439
+ . iter ( )
440
+ . map ( |v| v. to_owned ( ) )
441
+ . collect ( ) ;
442
+ assert_eq ! ( values, vec![ 4.0 , 0.0 , 5.0 , 0.0 ] ) ;
443
+
444
+ let nulls: Vec < _ > = as_float64_array ( batch. column ( 1 ) ) ?
445
+ . nulls ( )
446
+ . unwrap ( )
447
+ . iter ( )
448
+ . collect ( ) ;
449
+ assert_eq ! ( nulls, vec![ true , false , true , false ] ) ;
450
+
451
+ let values: Vec < _ > = as_string_array ( batch. column ( 2 ) ) ?. iter ( ) . flatten ( ) . collect ( ) ;
452
+ assert_eq ! ( values, vec![ "alpha" , "beta" , "gamma" , "delta" ] ) ;
453
+
454
+ Ok ( ( ) )
455
+ }
336
456
}
0 commit comments