diff --git a/examples/parquet_rw.rs b/examples/parquet_rw.rs index 968a1e0..7e646e6 100644 --- a/examples/parquet_rw.rs +++ b/examples/parquet_rw.rs @@ -1,6 +1,6 @@ use std::path::PathBuf; -use arrow::{array::UInt32Array, datatypes::{DataType, Field, Float64Type, Int32Type, Schema, UInt32Type}}; +use arrow::datatypes::{DataType, Field, Float64Type, Schema, UInt32Type, Utf8Type}; use renoir::prelude::*; use serde::{Deserialize, Serialize}; @@ -8,6 +8,7 @@ use serde::{Deserialize, Serialize}; struct Output { pub value: u32, pub root: f64, + pub text: String, } impl Output { @@ -15,10 +16,18 @@ impl Output { Schema::new(vec![ Field::new("value", DataType::UInt32, false), Field::new("root", DataType::Float64, false), + Field::new("text", DataType::Utf8, false), ]) } } +fn gen(i: u32) -> Output { + Output { + value: i, + root: (i as f64).sqrt(), + text: format!("{:x}", i * i * i), + } +} fn main() { let conf = RuntimeConfig::local(4).unwrap(); @@ -38,7 +47,7 @@ fn main() { // Write to multiple files in parallel let path = dir_path.clone(); ctx.stream_par_iter(0..100u32) - .map(|i| Output { value: i, root: (i as f64).sqrt()}) + .map(gen) .write_parquet_seq(path, Output::schema()); ctx.execute_blocking(); @@ -48,19 +57,18 @@ fn main() { let mut path = dir_path.clone(); path.push("one.parquet"); ctx.stream_par_iter(0..100u32) - .map(|i| Output { value: i, root: (i as f64).sqrt()}) + .map(gen) .write_parquet_one(path, Output::schema()); ctx.execute_blocking(); - - eprintln!("Reading from parquet is not supported yet."); + eprintln!("Reading from parquet"); let ctx = StreamContext::new(conf); let mut path = dir_path; path.push("one.parquet"); ctx.stream_parquet_one(path) - .to_rows::<(UInt32Type, Float64Type)>() + .to_rows::<(UInt32Type, Float64Type, Utf8Type)>() .for_each(|t| println!("{t:?}")); ctx.execute_blocking(); diff --git a/src/operator/source/parquet.rs b/src/operator/source/parquet.rs index 66f5756..3ab7663 100644 --- a/src/operator/source/parquet.rs +++ b/src/operator/source/parquet.rs @@ -1,6 +1,9 @@ use std::{fs::File, path::PathBuf}; -use arrow::array::{cast::AsArray, types::ArrowPrimitiveType, Array, RecordBatch}; +use arrow::{ + array::{cast::AsArray, types::ArrowPrimitiveType, Array, RecordBatch}, + datatypes::*, +}; use parquet::arrow::arrow_reader::{ArrowReaderBuilder, ParquetRecordBatchReader}; use crate::{ @@ -111,7 +114,7 @@ where } } -pub trait FromRecordBatchRow: Sized { +pub trait FromRecordBatchRow { type Native; fn from_record_batch_row( batch: &RecordBatch, @@ -120,38 +123,114 @@ pub trait FromRecordBatchRow: Sized { fn is_compatible(batch: &RecordBatch) -> bool; } -// Helper function to get a value from a column -fn get_value( - batch: &RecordBatch, - col: usize, - row: usize, -) -> Result -where - T::Native: Clone, -{ - batch - .column(col) - .as_primitive_opt::() - .ok_or_else(|| FromRecordBatchError::IncompatibleTypes(col)) - .map(|array| array.value(row).clone()) +pub trait ArrowTypeCast { + type Native; + const DATA_TYPE: DataType; + + fn value_at( + batch: &RecordBatch, + col: usize, + row: usize, + ) -> Result; +} + +macro_rules! impl_arrow_cast_primitive { + ($t:ty) => { + impl ArrowTypeCast for $t { + type Native = <$t as ArrowPrimitiveType>::Native; + const DATA_TYPE: DataType = <$t as ArrowPrimitiveType>::DATA_TYPE; + + #[inline] + fn value_at( + batch: &RecordBatch, + col: usize, + row: usize, + ) -> Result { + batch + .column(col) + .as_primitive_opt::<$t>() + .ok_or_else(|| FromRecordBatchError::IncompatibleTypes(col)) + .map(|array| array.value(row).clone()) + } + } + }; } -macro_rules! from_record_batch_tuple { +impl_arrow_cast_primitive!(Date32Type); +impl_arrow_cast_primitive!(Date64Type); +impl_arrow_cast_primitive!(Decimal128Type); +impl_arrow_cast_primitive!(Decimal256Type); +impl_arrow_cast_primitive!(DurationMicrosecondType); +impl_arrow_cast_primitive!(DurationMillisecondType); +impl_arrow_cast_primitive!(DurationNanosecondType); +impl_arrow_cast_primitive!(DurationSecondType); +impl_arrow_cast_primitive!(Float16Type); +impl_arrow_cast_primitive!(Float32Type); +impl_arrow_cast_primitive!(Float64Type); +impl_arrow_cast_primitive!(Int8Type); +impl_arrow_cast_primitive!(Int16Type); +impl_arrow_cast_primitive!(Int32Type); +impl_arrow_cast_primitive!(Int64Type); +impl_arrow_cast_primitive!(IntervalDayTimeType); +impl_arrow_cast_primitive!(IntervalMonthDayNanoType); +impl_arrow_cast_primitive!(IntervalYearMonthType); +impl_arrow_cast_primitive!(Time32MillisecondType); +impl_arrow_cast_primitive!(Time32SecondType); +impl_arrow_cast_primitive!(Time64MicrosecondType); +impl_arrow_cast_primitive!(Time64NanosecondType); +impl_arrow_cast_primitive!(TimestampMicrosecondType); +impl_arrow_cast_primitive!(TimestampMillisecondType); +impl_arrow_cast_primitive!(TimestampNanosecondType); +impl_arrow_cast_primitive!(TimestampSecondType); +impl_arrow_cast_primitive!(UInt8Type); +impl_arrow_cast_primitive!(UInt16Type); +impl_arrow_cast_primitive!(UInt32Type); +impl_arrow_cast_primitive!(UInt64Type); + +macro_rules! impl_arrow_cast_bytes { + ($t:ty) => { + impl ArrowTypeCast for $t { + type Native = <<$t as ByteArrayType>::Native as ToOwned>::Owned; + const DATA_TYPE: DataType = <$t as ByteArrayType>::DATA_TYPE; + + #[inline] + fn value_at( + batch: &RecordBatch, + col: usize, + row: usize, + ) -> Result { + batch + .column(col) + .as_bytes_opt::<$t>() + .ok_or_else(|| FromRecordBatchError::IncompatibleTypes(col)) + .map(|array| array.value(row).to_owned()) + } + } + }; +} + +impl_arrow_cast_bytes!(Utf8Type); +impl_arrow_cast_bytes!(LargeUtf8Type); +impl_arrow_cast_bytes!(BinaryType); + +macro_rules! impl_from_record_batch_tuple { ($($id:ident, )+) => { // Implementations for tuples - impl<$($id: ArrowPrimitiveType, )+> FromRecordBatchRow for ($($id, )+) + impl<$($id: ArrowTypeCast, )+> FromRecordBatchRow for ($($id, )+) { type Native = ($($id::Native, )+); + #[inline] fn from_record_batch_row(batch: &RecordBatch, row: usize) -> Result { let mut idx = 0; Ok(( - $(get_value::<$id>(batch, { idx += 1; idx - 1}, row)?,)+ + $($id::value_at(batch, { idx += 1; idx - 1}, row)?,)+ )) } + #[inline] fn is_compatible(batch: &RecordBatch) -> bool { let count = const { let mut cnt = 0; @@ -170,19 +249,19 @@ macro_rules! from_record_batch_tuple { }; } -from_record_batch_tuple!(A0,); -from_record_batch_tuple!(A0, A1,); -from_record_batch_tuple!(A0, A1, A2,); -from_record_batch_tuple!(A0, A1, A2, A3,); -from_record_batch_tuple!(A0, A1, A2, A3, A4,); -from_record_batch_tuple!(A0, A1, A2, A3, A4, A5,); -from_record_batch_tuple!(A0, A1, A2, A3, A4, A5, A6,); -from_record_batch_tuple!(A0, A1, A2, A3, A4, A5, A6, A7,); -from_record_batch_tuple!(A0, A1, A2, A3, A4, A5, A6, A7, A8,); -from_record_batch_tuple!(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9,); -from_record_batch_tuple!(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10,); -from_record_batch_tuple!(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11,); -from_record_batch_tuple!(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12,); +impl_from_record_batch_tuple!(A0,); +impl_from_record_batch_tuple!(A0, A1,); +impl_from_record_batch_tuple!(A0, A1, A2,); +impl_from_record_batch_tuple!(A0, A1, A2, A3,); +impl_from_record_batch_tuple!(A0, A1, A2, A3, A4,); +impl_from_record_batch_tuple!(A0, A1, A2, A3, A4, A5,); +impl_from_record_batch_tuple!(A0, A1, A2, A3, A4, A5, A6,); +impl_from_record_batch_tuple!(A0, A1, A2, A3, A4, A5, A6, A7,); +impl_from_record_batch_tuple!(A0, A1, A2, A3, A4, A5, A6, A7, A8,); +impl_from_record_batch_tuple!(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9,); +impl_from_record_batch_tuple!(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10,); +impl_from_record_batch_tuple!(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11,); +impl_from_record_batch_tuple!(A0, A1, A2, A3, A4, A5, A6, A7, A8, A9, A10, A11, A12,); #[derive(Debug, thiserror::Error)] pub enum FromRecordBatchError {