Skip to content

Commit ebfc155

Browse files
authored
Macro for creating record batch from literal slice (#12846)
* Add macro for creating record batch, useful for unit test or rapid development * Update docstring * Add additional checks in unit test and rename macro per user input
1 parent 1582e8d commit ebfc155

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

datafusion/common/src/test_util.rs

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,8 +279,88 @@ pub fn get_data_dir(
279279
}
280280
}
281281

282+
#[macro_export]
283+
macro_rules! create_array {
284+
(Boolean, $values: expr) => {
285+
std::sync::Arc::new(arrow::array::BooleanArray::from($values))
286+
};
287+
(Int8, $values: expr) => {
288+
std::sync::Arc::new(arrow::array::Int8Array::from($values))
289+
};
290+
(Int16, $values: expr) => {
291+
std::sync::Arc::new(arrow::array::Int16Array::from($values))
292+
};
293+
(Int32, $values: expr) => {
294+
std::sync::Arc::new(arrow::array::Int32Array::from($values))
295+
};
296+
(Int64, $values: expr) => {
297+
std::sync::Arc::new(arrow::array::Int64Array::from($values))
298+
};
299+
(UInt8, $values: expr) => {
300+
std::sync::Arc::new(arrow::array::UInt8Array::from($values))
301+
};
302+
(UInt16, $values: expr) => {
303+
std::sync::Arc::new(arrow::array::UInt16Array::from($values))
304+
};
305+
(UInt32, $values: expr) => {
306+
std::sync::Arc::new(arrow::array::UInt32Array::from($values))
307+
};
308+
(UInt64, $values: expr) => {
309+
std::sync::Arc::new(arrow::array::UInt64Array::from($values))
310+
};
311+
(Float16, $values: expr) => {
312+
std::sync::Arc::new(arrow::array::Float16Array::from($values))
313+
};
314+
(Float32, $values: expr) => {
315+
std::sync::Arc::new(arrow::array::Float32Array::from($values))
316+
};
317+
(Float64, $values: expr) => {
318+
std::sync::Arc::new(arrow::array::Float64Array::from($values))
319+
};
320+
(Utf8, $values: expr) => {
321+
std::sync::Arc::new(arrow::array::StringArray::from($values))
322+
};
323+
}
324+
325+
/// Creates a record batch from literal slice of values, suitable for rapid
326+
/// testing and development.
327+
///
328+
/// Example:
329+
/// ```
330+
/// use datafusion_common::{record_batch, create_array};
331+
/// let batch = record_batch!(
332+
/// ("a", Int32, vec![1, 2, 3]),
333+
/// ("b", Float64, vec![Some(4.0), None, Some(5.0)]),
334+
/// ("c", Utf8, vec!["alpha", "beta", "gamma"])
335+
/// );
336+
/// ```
337+
#[macro_export]
338+
macro_rules! record_batch {
339+
($(($name: expr, $type: ident, $values: expr)),*) => {
340+
{
341+
let schema = std::sync::Arc::new(arrow_schema::Schema::new(vec![
342+
$(
343+
arrow_schema::Field::new($name, arrow_schema::DataType::$type, true),
344+
)*
345+
]));
346+
347+
let batch = arrow_array::RecordBatch::try_new(
348+
schema,
349+
vec![$(
350+
create_array!($type, $values),
351+
)*]
352+
);
353+
354+
batch
355+
}
356+
}
357+
}
358+
282359
#[cfg(test)]
283360
mod tests {
361+
use crate::cast::{as_float64_array, as_int32_array, as_string_array};
362+
use crate::error::Result;
363+
284364
use super::*;
285365
use std::env;
286366

@@ -333,4 +413,44 @@ mod tests {
333413
let res = parquet_test_data();
334414
assert!(PathBuf::from(res).is_dir());
335415
}
416+
417+
#[test]
418+
fn test_create_record_batch() -> Result<()> {
419+
use arrow_array::Array;
420+
421+
let batch = record_batch!(
422+
("a", Int32, vec![1, 2, 3, 4]),
423+
("b", Float64, vec![Some(4.0), None, Some(5.0), None]),
424+
("c", Utf8, vec!["alpha", "beta", "gamma", "delta"])
425+
)?;
426+
427+
assert_eq!(3, batch.num_columns());
428+
assert_eq!(4, batch.num_rows());
429+
430+
let values: Vec<_> = as_int32_array(batch.column(0))?
431+
.values()
432+
.iter()
433+
.map(|v| v.to_owned())
434+
.collect();
435+
assert_eq!(values, vec![1, 2, 3, 4]);
436+
437+
let values: Vec<_> = as_float64_array(batch.column(1))?
438+
.values()
439+
.iter()
440+
.map(|v| v.to_owned())
441+
.collect();
442+
assert_eq!(values, vec![4.0, 0.0, 5.0, 0.0]);
443+
444+
let nulls: Vec<_> = as_float64_array(batch.column(1))?
445+
.nulls()
446+
.unwrap()
447+
.iter()
448+
.collect();
449+
assert_eq!(nulls, vec![true, false, true, false]);
450+
451+
let values: Vec<_> = as_string_array(batch.column(2))?.iter().flatten().collect();
452+
assert_eq!(values, vec!["alpha", "beta", "gamma", "delta"]);
453+
454+
Ok(())
455+
}
336456
}

0 commit comments

Comments
 (0)