diff --git a/src/arrow_reader/column/mod.rs b/src/arrow_reader/column/mod.rs index 36741328..87da2e94 100644 --- a/src/arrow_reader/column/mod.rs +++ b/src/arrow_reader/column/mod.rs @@ -1,10 +1,12 @@ use std::sync::Arc; use arrow::datatypes::Field; +use arrow::datatypes::{DataType as ArrowDataType, TimeUnit, UnionMode}; use bytes::Bytes; use snafu::ResultExt; use crate::error::{IoSnafu, Result}; +use crate::proto::column_encoding::Kind as ColumnEncodingKind; use crate::proto::stream::Kind; use crate::proto::{ColumnEncoding, StripeFooter}; use crate::reader::decode::boolean_rle::BooleanIter; @@ -26,14 +28,14 @@ pub struct Column { impl From for Field { fn from(value: Column) -> Self { - let dt = value.data_type.to_arrow_data_type(); + let dt = value.arrow_data_type(); Field::new(value.name, dt, true) } } impl From<&Column> for Field { fn from(value: &Column) -> Self { - let dt = value.data_type.to_arrow_data_type(); + let dt = value.arrow_data_type(); Field::new(value.name.clone(), dt, true) } } @@ -69,6 +71,84 @@ impl Column { &self.data_type } + pub fn arrow_data_type(&self) -> ArrowDataType { + let value_type = match self.data_type { + DataType::Boolean { .. } => ArrowDataType::Boolean, + DataType::Byte { .. } => ArrowDataType::Int8, + DataType::Short { .. } => ArrowDataType::Int16, + DataType::Int { .. } => ArrowDataType::Int32, + DataType::Long { .. } => ArrowDataType::Int64, + DataType::Float { .. } => ArrowDataType::Float32, + DataType::Double { .. } => ArrowDataType::Float64, + DataType::String { .. } | DataType::Varchar { .. } | DataType::Char { .. } => { + ArrowDataType::Utf8 + } + DataType::Binary { .. } => ArrowDataType::Binary, + DataType::Decimal { + precision, scale, .. + } => ArrowDataType::Decimal128(precision as u8, scale as i8), + DataType::Timestamp { .. } => ArrowDataType::Timestamp(TimeUnit::Nanosecond, None), + DataType::TimestampWithLocalTimezone { .. } => { + // TODO: get writer timezone + ArrowDataType::Timestamp(TimeUnit::Nanosecond, None) + } + DataType::Date { .. } => ArrowDataType::Date32, + DataType::Struct { .. } => { + let children = self + .children() + .into_iter() + .map(|col| { + let dt = col.arrow_data_type(); + Field::new(col.name(), dt, true) + }) + .collect(); + ArrowDataType::Struct(children) + } + DataType::List { .. } => { + let children = self.children(); + assert_eq!(children.len(), 1); + ArrowDataType::new_list(children[0].arrow_data_type(), true) + } + DataType::Map { .. } => { + let children = self.children(); + assert_eq!(children.len(), 2); + let key = &children[0]; + let key = key.arrow_data_type(); + let key = Field::new("key", key, false); + let value = &children[1]; + let value = value.arrow_data_type(); + let value = Field::new("value", value, true); + + let dt = ArrowDataType::Struct(vec![key, value].into()); + let dt = Arc::new(Field::new("entries", dt, true)); + ArrowDataType::Map(dt, false) + } + DataType::Union { .. } => { + let fields = self + .children() + .iter() + .enumerate() + .map(|(index, variant)| { + // Should be safe as limited to 256 variants total (in from_proto) + let index = index as u8 as i8; + let arrow_dt = variant.arrow_data_type(); + // Name shouldn't matter here (only ORC struct types give names to subtypes anyway) + let field = Arc::new(Field::new(format!("{index}"), arrow_dt, true)); + (index, field) + }) + .collect(); + ArrowDataType::Union(fields, UnionMode::Sparse) + } + }; + + match self.encoding().kind() { + ColumnEncodingKind::Direct | ColumnEncodingKind::DirectV2 => value_type, + ColumnEncodingKind::Dictionary | ColumnEncodingKind::DictionaryV2 => { + ArrowDataType::Dictionary(Box::new(ArrowDataType::UInt64), Box::new(value_type)) + } + } + } + pub fn name(&self) -> &str { &self.name } diff --git a/src/arrow_reader/decoder/map.rs b/src/arrow_reader/decoder/map.rs index 492d9bf5..4fcf88fe 100644 --- a/src/arrow_reader/decoder/map.rs +++ b/src/arrow_reader/decoder/map.rs @@ -36,13 +36,9 @@ impl MapArrayDecoder { let reader = stripe.stream_map.get(column, Kind::Length)?; let lengths = get_rle_reader(column, reader)?; - let keys_field = Field::new("keys", keys_column.data_type().to_arrow_data_type(), false); + let keys_field = Field::new("keys", keys_column.arrow_data_type(), false); let keys_field = Arc::new(keys_field); - let values_field = Field::new( - "values", - values_column.data_type().to_arrow_data_type(), - true, - ); + let values_field = Field::new("values", values_column.arrow_data_type(), true); let values_field = Arc::new(values_field); let fields = Fields::from(vec![keys_field, values_field]); diff --git a/src/arrow_reader/decoder/mod.rs b/src/arrow_reader/decoder/mod.rs index 0febbdb0..7edd4f98 100644 --- a/src/arrow_reader/decoder/mod.rs +++ b/src/arrow_reader/decoder/mod.rs @@ -4,7 +4,7 @@ use arrow::array::{ArrayRef, BooleanArray, BooleanBuilder, PrimitiveArray, Primi use arrow::buffer::NullBuffer; use arrow::datatypes::{ArrowPrimitiveType, UInt64Type}; use arrow::datatypes::{ - Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef, + Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, TimestampNanosecondType, }; use arrow::record_batch::RecordBatch; @@ -230,7 +230,6 @@ fn create_null_buffer(present: Option>) -> Option { pub struct NaiveStripeDecoder { stripe: Stripe, - schema_ref: SchemaRef, decoders: Vec>, index: usize, batch_size: usize, @@ -388,10 +387,10 @@ impl NaiveStripeDecoder { } else { //TODO(weny): any better way? let fields = self - .schema_ref - .fields - .into_iter() - .map(|field| field.name()) + .stripe + .columns + .iter() + .map(|col| col.name()) .zip(fields) .collect::>(); @@ -401,7 +400,7 @@ impl NaiveStripeDecoder { } } - pub fn new(stripe: Stripe, schema_ref: SchemaRef, batch_size: usize) -> Result { + pub fn new(stripe: Stripe, batch_size: usize) -> Result { let mut decoders = Vec::with_capacity(stripe.columns.len()); let number_of_rows = stripe.number_of_rows; @@ -412,7 +411,6 @@ impl NaiveStripeDecoder { Ok(Self { stripe, - schema_ref, decoders, index: 0, batch_size, diff --git a/src/arrow_reader/decoder/string.rs b/src/arrow_reader/decoder/string.rs index 74fbe485..0a3528b2 100644 --- a/src/arrow_reader/decoder/string.rs +++ b/src/arrow_reader/decoder/string.rs @@ -143,6 +143,7 @@ impl ArrayBatchDecoder for GenericByteArrayDecoder { batch_size: usize, parent_present: Option<&[bool]>, ) -> Result { + println!("GenericByteArrayDecoder::next_batch"); let array = self.next_byte_batch(batch_size, parent_present)?; let array = Arc::new(array) as ArrayRef; Ok(array) @@ -169,6 +170,7 @@ impl ArrayBatchDecoder for DictionaryStringArrayDecoder { batch_size: usize, parent_present: Option<&[bool]>, ) -> Result { + println!("DictionaryStringArrayDecoder::next_batch"); let keys = self .indexes .next_primitive_batch(batch_size, parent_present)?; diff --git a/src/arrow_reader/decoder/struct_decoder.rs b/src/arrow_reader/decoder/struct_decoder.rs index b6f38426..48bac5a0 100644 --- a/src/arrow_reader/decoder/struct_decoder.rs +++ b/src/arrow_reader/decoder/struct_decoder.rs @@ -36,7 +36,12 @@ impl StructArrayDecoder { let fields = column .children() .into_iter() - .map(Field::from) + .map(|col| { + println!("col {:#?}", col); + let field = Field::from(col); + println!("field {:?}", field); + field + }) .map(Arc::new) .collect::>(); let fields = Fields::from(fields); @@ -64,6 +69,10 @@ impl ArrayBatchDecoder for StructArrayDecoder { .collect::>>()?; let null_buffer = present.map(NullBuffer::from); + println!( + "next batch fields = {:?}, child_arrays = {:?}, nulls = {:?}", + self.fields, child_arrays, null_buffer + ); let array = StructArray::try_new(self.fields.clone(), child_arrays, null_buffer) .context(ArrowSnafu)?; let array = Arc::new(array); diff --git a/src/arrow_reader/mod.rs b/src/arrow_reader/mod.rs index 5e5c6346..d93a23c7 100644 --- a/src/arrow_reader/mod.rs +++ b/src/arrow_reader/mod.rs @@ -1,9 +1,7 @@ -use std::collections::HashMap; use std::sync::Arc; -use arrow::datatypes::{Schema, SchemaRef}; use arrow::error::ArrowError; -use arrow::record_batch::{RecordBatch, RecordBatchReader}; +use arrow::record_batch::RecordBatch; pub use self::decoder::NaiveStripeDecoder; use crate::error::Result; @@ -73,10 +71,8 @@ impl ArrowReaderBuilder { projected_data_type, stripe_index: 0, }; - let schema_ref = Arc::new(create_arrow_schema(&cursor)); ArrowReader { cursor, - schema_ref, current_stripe: None, batch_size: self.batch_size, } @@ -101,14 +97,12 @@ impl ArrowReaderBuilder { projected_data_type, stripe_index: 0, }; - let schema_ref = Arc::new(create_arrow_schema(&cursor)); - ArrowStreamReader::new(cursor, self.batch_size, schema_ref) + ArrowStreamReader::new(cursor, self.batch_size) } } pub struct ArrowReader { cursor: Cursor, - schema_ref: SchemaRef, current_stripe: Option> + Send>>, batch_size: usize, } @@ -124,8 +118,7 @@ impl ArrowReader { let stripe = self.cursor.next().transpose()?; match stripe { Some(stripe) => { - let decoder = - NaiveStripeDecoder::new(stripe, self.schema_ref.clone(), self.batch_size)?; + let decoder = NaiveStripeDecoder::new(stripe, self.batch_size)?; self.current_stripe = Some(Box::new(decoder)); self.next().transpose() } @@ -134,22 +127,6 @@ impl ArrowReader { } } -pub fn create_arrow_schema(cursor: &Cursor) -> Schema { - let metadata = cursor - .file_metadata - .user_custom_metadata() - .iter() - .map(|(key, value)| (key.clone(), String::from_utf8_lossy(value).to_string())) - .collect::>(); - cursor.projected_data_type.create_arrow_schema(&metadata) -} - -impl RecordBatchReader for ArrowReader { - fn schema(&self) -> SchemaRef { - self.schema_ref.clone() - } -} - impl Iterator for ArrowReader { type Item = std::result::Result; diff --git a/src/async_arrow_reader.rs b/src/async_arrow_reader.rs index 16d191a0..3529e654 100644 --- a/src/async_arrow_reader.rs +++ b/src/async_arrow_reader.rs @@ -4,7 +4,6 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use arrow::datatypes::SchemaRef; use arrow::error::ArrowError; use arrow::record_batch::RecordBatch; use futures::future::BoxFuture; @@ -61,7 +60,6 @@ pub struct StripeFactory { pub struct ArrowStreamReader { factory: Option>>, batch_size: usize, - schema_ref: SchemaRef, state: StreamState, } @@ -107,19 +105,14 @@ impl StripeFactory { } impl ArrowStreamReader { - pub fn new(cursor: Cursor, batch_size: usize, schema_ref: SchemaRef) -> Self { + pub fn new(cursor: Cursor, batch_size: usize) -> Self { Self { factory: Some(Box::new(cursor.into())), batch_size, - schema_ref, state: StreamState::Init, } } - pub fn schema(&self) -> SchemaRef { - self.schema_ref.clone() - } - fn poll_next_inner( mut self: Pin<&mut Self>, cx: &mut Context<'_>, @@ -149,11 +142,7 @@ impl ArrowStreamReader { StreamState::Reading(f) => match ready!(f.poll_unpin(cx)) { Ok((factory, Some(stripe))) => { self.factory = Some(Box::new(factory)); - match NaiveStripeDecoder::new( - stripe, - self.schema_ref.clone(), - self.batch_size, - ) { + match NaiveStripeDecoder::new(stripe, self.batch_size) { Ok(decoder) => { self.state = StreamState::Decoding(Box::new(decoder)); } diff --git a/src/schema.rs b/src/schema.rs index ae8c96e7..07db7b84 100644 --- a/src/schema.rs +++ b/src/schema.rs @@ -1,6 +1,4 @@ -use std::collections::HashMap; use std::fmt::Display; -use std::sync::Arc; use snafu::{ensure, OptionExt}; @@ -8,8 +6,6 @@ use crate::error::{NoTypesSnafu, Result, UnexpectedSnafu}; use crate::projection::ProjectionMask; use crate::proto; -use arrow::datatypes::{DataType as ArrowDataType, Field, Schema, TimeUnit, UnionMode}; - /// Represents the root data type of the ORC file. Contains multiple named child types /// which map to the columns available. Allows projecting only specific columns from /// the base schema. @@ -37,19 +33,6 @@ impl RootDataType { &self.children } - /// Convert into an Arrow schema. - pub fn create_arrow_schema(&self, user_metadata: &HashMap) -> Schema { - let fields = self - .children - .iter() - .map(|col| { - let dt = col.data_type().to_arrow_data_type(); - Field::new(col.name(), dt, true) - }) - .collect::>(); - Schema::new_with_metadata(fields, user_metadata.clone()) - } - /// Create new root data type based on mask of columns to project. pub fn project(&self, mask: &ProjectionMask) -> Self { // TODO: fix logic here to account for nested projection @@ -304,7 +287,10 @@ impl DataType { Kind::Long => Self::Long { column_index }, Kind::Float => Self::Float { column_index }, Kind::Double => Self::Double { column_index }, - Kind::String => Self::String { column_index }, + Kind::String => { + println!("{:?} to String", ty); + Self::String { column_index } + } Kind::Binary => Self::Binary { column_index }, Kind::Timestamp => Self::Timestamp { column_index }, Kind::List => { @@ -395,70 +381,6 @@ impl DataType { }; Ok(dt) } - - pub fn to_arrow_data_type(&self) -> ArrowDataType { - match self { - DataType::Boolean { .. } => ArrowDataType::Boolean, - DataType::Byte { .. } => ArrowDataType::Int8, - DataType::Short { .. } => ArrowDataType::Int16, - DataType::Int { .. } => ArrowDataType::Int32, - DataType::Long { .. } => ArrowDataType::Int64, - DataType::Float { .. } => ArrowDataType::Float32, - DataType::Double { .. } => ArrowDataType::Float64, - DataType::String { .. } | DataType::Varchar { .. } | DataType::Char { .. } => { - ArrowDataType::Utf8 - } - DataType::Binary { .. } => ArrowDataType::Binary, - DataType::Decimal { - precision, scale, .. - } => ArrowDataType::Decimal128(*precision as u8, *scale as i8), - DataType::Timestamp { .. } => ArrowDataType::Timestamp(TimeUnit::Nanosecond, None), - DataType::TimestampWithLocalTimezone { .. } => { - // TODO: get writer timezone - ArrowDataType::Timestamp(TimeUnit::Nanosecond, None) - } - DataType::Date { .. } => ArrowDataType::Date32, - DataType::Struct { children, .. } => { - let children = children - .iter() - .map(|col| { - let dt = col.data_type().to_arrow_data_type(); - Field::new(col.name(), dt, true) - }) - .collect(); - ArrowDataType::Struct(children) - } - DataType::List { child, .. } => { - let child = child.to_arrow_data_type(); - ArrowDataType::new_list(child, true) - } - DataType::Map { key, value, .. } => { - let key = key.to_arrow_data_type(); - let key = Field::new("key", key, false); - let value = value.to_arrow_data_type(); - let value = Field::new("value", value, true); - - let dt = ArrowDataType::Struct(vec![key, value].into()); - let dt = Arc::new(Field::new("entries", dt, true)); - ArrowDataType::Map(dt, false) - } - DataType::Union { variants, .. } => { - let fields = variants - .iter() - .enumerate() - .map(|(index, variant)| { - // Should be safe as limited to 256 variants total (in from_proto) - let index = index as u8 as i8; - let arrow_dt = variant.to_arrow_data_type(); - // Name shouldn't matter here (only ORC struct types give names to subtypes anyway) - let field = Arc::new(Field::new(format!("{index}"), arrow_dt, true)); - (index, field) - }) - .collect(); - ArrowDataType::Union(fields, UnionMode::Sparse) - } - } - } } impl Display for DataType { diff --git a/tests/integration/main.rs b/tests/integration/main.rs index 261262b3..1293798f 100644 --- a/tests/integration/main.rs +++ b/tests/integration/main.rs @@ -85,7 +85,7 @@ fn metaData() { test_expected_file("TestOrcFile.metaData"); } #[test] -#[ignore] // TODO: Why? +#[ignore] // TODO: {} instead of [{}] and decimal representation differs fn test1() { test_expected_file("TestOrcFile.test1"); }