Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for dictionary stripes #68

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 82 additions & 2 deletions src/arrow_reader/column/mod.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::sync::Arc;

use arrow::datatypes::Field;
use arrow::datatypes::{DataType as ArrowDataType, TimeUnit, UnionMode};
use bytes::Bytes;
use snafu::ResultExt;

use crate::error::{IoSnafu, Result};
use crate::proto::column_encoding::Kind as ColumnEncodingKind;
use crate::proto::stream::Kind;
use crate::proto::{ColumnEncoding, StripeFooter};
use crate::reader::decode::boolean_rle::BooleanIter;
Expand All @@ -26,14 +28,14 @@ pub struct Column {

impl From<Column> for Field {
fn from(value: Column) -> Self {
let dt = value.data_type.to_arrow_data_type();
let dt = value.arrow_data_type();
Field::new(value.name, dt, true)
}
}

impl From<&Column> for Field {
fn from(value: &Column) -> Self {
let dt = value.data_type.to_arrow_data_type();
let dt = value.arrow_data_type();
Field::new(value.name.clone(), dt, true)
}
}
Expand Down Expand Up @@ -69,6 +71,84 @@ impl Column {
&self.data_type
}

pub fn arrow_data_type(&self) -> ArrowDataType {
let value_type = match self.data_type {
DataType::Boolean { .. } => ArrowDataType::Boolean,
DataType::Byte { .. } => ArrowDataType::Int8,
DataType::Short { .. } => ArrowDataType::Int16,
DataType::Int { .. } => ArrowDataType::Int32,
DataType::Long { .. } => ArrowDataType::Int64,
DataType::Float { .. } => ArrowDataType::Float32,
DataType::Double { .. } => ArrowDataType::Float64,
DataType::String { .. } | DataType::Varchar { .. } | DataType::Char { .. } => {
ArrowDataType::Utf8
}
DataType::Binary { .. } => ArrowDataType::Binary,
DataType::Decimal {
precision, scale, ..
} => ArrowDataType::Decimal128(precision as u8, scale as i8),
DataType::Timestamp { .. } => ArrowDataType::Timestamp(TimeUnit::Nanosecond, None),
DataType::TimestampWithLocalTimezone { .. } => {
// TODO: get writer timezone
ArrowDataType::Timestamp(TimeUnit::Nanosecond, None)
}
DataType::Date { .. } => ArrowDataType::Date32,
DataType::Struct { .. } => {
let children = self
.children()
.into_iter()
.map(|col| {
let dt = col.arrow_data_type();
Field::new(col.name(), dt, true)
})
.collect();
ArrowDataType::Struct(children)
}
DataType::List { .. } => {
let children = self.children();
assert_eq!(children.len(), 1);
ArrowDataType::new_list(children[0].arrow_data_type(), true)
}
DataType::Map { .. } => {
let children = self.children();
assert_eq!(children.len(), 2);
let key = &children[0];
let key = key.arrow_data_type();
let key = Field::new("key", key, false);
let value = &children[1];
let value = value.arrow_data_type();
let value = Field::new("value", value, true);

let dt = ArrowDataType::Struct(vec![key, value].into());
let dt = Arc::new(Field::new("entries", dt, true));
ArrowDataType::Map(dt, false)
}
DataType::Union { .. } => {
let fields = self
.children()
.iter()
.enumerate()
.map(|(index, variant)| {
// Should be safe as limited to 256 variants total (in from_proto)
let index = index as u8 as i8;
let arrow_dt = variant.arrow_data_type();
// Name shouldn't matter here (only ORC struct types give names to subtypes anyway)
let field = Arc::new(Field::new(format!("{index}"), arrow_dt, true));
(index, field)
})
.collect();
ArrowDataType::Union(fields, UnionMode::Sparse)
}
};

match self.encoding().kind() {
ColumnEncodingKind::Direct | ColumnEncodingKind::DirectV2 => value_type,
ColumnEncodingKind::Dictionary | ColumnEncodingKind::DictionaryV2 => {
ArrowDataType::Dictionary(Box::new(ArrowDataType::UInt64), Box::new(value_type))
}
}
}

pub fn name(&self) -> &str {
&self.name
}
Expand Down
8 changes: 2 additions & 6 deletions src/arrow_reader/decoder/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,9 @@ impl MapArrayDecoder {
let reader = stripe.stream_map.get(column, Kind::Length)?;
let lengths = get_rle_reader(column, reader)?;

let keys_field = Field::new("keys", keys_column.data_type().to_arrow_data_type(), false);
let keys_field = Field::new("keys", keys_column.arrow_data_type(), false);
let keys_field = Arc::new(keys_field);
let values_field = Field::new(
"values",
values_column.data_type().to_arrow_data_type(),
true,
);
let values_field = Field::new("values", values_column.arrow_data_type(), true);
let values_field = Arc::new(values_field);

let fields = Fields::from(vec![keys_field, values_field]);
Expand Down
14 changes: 6 additions & 8 deletions src/arrow_reader/decoder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use arrow::array::{ArrayRef, BooleanArray, BooleanBuilder, PrimitiveArray, Primi
use arrow::buffer::NullBuffer;
use arrow::datatypes::{ArrowPrimitiveType, UInt64Type};
use arrow::datatypes::{
Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, SchemaRef,
Date32Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
TimestampNanosecondType,
};
use arrow::record_batch::RecordBatch;
Expand Down Expand Up @@ -230,7 +230,6 @@ fn create_null_buffer(present: Option<Vec<bool>>) -> Option<NullBuffer> {

pub struct NaiveStripeDecoder {
stripe: Stripe,
schema_ref: SchemaRef,
decoders: Vec<Box<dyn ArrayBatchDecoder>>,
index: usize,
batch_size: usize,
Expand Down Expand Up @@ -388,10 +387,10 @@ impl NaiveStripeDecoder {
} else {
//TODO(weny): any better way?
let fields = self
.schema_ref
.fields
.into_iter()
.map(|field| field.name())
.stripe
.columns
.iter()
.map(|col| col.name())
.zip(fields)
.collect::<Vec<_>>();

Expand All @@ -401,7 +400,7 @@ impl NaiveStripeDecoder {
}
}

pub fn new(stripe: Stripe, schema_ref: SchemaRef, batch_size: usize) -> Result<Self> {
pub fn new(stripe: Stripe, batch_size: usize) -> Result<Self> {
let mut decoders = Vec::with_capacity(stripe.columns.len());
let number_of_rows = stripe.number_of_rows;

Expand All @@ -412,7 +411,6 @@ impl NaiveStripeDecoder {

Ok(Self {
stripe,
schema_ref,
decoders,
index: 0,
batch_size,
Expand Down
2 changes: 2 additions & 0 deletions src/arrow_reader/decoder/string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ impl<T: ByteArrayType> ArrayBatchDecoder for GenericByteArrayDecoder<T> {
batch_size: usize,
parent_present: Option<&[bool]>,
) -> Result<ArrayRef> {
println!("GenericByteArrayDecoder::next_batch");
let array = self.next_byte_batch(batch_size, parent_present)?;
let array = Arc::new(array) as ArrayRef;
Ok(array)
Expand All @@ -169,6 +170,7 @@ impl ArrayBatchDecoder for DictionaryStringArrayDecoder {
batch_size: usize,
parent_present: Option<&[bool]>,
) -> Result<ArrayRef> {
println!("DictionaryStringArrayDecoder::next_batch");
let keys = self
.indexes
.next_primitive_batch(batch_size, parent_present)?;
Expand Down
11 changes: 10 additions & 1 deletion src/arrow_reader/decoder/struct_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,12 @@ impl StructArrayDecoder {
let fields = column
.children()
.into_iter()
.map(Field::from)
.map(|col| {
println!("col {:#?}", col);
let field = Field::from(col);
println!("field {:?}", field);
field
})
.map(Arc::new)
.collect::<Vec<_>>();
let fields = Fields::from(fields);
Expand Down Expand Up @@ -64,6 +69,10 @@ impl ArrayBatchDecoder for StructArrayDecoder {
.collect::<Result<Vec<_>>>()?;

let null_buffer = present.map(NullBuffer::from);
println!(
"next batch fields = {:?}, child_arrays = {:?}, nulls = {:?}",
self.fields, child_arrays, null_buffer
);
let array = StructArray::try_new(self.fields.clone(), child_arrays, null_buffer)
.context(ArrowSnafu)?;
let array = Arc::new(array);
Expand Down
29 changes: 3 additions & 26 deletions src/arrow_reader/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
use std::collections::HashMap;
use std::sync::Arc;

use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::ArrowError;
use arrow::record_batch::{RecordBatch, RecordBatchReader};
use arrow::record_batch::RecordBatch;

pub use self::decoder::NaiveStripeDecoder;
use crate::error::Result;
Expand Down Expand Up @@ -73,10 +71,8 @@ impl<R: ChunkReader> ArrowReaderBuilder<R> {
projected_data_type,
stripe_index: 0,
};
let schema_ref = Arc::new(create_arrow_schema(&cursor));
ArrowReader {
cursor,
schema_ref,
current_stripe: None,
batch_size: self.batch_size,
}
Expand All @@ -101,14 +97,12 @@ impl<R: AsyncChunkReader + 'static> ArrowReaderBuilder<R> {
projected_data_type,
stripe_index: 0,
};
let schema_ref = Arc::new(create_arrow_schema(&cursor));
ArrowStreamReader::new(cursor, self.batch_size, schema_ref)
ArrowStreamReader::new(cursor, self.batch_size)
}
}

pub struct ArrowReader<R> {
cursor: Cursor<R>,
schema_ref: SchemaRef,
current_stripe: Option<Box<dyn Iterator<Item = Result<RecordBatch>> + Send>>,
batch_size: usize,
}
Expand All @@ -124,8 +118,7 @@ impl<R: ChunkReader> ArrowReader<R> {
let stripe = self.cursor.next().transpose()?;
match stripe {
Some(stripe) => {
let decoder =
NaiveStripeDecoder::new(stripe, self.schema_ref.clone(), self.batch_size)?;
let decoder = NaiveStripeDecoder::new(stripe, self.batch_size)?;
self.current_stripe = Some(Box::new(decoder));
self.next().transpose()
}
Expand All @@ -134,22 +127,6 @@ impl<R: ChunkReader> ArrowReader<R> {
}
}

pub fn create_arrow_schema<R>(cursor: &Cursor<R>) -> Schema {
let metadata = cursor
.file_metadata
.user_custom_metadata()
.iter()
.map(|(key, value)| (key.clone(), String::from_utf8_lossy(value).to_string()))
.collect::<HashMap<_, _>>();
cursor.projected_data_type.create_arrow_schema(&metadata)
}

impl<R: ChunkReader> RecordBatchReader for ArrowReader<R> {
fn schema(&self) -> SchemaRef {
self.schema_ref.clone()
}
}

impl<R: ChunkReader> Iterator for ArrowReader<R> {
type Item = std::result::Result<RecordBatch, ArrowError>;

Expand Down
15 changes: 2 additions & 13 deletions src/async_arrow_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use arrow::datatypes::SchemaRef;
use arrow::error::ArrowError;
use arrow::record_batch::RecordBatch;
use futures::future::BoxFuture;
Expand Down Expand Up @@ -61,7 +60,6 @@ pub struct StripeFactory<R> {
pub struct ArrowStreamReader<R: AsyncChunkReader> {
factory: Option<Box<StripeFactory<R>>>,
batch_size: usize,
schema_ref: SchemaRef,
state: StreamState<R>,
}

Expand Down Expand Up @@ -107,19 +105,14 @@ impl<R: AsyncChunkReader + 'static> StripeFactory<R> {
}

impl<R: AsyncChunkReader + 'static> ArrowStreamReader<R> {
pub fn new(cursor: Cursor<R>, batch_size: usize, schema_ref: SchemaRef) -> Self {
pub fn new(cursor: Cursor<R>, batch_size: usize) -> Self {
Self {
factory: Some(Box::new(cursor.into())),
batch_size,
schema_ref,
state: StreamState::Init,
}
}

pub fn schema(&self) -> SchemaRef {
self.schema_ref.clone()
}

fn poll_next_inner(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
Expand Down Expand Up @@ -149,11 +142,7 @@ impl<R: AsyncChunkReader + 'static> ArrowStreamReader<R> {
StreamState::Reading(f) => match ready!(f.poll_unpin(cx)) {
Ok((factory, Some(stripe))) => {
self.factory = Some(Box::new(factory));
match NaiveStripeDecoder::new(
stripe,
self.schema_ref.clone(),
self.batch_size,
) {
match NaiveStripeDecoder::new(stripe, self.batch_size) {
Ok(decoder) => {
self.state = StreamState::Decoding(Box::new(decoder));
}
Expand Down
Loading
Loading