diff --git a/arrow-avro/Cargo.toml b/arrow-avro/Cargo.toml index c103c2ecc0f3..a9c237008140 100644 --- a/arrow-avro/Cargo.toml +++ b/arrow-avro/Cargo.toml @@ -34,7 +34,7 @@ path = "src/lib.rs" bench = false [features] -default = ["deflate", "snappy", "zstd"] +default = ["deflate", "snappy", "zstd", "bzip2", "xz"] deflate = ["flate2"] snappy = ["snap", "crc"] @@ -42,14 +42,19 @@ snappy = ["snap", "crc"] arrow-schema = { workspace = true } arrow-buffer = { workspace = true } arrow-array = { workspace = true } +arrow-data = { workspace = true } serde_json = { version = "1.0", default-features = false, features = ["std"] } serde = { version = "1.0.188", features = ["derive"] } flate2 = { version = "1.0", default-features = false, features = ["rust_backend"], optional = true } snap = { version = "1.0", default-features = false, optional = true } zstd = { version = "0.13", default-features = false, optional = true } +bzip2 = { version = "0.4.4", default-features = false, optional = true } +xz = { version = "0.1.0", default-features = false, optional = true } crc = { version = "3.0", optional = true } [dev-dependencies] -rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } - +bytes = "1.4" +futures = "0.3" +tokio = { version = "1.27", default-features = false, features = ["io-util", "macros", "rt-multi-thread"] } +rand = { version = "0.9", default-features = false, features = ["std", "std_rng", "thread_rng"] } \ No newline at end of file diff --git a/arrow-avro/src/codec.rs b/arrow-avro/src/codec.rs index 2ac1ad038bd7..76e09dc3f644 100644 --- a/arrow-avro/src/codec.rs +++ b/arrow-avro/src/codec.rs @@ -15,48 +15,63 @@ // specific language governing permissions and limitations // under the License. -use crate::schema::{Attributes, ComplexType, PrimitiveType, Record, Schema, TypeName}; +use crate::schema::{ + Array, Attributes, ComplexType, Enum, Fixed, Map, PrimitiveType, Record, RecordField, Schema, + Type, TypeName, +}; use arrow_schema::{ - ArrowError, DataType, Field, FieldRef, IntervalUnit, SchemaBuilder, SchemaRef, TimeUnit, + ArrowError, DataType, Field, Fields, IntervalUnit, TimeUnit, DECIMAL128_MAX_PRECISION, + DECIMAL128_MAX_SCALE, }; -use std::borrow::Cow; use std::collections::HashMap; use std::sync::Arc; /// Avro types are not nullable, with nullability instead encoded as a union /// where one of the variants is the null type. -/// -/// To accommodate this we special case two-variant unions where one of the -/// variants is the null type, and use this to derive arrow's notion of nullability -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum Nullability { - /// The nulls are encoded as the first union variant + /// The nulls are encoded as the first union variant => `[ "null", T ]` NullFirst, - /// The nulls are encoded as the second union variant + /// The nulls are encoded as the second union variant => `[ T, "null" ]` NullSecond, } /// An Avro datatype mapped to the arrow data model #[derive(Debug, Clone)] pub struct AvroDataType { - nullability: Option, - metadata: HashMap, - codec: Codec, + pub nullability: Option, + pub metadata: HashMap, + pub codec: Codec, } impl AvroDataType { - /// Returns an arrow [`Field`] with the given name - pub fn field_with_name(&self, name: &str) -> Field { - let d = self.codec.data_type(); - Field::new(name, d, self.nullability.is_some()).with_metadata(self.metadata.clone()) + /// Create a new AvroDataType with the given parts. + pub fn new( + codec: Codec, + nullability: Option, + metadata: HashMap, + ) -> Self { + AvroDataType { + codec, + nullability, + metadata, + } + } + + /// Create a new AvroDataType from a `Codec`, with default (no) nullability and empty metadata. + pub fn from_codec(codec: Codec) -> Self { + Self::new(codec, None, Default::default()) } - pub fn codec(&self) -> &Codec { - &self.codec + /// Returns the name of this field + fn to_schema<'a>(&self) -> Schema<'a> { + self.codec.schema(self.metadata.clone(), self.nullability).unwrap() } - pub fn nullability(&self) -> Option { - self.nullability + /// Returns an arrow [`Field`] with the given name, applying `nullability` if present. + pub fn field_with_name(&self, name: &str) -> Field { + let is_nullable = self.nullability.is_some(); + Field::new(name, self.codec.data_type(), is_nullable).with_metadata(self.metadata.clone()) } } @@ -65,12 +80,26 @@ impl AvroDataType { pub struct AvroField { name: String, data_type: AvroDataType, + default: Option, } impl AvroField { /// Returns the arrow [`Field`] pub fn field(&self) -> Field { - self.data_type.field_with_name(&self.name) + let mut fld = self.data_type.field_with_name(&self.name); + if let Some(def_val) = &self.default { + if !def_val.is_null() { + let mut md = fld.metadata().clone(); + md.insert("avro.default".to_string(), def_val.to_string()); + fld = fld.with_metadata(md); + } + } + fld + } + + /// Returns the name of this field + pub fn to_schema<'a>(&self) -> Schema<'a> { + self.data_type.to_schema() } /// Returns the [`AvroDataType`] @@ -78,6 +107,7 @@ impl AvroField { &self.data_type } + /// Returns the name of this field pub fn name(&self) -> &str { &self.name } @@ -91,9 +121,10 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { Schema::Complex(ComplexType::Record(r)) => { let mut resolver = Resolver::default(); let data_type = make_data_type(schema, None, &mut resolver)?; - Ok(AvroField { + Ok(Self { data_type, name: r.name.to_string(), + default: None, }) } _ => Err(ArrowError::ParseError(format!( @@ -103,11 +134,119 @@ impl<'a> TryFrom<&Schema<'a>> for AvroField { } } +/// An Avro datatype mapped to the arrow data model +#[derive(Debug, Clone)] +pub struct RecordFieldBuilder<'a> { + pub impala_mode: &'a bool, + pub field:&'a Field, +} + +impl<'a> RecordFieldBuilder<'a> { + + /// Sets the row-based batch size + pub fn with_impala_mode(mut self, impala_mode: &'a bool) -> Self { + self.impala_mode = impala_mode; + self + } + + pub fn finish(self) -> Result, ArrowError> { + let nullability = if *self.impala_mode { + Nullability::NullSecond + } else { + Nullability::NullFirst + }; + let nullable = if self.field.is_nullable() { + Some(nullability) + } else { + None + }; + let schema = Codec::from_field(self.field, nullability)?.schema(self.field.metadata().clone(), nullable)?; + let default_val = self.field + .metadata() + .get("avro.default") + .and_then(|s| serde_json::from_str(s).ok()); + Ok(RecordField { + name: self.field.name(), + doc: None, + aliases: vec![], + r#type: schema, + default: default_val, + }) + } +} + +impl<'a> From<&'a Field> for RecordFieldBuilder<'a> { + fn from(field: &'a Field) -> Self { + Self { + impala_mode: &false, + field, + } + } +} + +/// An Avro datatype mapped to the arrow data model +#[derive(Debug, Clone)] +pub struct SchemaBuilder<'a> { + pub impala_mode: &'a bool, + pub fields: &'a Fields, + pub name: &'a str, + pub namespace: Option<&'a str>, +} + +impl<'a> SchemaBuilder<'a> { + + /// Sets the row-based batch size + pub fn with_impala_mode(mut self, impala_mode: &'a bool) -> Self { + self.impala_mode = impala_mode; + self + } + + pub fn with_name(mut self, name: &'a str) -> Self { + self.name = name; + self + } + + pub fn with_namespace(mut self, namespace: Option<&'a str>) -> Self { + self.namespace = namespace; + self + } + + /// Consume this [`arrow_schema::SchemaBuilder`] yielding the final [`arrow_schema::Schema`] + pub fn finish(self) -> Result, ArrowError> { + let record_fields = self + .fields + .iter() + .map(|fref| RecordFieldBuilder::from(fref.as_ref()) + .with_impala_mode(&self.impala_mode) + .finish() + ).collect::, _>>()?; + let record = Record { + name: self.name, + namespace: self.namespace, + doc: None, + aliases: vec![], + fields: record_fields, + attributes: Default::default(), + }; + Ok(Schema::Complex(ComplexType::Record(record))) + } +} + +impl<'a> From<&'a arrow_schema::Schema> for SchemaBuilder<'a> { + fn from(schema: &'a arrow_schema::Schema) -> Self { + Self { + impala_mode: &false, + fields: &schema.fields(), + name: "", + namespace: Some(""), + } + } +} + /// An Avro encoding -/// -/// #[derive(Debug, Clone)] pub enum Codec { + /// Primitive Null, Boolean, Int32, @@ -115,22 +254,178 @@ pub enum Codec { Float32, Float64, Binary, - Utf8, + String, + /// Complex + Record(Arc<[AvroField]>), + Enum(Arc<[String]>, Arc<[i32]>), + Array(Arc), + Map(Arc), + Fixed(i32), + /// Logical + Decimal(usize, Option, Option), + Uuid, Date32, TimeMillis, TimeMicros, - /// TimestampMillis(is_utc) TimestampMillis(bool), - /// TimestampMicros(is_utc) TimestampMicros(bool), - Fixed(i32), - List(Arc), - Struct(Arc<[AvroField]>), - Interval, + Duration, } impl Codec { - fn data_type(&self) -> DataType { + + fn from_field(field: &Field, nullability_type: Nullability) -> Result { + let metadata = field.metadata(); + match field.data_type() { + // Primitive Types + DataType::Null => Ok(Self::Null), + DataType::Boolean => Ok(Self::Boolean), + DataType::Int8 | DataType::Int16 | DataType::Int32 => Ok(Self::Int32), + DataType::Int64 => Ok(Self::Int64), + DataType::Float32 => Ok(Self::Float32), + DataType::Float64 => Ok(Self::Float64), + DataType::Binary | DataType::LargeBinary => Ok(Self::Binary), + DataType::Utf8 | DataType::LargeUtf8 => Ok(Self::String), + // Complex Types + DataType::Struct(fields) => { + let avro_fields: Vec = fields + .iter() + .map(|fref| { + let child_codec = Codec::from_field(fref.as_ref(), nullability_type)?; + let default_val = fref + .metadata() + .get("avro.default") + .and_then(|s| serde_json::from_str(s).ok()); + let nullability = if fref.is_nullable() { + Some(nullability_type) + } else { + None + }; + Ok(AvroField { + name: fref.name().clone(), + data_type: AvroDataType::new(child_codec, nullability, fref.metadata().clone()), + default: default_val, + }) + }) + .collect::>()?; + Ok(Self::Record(Arc::from(avro_fields))) + } + DataType::Dictionary(key_type, value_type) => { + let valid_key = matches!( + key_type.as_ref(), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 + ); + let valid_val = matches!( + value_type.as_ref(), + DataType::Utf8 | DataType::LargeUtf8 + ); + match (valid_key && valid_val, metadata.get("avro.enum.symbols")) { + (false, _) => Ok(Self::String), + (true, None) => Ok(Self::String), + (true, Some(sym_json_str)) => { + let parsed: serde_json::Value = serde_json::from_str(sym_json_str) + .map_err(|e| ArrowError::ParseError(format!( + "Invalid JSON in avro.enum.symbols: {e}" + )))?; + if let Some(arr) = parsed.as_array() { + let symbols: Vec = arr + .iter() + .filter_map(|v| v.as_str()) + .map(|s| s.to_string()) + .collect(); + Ok(Self::Enum(Arc::from(symbols), Arc::from(vec![]))) + } else { + Err(ArrowError::ParseError( + "Expected JSON array for avro.enum.symbols".to_string(), + )) + } + } + } + } + DataType::List(child_field) | DataType::LargeList(child_field) => { + let nullability = if child_field.is_nullable() { + Some(nullability_type) + } else { + None + }; + let child_codec = Codec::from_field(child_field.as_ref(), nullability_type)?; + Ok(Self::Array( + Arc::new(AvroDataType::new( + child_codec, + nullability, + child_field.metadata().clone()), + ) + )) + } + DataType::FixedSizeList(child_field, _sz) => { + let nullability = if child_field.is_nullable() { + Some(nullability_type) + } else { + None + }; + let child_codec = Codec::from_field(child_field.as_ref(), nullability_type)?; + Ok(Self::Array( + Arc::new(AvroDataType::new( + child_codec, + nullability, + child_field.metadata().clone()), + ) + )) + } + DataType::Map(entry_field, _keys_sorted) => match entry_field.data_type() { + DataType::Struct(children) if children.len() == 2 => { + let value_field = &children[1]; + let nullability = if value_field.is_nullable() { + Some(nullability_type) + } else { + None + }; + let val_codec = Codec::from_field(value_field, nullability_type)?; + Ok(Self::Map(Arc::new(AvroDataType::new( + val_codec, + nullability, + value_field.metadata().clone()), + ))) + } + _ => Ok(Self::String), + }, + DataType::FixedSizeBinary(n) => { + let logical_type = metadata.get("logicalType").map(|s| s.as_str()); + match (*n, logical_type) { + (16, Some("uuid")) => Ok(Self::Uuid), + (12, Some("duration")) => Ok(Self::Duration), + _ => Ok(Self::Fixed(*n)), + } + } + // Logical Types + DataType::Interval(IntervalUnit::MonthDayNano) => Ok(Self::Duration), + DataType::Decimal128(p, s) => { + Ok(Self::Decimal(*p as usize, Some(*s as usize), Some(16))) + } + DataType::Decimal256(p, s) => { + Ok(Self::Decimal(*p as usize, Some(*s as usize), Some(32))) + } + DataType::Date32 => Ok(Self::Date32), + DataType::Time32(TimeUnit::Millisecond) => Ok(Self::TimeMillis), + DataType::Time64(TimeUnit::Microsecond) => Ok(Self::TimeMicros), + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + let is_utc = tz_opt.as_deref() == Some("+00:00"); + Ok(Self::TimestampMillis(is_utc)) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + let is_utc = tz_opt.as_deref() == Some("+00:00"); + Ok(Self::TimestampMicros(is_utc)) + } + other => { + Err(ArrowError::AvroError(format!( + "Unrecognized Avro logicalType={other}") + )) + } + } + } + + /// Convert this to an Arrow `DataType` + pub(crate) fn data_type(&self) -> DataType { match self { Self::Null => DataType::Null, Self::Boolean => DataType::Boolean, @@ -139,22 +434,333 @@ impl Codec { Self::Float32 => DataType::Float32, Self::Float64 => DataType::Float64, Self::Binary => DataType::Binary, - Self::Utf8 => DataType::Utf8, + Self::String => DataType::Utf8, + Self::Record(fields) => { + let arrow_fields: Vec = fields.iter().map(|f| f.field()).collect(); + DataType::Struct(arrow_fields.into()) + } + Self::Enum(_, _) => { + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)) + } + Self::Array(child_type) => { + let child_dt = child_type.codec.data_type(); + let child_md = child_type.metadata.to_owned(); + let child_field = + Field::new(Field::LIST_FIELD_DEFAULT_NAME, child_dt, true).with_metadata(child_md); + DataType::List(Arc::new(child_field)) + } + Self::Map(value_type) => { + let val_dt = value_type.codec.data_type(); + let val_md = value_type.metadata.to_owned(); + let val_field = Field::new("value", val_dt, true).with_metadata(val_md); + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + val_field, + ])), + false, + )), + false, + ) + } + Self::Fixed(sz) => DataType::FixedSizeBinary(*sz), + Self::Decimal(precision, scale, size_opt) => { + let p = *precision as u8; + let s = scale.unwrap_or(0) as i8; + let too_large_for_128 = match *size_opt { + Some(sz) => sz > 16, + None => { + (p as usize) > DECIMAL128_MAX_PRECISION as usize + || (s as usize) > DECIMAL128_MAX_SCALE as usize + } + }; + if too_large_for_128 { + DataType::Decimal256(p, s) + } else { + DataType::Decimal128(p, s) + } + } + Self::Uuid => DataType::FixedSizeBinary(16), Self::Date32 => DataType::Date32, Self::TimeMillis => DataType::Time32(TimeUnit::Millisecond), Self::TimeMicros => DataType::Time64(TimeUnit::Microsecond), + Self::TimestampMillis(is_utc) => DataType::Timestamp( + TimeUnit::Millisecond, + is_utc.then(|| "+00:00".into()), + ), + Self::TimestampMicros(is_utc) => DataType::Timestamp( + TimeUnit::Microsecond, + is_utc.then(|| "+00:00".into()), + ), + Self::Duration => DataType::Interval(IntervalUnit::MonthDayNano), + } + } + pub(crate) fn schema<'a>( + &self, + metadata: HashMap, + nullability: Option, + ) -> Result, ArrowError> { + let base = match self { + Self::Null => Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + Self::Boolean => Schema::TypeName(TypeName::Primitive(PrimitiveType::Boolean)), + Self::Int32 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + Self::Int64 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), + Self::Float32 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Float)), + Self::Float64 => Schema::TypeName(TypeName::Primitive(PrimitiveType::Double)), + Self::Binary => Schema::TypeName(TypeName::Primitive(PrimitiveType::Bytes)), + Self::String => Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), + Self::Record(fields) => { + let record_fields = fields + .iter() + .map(|field| RecordField { + name: Box::leak(field.name().to_string().into_boxed_str()), + doc: None, + aliases: vec![], + r#type: field.data_type.to_schema(), + default: field.default.clone(), + }) + .collect::>(); + // TODO: Make metadata part of lifetime + let record_name = metadata + .get("avro.record.name") + .cloned() + .unwrap_or_else(|| "record".to_string()); + let record_namespace = metadata + .get("avro.record.namespace") + .cloned(); + let mut attributes = Attributes::default(); + copy_metadata_to_attributes(&metadata, &mut attributes); + Schema::Complex(ComplexType::Record(Record { + name: Box::leak(record_name.into_boxed_str()), + namespace: record_namespace.map(|ns| { + let leaked = Box::leak(ns.into_boxed_str()); + leaked as &'a str + }), + doc: None, + aliases: vec![], + fields: record_fields, + attributes, + })) + } + Self::Enum(symbols, _ordinals) => { + let enum_name = metadata + .get("avro.enum.name") + .cloned() + .unwrap_or_else(|| "enum".to_string()); + let enum_namespace = metadata + .get("avro.enum.namespace") + .cloned(); + let mut attributes = Attributes::default(); + copy_metadata_to_attributes(&metadata, &mut attributes); + let mut leaked_syms = Vec::with_capacity(symbols.len()); + for sym in symbols.iter() { + let leaked: &'a str = Box::leak(sym.clone().into_boxed_str()); + leaked_syms.push(leaked); + } + Schema::Complex(ComplexType::Enum(Enum { + name: Box::leak(enum_name.into_boxed_str()), + namespace: enum_namespace.map(|ns| { + let leaked = Box::leak(ns.into_boxed_str()); + leaked as &'a str + }), + doc: None, + aliases: vec![], + symbols: leaked_syms, + default: None, + attributes, + })) + } + Self::Array(child) => { + let items_schema = child.to_schema(); + let mut attributes = Attributes::default(); + copy_metadata_to_attributes(&metadata, &mut attributes); + Schema::Complex(ComplexType::Array(Array { + items: Box::new(items_schema), + attributes, + })) + } + Self::Map(value_type) => { + let value_schema = value_type.to_schema(); + let mut attributes = Attributes::default(); + copy_metadata_to_attributes(&metadata, &mut attributes); + Schema::Complex(ComplexType::Map(Map { + values: Box::new(value_schema), + attributes, + })) + } + Self::Fixed(size) => { + let fixed_name = metadata + .get("avro.fixed.name") + .cloned() + .unwrap_or_else(|| format!("fixed_{size}")); + let fixed_namespace = metadata + .get("avro.fixed.namespace") + .cloned(); + let mut attributes = Attributes::default(); + copy_metadata_to_attributes(&metadata, &mut attributes); + Schema::Complex(ComplexType::Fixed(Fixed { + name: Box::leak(fixed_name.into_boxed_str()), + namespace: fixed_namespace.map(|ns| { + let leaked = Box::leak(ns.into_boxed_str()); + leaked as &'a str + }), + aliases: vec![], + size: *size as usize, + attributes, + })) + } + Self::Decimal(precision, scale, size_opt) => { + let p = *precision; + let s = scale.unwrap_or(0); + let mut attrs = Attributes { + logical_type: Some("decimal"), + additional: HashMap::from([ + ("precision", serde_json::Value::Number(p.into())), + ("scale", serde_json::Value::Number(s.into())), + ]), + }; + copy_metadata_to_attributes(&metadata, &mut attrs); + if let Some(size) = size_opt { + let fixed_name = metadata + .get("avro.fixed.name") + .cloned() + .unwrap_or_else(|| format!("decimal_fixed_{size}_{p}_{s}")); + let fixed_namespace = metadata + .get("avro.fixed.namespace") + .cloned(); + Schema::Complex(ComplexType::Fixed(Fixed { + name: Box::leak(fixed_name.into_boxed_str()), + namespace: fixed_namespace.map(|ns| { + let leaked = Box::leak(ns.into_boxed_str()); + leaked as &'a str + }), + aliases: vec![], + size: *size, + attributes: attrs, + })) + } else { + Schema::Type(Type { + r#type: TypeName::Primitive(PrimitiveType::Bytes), + attributes: attrs, + }) + } + } + Self::Uuid => { + let mut attrs = Attributes::default(); + attrs.logical_type = Some("uuid"); + copy_metadata_to_attributes(&metadata, &mut attrs); + let fixed_name = metadata + .get("avro.fixed.name") + .cloned() + .unwrap_or_else(|| "fixed_16_uuid".to_string()); + let fixed_namespace = metadata + .get("avro.fixed.namespace") + .cloned(); + Schema::Complex(ComplexType::Fixed(Fixed { + name: Box::leak(fixed_name.into_boxed_str()), + namespace: fixed_namespace.map(|ns| { + let leaked = Box::leak(ns.into_boxed_str()); + leaked as &'a str + }), + aliases: vec![], + size: 16, + attributes: attrs, + })) + } + Self::Date32 => { + let mut attrs = Attributes::default(); + attrs.logical_type = Some("date"); + copy_metadata_to_attributes(&metadata, &mut attrs); + Schema::Type(Type { + r#type: TypeName::Primitive(PrimitiveType::Int), + attributes: attrs, + }) + } + Self::TimeMillis => { + let mut attrs = Attributes::default(); + attrs.logical_type = Some("time-millis"); + copy_metadata_to_attributes(&metadata, &mut attrs); + Schema::Type(Type { + r#type: TypeName::Primitive(PrimitiveType::Int), + attributes: attrs, + }) + } + Self::TimeMicros => { + let mut attrs = Attributes::default(); + attrs.logical_type = Some("time-micros"); + copy_metadata_to_attributes(&metadata, &mut attrs); + Schema::Type(Type { + r#type: TypeName::Primitive(PrimitiveType::Long), + attributes: attrs, + }) + } Self::TimestampMillis(is_utc) => { - DataType::Timestamp(TimeUnit::Millisecond, is_utc.then(|| "+00:00".into())) + let mut attrs = Attributes::default(); + let lt = if *is_utc { + "timestamp-millis" + } else { + "local-timestamp-millis" + }; + attrs.logical_type = Some(lt); + copy_metadata_to_attributes(&metadata, &mut attrs); + Schema::Type(Type { + r#type: TypeName::Primitive(PrimitiveType::Long), + attributes: attrs, + }) } Self::TimestampMicros(is_utc) => { - DataType::Timestamp(TimeUnit::Microsecond, is_utc.then(|| "+00:00".into())) + let mut attrs = Attributes::default(); + let lt = if *is_utc { + "timestamp-micros" + } else { + "local-timestamp-micros" + }; + attrs.logical_type = Some(lt); + copy_metadata_to_attributes(&metadata, &mut attrs); + Schema::Type(Type { + r#type: TypeName::Primitive(PrimitiveType::Long), + attributes: attrs, + }) } - Self::Interval => DataType::Interval(IntervalUnit::MonthDayNano), - Self::Fixed(size) => DataType::FixedSizeBinary(*size), - Self::List(f) => { - DataType::List(Arc::new(f.field_with_name(Field::LIST_FIELD_DEFAULT_NAME))) + Self::Duration => { + let mut attrs = Attributes::default(); + attrs.logical_type = Some("duration"); + copy_metadata_to_attributes(&metadata, &mut attrs); + let fixed_name = metadata + .get("avro.fixed.name") + .cloned() + .unwrap_or_else(|| "fixed_12_duration".to_string()); + let fixed_namespace = metadata + .get("avro.fixed.namespace") + .cloned(); + Schema::Complex(ComplexType::Fixed(Fixed { + name: Box::leak(fixed_name.into_boxed_str()), + namespace: fixed_namespace.map(|ns| { + let leaked = Box::leak(ns.into_boxed_str()); + leaked as &'a str + }), + aliases: vec![], + size: 12, + attributes: attrs, + })) } - Self::Struct(f) => DataType::Struct(f.iter().map(|x| x.field()).collect()), + }; + if let Some(nul) = nullability { + let union = match nul { + Nullability::NullFirst => vec![ + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + base, + ], + Nullability::NullSecond => vec![ + base, + Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), + ], + }; + Ok(Schema::Union(union)) + } else { + Ok(base) } } } @@ -169,42 +775,40 @@ impl From for Codec { PrimitiveType::Float => Self::Float32, PrimitiveType::Double => Self::Float64, PrimitiveType::Bytes => Self::Binary, - PrimitiveType::String => Self::Utf8, + PrimitiveType::String => Self::String, } } } /// Resolves Avro type names to [`AvroDataType`] -/// -/// See -#[derive(Debug, Default)] +#[derive(Default, Debug)] struct Resolver<'a> { map: HashMap<(&'a str, &'a str), AvroDataType>, } impl<'a> Resolver<'a> { - fn register(&mut self, name: &'a str, namespace: Option<&'a str>, schema: AvroDataType) { - self.map.insert((name, namespace.unwrap_or("")), schema); + fn register(&mut self, name: &'a str, namespace: Option<&'a str>, dt: AvroDataType) { + let ns = namespace.unwrap_or(""); + self.map.insert((name, ns), dt); } - fn resolve(&self, name: &str, namespace: Option<&'a str>) -> Result { - let (namespace, name) = name - .rsplit_once('.') - .unwrap_or_else(|| (namespace.unwrap_or(""), name)); - + fn resolve( + &self, + full_name: &str, + namespace: Option<&'a str>, + ) -> Result { + let (ns, nm) = match full_name.rsplit_once('.') { + Some((a, b)) => (a, b), + None => (namespace.unwrap_or(""), full_name), + }; self.map - .get(&(namespace, name)) - .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {namespace}.{name}"))) + .get(&(nm, ns)) .cloned() + .ok_or_else(|| ArrowError::ParseError(format!("Failed to resolve {ns}.{nm}"))) } } -/// Parses a [`AvroDataType`] from the provided [`Schema`] and the given `name` and `namespace` -/// -/// `name`: is name used to refer to `schema` in its parent -/// `namespace`: an optional qualifier used as part of a type hierarchy -/// -/// See [`Resolver`] for more information +/// Parses a [`AvroDataType`] from the provided [`Schema`], plus optional `namespace`. fn make_data_type<'a>( schema: &Schema<'a>, namespace: Option<&'a str>, @@ -217,113 +821,1124 @@ fn make_data_type<'a>( codec: (*p).into(), }), Schema::TypeName(TypeName::Ref(name)) => resolver.resolve(name, namespace), - Schema::Union(f) => { - // Special case the common case of nullable primitives - let null = f + Schema::Union(u) => { + let null_count = u .iter() - .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))); - match (f.len() == 2, null) { - (true, Some(0)) => { - let mut field = make_data_type(&f[1], namespace, resolver)?; - field.nullability = Some(Nullability::NullFirst); - Ok(field) - } - (true, Some(1)) => { - let mut field = make_data_type(&f[0], namespace, resolver)?; - field.nullability = Some(Nullability::NullSecond); - Ok(field) - } - _ => Err(ArrowError::NotYetImplemented(format!( - "Union of {f:?} not currently supported" - ))), + .filter(|x| *x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))) + .count(); + if null_count == 1 && u.len() == 2 { + let null_idx = u + .iter() + .position(|x| x == &Schema::TypeName(TypeName::Primitive(PrimitiveType::Null))) + .unwrap(); + let other_idx = if null_idx == 0 { 1 } else { 0 }; + let mut dt = make_data_type(&u[other_idx], namespace, resolver)?; + dt.nullability = if null_idx == 0 { + Some(Nullability::NullFirst) + } else { + Some(Nullability::NullSecond) + }; + Ok(dt) + } else { + Err(ArrowError::NotYetImplemented(format!( + "Union of {u:?} not currently supported" + ))) } } Schema::Complex(c) => match c { ComplexType::Record(r) => { - let namespace = r.namespace.or(namespace); + let ns = r.namespace.or(namespace); let fields = r .fields .iter() - .map(|field| { - Ok(AvroField { - name: field.name.to_string(), - data_type: make_data_type(&field.r#type, namespace, resolver)?, + .map(|f| { + let data_type = make_data_type(&f.r#type, ns, resolver)?; + Ok::(AvroField { + name: f.name.to_string(), + data_type, + default: f.default.clone(), }) }) - .collect::>()?; - - let field = AvroDataType { + .collect::, ArrowError>>()?; + let rec_dt = AvroDataType { nullability: None, - codec: Codec::Struct(fields), metadata: r.attributes.field_metadata(), + codec: Codec::Record(Arc::from(fields)), }; - resolver.register(r.name, namespace, field.clone()); - Ok(field) + resolver.register(r.name, ns, rec_dt.clone()); + Ok(rec_dt) + } + ComplexType::Enum(e) => { + let mut md = e.attributes.field_metadata(); + if let Ok(symbols_json) = serde_json::to_string(&e.symbols) { + md.insert("avro.enum.symbols".to_string(), symbols_json); + } + let en = AvroDataType { + nullability: None, + metadata: md, + codec: Codec::Enum( + Arc::from( + e.symbols + .iter() + .map(|s| s.to_string()) + .collect::>(), + ), + Arc::from(vec![]), + ), + }; + resolver.register(e.name, namespace, en.clone()); + Ok(en) } ComplexType::Array(a) => { - let mut field = make_data_type(a.items.as_ref(), namespace, resolver)?; + let child = make_data_type(&a.items, namespace, resolver)?; Ok(AvroDataType { nullability: None, metadata: a.attributes.field_metadata(), - codec: Codec::List(Arc::new(field)), + codec: Codec::Array(Arc::new(child)), }) } - ComplexType::Fixed(f) => { - let size = f.size.try_into().map_err(|e| { - ArrowError::ParseError(format!("Overflow converting size to i32: {e}")) - })?; - - let field = AvroDataType { + ComplexType::Map(m) => { + let val = make_data_type(&m.values, namespace, resolver)?; + Ok(AvroDataType { nullability: None, - metadata: f.attributes.field_metadata(), - codec: Codec::Fixed(size), + metadata: m.attributes.field_metadata(), + codec: Codec::Map(Arc::new(val)), + }) + } + ComplexType::Fixed(fx) => { + let size = fx.size as i32; + let md = fx.attributes.field_metadata(); + let dt = match fx.attributes.logical_type.as_deref() { + Some("decimal") => { + let (precision, scale, _) = + parse_decimal_attributes(&fx.attributes, Some(size as usize), true)?; + AvroDataType { + nullability: None, + metadata: md, + codec: Codec::Decimal(precision, Some(scale), Some(size as usize)), + } + } + Some("duration") if fx.size == 12 => AvroDataType { + nullability: None, + metadata: md, + codec: Codec::Duration, + }, + Some("uuid") if fx.size == 16 => AvroDataType { + nullability: None, + metadata: md, + codec: Codec::Uuid, + }, + _ => fixed_fallback(md, size), }; - resolver.register(f.name, namespace, field.clone()); - Ok(field) + resolver.register(fx.name, namespace, dt.clone()); + Ok(dt) } - ComplexType::Enum(e) => Err(ArrowError::NotYetImplemented(format!( - "Enum of {e:?} not currently supported" - ))), - ComplexType::Map(m) => Err(ArrowError::NotYetImplemented(format!( - "Map of {m:?} not currently supported" - ))), }, Schema::Type(t) => { - let mut field = + let mut dt = make_data_type(&Schema::TypeName(t.r#type.clone()), namespace, resolver)?; - - // https://avro.apache.org/docs/1.11.1/specification/#logical-types - match (t.attributes.logical_type, &mut field.codec) { - (Some("decimal"), c @ Codec::Fixed(_)) => { - return Err(ArrowError::NotYetImplemented( - "Decimals are not currently supported".to_string(), - )) - } - (Some("date"), c @ Codec::Int32) => *c = Codec::Date32, - (Some("time-millis"), c @ Codec::Int32) => *c = Codec::TimeMillis, - (Some("time-micros"), c @ Codec::Int64) => *c = Codec::TimeMicros, - (Some("timestamp-millis"), c @ Codec::Int64) => *c = Codec::TimestampMillis(true), - (Some("timestamp-micros"), c @ Codec::Int64) => *c = Codec::TimestampMicros(true), - (Some("local-timestamp-millis"), c @ Codec::Int64) => { - *c = Codec::TimestampMillis(false) - } - (Some("local-timestamp-micros"), c @ Codec::Int64) => { - *c = Codec::TimestampMicros(false) - } - (Some("duration"), c @ Codec::Fixed(12)) => *c = Codec::Interval, - (Some(logical), _) => { - // Insert unrecognized logical type into metadata map - field.metadata.insert("logicalType".into(), logical.into()); + match (t.attributes.logical_type, &mut dt.codec) { + (Some("decimal"), Codec::Fixed(size)) => { + let (precision, scale, size_opt) = + parse_decimal_attributes(&t.attributes, Some(*size as usize), false)?; + if let Some(sz_actual) = size_opt { + *size = sz_actual as i32; + } + dt.codec = Codec::Decimal(precision, Some(scale), Some(*size as usize)); + } + (Some("decimal"), Codec::Binary) => { + let (precision, scale, _) = parse_decimal_attributes(&t.attributes, None, false)?; + dt.codec = Codec::Decimal(precision, Some(scale), None); + } + (Some("uuid"), Codec::String) => { + dt.codec = Codec::Uuid; + } + (Some("date"), Codec::Int32) => { + dt.codec = Codec::Date32; + } + (Some("time-millis"), Codec::Int32) => { + dt.codec = Codec::TimeMillis; + } + (Some("time-micros"), Codec::Int64) => { + dt.codec = Codec::TimeMicros; + } + (Some("timestamp-millis"), Codec::Int64) => { + dt.codec = Codec::TimestampMillis(true); + } + (Some("timestamp-micros"), Codec::Int64) => { + dt.codec = Codec::TimestampMicros(true); + } + (Some("local-timestamp-millis"), Codec::Int64) => { + dt.codec = Codec::TimestampMillis(false); + } + (Some("local-timestamp-micros"), Codec::Int64) => { + dt.codec = Codec::TimestampMicros(false); + } + (Some("duration"), Codec::Fixed(12)) => { + dt.codec = Codec::Duration; + } + (Some(other), _) => { + if !dt.metadata.contains_key("logicalType") { + dt.metadata.insert("logicalType".into(), other.into()); + } } (None, _) => {} } + for (k, v) in &t.attributes.additional { + dt.metadata.insert(k.to_string(), v.to_string()); + } + Ok(dt) + } + } +} + +fn fixed_fallback(md: HashMap, size: i32) -> AvroDataType { + AvroDataType { + nullability: None, + metadata: md, + codec: Codec::Fixed(size), + } +} + +fn copy_metadata_to_attributes( + source: &HashMap, + target: &mut Attributes, +) { + for (k, v) in source { + if k == "type" + || k == "name" + || k == "fields" + || k == "aliases" + || k == "namespace" + || k == "doc" + || k.starts_with("avro.") + { + continue; + } + // For "precision" or "scale", try parsing as an integer. + let maybe_parsed_value = if k == "precision" || k == "scale" { + match v.parse::() { + Ok(parsed_int) => serde_json::Value::Number(parsed_int.into()), + Err(_) => serde_json::Value::String(v.clone()), + } + } else { + serde_json::Value::String(v.clone()) + }; + target.additional.insert( + Box::leak(k.clone().into_boxed_str()), + maybe_parsed_value, + ); + } +} + +fn parse_decimal_attributes( + attributes: &Attributes, + fallback_size: Option, + precision_required: bool, +) -> Result<(usize, usize, Option), ArrowError> { + let precision = attributes + .additional + .get("precision") + .and_then(|v| v.as_u64()) + .or(if precision_required { None } else { Some(10) }) + .ok_or_else(|| ArrowError::ParseError("Decimal requires precision".to_string()))? + as usize; + let scale = attributes + .additional + .get("scale") + .and_then(|v| v.as_u64()) + .unwrap_or(0) as usize; + let size = attributes + .additional + .get("size") + .and_then(|v| v.as_u64()) + .map(|s| s as usize) + .or(fallback_size); + Ok((precision, scale, size)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::schema::{Schema, ComplexType}; + use arrow_schema::{ArrowError, DataType, Field, TimeUnit, Schema as ArrowSchema}; + use serde_json::json; + use std::collections::HashMap; + use std::sync::Arc; + + fn arrow_schema_round_trip(schema: &ArrowSchema) -> Result { + let record_name = schema + .metadata() + .get("avro.record.name") + .cloned() + .unwrap_or_else(|| "topLevelRecord".to_string()); + let record_namespace = schema + .metadata() + .get("avro.record.namespace").map(|s| s.as_str()); + let avro_schema = SchemaBuilder::from(schema) + .with_impala_mode(&false) + .with_name(record_name.as_str()) + .with_namespace(record_namespace) + .finish()?; + let mut resolver = Resolver::default(); + let avro_dt = make_data_type(&avro_schema, None, &mut resolver)?; + Ok(avro_dt) + } + + fn single_field_codec(avro_dt: &AvroDataType) -> &Codec { + match &avro_dt.codec { + Codec::Record(fields) => { + if fields.len() != 1 { + panic!("Expected exactly 1 field in record, got {}", fields.len()); + } + &fields[0].data_type().codec + } + other => panic!("Expected top-level record, got {other:?}"), + } + } + + #[test] + fn test_field_to_schema_uuid() { + let mut md = HashMap::new(); + md.insert("logicalType".to_string(), "uuid".to_string()); + let arrow_field = + Field::new("uuid_col", DataType::FixedSizeBinary(16), false).with_metadata(md); + let arrow_schema = ArrowSchema::new(vec![arrow_field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema).unwrap(); + match &avro_dt.codec { + Codec::Record(fields) => { + assert_eq!(fields.len(), 1); + let col = &fields[0]; + assert_eq!(col.name(), "uuid_col"); + match col.data_type().codec { + Codec::Uuid => {} + ref other => panic!("Expected Codec::Uuid, got {other:?}"), + } + } + ref other => panic!("Expected top-level record, got {other:?}"), + } + } + + #[test] + fn test_field_to_schema_duration() -> Result<(), ArrowError> { + let mut md = HashMap::new(); + md.insert("logicalType".to_string(), "duration".to_string()); + let arrow_field = Field::new("duration_col", DataType::FixedSizeBinary(12), true) + .with_metadata(md); + let arrow_schema = ArrowSchema::new(vec![arrow_field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let f0 = match &avro_dt.codec { + Codec::Record(fields) => &fields[0], + other => panic!("Expected record, got {other:?}"), + }; + assert_eq!(f0.name(), "duration_col"); + match f0.data_type().codec { + Codec::Duration => {} + ref other => panic!("Expected Codec::Duration, got {other:?}"), + }; + assert_eq!(f0.data_type().nullability, Some(Nullability::NullFirst)); + Ok(()) + } + + #[test] + fn test_field_to_schema_enum_dictionary_with_symbols() -> Result<(), ArrowError> { + let mut md = HashMap::new(); + md.insert("avro.enum.symbols".to_string(), r#"["RED","GREEN","BLUE"]"#.to_string()); + let dict_type = DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8)); + let arrow_field = Field::new("enum_col", dict_type, false).with_metadata(md); + let arrow_schema = ArrowSchema::new(vec![arrow_field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let codec = single_field_codec(&avro_dt); + match codec { + Codec::Enum(symbols, _defaults) => { + assert_eq!(symbols.len(), 3); + assert_eq!(symbols[0], "RED"); + assert_eq!(symbols[1], "GREEN"); + assert_eq!(symbols[2], "BLUE"); + } + other => panic!("Expected Codec::Enum, got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_field_to_schema_enum_dictionary_no_symbols() -> Result<(), ArrowError> { + let dict_type = DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let arrow_field = Field::new("maybe_enum_col", dict_type, true); + let arrow_schema = ArrowSchema::new(vec![arrow_field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let codec = single_field_codec(&avro_dt); + assert!(matches!(codec, Codec::String)); + Ok(()) + } + + #[test] + fn test_field_to_schema_date32() -> Result<(), ArrowError> { + let arrow_field = Field::new("d32", DataType::Date32, false); + let arrow_schema = ArrowSchema::new(vec![arrow_field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let codec = single_field_codec(&avro_dt); + match codec { + Codec::Date32 => {} + other => panic!("Expected Codec::Date32, got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_field_to_schema_time_millis() -> Result<(), ArrowError> { + let arrow_field = Field::new("tmillis", DataType::Time32(TimeUnit::Millisecond), true); + let arrow_schema = ArrowSchema::new(vec![arrow_field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let f0_codec = single_field_codec(&avro_dt); + match f0_codec { + Codec::TimeMillis => {} + other => panic!("Expected Codec::TimeMillis, got {other:?}"), + } + let f0 = match &avro_dt.codec { + Codec::Record(fields) => &fields[0], + other => panic!("Expected record, got {other:?}"), + }; + assert_eq!(f0.data_type().nullability, Some(Nullability::NullFirst)); + Ok(()) + } + + #[test] + fn test_field_to_schema_time_micros() -> Result<(), ArrowError> { + let arrow_field = Field::new("tmicros", DataType::Time64(TimeUnit::Microsecond), false); + let arrow_schema = ArrowSchema::new(vec![arrow_field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let codec = single_field_codec(&avro_dt); + match codec { + Codec::TimeMicros => {} + other => panic!("Expected Codec::TimeMicros, got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_field_to_schema_timestamp_millis_utc() -> Result<(), ArrowError> { + let arrow_field = Field::new( + "tsmillis_utc", + DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("+00:00"))), + true, + ); + let arrow_schema = ArrowSchema::new(vec![arrow_field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let f0_codec = single_field_codec(&avro_dt); + match f0_codec { + Codec::TimestampMillis(is_utc) => assert!(*is_utc), + other => panic!("Expected Codec::TimestampMillis(true), got {other:?}"), + } + let f0 = match &avro_dt.codec { + Codec::Record(fields) => &fields[0], + _ => panic!("Expected record"), + }; + assert_eq!(f0.data_type().nullability, Some(Nullability::NullFirst)); + Ok(()) + } + + #[test] + fn test_field_to_schema_timestamp_micros_local() -> Result<(), ArrowError> { + let arrow_field = Field::new( + "tsmicros_local", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ); + let arrow_schema = ArrowSchema::new(vec![arrow_field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let codec = single_field_codec(&avro_dt); + match codec { + Codec::TimestampMicros(is_utc) => assert!(!*is_utc), + other => panic!("Expected Codec::TimestampMicros(false), got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_arrow_map_to_avro_schema() -> Result<(), ArrowError> { + let key_field = Field::new("key", DataType::Utf8, false); + let value_field = Field::new("value", DataType::Int32, true); + let entries_struct_field = Field::new( + "entries", + DataType::Struct(vec![key_field.clone(), value_field.clone()].into()), + false, + ); + let map_field = Field::new("my_map", DataType::Map(Arc::new(entries_struct_field), false), true); + let arrow_schema = ArrowSchema::new(vec![map_field.clone()]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let f0 = match &avro_dt.codec { + Codec::Record(fields) => &fields[0], + other => panic!("Expected a record, got {other:?}"), + }; + match f0.data_type().codec { + Codec::Map(ref val_type) => { + match val_type.codec { + Codec::Int32 => {} + ref other => panic!("Expected int or union, got {other:?}"), + } + } + ref other => panic!("Expected a Map codec, got {other:?}"), + } + let avro_sch = SchemaBuilder::from(&arrow_schema) + .with_impala_mode(&false) + .finish()?; + if let Schema::Complex(ComplexType::Record(r)) = avro_sch { + assert_eq!(r.fields.len(), 1); + let top_f = &r.fields[0]; + if let Schema::Union(u) = &top_f.r#type { + assert_eq!(u.len(), 2); + } else { + panic!("Expected union for a nullable field"); + } + } else { + panic!("Expected record"); + } + Ok(()) + } + + #[test] + fn test_avro_map_round_trip() -> Result<(), ArrowError> { + let key_field = Field::new("key", DataType::Utf8, false); + let value_field = Field::new("value", DataType::Int32, false); + let entries_struct = Field::new( + "entries", + DataType::Struct(vec![key_field, value_field].into()), + false, + ); + let map_field = Field::new( + "example_map", + DataType::Map(Arc::new(entries_struct), false), + false, + ); + let arrow_schema = ArrowSchema::new(vec![map_field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let f0 = match &avro_dt.codec { + Codec::Record(fields) => &fields[0], + other => panic!("Expected a record, got {other:?}"), + }; + match &f0.data_type().codec { + Codec::Map(val_type) => match val_type.codec { + Codec::Int32 => { /* as expected */ } + ref other => panic!("Unexpected map value type: {:?}", other), + }, + other => panic!("Expected map codec, got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_field_to_schema_list_of_int() -> Result<(), ArrowError> { + let item_field = Field::new("item", DataType::Int32, false); + let list_field = Field::new("list_col", DataType::List(Arc::new(item_field)), false); + let arrow_schema = ArrowSchema::new(vec![list_field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let child_codec = match &avro_dt.codec { + Codec::Record(fields) => &fields[0].data_type().codec, + other => panic!("Expected record, got {other:?}"), + }; + match child_codec { + Codec::Array(child_at) => match child_at.codec { + Codec::Int32 => {} + ref other => panic!("Expected child=Int32, got {other:?}"), + }, + other => panic!("Expected Codec::Array(...), got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_field_to_schema_fixedsizelist_of_strings_nullable() { + let item_field = Field::new("sub", DataType::Utf8, true); + let fsl_field = Field::new("fsl_col", DataType::FixedSizeList(Arc::new(item_field), 2), true); + let arrow_schema = ArrowSchema::new(vec![fsl_field]); + let avro_sch = SchemaBuilder::from(&arrow_schema) + .with_impala_mode(&false) + .finish().unwrap(); + let mut resolver = Resolver::default(); + let dt = make_data_type(&avro_sch, None, &mut resolver).unwrap(); + match &dt.codec { + Codec::Record(fields) => { + assert_eq!(fields.len(), 1); + let child_dt = &fields[0].data_type().codec; + match child_dt { + Codec::Array(child2) => { + assert!(matches!(child2.codec, Codec::String | Codec::Record(_))); + } + other => panic!("Expected array for fixedSizeList => {other:?}"), + } + } + other => panic!("Expected record => {other:?}"), + } + } + + #[test] + fn test_field_to_schema_record_simple() { + let child_a = Field::new("child_a", DataType::Int32, false); + let mut md_b = HashMap::new(); + md_b.insert("avro.default".to_string(), "true".to_string()); + let child_b = Field::new("child_b", DataType::Boolean, false).with_metadata(md_b); + let struct_type = DataType::Struct(vec![child_a.clone(), child_b.clone()].into()); + let top_field = Field::new("my_struct", struct_type, false); + let arrow_schema = ArrowSchema::new(vec![top_field]); + let avro_sch = SchemaBuilder::from(&arrow_schema) + .with_impala_mode(&false) + .finish().unwrap(); + let mut resolver = Resolver::default(); + let dt = make_data_type(&avro_sch, None, &mut resolver).unwrap(); + match &dt.codec { + Codec::Record(fields) => { + assert_eq!(fields.len(), 1); + let struct_avro = &fields[0]; + assert_eq!(struct_avro.name(), "my_struct"); + match &struct_avro.data_type().codec { + Codec::Record(child_fields) => { + assert_eq!(child_fields.len(), 2); + assert_eq!(child_fields[0].name(), "child_a"); + match child_fields[0].data_type().codec { + Codec::Int32 => {} + ref other => panic!("Expected Int32 for child_a => {other:?}"), + } + assert_eq!(child_fields[1].name(), "child_b"); + match child_fields[1].data_type().codec { + Codec::Boolean => {} + ref other => panic!("Expected Boolean for child_b => {other:?}"), + } + if let Some(def_val) = &child_fields[1].default { + assert_eq!(def_val, &json!(true)); + } else { + panic!("Expected default=true for child_b"); + } + } + ref other => panic!("Expected inner Codec::Record => {other:?}"), + } + } + other => panic!("Expected top-level record => {other:?}"), + } + } + + #[test] + fn test_decimal_arrow_field_to_schema() -> Result<(), ArrowError> { + let field = Field::new("decimal_col", DataType::Decimal128(10, 2), false); + let arrow_schema = ArrowSchema::new(vec![field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let f0 = match &avro_dt.codec { + Codec::Record(fields) => &fields[0], + other => panic!("Expected record, got {other:?}"), + }; + match f0.data_type().codec { + Codec::Decimal(prec, sc, sz) => { + assert_eq!(prec, 10); + assert_eq!(sc, Some(2)); + assert_eq!(sz, Some(16), "Default for decimal128 => 16 bytes"); + } + ref other => panic!("Expected decimal, got {other:?}"), + } + let avro_sch = SchemaBuilder::from(&arrow_schema) + .with_impala_mode(&false) + .finish()?; + match avro_sch { + Schema::Complex(ComplexType::Record(r)) => { + assert_eq!(r.fields.len(), 1); + let df = &r.fields[0]; + match &df.r#type { + Schema::Complex(ComplexType::Fixed(fx)) => { + assert_eq!(fx.size, 16); + let lt = fx.attributes.logical_type; + assert_eq!(lt, Some("decimal")); + let extra = &fx.attributes.additional; + let prec_val = extra.get("precision").unwrap(); + let scale_val = extra.get("scale").unwrap(); + assert_eq!(prec_val, &serde_json::Value::Number(10.into())); + assert_eq!(scale_val, &serde_json::Value::Number(2.into())); + } + _ => panic!("Expected a fixed decimal schema"), + } + } + other => panic!("Expected a record, got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_field_to_schema_boolean() -> Result<(), ArrowError> { + let field = Field::new("bool_col", DataType::Boolean, false); + let arrow_schema = ArrowSchema::new(vec![field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let f0_codec = single_field_codec(&avro_dt); + assert!(matches!(f0_codec, Codec::Boolean)); + Ok(()) + } + + #[test] + fn test_field_to_schema_int32() -> Result<(), ArrowError> { + let field = Field::new("int_col", DataType::Int32, true); + let arrow_schema = ArrowSchema::new(vec![field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let f0 = match &avro_dt.codec { + Codec::Record(fields) => &fields[0], + _ => panic!("Expected record"), + }; + match f0.data_type().codec { + Codec::Int32 => {} + ref other => panic!("Expected Codec::Int32, got {other:?}"), + } + assert_eq!(f0.data_type().nullability, Some(Nullability::NullFirst)); + Ok(()) + } + + #[test] + fn test_field_to_schema_int64() -> Result<(), ArrowError> { + let field = Field::new("long_col", DataType::Int64, false); + let arrow_schema = ArrowSchema::new(vec![field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let c0 = single_field_codec(&avro_dt); + assert!(matches!(c0, Codec::Int64)); + Ok(()) + } + + #[test] + fn test_field_to_schema_float32() -> Result<(), ArrowError> { + let field = Field::new("float_col", DataType::Float32, false); + let arrow_schema = ArrowSchema::new(vec![field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + match single_field_codec(&avro_dt) { + Codec::Float32 => {} + ref other => panic!("Expected float32, got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_field_to_schema_float64() -> Result<(), ArrowError> { + let field = Field::new("double_col", DataType::Float64, false); + let arrow_schema = ArrowSchema::new(vec![field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + match single_field_codec(&avro_dt) { + Codec::Float64 => {} + ref other => panic!("Expected float64, got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_field_to_schema_binary() -> Result<(), ArrowError> { + let field = Field::new("bin_col", DataType::Binary, true); + let arrow_schema = ArrowSchema::new(vec![field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + let f0 = match &avro_dt.codec { + Codec::Record(fields) => &fields[0], + other => panic!("Expected record, got {other:?}"), + }; + match f0.data_type().codec { + Codec::Binary => {} + ref other => panic!("Expected Codec::Binary, got {other:?}"), + } + assert_eq!(f0.data_type().nullability, Some(Nullability::NullFirst)); + Ok(()) + } + + #[test] + fn test_field_to_schema_string() -> Result<(), ArrowError> { + let field = Field::new("str_col", DataType::Utf8, false); + let arrow_schema = ArrowSchema::new(vec![field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + match single_field_codec(&avro_dt) { + Codec::String => {} + ref other => panic!("Expected string, got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_field_to_schema_fixedsizebinary() -> Result<(), ArrowError> { + let field = Field::new("fixed8_col", DataType::FixedSizeBinary(8), false); + let arrow_schema = ArrowSchema::new(vec![field]); + let avro_dt = arrow_schema_round_trip(&arrow_schema)?; + match single_field_codec(&avro_dt) { + Codec::Fixed(sz) => { + assert_eq!(*sz, 8); + } + other => panic!("Expected Codec::Fixed(8), got {other:?}"), + } + Ok(()) + } - if !t.attributes.additional.is_empty() { - for (k, v) in &t.attributes.additional { - field.metadata.insert(k.to_string(), v.to_string()); + #[test] + fn test_arrow_schema_to_avro_schema_all_supported() -> Result<(), ArrowError> { + let arrow_schema = ArrowSchema::new(vec![ + Field::new("bool_col", DataType::Boolean, false), + Field::new("int_col", DataType::Int32, true), + Field::new("long_col", DataType::Int64, false), + Field::new("float_col", DataType::Float32, false), + Field::new("double_col", DataType::Float64, true), + Field::new("bin_col", DataType::Binary, true), + Field::new("str_col", DataType::Utf8, false), + Field::new("fixed4_col", DataType::FixedSizeBinary(4), true), + ]); + let avro_sch = SchemaBuilder::from(&arrow_schema) + .with_impala_mode(&false) + .finish()?; + let mut resolver = Resolver::default(); + let top_dt = make_data_type(&avro_sch, None, &mut resolver)?; + match &top_dt.codec { + Codec::Record(fields) => { + assert_eq!(fields.len(), 8); + match fields[0].data_type().codec { + Codec::Boolean => {} + ref other => panic!("Expected bool => {other:?}"), + } + match fields[1].data_type().codec { + Codec::Int32 => {} + ref other => panic!("Expected int => {other:?}"), + } + assert_eq!(fields[1].data_type().nullability, Some(Nullability::NullFirst)); + match fields[2].data_type().codec { + Codec::Int64 => {} + ref other => panic!("Expected long => {other:?}"), + } + match fields[3].data_type().codec { + Codec::Float32 => {} + ref other => panic!("Expected float => {other:?}"), + } + match fields[4].data_type().codec { + Codec::Float64 => {} + ref other => panic!("Expected double => {other:?}"), + } + assert_eq!(fields[4].data_type().nullability, Some(Nullability::NullFirst)); + match fields[5].data_type().codec { + Codec::Binary => {} + ref other => panic!("Expected bytes => {other:?}"), + } + match fields[6].data_type().codec { + Codec::String => {} + ref other => panic!("Expected string => {other:?}"), + } + match fields[7].data_type().codec { + Codec::Fixed(sz) => { + assert_eq!(sz, 4); + } + ref other => panic!("Expected fixed => {other:?}"), + } + assert_eq!( + fields[7].data_type().nullability, + Some(Nullability::NullFirst) + ); + } + ref other => panic!("Expected top-level record => {other:?}"), + } + Ok(()) + } + + #[test] + fn test_skip_avro_default_null_in_metadata() { + let dt = AvroDataType::from_codec(Codec::Int32); + let field = AvroField { + name: "test_col".into(), + data_type: dt, + default: Some(json!(null)), + }; + let arrow_field = field.field(); + assert!(arrow_field.metadata().get("avro.default").is_none()); + } + + #[test] + fn test_store_avro_default_nonnull_in_metadata() { + let dt = AvroDataType::from_codec(Codec::Int32); + let field = AvroField { + name: "test_col".into(), + data_type: dt, + default: Some(json!(42)), + }; + let arrow_field = field.field(); + let metadata = arrow_field.metadata(); + let got = metadata.get("avro.default").cloned(); + assert_eq!(got, Some("42".to_string())); + } + + #[test] + fn test_no_default_metadata_if_none() { + let dt = AvroDataType::from_codec(Codec::String); + let field = AvroField { + name: "col".to_string(), + data_type: dt, + default: None, + }; + let arrow_field = field.field(); + assert!(arrow_field.metadata().get("avro.default").is_none()); + } + + #[test] + fn test_avro_field() { + let field_codec = AvroDataType::from_codec(Codec::Int64); + let avro_field = AvroField { + name: "long_col".to_string(), + data_type: field_codec.clone(), + default: None, + }; + assert_eq!(avro_field.name(), "long_col"); + let arrow_field = avro_field.field(); + assert_eq!(arrow_field.name(), "long_col"); + assert_eq!(arrow_field.data_type(), &DataType::Int64); + assert!(!arrow_field.is_nullable()); + } + + #[test] + fn test_avro_field_with_default() { + let field_codec = AvroDataType::from_codec(Codec::Int32); + let default_value = json!(123); + let avro_field = AvroField { + name: "int_col".to_string(), + data_type: field_codec.clone(), + default: Some(default_value.clone()), + }; + let arrow_field = avro_field.field(); + let metadata = arrow_field.metadata(); + assert_eq!( + metadata.get("avro.default").unwrap(), + &default_value.to_string() + ); + } + + #[test] + fn test_codec_fixedsizebinary() { + let codec = Codec::Fixed(12); + let dt = codec.data_type(); + match dt { + DataType::FixedSizeBinary(n) => assert_eq!(n, 12), + _ => panic!("Expected FixedSizeBinary(12)"), + } + } + + #[test] + fn test_union_long_null() -> Result<(), ArrowError> { + let json_schema = r#" + { + "type": "record", + "name": "test_long_null", + "fields": [ + {"name": "f0", "type": ["long", "null"]} + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_field = AvroField::try_from(&schema)?; + match &avro_field.data_type().codec { + Codec::Record(fields) => { + assert_eq!(fields.len(), 1); + assert_eq!(fields[0].name(), "f0"); + let child_dt = fields[0].data_type(); + assert_eq!(child_dt.nullability, Some(Nullability::NullSecond)); + assert!(matches!(child_dt.codec, Codec::Int64)); + } + _ => panic!("Expected record with a single [long,null] field"), + } + Ok(()) + } + + #[test] + fn test_union_array_of_int_null() -> Result<(), ArrowError> { + let json_schema = r#" + { + "type":"record", + "name":"test_array_int_null", + "fields":[ + {"name":"arr","type":[{"type":"array","items":["int","null"]},"null"]} + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_field = AvroField::try_from(&schema)?; + match &avro_field.data_type().codec { + Codec::Record(fields) => { + assert_eq!(fields.len(), 1); + let arr_dt = fields[0].data_type(); + assert_eq!(arr_dt.nullability, Some(Nullability::NullSecond)); + match &arr_dt.codec { + Codec::Array(child_dt) => { + assert_eq!(child_dt.nullability, Some(Nullability::NullSecond)); + assert!(matches!(child_dt.codec, Codec::Int32)); + } + other => panic!("Expected Array, got {other:?}"), + } + } + other => panic!("Expected record, got {other:?}"), + } + Ok(()) + } + + #[test] + fn test_union_nested_array_of_int_null() -> Result<(), ArrowError> { + let json_schema = r#" + { + "type":"record", + "name":"test_nested_array_int_null", + "fields":[ + { + "name":"nested_arr", + "type":[ + { + "type":"array", + "items":[ + { + "type":"array", + "items":["int","null"] + }, + "null" + ] + }, + "null" + ] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_field = AvroField::try_from(&schema)?; + match &avro_field.data_type().codec { + Codec::Record(fields) => { + assert_eq!(fields.len(), 1); + let outer = fields[0].data_type(); + assert_eq!(outer.nullability, Some(Nullability::NullSecond)); + match &outer.codec { + Codec::Array(mid) => { + assert_eq!(mid.nullability, Some(Nullability::NullSecond)); + match &mid.codec { + Codec::Array(inner) => { + assert_eq!(inner.nullability, Some(Nullability::NullSecond)); + assert!(matches!(inner.codec, Codec::Int32)); + } + other => panic!("Expected inner array => {other:?}"), + } + } + other => panic!("Expected outer array => {other:?}"), + } + } + other => panic!("Expected record => {other:?}"), + } + Ok(()) + } + + #[test] + fn test_union_map_of_int_null() -> Result<(), ArrowError> { + let json_schema = r#" + { + "type":"record", + "name":"test_map_int_null", + "fields":[ + {"name":"map_field","type":[{"type":"map","values":["int","null"]},"null"]} + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_field = AvroField::try_from(&schema)?; + match &avro_field.data_type().codec { + Codec::Record(fields) => { + let map_dt = fields[0].data_type(); + assert_eq!(map_dt.nullability, Some(Nullability::NullSecond)); + match map_dt.codec { + Codec::Map(ref val_dt) => { + assert_eq!(val_dt.nullability, Some(Nullability::NullSecond)); + assert!(matches!(val_dt.codec, Codec::Int32)); + } + ref other => panic!("Expected Map => {other:?}"), + } + } + other => panic!("Expected record => {other:?}"), + } + Ok(()) + } + + #[test] + fn test_union_map_array_of_int_null() -> Result<(), ArrowError> { + let json_schema = r#" + { + "type":"record", + "name":"test_map_array_int_null", + "fields":[ + { + "name":"map_arr", + "type":[ + { + "type":"array", + "items":[ + { + "type":"map", + "values":["int","null"] + }, + "null" + ] + }, + "null" + ] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_field = AvroField::try_from(&schema)?; + match &avro_field.data_type().codec { + Codec::Record(fields) => { + let outer_dt = fields[0].data_type(); + assert_eq!(outer_dt.nullability, Some(Nullability::NullSecond)); + match &outer_dt.codec { + Codec::Array(map_dt) => { + assert_eq!(map_dt.nullability, Some(Nullability::NullSecond)); + match &map_dt.codec { + Codec::Map(val_dt) => { + assert_eq!(val_dt.nullability, Some(Nullability::NullSecond)); + assert!(matches!(val_dt.codec, Codec::Int32)); + } + other => panic!("Expected map => {other:?}"), + } + } + other => panic!("Expected array => {other:?}"), + } + } + other => panic!("Expected record => {other:?}"), + } + Ok(()) + } + + #[test] + fn test_union_nested_struct_out_of_spec() -> Result<(), ArrowError> { + let json_schema = r#" + { + "type":"record","name":"topLevelRecord","fields":[ + {"name":"nested_struct","type":[ + { + "type":"record", + "name":"nested_struct", + "namespace":"topLevelRecord", + "fields":[ + {"name":"A","type":["int","null"]}, + {"name":"b","type":[{"type":"array","items":["int","null"]},"null"]} + ] + }, + "null" + ]} + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_field = AvroField::try_from(&schema)?; + match &avro_field.data_type().codec { + Codec::Record(fields) => { + assert_eq!(fields.len(), 1); + let nested_dt = fields[0].data_type(); + assert_eq!(nested_dt.nullability, Some(Nullability::NullSecond)); + match nested_dt.codec { + Codec::Record(ref subfields) => { + assert_eq!(subfields.len(), 2); + let f_a = &subfields[0]; + assert_eq!(f_a.data_type().nullability, Some(Nullability::NullSecond)); + assert!(matches!(f_a.data_type().codec, Codec::Int32)); + } + ref other => panic!("Expected record => {other:?}"), } } - Ok(field) + other => panic!("Expected top-level record => {other:?}"), } + Ok(()) } } diff --git a/arrow-avro/src/compression.rs b/arrow-avro/src/compression.rs index f29b8dd07606..605f08c5bd25 100644 --- a/arrow-avro/src/compression.rs +++ b/arrow-avro/src/compression.rs @@ -16,20 +16,30 @@ // under the License. use arrow_schema::ArrowError; -use std::io; use std::io::Read; /// The metadata key used for storing the JSON encoded [`CompressionCodec`] pub const CODEC_METADATA_KEY: &str = "avro.codec"; #[derive(Debug, Copy, Clone, Eq, PartialEq)] +/// CompressionCodec includes the enumerated types for each supported compression +/// type pub enum CompressionCodec { + /// Deflate - compression Deflate, + /// Snappy - compression Snappy, + /// ZStandard - compression ZStandard, + /// Bzip2 - compression + Bzip2, + /// Xz - compression + Xz, } impl CompressionCodec { + /// Decompress an Avro block that was encoded with this codec. + /// Used by the **reader** to decode block data from an Avro container file. pub(crate) fn decompress(&self, block: &[u8]) -> Result, ArrowError> { match self { #[cfg(feature = "deflate")] @@ -43,16 +53,19 @@ impl CompressionCodec { CompressionCodec::Deflate => Err(ArrowError::ParseError( "Deflate codec requires deflate feature".to_string(), )), + #[cfg(feature = "snappy")] CompressionCodec::Snappy => { - // Each compressed block is followed by the 4-byte, big-endian CRC32 - // checksum of the uncompressed data in the block. + if block.len() < 4 { + return Err(ArrowError::ParseError( + "Snappy block too short to contain trailing crc".to_string(), + )); + } let crc = &block[block.len() - 4..]; - let block = &block[..block.len() - 4]; - + let block_data = &block[..block.len() - 4]; let mut decoder = snap::raw::Decoder::new(); let decoded = decoder - .decompress_vec(block) + .decompress_vec(block_data) .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; let checksum = crc::Crc::::new(&crc::CRC_32_ISO_HDLC).checksum(&decoded); @@ -77,6 +90,126 @@ impl CompressionCodec { CompressionCodec::ZStandard => Err(ArrowError::ParseError( "ZStandard codec requires zstd feature".to_string(), )), + + #[cfg(feature = "bzip2")] + CompressionCodec::Bzip2 => { + let mut decoder = bzip2::read::BzDecoder::new(block); + let mut out = Vec::new(); + decoder.read_to_end(&mut out)?; + Ok(out) + } + #[cfg(not(feature = "bzip2"))] + CompressionCodec::Bzip2 => Err(ArrowError::ParseError( + "Bzip2 codec requires bzip2 feature".to_string(), + )), + + #[cfg(feature = "xz")] + CompressionCodec::Xz => { + let mut decoder = xz::read::XzDecoder::new(block); + let mut out = Vec::new(); + decoder.read_to_end(&mut out)?; + Ok(out) + } + #[cfg(not(feature = "xz"))] + CompressionCodec::Xz => Err(ArrowError::ParseError( + "XZ codec requires xz feature".to_string(), + )), + } + } + + /// Compress a block using this Avro codec. + /// Used by the **writer** to encode block data before writing it. + /// + /// Snappy: Avro requires a 4-byte big-endian CRC32 of the *uncompressed* data appended. + pub(crate) fn compress_block(&self, data: &[u8]) -> Result, ArrowError> { + match self { + #[cfg(feature = "deflate")] + CompressionCodec::Deflate => { + use flate2::{write::DeflateEncoder, Compression}; + use std::io::Write; + let mut encoder = DeflateEncoder::new(Vec::new(), Compression::default()); + encoder + .write_all(data) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + let compressed = encoder + .finish() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + Ok(compressed) + } + #[cfg(not(feature = "deflate"))] + CompressionCodec::Deflate => Err(ArrowError::ParseError( + "Deflate codec requires deflate feature".to_string(), + )), + + #[cfg(feature = "snappy")] + CompressionCodec::Snappy => { + let mut encoder = snap::raw::Encoder::new(); + let compressed = encoder + .compress_vec(data) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + let crc = crc::Crc::::new(&crc::CRC_32_ISO_HDLC).checksum(data); + let mut out = Vec::with_capacity(compressed.len() + 4); + out.extend_from_slice(&compressed); + out.extend_from_slice(&crc.to_be_bytes()); + Ok(out) + } + #[cfg(not(feature = "snappy"))] + CompressionCodec::Snappy => Err(ArrowError::ParseError( + "Snappy codec requires snappy feature".to_string(), + )), + + #[cfg(feature = "zstd")] + CompressionCodec::ZStandard => { + use std::io::Write; + let mut encoder = zstd::Encoder::new(Vec::new(), 0) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + encoder + .write_all(data) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + let compressed = encoder + .finish() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + Ok(compressed) + } + #[cfg(not(feature = "zstd"))] + CompressionCodec::ZStandard => Err(ArrowError::ParseError( + "ZStandard codec requires zstd feature".to_string(), + )), + + #[cfg(feature = "bzip2")] + CompressionCodec::Bzip2 => { + use std::io::Write; + use bzip2::{write::BzEncoder, Compression}; + let mut encoder = BzEncoder::new(Vec::new(), Compression::default()); + encoder + .write_all(data) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + let compressed = encoder + .finish() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + Ok(compressed) + } + #[cfg(not(feature = "bzip2"))] + CompressionCodec::Bzip2 => Err(ArrowError::ParseError( + "Bzip2 codec requires bzip2 feature".to_string(), + )), + + #[cfg(feature = "xz")] + CompressionCodec::Xz => { + use std::io::Write; + let mut encoder = xz::write::XzEncoder::new(Vec::new(), 6); + encoder + .write_all(data) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + let compressed = encoder + .finish() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + Ok(compressed) + } + #[cfg(not(feature = "xz"))] + CompressionCodec::Xz => Err(ArrowError::ParseError( + "XZ codec requires xz feature".to_string(), + )), } } } diff --git a/arrow-avro/src/lib.rs b/arrow-avro/src/lib.rs index d01d681b7af0..8a3915f26fd7 100644 --- a/arrow-avro/src/lib.rs +++ b/arrow-avro/src/lib.rs @@ -21,7 +21,6 @@ //! [Apache Avro]: https://avro.apache.org/ #![warn(missing_docs)] -#![allow(unused)] // Temporary pub mod reader; mod schema; @@ -29,6 +28,11 @@ mod schema; mod compression; mod codec; +pub mod writer; + +pub use self::reader::{Decoder, Reader, ReaderBuilder}; +pub use self::compression::{CompressionCodec}; + #[cfg(test)] mod test_util { diff --git a/arrow-avro/src/reader/block.rs b/arrow-avro/src/reader/block.rs index 479f0ef90909..43722da23938 100644 --- a/arrow-avro/src/reader/block.rs +++ b/arrow-avro/src/reader/block.rs @@ -86,7 +86,6 @@ impl BlockDecoder { "Block count cannot be negative, got {c}" )) })?; - self.state = BlockDecoderState::Size; } } @@ -114,15 +113,18 @@ impl BlockDecoder { } BlockDecoderState::Sync => { let to_decode = buf.len().min(self.bytes_remaining); - let write = &mut self.in_progress.sync[16 - to_decode..]; - write[..to_decode].copy_from_slice(&buf[..to_decode]); + let start = 16 - self.bytes_remaining; + let end = start + to_decode; + self.in_progress.sync[start..end].copy_from_slice(&buf[..to_decode]); self.bytes_remaining -= to_decode; buf = &buf[to_decode..]; if self.bytes_remaining == 0 { self.state = BlockDecoderState::Finished; } } - BlockDecoderState::Finished => return Ok(max_read - buf.len()), + BlockDecoderState::Finished => { + return Ok(max_read - buf.len()); + } } } Ok(max_read) @@ -139,3 +141,217 @@ impl BlockDecoder { } } } + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::ArrowError; + use std::convert::TryFrom; + + fn encode_vlq(value: i64) -> Vec { + let mut buf = vec![]; + let mut ux = ((value << 1) ^ (value >> 63)) as u64; // ZigZag + + loop { + let mut byte = (ux & 0x7F) as u8; + ux >>= 7; + if ux != 0 { + byte |= 0x80; + } + buf.push(byte); + if ux == 0 { + break; + } + } + buf + } + + #[test] + fn test_empty_input() { + let mut decoder = BlockDecoder::default(); + let buf = []; + let read = decoder.decode(&buf).unwrap(); + assert_eq!(read, 0); + assert!(decoder.flush().is_none()); + } + + #[test] + fn test_single_block_full_buffer() { + let mut decoder = BlockDecoder::default(); + let count_encoded = encode_vlq(10); + let size_encoded = encode_vlq(4); + let data = vec![1u8, 2, 3, 4]; + let sync_marker = vec![0xAB; 16]; + let mut input = Vec::new(); + input.extend_from_slice(&count_encoded); + input.extend_from_slice(&size_encoded); + input.extend_from_slice(&data); + input.extend_from_slice(&sync_marker); + let read = decoder.decode(&input).unwrap(); + assert_eq!(read, input.len()); + let block = decoder.flush().expect("Should produce a finished block"); + assert_eq!(block.count, 10); + assert_eq!(block.data, data); + let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); + assert_eq!(block.sync, expected_sync); + } + + #[test] + fn test_single_block_partial_buffer() { + let mut decoder = BlockDecoder::default(); + let count_encoded = encode_vlq(2); + let size_encoded = encode_vlq(3); + let data = vec![10u8, 20, 30]; + let sync_marker = vec![0xCD; 16]; + let mut input = Vec::new(); + input.extend_from_slice(&count_encoded); + input.extend_from_slice(&size_encoded); + input.extend_from_slice(&data); + input.extend_from_slice(&sync_marker); + // Split into 3 parts + let part1 = &input[0..1]; + let part2 = &input[1..2]; + let part3 = &input[2..]; + let read = decoder.decode(part1).unwrap(); + assert_eq!(read, 1); + assert!(decoder.flush().is_none()); + let read = decoder.decode(part2).unwrap(); + assert_eq!(read, 1); + assert!(decoder.flush().is_none()); + let read = decoder.decode(part3).unwrap(); + assert_eq!(read, part3.len()); + let block = decoder.flush().expect("Should produce a finished block"); + assert_eq!(block.count, 2); + assert_eq!(block.data, data); + let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); + assert_eq!(block.sync, expected_sync); + } + + #[test] + fn test_multiple_blocks_in_one_buffer() { + let mut decoder = BlockDecoder::default(); + // Block1 + let block1_count = encode_vlq(1); + let block1_size = encode_vlq(2); + let block1_data = vec![0x01, 0x02]; + let block1_sync = vec![0xAA; 16]; + // Block2 + let block2_count = encode_vlq(3); + let block2_size = encode_vlq(1); + let block2_data = vec![0x99]; + let block2_sync = vec![0xBB; 16]; + let mut input = Vec::new(); + input.extend_from_slice(&block1_count); + input.extend_from_slice(&block1_size); + input.extend_from_slice(&block1_data); + input.extend_from_slice(&block1_sync); + input.extend_from_slice(&block2_count); + input.extend_from_slice(&block2_size); + input.extend_from_slice(&block2_data); + input.extend_from_slice(&block2_sync); + let read1 = decoder.decode(&input).unwrap(); + let block1 = decoder.flush().expect("First block should be complete"); + assert_eq!(block1.count, 1); + assert_eq!(block1.data, block1_data); + let expected_sync1: [u8; 16] = <[u8; 16]>::try_from(&block1_sync[..16]).unwrap(); + assert_eq!(block1.sync, expected_sync1); + let remainder = &input[read1..]; + decoder.decode(remainder).unwrap(); + let block2 = decoder.flush().expect("Second block should be complete"); + assert_eq!(block2.count, 3); + assert_eq!(block2.data, block2_data); + let expected_sync2: [u8; 16] = <[u8; 16]>::try_from(&block2_sync[..16]).unwrap(); + assert_eq!(block2.sync, expected_sync2); + } + + #[test] + fn test_negative_count_should_error() { + let mut decoder = BlockDecoder::default(); + let bad_count = encode_vlq(-1); + let size = encode_vlq(5); + let mut input = Vec::new(); + input.extend_from_slice(&bad_count); + input.extend_from_slice(&size); + let err = decoder.decode(&input).unwrap_err(); + match err { + ArrowError::ParseError(msg) => { + assert!( + msg.contains("Block count cannot be negative"), + "Expected negative count parse error, got: {msg}" + ); + } + _ => panic!("Unexpected error type: {err:?}"), + } + } + + #[test] + fn test_negative_size_should_error() { + let mut decoder = BlockDecoder::default(); + let count = encode_vlq(5); + let bad_size = encode_vlq(-10); + let mut input = Vec::new(); + input.extend_from_slice(&count); + input.extend_from_slice(&bad_size); + let err = decoder.decode(&input).unwrap_err(); + match err { + ArrowError::ParseError(msg) => { + assert!( + msg.contains("Block size cannot be negative"), + "Expected negative size parse error, got: {msg}" + ); + } + _ => panic!("Unexpected error type: {err:?}"), + } + } + + #[test] + fn test_partial_sync_across_multiple_calls() { + let mut decoder = BlockDecoder::default(); + let count_encoded = encode_vlq(1); + let size_encoded = encode_vlq(2); + let data = vec![0x01, 0x02]; + let sync_marker = vec![0xCC; 16]; + let mut input = Vec::new(); + input.extend_from_slice(&count_encoded); + input.extend_from_slice(&size_encoded); + input.extend_from_slice(&data); + input.extend_from_slice(&sync_marker); + let split_point = input.len() - 4; + let part1 = &input[..split_point]; + let part2 = &input[split_point..]; + let read1 = decoder.decode(part1).unwrap(); + assert_eq!(read1, part1.len()); + assert!(decoder.flush().is_none()); + let read2 = decoder.decode(part2).unwrap(); + assert_eq!(read2, part2.len()); + let block = decoder.flush().expect("Block should be complete now"); + assert_eq!(block.count, 1); + assert_eq!(block.data, data); + let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); + assert_eq!(block.sync, expected_sync, "Should match [0xCC; 16]"); + } + + #[test] + fn test_already_finished_state() { + let mut decoder = BlockDecoder::default(); + let count_encoded = encode_vlq(2); + let size_encoded = encode_vlq(1); + let data = vec![0xAB]; + let sync_marker = vec![0xFF; 16]; + let mut input = Vec::new(); + input.extend_from_slice(&count_encoded); + input.extend_from_slice(&size_encoded); + input.extend_from_slice(&data); + input.extend_from_slice(&sync_marker); + let read = decoder.decode(&input).unwrap(); + assert_eq!(read, input.len()); + let block = decoder.flush().expect("Should have a block"); + assert_eq!(block.count, 2); + assert_eq!(block.data, data); + let expected_sync: [u8; 16] = <[u8; 16]>::try_from(&sync_marker[..16]).unwrap(); + assert_eq!(block.sync, expected_sync); + let read2 = decoder.decode(&[]).unwrap(); + assert_eq!(read2, 0); + assert!(decoder.flush().is_none()); + } +} diff --git a/arrow-avro/src/reader/cursor.rs b/arrow-avro/src/reader/cursor.rs index 4b6a5a4d65db..5eab86b04697 100644 --- a/arrow-avro/src/reader/cursor.rs +++ b/arrow-avro/src/reader/cursor.rs @@ -65,6 +65,7 @@ impl<'a> AvroCursor<'a> { Ok(val) } + /// Decode a zig-zag encoded Avro int (32-bit). #[inline] pub(crate) fn get_int(&mut self) -> Result { let varint = self.read_vlq()?; @@ -74,18 +75,20 @@ impl<'a> AvroCursor<'a> { Ok((val >> 1) as i32 ^ -((val & 1) as i32)) } + /// Decode a zig-zag encoded Avro long (64-bit). #[inline] pub(crate) fn get_long(&mut self) -> Result { let val = self.read_vlq()?; Ok((val >> 1) as i64 ^ -((val & 1) as i64)) } + /// Read a variable-length byte array from Avro (where the length is stored as an Avro long). pub(crate) fn get_bytes(&mut self) -> Result<&'a [u8], ArrowError> { let len: usize = self.get_long()?.try_into().map_err(|_| { ArrowError::ParseError("offset overflow reading avro bytes".to_string()) })?; - if (self.buf.len() < len) { + if self.buf.len() < len { return Err(ArrowError::ParseError( "Unexpected EOF reading bytes".to_string(), )); @@ -95,9 +98,10 @@ impl<'a> AvroCursor<'a> { Ok(ret) } + /// Read a little-endian 32-bit float #[inline] pub(crate) fn get_float(&mut self) -> Result { - if (self.buf.len() < 4) { + if self.buf.len() < 4 { return Err(ArrowError::ParseError( "Unexpected EOF reading float".to_string(), )); @@ -107,15 +111,225 @@ impl<'a> AvroCursor<'a> { Ok(ret) } + /// Read a little-endian 64-bit float #[inline] pub(crate) fn get_double(&mut self) -> Result { - if (self.buf.len() < 8) { + if self.buf.len() < 8 { return Err(ArrowError::ParseError( - "Unexpected EOF reading float".to_string(), + "Unexpected EOF reading double".to_string(), )); } let ret = f64::from_le_bytes(self.buf[..8].try_into().unwrap()); self.buf = &self.buf[8..]; Ok(ret) } + + /// Read exactly `n` bytes from the buffer (e.g. for Avro `fixed`). + pub(crate) fn get_fixed(&mut self, n: usize) -> Result<&'a [u8], ArrowError> { + if self.buf.len() < n { + return Err(ArrowError::ParseError( + "Unexpected EOF reading fixed".to_string(), + )); + } + let ret = &self.buf[..n]; + self.buf = &self.buf[n..]; + Ok(ret) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_schema::ArrowError; + + #[test] + fn test_new_and_position() { + let data = [1, 2, 3, 4]; + let cursor = AvroCursor::new(&data); + assert_eq!(cursor.position(), 0); + } + + #[test] + fn test_get_u8_ok() { + let data = [0x12, 0x34, 0x56]; + let mut cursor = AvroCursor::new(&data); + assert_eq!(cursor.get_u8().unwrap(), 0x12); + assert_eq!(cursor.position(), 1); + assert_eq!(cursor.get_u8().unwrap(), 0x34); + assert_eq!(cursor.position(), 2); + assert_eq!(cursor.get_u8().unwrap(), 0x56); + assert_eq!(cursor.position(), 3); + } + + #[test] + fn test_get_u8_eof() { + let data = []; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_u8(); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF")) + ); + } + + #[test] + fn test_get_bool_ok() { + let data = [0x00, 0x01, 0xFF]; + let mut cursor = AvroCursor::new(&data); + assert!(!cursor.get_bool().unwrap()); // 0x00 -> false + assert!(cursor.get_bool().unwrap()); // 0x01 -> true + assert!(cursor.get_bool().unwrap()); // 0xFF -> true (non-zero) + } + + #[test] + fn test_get_bool_eof() { + let data = []; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_bool(); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF")) + ); + } + + #[test] + fn test_read_vlq_ok() { + let data = [0x80, 0x01, 0x05]; + let mut cursor = AvroCursor::new(&data); + let val1 = cursor.read_vlq().unwrap(); + assert_eq!(val1, 128); + let val2 = cursor.read_vlq().unwrap(); + assert_eq!(val2, 5); + } + + #[test] + fn test_read_vlq_bad_varint() { + let data = [0x80]; + let mut cursor = AvroCursor::new(&data); + let result = cursor.read_vlq(); + assert!(matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("bad varint"))); + } + + #[test] + fn test_get_int_ok() { + let data = [0x04, 0x03]; // encodes +2, -2 + let mut cursor = AvroCursor::new(&data); + assert_eq!(cursor.get_int().unwrap(), 2); + assert_eq!(cursor.get_int().unwrap(), -2); + } + + #[test] + fn test_get_int_overflow() { + let data = [0x80, 0x80, 0x80, 0x80, 0x10]; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_int(); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("varint overflow")) + ); + } + + #[test] + fn test_get_long_ok() { + let data = [0x04, 0x03, 0xAC, 0x02]; + let mut cursor = AvroCursor::new(&data); + assert_eq!(cursor.get_long().unwrap(), 2); + assert_eq!(cursor.get_long().unwrap(), -2); + assert_eq!(cursor.get_long().unwrap(), 150); + } + + #[test] + fn test_get_long_eof() { + let data = [0x80]; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_long(); + assert!(matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("bad varint"))); + } + + #[test] + fn test_get_bytes_ok() { + let data = [0x06, 0xAA, 0xBB, 0xCC, 0x05, 0x01]; + let mut cursor = AvroCursor::new(&data); + let bytes = cursor.get_bytes().unwrap(); + assert_eq!(bytes, [0xAA, 0xBB, 0xCC]); + assert_eq!(cursor.position(), 4); + } + + #[test] + fn test_get_bytes_overflow() { + let data = [0xAC, 0x02, 0x01, 0x02, 0x03]; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_bytes(); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF reading bytes")) + ); + } + + #[test] + fn test_get_bytes_negative_length() { + let data = [0x01]; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_bytes(); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("offset overflow")) + ); + } + + #[test] + fn test_get_float_ok() { + let data = [0x00, 0x00, 0x80, 0x3F, 0x01]; + let mut cursor = AvroCursor::new(&data); + let val = cursor.get_float().unwrap(); + assert!((val - 1.0).abs() < f32::EPSILON); + assert_eq!(cursor.position(), 4); + } + + #[test] + fn test_get_float_eof() { + let data = [0x00, 0x00, 0x80]; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_float(); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF reading float")) + ); + } + + #[test] + fn test_get_double_ok() { + let data = [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xF0, 0x3F, 0x99]; + let mut cursor = AvroCursor::new(&data); + let val = cursor.get_double().unwrap(); + assert!((val - 1.0).abs() < f64::EPSILON); + assert_eq!(cursor.position(), 8); + } + + #[test] + fn test_get_double_eof() { + let data = [0x00, 0x00, 0x00, 0x00]; // only 4 bytes + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_double(); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF reading double")) + ); + } + + #[test] + fn test_get_fixed_ok() { + let data = [0x11, 0x22, 0x33, 0x44]; + let mut cursor = AvroCursor::new(&data); + let val = cursor.get_fixed(2).unwrap(); + assert_eq!(val, [0x11, 0x22]); + assert_eq!(cursor.position(), 2); + + let val = cursor.get_fixed(2).unwrap(); + assert_eq!(val, [0x33, 0x44]); + assert_eq!(cursor.position(), 4); + } + + #[test] + fn test_get_fixed_eof() { + let data = [0x11, 0x22]; + let mut cursor = AvroCursor::new(&data); + let result = cursor.get_fixed(3); + assert!( + matches!(result, Err(ArrowError::ParseError(msg)) if msg.contains("Unexpected EOF reading fixed")) + ); + } } diff --git a/arrow-avro/src/reader/header.rs b/arrow-avro/src/reader/header.rs index 98c285171bf3..93f4617ef32d 100644 --- a/arrow-avro/src/reader/header.rs +++ b/arrow-avro/src/reader/header.rs @@ -74,17 +74,18 @@ impl Header { self.sync } - /// Returns the [`CompressionCodec`] if any + /// Returns the [`CompressionCodec`] if any. pub fn compression(&self) -> Result, ArrowError> { let v = self.get(CODEC_METADATA_KEY); - match v { None | Some(b"null") => Ok(None), Some(b"deflate") => Ok(Some(CompressionCodec::Deflate)), Some(b"snappy") => Ok(Some(CompressionCodec::Snappy)), Some(b"zstandard") => Ok(Some(CompressionCodec::ZStandard)), + Some(b"bzip2") => Ok(Some(CompressionCodec::Bzip2)), + Some(b"xz") => Ok(Some(CompressionCodec::Xz)), Some(v) => Err(ArrowError::ParseError(format!( - "Unrecognized compression codec \'{}\'", + "Unrecognized compression codec '{}'", String::from_utf8_lossy(v) ))), } @@ -147,8 +148,6 @@ impl HeaderDecoder { /// This method can be called multiple times with consecutive chunks of data, allowing /// integration with chunked IO systems like [`BufRead::fill_buf`] /// - /// All errors should be considered fatal, and decoding aborted - /// /// Once the entire [`Header`] has been decoded this method will not read any further /// input bytes, and the header can be obtained with [`Self::flush`] /// @@ -209,7 +208,6 @@ impl HeaderDecoder { buf = &buf[to_read..]; if self.bytes_remaining == 0 { self.meta_offsets.push(self.meta_buf.len()); - self.tuples_remaining -= 1; match self.tuples_remaining { 0 => self.state = HeaderDecoderState::BlockCount, @@ -264,13 +262,13 @@ impl HeaderDecoder { #[cfg(test)] mod test { use super::*; - use crate::codec::{AvroDataType, AvroField}; + use crate::codec::AvroField; use crate::reader::read_header; use crate::schema::SCHEMA_METADATA_KEY; use crate::test_util::arrow_test_data; use arrow_schema::{DataType, Field, Fields, TimeUnit}; use std::fs::File; - use std::io::{BufRead, BufReader}; + use std::io::BufReader; #[test] fn test_header_decode() { @@ -353,4 +351,35 @@ mod test { 325166208089902833952788552656412487328 ); } + #[test] + fn test_header_schema_default() { + let json_schema = r#" + { + "type": "record", + "name": "TestRecord", + "fields": [ + {"name": "a", "type": "int", "default": 10} + ] + } + "#; + let key = "avro.schema"; + let key_bytes = key.as_bytes(); + let value_bytes = json_schema.as_bytes(); + let mut meta_buf = Vec::new(); + meta_buf.extend_from_slice(key_bytes); + meta_buf.extend_from_slice(value_bytes); + let meta_offsets = vec![key_bytes.len(), key_bytes.len() + value_bytes.len()]; + let header = Header { + meta_offsets, + meta_buf, + sync: [0; 16], + }; + let schema = header.schema().unwrap().unwrap(); + if let Schema::Complex(crate::schema::ComplexType::Record(record)) = schema { + assert_eq!(record.fields.len(), 1); + assert_eq!(record.fields[0].default, Some(serde_json::json!(10))); + } else { + panic!("Expected record schema"); + } + } } diff --git a/arrow-avro/src/reader/mod.rs b/arrow-avro/src/reader/mod.rs index 12fa67d9c8e3..4f5c0a8f174d 100644 --- a/arrow-avro/src/reader/mod.rs +++ b/arrow-avro/src/reader/mod.rs @@ -15,22 +15,132 @@ // specific language governing permissions and limitations // under the License. -//! Read Avro data to Arrow +//! Avro reader +//! +//! This module provides facilities to read Apache Avro-encoded files or streams +//! into Arrow's [`RecordBatch`] format. In particular, it introduces: +//! +//! * [`ReaderBuilder`]: Configures Avro reading, e.g., batch size +//! * [`Reader`]: Yields [`RecordBatch`] values, implementing [`Iterator`] +//! * [`Decoder`]: A low-level push-based decoder for Avro records +//! +//! # Basic Usage +//! +//! [`Reader`] can be used directly with synchronous data sources, such as [`std::fs::File`]. +//! +//! ## Reading a Single Batch +//! +//! ``` +//! # use std::fs::File; +//! # use std::io::BufReader; +//! +//! let file = File::open("test/data/simple_enum.avro").unwrap(); +//! let mut avro = arrow_avro::ReaderBuilder::new().build(BufReader::new(file)).unwrap(); +//! let batch = avro.next().unwrap().unwrap(); +//! ``` +//! +//! # Async Usage +//! +//! The lower-level [`Decoder`] can be integrated with various forms of async data streams, +//! and is designed to be agnostic to different async IO primitives within +//! the Rust ecosystem. It works by incrementally decoding Avro data from byte slices. +//! +//! For example, see below for how it could be used with an arbitrary `Stream` of `Bytes`: +//! +//! ``` +//! # use std::task::{Poll, ready}; +//! # use bytes::{Buf, Bytes}; +//! # use arrow_schema::ArrowError; +//! # use futures::stream::{Stream, StreamExt}; +//! # use arrow_array::RecordBatch; +//! # use arrow_avro::reader::Decoder; +//! # +//! fn decode_stream + Unpin>( +//! mut decoder: Decoder, +//! mut input: S, +//! ) -> impl Stream> { +//! let mut buffered = Bytes::new(); +//! futures::stream::poll_fn(move |cx| { +//! loop { +//! if buffered.is_empty() { +//! buffered = match ready!(input.poll_next_unpin(cx)) { +//! Some(b) => b, +//! None => break, +//! }; +//! } +//! let decoded = match decoder.decode(buffered.as_ref()) { +//! Ok(decoded) => decoded, +//! Err(e) => return Poll::Ready(Some(Err(e))), +//! }; +//! let read = buffered.len(); +//! buffered.advance(decoded); +//! if decoded != read { +//! break +//! } +//! } +//! // Convert any fully-decoded rows to a RecordBatch, if available +//! Poll::Ready(decoder.flush().transpose()) +//! }) +//! } +//! ``` +//! +//! In a similar vein, it can also be used with tokio-based IO primitives +//! +//! ``` +//! # use std::sync::Arc; +//! # use arrow_schema::{DataType, Field, Schema}; +//! # use std::pin::Pin; +//! # use std::task::{Poll, ready}; +//! # use futures::{Stream, TryStreamExt}; +//! # use tokio::io::AsyncBufRead; +//! # use arrow_array::RecordBatch; +//! # use arrow_avro::reader::Decoder; +//! # use arrow_schema::ArrowError; +//! fn decode_stream( +//! mut decoder: Decoder, +//! mut reader: R, +//! ) -> impl Stream> { +//! futures::stream::poll_fn(move |cx| { +//! loop { +//! let b = match ready!(Pin::new(&mut reader).poll_fill_buf(cx)) { +//! Ok(b) if b.is_empty() => break, +//! Ok(b) => b, +//! Err(e) => return Poll::Ready(Some(Err(e.into()))), +//! }; +//! let read = b.len(); +//! let decoded = match decoder.decode(b) { +//! Ok(decoded) => decoded, +//! Err(e) => return Poll::Ready(Some(Err(e))), +//! }; +//! Pin::new(&mut reader).consume(decoded); +//! if decoded != read { +//! break; +//! } +//! } +//! +//! Poll::Ready(decoder.flush().transpose()) +//! }) +//! } +//! ``` +//! -use crate::reader::block::{Block, BlockDecoder}; -use crate::reader::header::{Header, HeaderDecoder}; -use arrow_schema::ArrowError; +use arrow_array::{RecordBatch, RecordBatchReader}; +use arrow_schema::{ArrowError, SchemaRef}; use std::io::BufRead; -mod header; - mod block; - mod cursor; +mod header; mod record; mod vlq; -/// Read a [`Header`] from the provided [`BufRead`] +use crate::codec::AvroField; +use crate::schema::Schema as AvroSchema; +use block::BlockDecoder; +use header::{Header, HeaderDecoder}; +use record::RecordDecoder; + +/// Read the Avro file header (magic, metadata, sync marker) from `reader`. fn read_header(mut reader: R) -> Result { let mut decoder = HeaderDecoder::default(); loop { @@ -45,75 +155,409 @@ fn read_header(mut reader: R) -> Result { break; } } + decoder.flush().ok_or_else(|| { + ArrowError::ParseError("Unexpected EOF while reading Avro header".to_string()) + }) +} - decoder - .flush() - .ok_or_else(|| ArrowError::ParseError("Unexpected EOF".to_string())) +/// A low-level interface for decoding Avro-encoded bytes into Arrow [`RecordBatch`]. +#[derive(Debug)] +pub struct Decoder { + record_decoder: RecordDecoder, + batch_size: usize, + decoded_rows: usize, } -/// Return an iterator of [`Block`] from the provided [`BufRead`] -fn read_blocks(mut reader: R) -> impl Iterator> { - let mut decoder = BlockDecoder::default(); +impl Decoder { + /// Create a new [`Decoder`], wrapping an existing [`RecordDecoder`]. + pub fn new(record_decoder: RecordDecoder, batch_size: usize) -> Self { + Self { + record_decoder, + batch_size, + decoded_rows: 0, + } + } - let mut try_next = move || { - loop { - let buf = reader.fill_buf()?; - if buf.is_empty() { + /// Return the Arrow schema for the rows decoded by this decoder + pub fn schema(&self) -> SchemaRef { + self.record_decoder.schema().clone() + } + + /// Return the configured maximum number of rows per batch + pub fn batch_size(&self) -> usize { + self.batch_size + } + + /// Feed `data` into the decoder row by row until we either: + /// - consume all bytes in `data`, or + /// - reach `batch_size` decoded rows. + /// + /// Returns the number of bytes consumed. + pub fn decode(&mut self, data: &[u8]) -> Result { + let mut total_consumed = 0usize; + while total_consumed < data.len() && self.decoded_rows < self.batch_size { + let consumed = self.record_decoder.decode(&data[total_consumed..], 1)?; + if consumed == 0 { break; } - let read = buf.len(); - let decoded = decoder.decode(buf)?; - reader.consume(decoded); - if decoded != read { - break; + total_consumed += consumed; + self.decoded_rows += 1; + } + Ok(total_consumed) + } + + /// Produce a [`RecordBatch`] if at least one row is fully decoded, returning + /// `Ok(None)` if no new rows are available. + pub fn flush(&mut self) -> Result, ArrowError> { + if self.decoded_rows == 0 { + Ok(None) + } else { + let batch = self.record_decoder.flush()?; + self.decoded_rows = 0; + Ok(Some(batch)) + } + } +} + +/// A builder to create an [`Avro Reader`](Reader) that reads Avro data +/// into Arrow [`RecordBatch`]. +#[derive(Debug)] +pub struct ReaderBuilder { + batch_size: usize, + strict_mode: bool, +} + +impl Default for ReaderBuilder { + fn default() -> Self { + Self { + batch_size: 1024, + strict_mode: false, + } + } +} + +impl ReaderBuilder { + /// Creates a new [`ReaderBuilder`] with default settings: + /// - `batch_size` = 1024 + /// - `strict_mode` = false + pub fn new() -> Self { + Self::default() + } + + /// Sets the row-based batch size + pub fn with_batch_size(mut self, batch_size: usize) -> Self { + self.batch_size = batch_size; + self + } + + /// Controls whether certain Avro unions of the form `[T, "null"]` should produce an error. + pub fn with_strict_mode(mut self, strict_mode: bool) -> Self { + self.strict_mode = strict_mode; + self + } + + /// Create a [`Reader`] from this builder and a `BufRead` + pub fn build(self, mut reader: R) -> Result, ArrowError> { + let header = read_header(&mut reader)?; + let compression = header.compression()?; + let avro_schema: Option> = header + .schema() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + let avro_schema = avro_schema.ok_or_else(|| { + ArrowError::ParseError("No Avro schema present in file header".to_string()) + })?; + let root_field = AvroField::try_from(&avro_schema)?; + let record_decoder = RecordDecoder::try_new(root_field.data_type(), self.strict_mode)?; + let decoder = Decoder::new(record_decoder, self.batch_size); + Ok(Reader { + reader, + header, + compression, + decoder, + block_decoder: BlockDecoder::default(), + block_data: Vec::new(), + finished: false, + }) + } + + /// Create a [`Decoder`] from this builder and a `BufRead` by + /// reading and parsing the Avro file's header. This will + /// not create a full [`Reader`]. + pub fn build_decoder(self, mut reader: R) -> Result { + let header = read_header(&mut reader)?; + let avro_schema: Option> = header + .schema() + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + + let avro_schema = avro_schema.ok_or_else(|| { + ArrowError::ParseError("No Avro schema present in file header".to_string()) + })?; + let root_field = AvroField::try_from(&avro_schema)?; + let record_decoder = RecordDecoder::try_new(root_field.data_type(), self.strict_mode)?; + Ok(Decoder::new(record_decoder, self.batch_size)) + } +} + +/// A high-level Avro `Reader` that reads container-file blocks +/// and feeds them into a row-level [`Decoder`]. +#[derive(Debug)] +pub struct Reader { + reader: R, + header: Header, + compression: Option, + decoder: Decoder, + block_decoder: BlockDecoder, + block_data: Vec, + finished: bool, +} + +impl Reader { + /// Return the Arrow schema discovered from the Avro file header + pub fn schema(&self) -> SchemaRef { + self.decoder.schema() + } + + /// Return the Avro container-file header + pub fn avro_header(&self) -> &Header { + &self.header + } +} + +impl Reader { + /// Reads the next [`RecordBatch`] from the Avro file or `Ok(None)` on EOF + fn read(&mut self) -> Result, ArrowError> { + if self.finished { + return Ok(None); + } + loop { + if !self.block_data.is_empty() { + let consumed = self.decoder.decode(&self.block_data)?; + if consumed > 0 { + self.block_data.drain(..consumed); + } + match self.decoder.flush()? { + None => { + if !self.block_data.is_empty() { + break; + } + } + Some(batch) => { + return Ok(Some(batch)); + } + } + } + let maybe_block = { + let buf = self.reader.fill_buf()?; + if buf.is_empty() { + None + } else { + let read_len = buf.len(); + let consumed_len = self.block_decoder.decode(buf)?; + self.reader.consume(consumed_len); + if consumed_len == 0 && read_len != 0 { + return Err(ArrowError::ParseError( + "Could not decode next Avro block from partial data".to_string(), + )); + } + self.block_decoder.flush() + } + }; + match maybe_block { + Some(block) => { + let block_data = if let Some(ref codec) = self.compression { + codec.decompress(&block.data)? + } else { + block.data + }; + self.block_data = block_data; + } + None => { + self.finished = true; + if !self.block_data.is_empty() { + let consumed = self.decoder.decode(&self.block_data)?; + self.block_data.drain(..consumed); + } + return self.decoder.flush(); + } } } - Ok(decoder.flush()) - }; - std::iter::from_fn(move || try_next().transpose()) + self.decoder.flush() + } +} + +impl Iterator for Reader { + type Item = Result; + + fn next(&mut self) -> Option { + match self.read() { + Ok(Some(batch)) => Some(Ok(batch)), + Ok(None) => None, + Err(e) => Some(Err(e)), + } + } +} + +impl RecordBatchReader for Reader { + fn schema(&self) -> SchemaRef { + self.schema() + } } #[cfg(test)] mod test { - use crate::codec::AvroField; - use crate::compression::CompressionCodec; - use crate::reader::record::RecordDecoder; - use crate::reader::{read_blocks, read_header}; + use super::*; + use crate::reader::vlq::VLQDecoder; use crate::test_util::arrow_test_data; - use arrow_array::*; + use arrow_array::builder::{ + ArrayBuilder, BooleanBuilder, Float32Builder, Float64Builder, Int32Builder, Int64Builder, + ListBuilder, MapBuilder, MapFieldNames, StringBuilder, StructBuilder, + }; + use arrow_array::types::Int32Type; + use arrow_array::{ + Array, BinaryArray, BooleanArray, Decimal128Array, DictionaryArray, FixedSizeBinaryArray, + Float32Array, Float64Array, Int32Array, Int64Array, ListArray, RecordBatch, StringArray, + StructArray, TimestampMicrosecondArray, + }; + use arrow_buffer::{Buffer, NullBuffer, OffsetBuffer, ScalarBuffer}; + use arrow_data::ArrayDataBuilder; + use arrow_schema::{DataType, Field, Fields, Schema}; + use bytes::{Buf, Bytes}; + use futures::{stream, Stream, StreamExt, TryStreamExt}; + use std::collections::HashMap; + use std::fs; use std::fs::File; - use std::io::BufReader; + use std::io::{BufReader, Cursor}; use std::sync::Arc; + use std::task::{ready, Poll}; - fn read_file(file: &str, batch_size: usize) -> RecordBatch { - let file = File::open(file).unwrap(); - let mut reader = BufReader::new(file); - let header = read_header(&mut reader).unwrap(); - let compression = header.compression().unwrap(); - let schema = header.schema().unwrap().unwrap(); - let root = AvroField::try_from(&schema).unwrap(); - let mut decoder = RecordDecoder::try_new(root.data_type()).unwrap(); - - for result in read_blocks(reader) { - let block = result.unwrap(); - assert_eq!(block.sync, header.sync()); - if let Some(c) = compression { - let decompressed = c.decompress(&block.data).unwrap(); - - let mut offset = 0; - let mut remaining = block.count; - while remaining > 0 { - let to_read = remaining.max(batch_size); - offset += decoder - .decode(&decompressed[offset..], block.count) - .unwrap(); + fn read_file(path: &str, _schema: Option) -> super::Reader> { + let file = File::open(path).unwrap(); + let reader = BufReader::new(file); + let builder = ReaderBuilder::new().with_batch_size(64); + builder.build(reader).unwrap() + } - remaining -= to_read; + fn decode_stream + Unpin>( + mut decoder: Decoder, + mut input: S, + ) -> impl Stream> { + let mut buffered = Bytes::new(); + futures::stream::poll_fn(move |cx| { + loop { + if buffered.is_empty() { + buffered = match ready!(input.poll_next_unpin(cx)) { + Some(b) => b, + None => break, + }; + } + let decoded = match decoder.decode(buffered.as_ref()) { + Ok(decoded) => decoded, + Err(e) => return Poll::Ready(Some(Err(e))), + }; + let read = buffered.len(); + buffered.advance(decoded); + if decoded != read { + break; } - assert_eq!(offset, decompressed.len()); } + Poll::Ready(decoder.flush().transpose()) + }) + } + + #[test] + fn test_basic_usage_single_batch() { + let file = File::open(arrow_test_data("avro/simple_enum.avro")) + .expect("Failed to open test/data/simple_enum.avro"); + let mut avro = ReaderBuilder::new() + .build(BufReader::new(file)) + .expect("Failed to build Avro Reader"); + + let batch = avro + .next() + .expect("No batch found?") + .expect("Error reading batch"); + + assert!(batch.num_rows() > 0, "Expected at least 1 row"); + assert!(batch.num_columns() > 0, "Expected at least 1 column"); + } + + #[test] + fn test_reader_read() -> Result<(), ArrowError> { + let file_path = "test/data/simple_enum.avro"; + let file = File::open(file_path).expect("Failed to open Avro file"); + let mut reader_direct = ReaderBuilder::new() + .build(BufReader::new(file)) + .expect("Failed to build Reader"); + let mut direct_batches = Vec::new(); + while let Some(batch) = reader_direct.read()? { + direct_batches.push(batch); } - decoder.flush().unwrap() + let file = File::open(file_path).expect("Failed to open Avro file"); + let reader_iter = ReaderBuilder::new() + .build(BufReader::new(file)) + .expect("Failed to build Reader"); + let iter_batches: Result, _> = reader_iter.collect(); + let iter_batches = iter_batches?; + assert_eq!(direct_batches, iter_batches); + Ok(()) + } + + #[tokio::test] + async fn test_async_decoder_with_bytes_stream() -> Result<(), ArrowError> { + let path = arrow_test_data("avro/simple_enum.avro"); + let data = fs::read(&path).expect("Failed to read .avro file"); + let mut cursor = Cursor::new(&data); + let decoder: Decoder = ReaderBuilder::new().build_decoder(&mut cursor)?; + let header_consumed = cursor.position() as usize; + let mut remainder = &data[header_consumed..]; + let mut vlq_dec = VLQDecoder::default(); + let _block_count_i64 = vlq_dec + .long(&mut remainder) + .ok_or_else(|| ArrowError::ParseError("EOF reading block count".to_string()))?; + let block_size_i64 = vlq_dec + .long(&mut remainder) + .ok_or_else(|| ArrowError::ParseError("EOF reading block size".to_string()))?; + let block_size = block_size_i64 as usize; + if remainder.len() < block_size { + return Err(ArrowError::ParseError(format!( + "File truncated: Needed {} bytes for block data, got {}", + block_size, + remainder.len() + ))); + } + let block_data = &remainder[..block_size]; + remainder = &remainder[block_size..]; + if remainder.len() < 16 { + return Err(ArrowError::ParseError( + "Missing sync marker in Avro block".to_string(), + )); + } + let _sync_marker = &remainder[..16]; + let _remainder = &remainder[16..]; + let chunks = block_data + .chunks(16) + .map(Bytes::copy_from_slice) + .collect::>(); + let input_stream = stream::iter(chunks); + let record_batch_stream = decode_stream(decoder, input_stream); + let batches: Vec<_> = record_batch_stream.try_collect().await?; + assert!( + !batches.is_empty(), + "Should decode at least one batch from the block" + ); + let file = File::open(&path).unwrap(); + let mut sync_reader = ReaderBuilder::new() + .build(BufReader::new(file)) + .expect("Could not build sync_reader"); + let expected_batch = sync_reader + .next() + .expect("No batch in file") + .expect("Sync decode failed"); + assert_eq!( + batches[0], expected_batch, + "Async decode differs from sync decode" + ); + Ok(()) } #[test] @@ -122,6 +566,8 @@ mod test { "avro/alltypes_plain.avro", "avro/alltypes_plain.snappy.avro", "avro/alltypes_plain.zstandard.avro", + "avro/alltypes_plain.bzip2.avro", + "avro/alltypes_plain.xz.avro", ]; let expected = RecordBatch::try_from_iter_with_nullable([ @@ -172,14 +618,14 @@ mod test { ( "date_string_col", Arc::new(BinaryArray::from_iter_values([ - [48, 51, 47, 48, 49, 47, 48, 57], - [48, 51, 47, 48, 49, 47, 48, 57], - [48, 52, 47, 48, 49, 47, 48, 57], - [48, 52, 47, 48, 49, 47, 48, 57], - [48, 50, 47, 48, 49, 47, 48, 57], - [48, 50, 47, 48, 49, 47, 48, 57], - [48, 49, 47, 48, 49, 47, 48, 57], - [48, 49, 47, 48, 49, 47, 48, 57], + b"03/01/09", + b"03/01/09", + b"04/01/09", + b"04/01/09", + b"02/01/09", + b"02/01/09", + b"01/01/09", + b"01/01/09", ])) as _, true, ), @@ -207,12 +653,1155 @@ mod test { ), ]) .unwrap(); - for file in files { let file = arrow_test_data(file); + let mut reader = read_file(&file, None); + let batch_large = reader.next().unwrap().unwrap(); + assert_eq!(batch_large, expected); + let mut reader_small = read_file(&file, None); + let batch_small = reader_small.next().unwrap().unwrap(); + assert_eq!(batch_small, expected); + } + } - assert_eq!(read_file(&file, 8), expected); - assert_eq!(read_file(&file, 3), expected); + #[test] + fn test_alltypes_dictionary() { + let file = "avro/alltypes_dictionary.avro"; + let expected = RecordBatch::try_from_iter_with_nullable([ + ("id", Arc::new(Int32Array::from(vec![0, 1])) as _, true), + ( + "bool_col", + Arc::new(BooleanArray::from(vec![Some(true), Some(false)])) as _, + true, + ), + ( + "tinyint_col", + Arc::new(Int32Array::from(vec![0, 1])) as _, + true, + ), + ( + "smallint_col", + Arc::new(Int32Array::from(vec![0, 1])) as _, + true, + ), + ("int_col", Arc::new(Int32Array::from(vec![0, 1])) as _, true), + ( + "bigint_col", + Arc::new(Int64Array::from(vec![0, 10])) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from(vec![0.0, 1.1])) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from(vec![0.0, 10.1])) as _, + true, + ), + ( + "date_string_col", + Arc::new(BinaryArray::from_iter_values([b"01/01/09", b"01/01/09"])) as _, + true, + ), + ( + "string_col", + Arc::new(BinaryArray::from_iter_values([b"0", b"1"])) as _, + true, + ), + ( + "timestamp_col", + Arc::new( + TimestampMicrosecondArray::from_iter_values([ + 1230768000000000, // 2009-01-01T00:00:00.000 + 1230768060000000, // 2009-01-01T00:01:00.000 + ]) + .with_timezone("+00:00"), + ) as _, + true, + ), + ]) + .unwrap(); + let file_path = arrow_test_data(file); + let mut reader = read_file(&file_path, None); + let batch_large = reader.next().unwrap().unwrap(); + assert_eq!( + batch_large, expected, + "Decoded RecordBatch does not match for file {}", + file + ); + let mut reader_small = read_file(&file_path, None); + let batch_small = reader_small.next().unwrap().unwrap(); + assert_eq!( + batch_small, expected, + "Decoded RecordBatch (batch size 64) does not match for file {}", + file + ); + } + + #[test] + fn test_alltypes_nulls_plain() { + let file = "avro/alltypes_nulls_plain.avro"; + let expected = RecordBatch::try_from_iter_with_nullable([ + ( + "string_col", + Arc::new(StringArray::from(vec![None::<&str>])) as _, + true, + ), + ("int_col", Arc::new(Int32Array::from(vec![None])) as _, true), + ( + "bool_col", + Arc::new(BooleanArray::from(vec![None])) as _, + true, + ), + ( + "bigint_col", + Arc::new(Int64Array::from(vec![None])) as _, + true, + ), + ( + "float_col", + Arc::new(Float32Array::from(vec![None])) as _, + true, + ), + ( + "double_col", + Arc::new(Float64Array::from(vec![None])) as _, + true, + ), + ( + "bytes_col", + Arc::new(BinaryArray::from(vec![None::<&[u8]>])) as _, + true, + ), + ]) + .unwrap(); + let file_path = arrow_test_data(file); + let mut reader = read_file(&file_path, None); + let batch_large = reader.next().unwrap().unwrap(); + assert_eq!( + batch_large, expected, + "Decoded RecordBatch does not match for file {}", + file + ); + let mut reader_small = read_file(&file_path, None); + let batch_small = reader_small.next().unwrap().unwrap(); + assert_eq!( + batch_small, expected, + "Decoded RecordBatch does not match for file {}", + file + ); + } + + #[test] + fn test_binary() { + let file = arrow_test_data("avro/binary.avro"); + let mut reader = read_file(&file, None); + let batch = reader.next().unwrap().unwrap(); + let expected = RecordBatch::try_from_iter_with_nullable([( + "foo", + Arc::new(BinaryArray::from_iter_values(vec![ + b"\x00".as_ref(), + b"\x01".as_ref(), + b"\x02".as_ref(), + b"\x03".as_ref(), + b"\x04".as_ref(), + b"\x05".as_ref(), + b"\x06".as_ref(), + b"\x07".as_ref(), + b"\x08".as_ref(), + b"\t".as_ref(), + b"\n".as_ref(), + b"\x0b".as_ref(), + ])) as Arc, + true, + )]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_decimal() { + let files = [ + ("avro/fixed_length_decimal.avro", 25, 2), + ("avro/fixed_length_decimal_legacy.avro", 13, 2), + ("avro/int32_decimal.avro", 4, 2), + ("avro/int64_decimal.avro", 10, 2), + ]; + let decimal_values: Vec = (1..=24).map(|n| n as i128 * 100).collect(); + + for (file, precision, scale) in files { + let file_path = arrow_test_data(file); + let mut reader = read_file(&file_path, None); + let actual_batch = reader.next().unwrap().unwrap(); + + let expected_array = Decimal128Array::from_iter_values(decimal_values.clone()) + .with_precision_and_scale(precision, scale) + .unwrap(); + + let mut meta = HashMap::new(); + meta.insert("precision".to_string(), precision.to_string()); + meta.insert("scale".to_string(), scale.to_string()); + let field_with_meta = Field::new("value", DataType::Decimal128(precision, scale), true) + .with_metadata(meta); + + let expected_schema = Arc::new(Schema::new(vec![field_with_meta])); + let expected_batch = + RecordBatch::try_new(expected_schema.clone(), vec![Arc::new(expected_array)]) + .expect("Failed to build expected RecordBatch"); + + assert_eq!( + actual_batch, expected_batch, + "Decoded RecordBatch does not match the expected Decimal128 data for file {}", + file + ); + } + } + + #[test] + fn test_datapage_v2() { + let file = arrow_test_data("avro/datapage_v2.snappy.avro"); + let mut reader = read_file(&file, None); + let batch = reader.next().unwrap().unwrap(); + let a = StringArray::from(vec![ + Some("abc"), + Some("abc"), + Some("abc"), + None, + Some("abc"), + ]); + let b = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(4), Some(5)]); + let c = Float64Array::from(vec![Some(2.0), Some(3.0), Some(4.0), Some(5.0), Some(2.0)]); + let d = BooleanArray::from(vec![ + Some(true), + Some(true), + Some(true), + Some(false), + Some(true), + ]); + let e_values = Int32Array::from(vec![ + Some(1), + Some(2), + Some(3), + Some(1), + Some(2), + Some(3), + Some(1), + Some(2), + ]); + let e_offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0i32, 3, 3, 3, 6, 8])); + let e_validity = Some(NullBuffer::from(vec![true, false, false, true, true])); + let field_e = Arc::new(Field::new("item", DataType::Int32, true)); + let e = ListArray::new(field_e, e_offsets, Arc::new(e_values), e_validity); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("a", Arc::new(a) as Arc, true), + ("b", Arc::new(b) as Arc, true), + ("c", Arc::new(c) as Arc, true), + ("d", Arc::new(d) as Arc, true), + ("e", Arc::new(e) as Arc, true), + ]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_dict_pages_offset_zero() { + let file = arrow_test_data("avro/dict-page-offset-zero.avro"); + let mut reader = read_file(&file, None); + let batch = reader.next().unwrap().unwrap(); + let num_rows = batch.num_rows(); + + let expected_field = Int32Array::from(vec![Some(1552); num_rows]); + let expected = RecordBatch::try_from_iter_with_nullable([( + "l_partkey", + Arc::new(expected_field) as Arc, + true, + )]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_list_columns() { + let file = arrow_test_data("avro/list_columns.avro"); + let mut reader = read_file(&file, None); + let mut int64_list_builder = ListBuilder::new(Int64Builder::new()); + { + { + let values = int64_list_builder.values(); + values.append_value(1); + values.append_value(2); + values.append_value(3); + } + int64_list_builder.append(true); + } + { + { + let values = int64_list_builder.values(); + values.append_null(); + values.append_value(1); + } + int64_list_builder.append(true); + } + { + { + let values = int64_list_builder.values(); + values.append_value(4); + } + int64_list_builder.append(true); + } + let int64_list = int64_list_builder.finish(); + let mut utf8_list_builder = ListBuilder::new(StringBuilder::new()); + { + { + let values = utf8_list_builder.values(); + values.append_value("abc"); + values.append_value("efg"); + values.append_value("hij"); + } + utf8_list_builder.append(true); + } + { + utf8_list_builder.append(false); + } + { + { + let values = utf8_list_builder.values(); + values.append_value("efg"); + values.append_null(); + values.append_value("hij"); + values.append_value("xyz"); + } + utf8_list_builder.append(true); + } + let utf8_list = utf8_list_builder.finish(); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("int64_list", Arc::new(int64_list) as Arc, true), + ("utf8_list", Arc::new(utf8_list) as Arc, true), + ]) + .unwrap(); + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_nested_lists() { + let file = arrow_test_data("avro/nested_lists.snappy.avro"); + let mut reader = read_file(&file, None); + let left = reader.next().unwrap().unwrap(); + let inner_values = StringArray::from(vec![ + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + Some("a"), + Some("b"), + Some("c"), + Some("d"), + Some("e"), + Some("f"), + ]); + let inner_offsets = Buffer::from_slice_ref([0, 2, 3, 3, 4, 6, 8, 8, 9, 11, 13, 14, 14, 15]); + let inner_validity = [ + true, true, false, true, true, true, false, true, true, true, true, false, true, + ]; + let inner_null_buffer = Buffer::from_iter(inner_validity.iter().copied()); + let inner_field = Field::new("item", DataType::Utf8, true); + let inner_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(inner_field))) + .len(13) + .add_buffer(inner_offsets) + .add_child_data(inner_values.to_data()) + .null_bit_buffer(Some(inner_null_buffer)) + .build() + .unwrap(); + let inner_list_array = ListArray::from(inner_list_data); + let middle_offsets = Buffer::from_slice_ref([0, 2, 4, 6, 8, 11, 13]); + let middle_validity = [true; 6]; + let middle_null_buffer = Buffer::from_iter(middle_validity.iter().copied()); + let middle_field = Field::new("item", inner_list_array.data_type().clone(), true); + let middle_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(middle_field))) + .len(6) + .add_buffer(middle_offsets) + .add_child_data(inner_list_array.to_data()) + .null_bit_buffer(Some(middle_null_buffer)) + .build() + .unwrap(); + let middle_list_array = ListArray::from(middle_list_data); + let outer_offsets = Buffer::from_slice_ref([0, 2, 4, 6]); + let outer_null_buffer = Buffer::from_slice_ref([0b111]); // all valid + let outer_field = Field::new("item", middle_list_array.data_type().clone(), true); + let outer_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(outer_field))) + .len(3) + .add_buffer(outer_offsets) + .add_child_data(middle_list_array.to_data()) + .null_bit_buffer(Some(outer_null_buffer)) + .build() + .unwrap(); + let a_expected = ListArray::from(outer_list_data); + let b_expected = Int32Array::from(vec![1, 1, 1]); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("a", Arc::new(a_expected) as Arc, true), + ("b", Arc::new(b_expected) as Arc, true), + ]) + .unwrap(); + assert_eq!(left, expected, "Mismatch for batch size=64"); + } + + #[test] + fn test_nested_records() { + let file = arrow_test_data("avro/nested_records.avro"); + let mut reader = read_file(&file, None); + let batch = reader.next().unwrap().unwrap(); + let f1_f1_1 = StringArray::from(vec!["aaa", "bbb"]); + let f1_f1_2 = Int32Array::from(vec![10, 20]); + let rounded_pi = (std::f64::consts::PI * 100.0).round() / 100.0; + let f1_f1_3_1 = Float64Array::from(vec![rounded_pi, rounded_pi]); + let f1_f1_3 = StructArray::from(vec![( + Arc::new(Field::new("f1_3_1", DataType::Float64, false)), + Arc::new(f1_f1_3_1) as Arc, + )]); + + let f1_expected = StructArray::from(vec![ + ( + Arc::new(Field::new("f1_1", DataType::Utf8, false)), + Arc::new(f1_f1_1) as Arc, + ), + ( + Arc::new(Field::new("f1_2", DataType::Int32, false)), + Arc::new(f1_f1_2) as Arc, + ), + ( + Arc::new(Field::new( + "f1_3", + DataType::Struct(Fields::from(vec![Field::new( + "f1_3_1", + DataType::Float64, + false, + )])), + false, + )), + Arc::new(f1_f1_3) as Arc, + ), + ]); + let f2_fields = vec![ + Field::new("f2_1", DataType::Boolean, false), + Field::new("f2_2", DataType::Float32, false), + ]; + let f2_struct_builder = StructBuilder::new( + f2_fields + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>>(), + vec![ + Box::new(BooleanBuilder::new()) as Box, + Box::new(Float32Builder::new()) as Box, + ], + ); + let mut f2_list_builder = ListBuilder::new(f2_struct_builder); + { + let struct_builder = f2_list_builder.values(); + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(true); + } + { + let b = struct_builder.field_builder::(1).unwrap(); + b.append_value(1.2_f32); + } + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(true); + } + { + let b = struct_builder.field_builder::(1).unwrap(); + b.append_value(2.2_f32); + } + f2_list_builder.append(true); + } + { + let struct_builder = f2_list_builder.values(); + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(false); + } + { + let b = struct_builder.field_builder::(1).unwrap(); + b.append_value(10.2_f32); + } + f2_list_builder.append(true); + } + let f2_expected = f2_list_builder.finish(); + let mut f3_struct_builder = StructBuilder::new( + vec![Arc::new(Field::new("f3_1", DataType::Utf8, false))], + vec![Box::new(StringBuilder::new()) as Box], + ); + f3_struct_builder.append(true); + { + let b = f3_struct_builder.field_builder::(0).unwrap(); + b.append_value("xyz"); + } + f3_struct_builder.append(false); + { + let b = f3_struct_builder.field_builder::(0).unwrap(); + b.append_null(); + } + let f3_expected = f3_struct_builder.finish(); + let f4_fields = [Field::new("f4_1", DataType::Int64, false)]; + let f4_struct_builder = StructBuilder::new( + f4_fields + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>>(), + vec![Box::new(Int64Builder::new()) as Box], + ); + let mut f4_list_builder = ListBuilder::new(f4_struct_builder); + { + let struct_builder = f4_list_builder.values(); + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(200); + } + struct_builder.append(false); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_null(); + } + f4_list_builder.append(true); + } + { + let struct_builder = f4_list_builder.values(); + struct_builder.append(false); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_null(); + } + struct_builder.append(true); + { + let b = struct_builder.field_builder::(0).unwrap(); + b.append_value(300); + } + f4_list_builder.append(true); + } + let f4_expected = f4_list_builder.finish(); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("f1", Arc::new(f1_expected) as Arc, false), + ("f2", Arc::new(f2_expected) as Arc, false), + ("f3", Arc::new(f3_expected) as Arc, true), + ("f4", Arc::new(f4_expected) as Arc, false), + ]) + .unwrap(); + assert_eq!(batch, expected, "Mismatch in nested_records.avro contents"); + } + + #[test] + fn test_nonnullable_impala() { + let file = arrow_test_data("avro/nonnullable.impala.avro"); + let mut reader = read_file(&file, None); + let id = Int64Array::from(vec![Some(8)]); + let mut int_array_builder = ListBuilder::new(Int32Builder::new()); + { + let vb = int_array_builder.values(); + vb.append_value(-1); + } + int_array_builder.append(true); + let int_array = int_array_builder.finish(); + let mut iaa_builder = ListBuilder::new(ListBuilder::new(Int32Builder::new())); + { + let inner_list_builder = iaa_builder.values(); + { + let vb = inner_list_builder.values(); + vb.append_value(-1); + vb.append_value(-2); + } + inner_list_builder.append(true); + inner_list_builder.append(true); + } + iaa_builder.append(true); + let int_array_array = iaa_builder.finish(); + let field_names = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let mut int_map_builder = + MapBuilder::new(Some(field_names), StringBuilder::new(), Int32Builder::new()); + { + let (keys, vals) = int_map_builder.entries(); + keys.append_value("k1"); + vals.append_value(-1); + } + int_map_builder.append(true).unwrap(); + let int_map = int_map_builder.finish(); + let field_names2 = MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }; + let mut ima_builder = ListBuilder::new(MapBuilder::new( + Some(field_names2), + StringBuilder::new(), + Int32Builder::new(), + )); + { + let map_builder = ima_builder.values(); + map_builder.append(true).unwrap(); + { + let (keys, vals) = map_builder.entries(); + keys.append_value("k1"); + vals.append_value(1); + } + map_builder.append(true).unwrap(); + map_builder.append(true).unwrap(); + map_builder.append(true).unwrap(); } + ima_builder.append(true); + let int_map_array_ = ima_builder.finish(); + let nested_schema_fields = vec![ + Field::new("a", DataType::Int32, true), + Field::new( + "B", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + Field::new( + "c", + DataType::Struct(Fields::from(vec![Field::new( + "D", + DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Utf8, true), + ])), + true, + ))), + true, + ))), + true, + )])), + true, + ), + Field::new( + "G", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::Struct(Fields::from(vec![Field::new( + "h", + DataType::Struct(Fields::from(vec![Field::new( + "i", + DataType::List(Arc::new(Field::new( + "item", + DataType::Float64, + true, + ))), + true, + )])), + true, + )])), + true, + ), + ])), + false, + )), + false, + ), + true, + ), + ]; + let nested_schema = Arc::new(Schema::new(nested_schema_fields.clone())); + let mut nested_sb = StructBuilder::new( + nested_schema_fields + .iter() + .map(|f| Arc::new(f.clone())) + .collect::>(), + vec![ + Box::new(Int32Builder::new()), + Box::new(ListBuilder::new(Int32Builder::new())), + { + let d_list_field = Field::new( + "D", + DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(Fields::from(vec![ + Field::new("e", DataType::Int32, true), + Field::new("f", DataType::Utf8, true), + ])), + true, + ))), + true, + ))), + true, + ); + let struct_c_builder = StructBuilder::new( + vec![Arc::new(d_list_field)], + vec![Box::new(ListBuilder::new(ListBuilder::new( + StructBuilder::new( + vec![ + Arc::new(Field::new("e", DataType::Int32, true)), + Arc::new(Field::new("f", DataType::Utf8, true)), + ], + vec![ + Box::new(Int32Builder::new()), + Box::new(StringBuilder::new()), + ], + ), + )))], + ); + Box::new(struct_c_builder) + }, + { + Box::new(MapBuilder::new( + Some(MapFieldNames { + entry: "entries".to_string(), + key: "key".to_string(), + value: "value".to_string(), + }), + StringBuilder::new(), + StructBuilder::new( + vec![Arc::new(Field::new( + "h", + DataType::Struct(Fields::from(vec![Field::new( + "i", + DataType::List(Arc::new(Field::new( + "item", + DataType::Float64, + true, + ))), + true, + )])), + true, + ))], + vec![Box::new(StructBuilder::new( + vec![Arc::new(Field::new( + "i", + DataType::List(Arc::new(Field::new( + "item", + DataType::Float64, + true, + ))), + true, + ))], + vec![Box::new(ListBuilder::new(Float64Builder::new()))], + ))], + ), + )) + }, + ], + ); + nested_sb.append(true); + { + let a_builder = nested_sb.field_builder::(0).unwrap(); + a_builder.append_value(-1); + let b_builder = nested_sb + .field_builder::>(1) + .unwrap(); + { + let vb = b_builder.values(); + vb.append_value(-1); + } + b_builder.append(true); + let c_sb = nested_sb.field_builder::(2).unwrap(); + c_sb.append(true); + { + let d_list_builder = c_sb + .field_builder::>>(0) + .unwrap(); + { + let sub_list_builder = d_list_builder.values(); + { + let ef_struct_builder = sub_list_builder.values(); + ef_struct_builder.append(true); + { + let e_b = ef_struct_builder.field_builder::(0).unwrap(); + e_b.append_value(-1); + let f_b = ef_struct_builder.field_builder::(1).unwrap(); + f_b.append_value("nonnullable"); + } + sub_list_builder.append(true); + } + d_list_builder.append(true); + } + } + let g_map_builder = nested_sb + .field_builder::>(3) + .unwrap(); + g_map_builder.append(true).unwrap(); + { + let (keys, values) = g_map_builder.entries(); + keys.append_value("k1"); + values.append(true); + let h_struct_builder = values.field_builder::(0).unwrap(); + h_struct_builder.append(true); + { + let i_list_builder = h_struct_builder + .field_builder::>(0) + .unwrap(); + i_list_builder.append(true); + } + } + } + let nested_struct = nested_sb.finish(); + let schema = Arc::new(Schema::new(vec![ + Field::new("ID", DataType::Int64, true), + Field::new( + "Int_Array", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ), + Field::new( + "int_array_array", + DataType::List(Arc::new(Field::new( + "item", + DataType::List(Arc::new(Field::new("item", DataType::Int32, true))), + true, + ))), + true, + ), + Field::new( + "Int_Map", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ])), + false, + )), + false, + ), + true, + ), + Field::new( + "int_map_array", + DataType::List(Arc::new(Field::new( + "item", + DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct(Fields::from(vec![ + Field::new("key", DataType::Utf8, false), + Field::new("value", DataType::Int32, true), + ])), + false, + )), + false, + ), + true, + ))), + true, + ), + Field::new( + "nested_Struct", + DataType::Struct(nested_schema.as_ref().fields.clone()), + true, + ), + ])); + let expected = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(id) as Arc, + Arc::new(int_array), + Arc::new(int_array_array), + Arc::new(int_map), + Arc::new(int_map_array_), + Arc::new(nested_struct), + ], + ) + .unwrap(); + let batch = reader.next().unwrap().unwrap(); + assert_eq!(batch, expected, "nonnullable impala avro data mismatch"); + } + + #[test] + fn test_nullable_impala() { + use arrow_array::{Int64Array, ListArray, StructArray}; + let file = arrow_test_data("avro/nullable.impala.avro"); + let mut r1 = read_file(&file, None); + let batch1 = r1.next().unwrap().unwrap(); + let mut r2 = read_file(&file, None); + let batch2 = r2.next().unwrap().unwrap(); + assert_eq!( + batch1, batch2, + "Reading file multiple times should produce the same data" + ); + let batch = batch1; + assert_eq!(batch.num_rows(), 7); + let id_array = batch + .column(0) + .as_any() + .downcast_ref::() + .expect("id column should be an Int64Array"); + let expected_ids = [1, 2, 3, 4, 5, 6, 7]; + for (i, &expected_id) in expected_ids.iter().enumerate() { + assert_eq!(id_array.value(i), expected_id, "Mismatch in id at row {i}"); + } + let int_array = batch + .column(1) + .as_any() + .downcast_ref::() + .expect("int_array column should be a ListArray"); + + { + let offsets = int_array.value_offsets(); + let start = offsets[0] as usize; + let end = offsets[1] as usize; + let values = int_array + .values() + .as_any() + .downcast_ref::() + .expect("Values of int_array should be an Int32Array"); + let row0: Vec> = (start..end).map(|idx| Some(values.value(idx))).collect(); + assert_eq!( + row0, + vec![Some(1), Some(2), Some(3)], + "Mismatch in int_array row 0" + ); + } + let nested_struct = batch + .column(5) + .as_any() + .downcast_ref::() + .expect("nested_struct column should be a StructArray"); + let a_array = nested_struct + .column_by_name("A") + .expect("Field A should exist in nested_struct") + .as_any() + .downcast_ref::() + .expect("Field A should be an Int32Array"); + assert_eq!(a_array.value(0), 1, "Mismatch in nested_struct.A at row 0"); + assert!( + !a_array.is_valid(1), + "Expected null in nested_struct.A at row 1" + ); + assert!( + !a_array.is_valid(3), + "Expected null in nested_struct.A at row 3" + ); + assert_eq!(a_array.value(6), 7, "Mismatch in nested_struct.A at row 6"); + } + + #[test] + fn test_nulls_snappy() { + let file = arrow_test_data("avro/nulls.snappy.avro"); + let mut reader = read_file(&file, None); + let batch = reader.next().unwrap().unwrap(); + let b_c_int = Int32Array::from(vec![None; 8]); + let b_c_int_data = b_c_int.into_data(); + let b_struct_field = Field::new("b_c_int", DataType::Int32, true); + let b_struct_type = DataType::Struct(vec![b_struct_field].into()); + let struct_validity = arrow_buffer::Buffer::from_iter((0..8).map(|_| true)); + let b_struct_data = ArrayDataBuilder::new(b_struct_type) + .len(8) + .null_bit_buffer(Some(struct_validity)) + .child_data(vec![b_c_int_data]) + .build() + .unwrap(); + let b_struct_array = StructArray::from(b_struct_data); + + let expected = RecordBatch::try_from_iter_with_nullable([( + "b_struct", + Arc::new(b_struct_array) as _, + true, + )]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_repeated_no_annotation() { + let file = arrow_test_data("avro/repeated_no_annotation.avro"); + let mut reader = read_file(&file, None); + let batch = reader.next().unwrap().unwrap(); + use arrow_array::{Int32Array, Int64Array, ListArray, StringArray, StructArray}; + use arrow_buffer::Buffer; + use arrow_data::ArrayDataBuilder; + use arrow_schema::{DataType, Field, Fields}; + let id_array = Int32Array::from(vec![1, 2, 3, 4, 5, 6]); + let number_array = Int64Array::from(vec![ + Some(5555555555), + Some(1111111111), + Some(1111111111), + Some(2222222222), + Some(3333333333), + ]); + let kind_array = + StringArray::from(vec![None, Some("home"), Some("home"), None, Some("mobile")]); + let phone_fields = Fields::from(vec![ + Field::new("number", DataType::Int64, true), + Field::new("kind", DataType::Utf8, true), + ]); + let phone_struct_data = ArrayDataBuilder::new(DataType::Struct(phone_fields)) + .len(5) + .child_data(vec![number_array.into_data(), kind_array.into_data()]) + .build() + .unwrap(); + let phone_struct_array = StructArray::from(phone_struct_data); + let phone_list_offsets = Buffer::from_slice_ref([0, 0, 0, 0, 1, 2, 5]); + let phone_list_validity = Buffer::from_iter([false, false, true, true, true, true]); + let phone_item_field = Field::new("item", phone_struct_array.data_type().clone(), true); + let phone_list_data = ArrayDataBuilder::new(DataType::List(Arc::new(phone_item_field))) + .len(6) + .add_buffer(phone_list_offsets) + .null_bit_buffer(Some(phone_list_validity)) + .child_data(vec![phone_struct_array.into_data()]) + .build() + .unwrap(); + let phone_list_array = ListArray::from(phone_list_data); + let phone_numbers_validity = Buffer::from_iter([false, false, true, true, true, true]); + let phone_numbers_field = Field::new("phone", phone_list_array.data_type().clone(), true); + let phone_numbers_struct_data = + ArrayDataBuilder::new(DataType::Struct(Fields::from(vec![phone_numbers_field]))) + .len(6) + .null_bit_buffer(Some(phone_numbers_validity)) + .child_data(vec![phone_list_array.into_data()]) + .build() + .unwrap(); + let phone_numbers_struct_array = StructArray::from(phone_numbers_struct_data); + let expected = RecordBatch::try_from_iter_with_nullable([ + ("id", Arc::new(id_array) as _, true), + ( + "phoneNumbers", + Arc::new(phone_numbers_struct_array) as _, + true, + ), + ]) + .unwrap(); + assert_eq!(batch, expected); + } + + #[test] + fn test_simple() { + + fn build_expected_enum() -> RecordBatch { + // Build the DictionaryArrays for f1, f2, f3 + let keys_f1 = Int32Array::from(vec![0, 1, 2, 3]); + let vals_f1 = StringArray::from(vec!["a", "b", "c", "d"]); + let f1_dict = + DictionaryArray::::try_new(keys_f1, Arc::new(vals_f1)).unwrap(); + let keys_f2 = Int32Array::from(vec![2, 3, 0, 1]); + let vals_f2 = StringArray::from(vec!["e", "f", "g", "h"]); + let f2_dict = + DictionaryArray::::try_new(keys_f2, Arc::new(vals_f2)).unwrap(); + let keys_f3 = Int32Array::from(vec![Some(1), Some(2), None, Some(0)]); + let vals_f3 = StringArray::from(vec!["i", "j", "k"]); + let f3_dict = + DictionaryArray::::try_new(keys_f3, Arc::new(vals_f3)).unwrap(); + let dict_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let mut md_f1 = HashMap::new(); + md_f1.insert( + "avro.enum.symbols".to_string(), + r#"["a","b","c","d"]"#.to_string(), + ); + let f1_field = Field::new("f1", dict_type.clone(), false).with_metadata(md_f1); + let mut md_f2 = HashMap::new(); + md_f2.insert( + "avro.enum.symbols".to_string(), + r#"["e","f","g","h"]"#.to_string(), + ); + let f2_field = Field::new("f2", dict_type.clone(), false).with_metadata(md_f2); + let mut md_f3 = HashMap::new(); + md_f3.insert( + "avro.enum.symbols".to_string(), + r#"["i","j","k"]"#.to_string(), + ); + let f3_field = Field::new("f3", dict_type.clone(), true).with_metadata(md_f3); + let expected_schema = Arc::new(Schema::new(vec![ + f1_field, + f2_field, + f3_field, + ])); + RecordBatch::try_new( + expected_schema, + vec![ + Arc::new(f1_dict) as Arc, + Arc::new(f2_dict) as Arc, + Arc::new(f3_dict) as Arc, + ], + ) + .unwrap() + } + + fn build_expected_fixed() -> RecordBatch { + let f1 = FixedSizeBinaryArray::try_from_iter( + vec![b"abcde", b"12345"].into_iter() + ).unwrap(); + let f2 = FixedSizeBinaryArray::try_from_iter( + vec![b"fghijklmno", b"1234567890"].into_iter() + ).unwrap(); + let f3 = FixedSizeBinaryArray::try_from_sparse_iter_with_size( + vec![Some(b"ABCDEF" as &[u8]), None].into_iter(), + 6, + ).unwrap(); + let expected_schema = Arc::new(Schema::new(vec![ + Field::new("f1", DataType::FixedSizeBinary(5), false), + Field::new("f2", DataType::FixedSizeBinary(10), false), + Field::new("f3", DataType::FixedSizeBinary(6), true), + ])); + + RecordBatch::try_new( + expected_schema, + vec![ + Arc::new(f1) as Arc, + Arc::new(f2) as Arc, + Arc::new(f3) as Arc, + ], + ) + .unwrap() + } + let tests = [ + ("avro/simple_enum.avro", build_expected_enum()), + ("avro/simple_fixed.avro", build_expected_fixed()), + ]; + + for (file_name, expected) in tests { + let file = arrow_test_data(file_name); + let mut reader = read_file(&file, None); + let actual = reader + .next() + .expect("Should have a batch") + .expect("Error reading batch"); + + assert_eq!(actual, expected, "Mismatch for file {file_name}"); + } + } + + + #[test] + fn test_single_nan() { + let file = arrow_test_data("avro/single_nan.avro"); + let mut reader = read_file(&file, None); + let batch = reader + .next() + .expect("Should have a batch") + .expect("Error reading single_nan batch"); + let schema = Arc::new(Schema::new(vec![Field::new( + "mycol", + DataType::Float64, + true, + )])); + let col = arrow_array::Float64Array::from(vec![None]); + let expected = RecordBatch::try_new(schema.clone(), vec![Arc::new(col)]).unwrap(); + assert_eq!(batch, expected, "Mismatch in single_nan.avro data"); } } diff --git a/arrow-avro/src/reader/record.rs b/arrow-avro/src/reader/record.rs index 52a58cf63303..801212523ef1 100644 --- a/arrow-avro/src/reader/record.rs +++ b/arrow-avro/src/reader/record.rs @@ -16,44 +16,56 @@ // under the License. use crate::codec::{AvroDataType, Codec, Nullability}; -use crate::reader::block::{Block, BlockDecoder}; use crate::reader::cursor::AvroCursor; -use crate::reader::header::Header; -use crate::schema::*; +use arrow_array::builder::{Decimal128Builder, Decimal256Builder, PrimitiveBuilder}; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::*; use arrow_schema::{ - ArrowError, DataType, Field as ArrowField, FieldRef, Fields, Schema as ArrowSchema, SchemaRef, + ArrowError, DataType, Field as ArrowField, FieldRef, Fields, IntervalUnit, + Schema as ArrowSchema, SchemaRef, DECIMAL128_MAX_PRECISION, DECIMAL256_MAX_PRECISION, }; -use std::collections::HashMap; -use std::io::Read; +use std::cmp::Ordering; use std::sync::Arc; -/// Decodes avro encoded data into [`RecordBatch`] +const DEFAULT_CAPACITY: usize = 1024; + +/// A decoder that converts Avro-encoded data into an Arrow [`RecordBatch`]. +#[derive(Debug)] pub struct RecordDecoder { schema: SchemaRef, fields: Vec, } impl RecordDecoder { - pub fn try_new(data_type: &AvroDataType) -> Result { - match Decoder::try_new(data_type)? { - Decoder::Record(fields, encodings) => Ok(Self { + /// Create a new [`RecordDecoder`] from an [`AvroDataType`] that must be a `Record`. + /// + /// - `strict_mode`: if `true`, we will throw an error if we encounter + /// a union of the form `[T, "null"]` (i.e. `Nullability::NullSecond`). + pub fn try_new( + data_type: &AvroDataType, + strict_mode: bool, + ) -> Result { + match Decoder::try_new(data_type, strict_mode)? { + Decoder::Record(fields, decoders) => Ok(Self { schema: Arc::new(ArrowSchema::new(fields)), - fields: encodings, + fields: decoders, }), - encoding => Err(ArrowError::ParseError(format!( - "Expected record got {encoding:?}" + other => Err(ArrowError::ParseError(format!( + "Expected record got {other:?}" ))), } } + /// Return the [`SchemaRef`] describing the Arrow schema of rows produced by this decoder. pub fn schema(&self) -> &SchemaRef { &self.schema } - /// Decode `count` records from `buf` + /// Decode `count` Avro records from `buf`. + /// + /// This accumulates data in internal buffers. Once done reading, call + /// [`Self::flush`] to yield an Arrow [`RecordBatch`]. pub fn decode(&mut self, buf: &[u8], count: usize) -> Result { let mut cursor = AvroCursor::new(buf); for _ in 0..count { @@ -64,43 +76,66 @@ impl RecordDecoder { Ok(cursor.position()) } - /// Flush the decoded records into a [`RecordBatch`] + /// Flush into a [`RecordBatch`], + /// + /// We collect arrays from each `Decoder` and build a new [`RecordBatch`]. pub fn flush(&mut self) -> Result { let arrays = self .fields .iter_mut() - .map(|x| x.flush(None)) + .map(|d| d.flush(None)) .collect::, _>>()?; - RecordBatch::try_new(self.schema.clone(), arrays) } } +/// For 2-branch unions we store either `[null, T]` or `[T, null]`. +/// +/// - `NullFirst`: `[null, T]` => branch=0 => null, branch=1 => decode T +/// - `NullSecond`: `[T, null]` => branch=0 => decode T, branch=1 => null +#[derive(Debug, Copy, Clone)] +enum UnionOrder { + NullFirst, + NullSecond, +} + #[derive(Debug)] enum Decoder { + /// Primitive Types Null(usize), Boolean(BooleanBufferBuilder), Int32(Vec), Int64(Vec), Float32(Vec), Float64(Vec), + Binary(OffsetBufferBuilder, Vec), + String(OffsetBufferBuilder, Vec), + /// Complex Types + Record(Fields, Vec), + Enum(Arc<[String]>, Vec), + List(FieldRef, OffsetBufferBuilder, Box), + Map( + FieldRef, + OffsetBufferBuilder, + OffsetBufferBuilder, + Vec, + Box, + ), + Nullable(UnionOrder, NullBufferBuilder, Box), + Fixed(i32, Vec), + /// Logical Types + Decimal(usize, Option, Option, DecimalBuilder), Date32(Vec), TimeMillis(Vec), TimeMicros(Vec), TimestampMillis(bool, Vec), TimestampMicros(bool, Vec), - Binary(OffsetBufferBuilder, Vec), - String(OffsetBufferBuilder, Vec), - List(FieldRef, OffsetBufferBuilder, Box), - Record(Fields, Vec), - Nullable(Nullability, NullBufferBuilder, Box), + Interval(Vec), } impl Decoder { - fn try_new(data_type: &AvroDataType) -> Result { - let nyi = |s: &str| Err(ArrowError::NotYetImplemented(s.to_string())); - - let decoder = match data_type.codec() { + fn try_new(data_type: &AvroDataType, strict_mode: bool) -> Result { + let base = match &data_type.codec { Codec::Null => Self::Null(0), Codec::Boolean => Self::Boolean(BooleanBufferBuilder::new(DEFAULT_CAPACITY)), Codec::Int32 => Self::Int32(Vec::with_capacity(DEFAULT_CAPACITY)), @@ -111,182 +146,1539 @@ impl Decoder { OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Utf8 => Self::String( + Codec::String => Self::String( OffsetBufferBuilder::new(DEFAULT_CAPACITY), Vec::with_capacity(DEFAULT_CAPACITY), ), - Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), - Codec::TimestampMillis(is_utc) => { - Self::TimestampMillis(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + Codec::Record(avro_fields) => { + let mut fields = Vec::with_capacity(avro_fields.len()); + let mut children = Vec::with_capacity(avro_fields.len()); + for f in avro_fields.iter() { + // Recursively build a Decoder for each child + let child = Self::try_new(f.data_type(), strict_mode)?; + fields.push(f.field()); + children.push(child); + } + Self::Record(fields.into(), children) } - Codec::TimestampMicros(is_utc) => { - Self::TimestampMicros(*is_utc, Vec::with_capacity(DEFAULT_CAPACITY)) + Codec::Enum(syms, _) => { + Self::Enum(Arc::clone(syms), Vec::with_capacity(DEFAULT_CAPACITY)) } - Codec::Fixed(_) => return nyi("decoding fixed"), - Codec::Interval => return nyi("decoding interval"), - Codec::List(item) => { - let decoder = Self::try_new(item)?; + Codec::Array(child) => { + let child_dec = Self::try_new(child, strict_mode)?; + let item_field = child.field_with_name("item").with_nullable(true); Self::List( - Arc::new(item.field_with_name("item")), + Arc::new(item_field), + OffsetBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(child_dec), + ) + } + Codec::Map(child) => { + let val_field = child.field_with_name("value").with_nullable(true); + let map_field = Arc::new(ArrowField::new( + "entries", + DataType::Struct(Fields::from(vec![ + ArrowField::new("key", DataType::Utf8, false), + val_field, + ])), + false, + )); + let valdec = Self::try_new(child, strict_mode)?; + Self::Map( + map_field, + OffsetBufferBuilder::new(DEFAULT_CAPACITY), OffsetBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), + Vec::with_capacity(DEFAULT_CAPACITY), + Box::new(valdec), ) } - Codec::Struct(fields) => { - let mut arrow_fields = Vec::with_capacity(fields.len()); - let mut encodings = Vec::with_capacity(fields.len()); - for avro_field in fields.iter() { - let encoding = Self::try_new(avro_field.data_type())?; - arrow_fields.push(avro_field.field()); - encodings.push(encoding); + Codec::Fixed(sz) => Self::Fixed(*sz, Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Decimal(p, s, size) => { + let b = DecimalBuilder::new(*p, *s, *size)?; + Self::Decimal(*p, *s, *size, b) + } + Codec::Uuid => Self::Fixed(16, Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::Date32 => Self::Date32(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMillis => Self::TimeMillis(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimeMicros => Self::TimeMicros(Vec::with_capacity(DEFAULT_CAPACITY)), + Codec::TimestampMillis(utc) => { + Self::TimestampMillis(*utc, Vec::with_capacity(DEFAULT_CAPACITY)) + } + Codec::TimestampMicros(utc) => { + Self::TimestampMicros(*utc, Vec::with_capacity(DEFAULT_CAPACITY)) + } + Codec::Duration => Self::Interval(Vec::with_capacity(DEFAULT_CAPACITY)), + }; + let union_order = match data_type.nullability { + None => None, + Some(Nullability::NullFirst) => Some(UnionOrder::NullFirst), + Some(Nullability::NullSecond) => { + if strict_mode { + return Err(ArrowError::ParseError( + "Found Avro union of the form ['T','null'], which is disallowed in strict_mode" + .to_string(), + )); } - Self::Record(arrow_fields.into(), encodings) + Some(UnionOrder::NullSecond) } }; - - Ok(match data_type.nullability() { - Some(nullability) => Self::Nullable( - nullability, + let decoder = match union_order { + Some(order) => Decoder::Nullable( + order, NullBufferBuilder::new(DEFAULT_CAPACITY), - Box::new(decoder), + Box::new(base), ), - None => decoder, - }) + None => base, + }; + Ok(decoder) } - /// Append a null record fn append_null(&mut self) { match self { - Self::Null(count) => *count += 1, + Self::Null(n) => *n += 1, Self::Boolean(b) => b.append(false), Self::Int32(v) | Self::Date32(v) | Self::TimeMillis(v) => v.push(0), Self::Int64(v) | Self::TimeMicros(v) | Self::TimestampMillis(_, v) | Self::TimestampMicros(_, v) => v.push(0), - Self::Float32(v) => v.push(0.), - Self::Float64(v) => v.push(0.), - Self::Binary(offsets, _) | Self::String(offsets, _) => offsets.push_length(0), - Self::List(_, offsets, e) => { - offsets.push_length(0); - e.append_null(); + Self::Float32(v) => v.push(0.0), + Self::Float64(v) => v.push(0.0), + Self::Binary(off, _) | Self::String(off, _) => off.push_length(0), + Self::Record(_, children) => { + for c in children { + c.append_null(); + } + } + Self::Enum(_, idxs) => idxs.push(0), + Self::List(_, off, _) => { + off.push_length(0); + } + Self::Map(_, _koff, moff, _kdata, _valdec) => { + moff.push_length(0); + } + Self::Nullable(_, nb, child) => { + nb.append(false); + child.append_null(); + } + Self::Fixed(sz, accum) => { + accum.extend(std::iter::repeat(0u8).take(*sz as usize)); + } + Self::Decimal(_, _, _, db) => { + let _ = db.append_null(); + } + Self::Interval(ivals) => { + ivals.push(IntervalMonthDayNano { + months: 0, + days: 0, + nanoseconds: 0, + }); } - Self::Record(_, e) => e.iter_mut().for_each(|e| e.append_null()), - Self::Nullable(_, _, _) => unreachable!("Nulls cannot be nested"), } } - /// Decode a single record from `buf` - fn decode(&mut self, buf: &mut AvroCursor<'_>) -> Result<(), ArrowError> { + fn decode(&mut self, buf: &mut AvroCursor) -> Result<(), ArrowError> { match self { - Self::Null(x) => *x += 1, - Self::Boolean(values) => values.append(buf.get_bool()?), - Self::Int32(values) | Self::Date32(values) | Self::TimeMillis(values) => { - values.push(buf.get_int()?) - } - Self::Int64(values) - | Self::TimeMicros(values) - | Self::TimestampMillis(_, values) - | Self::TimestampMicros(_, values) => values.push(buf.get_long()?), - Self::Float32(values) => values.push(buf.get_float()?), - Self::Float64(values) => values.push(buf.get_double()?), - Self::Binary(offsets, values) | Self::String(offsets, values) => { - let data = buf.get_bytes()?; - offsets.push_length(data.len()); - values.extend_from_slice(data); - } - Self::List(_, _, _) => { - return Err(ArrowError::NotYetImplemented( - "Decoding ListArray".to_string(), - )) - } - Self::Record(_, encodings) => { - for encoding in encodings { - encoding.decode(buf)?; + Self::Null(n) => { + *n += 1; + } + Self::Boolean(b) => { + b.append(buf.get_bool()?); + } + Self::Int32(v) => { + v.push(buf.get_int()?); + } + Self::Int64(v) => { + v.push(buf.get_long()?); + } + Self::Float32(vals) => { + vals.push(buf.get_float()?); + } + Self::Float64(vals) => { + vals.push(buf.get_double()?); + } + Self::Binary(off, data) | Self::String(off, data) => { + let bytes = buf.get_bytes()?; + off.push_length(bytes.len()); + data.extend_from_slice(bytes); + } + Self::Record(_, children) => { + for c in children { + c.decode(buf)?; } } - Self::Nullable(nullability, nulls, e) => { - let is_valid = buf.get_bool()? == matches!(nullability, Nullability::NullFirst); - nulls.append(is_valid); - match is_valid { - true => e.decode(buf)?, - false => e.append_null(), + Self::Enum(_, idxs) => { + idxs.push(buf.get_int()?); + } + Self::List(_, off, child) => { + let total_items = read_array_blocks(buf, |cursor| child.decode(cursor))?; + off.push_length(total_items); + } + Self::Map(_, koff, moff, kdata, valdec) => { + let newly_added = read_map_blocks(buf, |cur| { + let kb = cur.get_bytes()?; + koff.push_length(kb.len()); + kdata.extend_from_slice(kb); + valdec.decode(cur) + })?; + moff.push_length(newly_added); + } + Self::Nullable(order, nb, child) => { + let branch = buf.get_int()?; + match order { + UnionOrder::NullFirst => { + if branch == 0 { + nb.append(false); + child.append_null(); + } else { + nb.append(true); + child.decode(buf)?; + } + } + UnionOrder::NullSecond => { + if branch == 0 { + nb.append(true); + child.decode(buf)?; + } else { + nb.append(false); + child.append_null(); + } + } } } + Self::Fixed(sz, accum) => { + let fx = buf.get_fixed(*sz as usize)?; + accum.extend_from_slice(fx); + } + Self::Decimal(_, _, fsz, db) => { + let raw = match *fsz { + Some(n) => buf.get_fixed(n)?, + None => buf.get_bytes()?, + }; + db.append_bytes(raw)?; + } + Self::Date32(vals) => vals.push(buf.get_int()?), + Self::TimeMillis(vals) => vals.push(buf.get_int()?), + Self::TimeMicros(vals) => vals.push(buf.get_long()?), + Self::TimestampMillis(_, vals) => vals.push(buf.get_long()?), + Self::TimestampMicros(_, vals) => vals.push(buf.get_long()?), + Self::Interval(ivals) => { + let x = buf.get_fixed(12)?; + let months = i32::from_le_bytes(x[0..4].try_into().unwrap()); + let days = i32::from_le_bytes(x[4..8].try_into().unwrap()); + let ms = i32::from_le_bytes(x[8..12].try_into().unwrap()); + let nanos = ms as i64 * 1_000_000; + ivals.push(IntervalMonthDayNano { + months, + days, + nanoseconds: nanos, + }); + } } Ok(()) } - /// Flush decoded records to an [`ArrayRef`] - fn flush(&mut self, nulls: Option) -> Result { - Ok(match self { - Self::Nullable(_, n, e) => e.flush(n.finish())?, - Self::Null(size) => Arc::new(NullArray::new(std::mem::replace(size, 0))), - Self::Boolean(b) => Arc::new(BooleanArray::new(b.finish(), nulls)), - Self::Int32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Date32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Int64(values) => Arc::new(flush_primitive::(values, nulls)), - Self::TimeMillis(values) => { - Arc::new(flush_primitive::(values, nulls)) - } - Self::TimeMicros(values) => { - Arc::new(flush_primitive::(values, nulls)) - } - Self::TimestampMillis(is_utc, values) => Arc::new( - flush_primitive::(values, nulls) - .with_timezone_opt(is_utc.then(|| "+00:00")), - ), - Self::TimestampMicros(is_utc, values) => Arc::new( - flush_primitive::(values, nulls) - .with_timezone_opt(is_utc.then(|| "+00:00")), - ), - Self::Float32(values) => Arc::new(flush_primitive::(values, nulls)), - Self::Float64(values) => Arc::new(flush_primitive::(values, nulls)), - - Self::Binary(offsets, values) => { - let offsets = flush_offsets(offsets); - let values = flush_values(values).into(); - Arc::new(BinaryArray::new(offsets, values, nulls)) + fn flush(&mut self, nulls: Option) -> Result, ArrowError> { + match self { + Self::Null(count) => { + let c = std::mem::replace(count, 0); + Ok(Arc::new(NullArray::new(c)) as Arc) + } + Self::Boolean(b) => { + let bits = b.finish(); + Ok(Arc::new(BooleanArray::new(bits, nulls)) as Arc) + } + Self::Int32(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::Date32(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) } - Self::String(offsets, values) => { - let offsets = flush_offsets(offsets); - let values = flush_values(values).into(); - Arc::new(StringArray::new(offsets, values, nulls)) + Self::Int64(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) } - Self::List(field, offsets, values) => { - let values = values.flush(None)?; - let offsets = flush_offsets(offsets); - Arc::new(ListArray::new(field.clone(), offsets, values, nulls)) + Self::Float32(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) } - Self::Record(fields, encodings) => { - let arrays = encodings - .iter_mut() - .map(|x| x.flush(None)) - .collect::, _>>()?; - Arc::new(StructArray::new(fields.clone(), arrays, nulls)) + Self::Float64(v) => { + let arr = flush_primitive::(v, nulls); + Ok(Arc::new(arr) as Arc) } - }) + Self::Binary(off, data) => { + let offsets = flush_offsets(off); + let vals = flush_values(data).into(); + let arr = BinaryArray::new(offsets, vals, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::String(off, data) => { + let offsets = flush_offsets(off); + let vals = flush_values(data).into(); + let arr = StringArray::new(offsets, vals, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::Record(fields, children) => { + let mut child_arrays = Vec::with_capacity(children.len()); + for c in children { + child_arrays.push(c.flush(None)?); + } + let first_len = match child_arrays.first() { + Some(a) => a.len(), + None => 0, + }; + for (i, arr) in child_arrays.iter().enumerate() { + if arr.len() != first_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Inconsistent struct child length for field #{i}. Expected {first_len}, got {}", + arr.len() + ))); + } + } + if let Some(n) = &nulls { + if n.len() != first_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Struct null buffer length {} != struct fields length {first_len}", + n.len() + ))); + } + } + let sarr = StructArray::new(fields.clone(), child_arrays, nulls); + Ok(Arc::new(sarr) as Arc) + } + Self::Enum(symbols, idxs) => { + let dict_vals = StringArray::from_iter_values(symbols.iter()); + let i32arr = match nulls { + Some(nb) => { + let buff = Buffer::from_slice_ref(&*idxs); + PrimitiveArray::::try_new( + arrow_buffer::ScalarBuffer::from(buff), + Some(nb), + )? + } + None => Int32Array::from_iter_values(idxs.iter().cloned()), + }; + idxs.clear(); + let d = DictionaryArray::::try_new(i32arr, Arc::new(dict_vals))?; + Ok(Arc::new(d) as Arc) + } + Self::List(item_field, off, child) => { + let c = child.flush(None)?; + let offsets = flush_offsets(off); + let final_len = offsets.len() - 1; + if let Some(n) = &nulls { + if n.len() != final_len { + return Err(ArrowError::InvalidArgumentError(format!( + "List array null buffer length {} != final list length {final_len}", + n.len() + ))); + } + } + let larr = ListArray::new(item_field.clone(), offsets, c, nulls); + Ok(Arc::new(larr) as Arc) + } + Self::Map(map_field, k_off, m_off, kdata, valdec) => { + let moff = flush_offsets(m_off); + let koff = flush_offsets(k_off); + let kd = flush_values(kdata).into(); + let val_arr = valdec.flush(None)?; + let key_arr = StringArray::new(koff, kd, None); + if key_arr.len() != val_arr.len() { + return Err(ArrowError::InvalidArgumentError(format!( + "Map keys length ({}) != map values length ({})", + key_arr.len(), + val_arr.len() + ))); + } + let final_len = moff.len() - 1; + if let Some(n) = &nulls { + if n.len() != final_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Map array null buffer length {} != final map length {final_len}", + n.len() + ))); + } + } + let entries_struct = StructArray::new( + Fields::from(vec![ + Arc::new(ArrowField::new("key", DataType::Utf8, false)), + Arc::new(ArrowField::new("value", val_arr.data_type().clone(), true)), + ]), + vec![Arc::new(key_arr), val_arr], + None, + ); + let map_arr = MapArray::new(map_field.clone(), moff, entries_struct, nulls, false); + Ok(Arc::new(map_arr) as Arc) + } + Self::Nullable(_, nb_builder, child) => { + let mask = nb_builder.finish(); + child.flush(mask) + } + Self::Fixed(sz, accum) => { + let b: Buffer = flush_values(accum).into(); + let arr = FixedSizeBinaryArray::try_new(*sz, b, nulls) + .map_err(|e| ArrowError::ParseError(e.to_string()))?; + Ok(Arc::new(arr) as Arc) + } + Self::Decimal(precision, scale, sz, builder) => { + let p = *precision; + let s = scale.unwrap_or(0); + let new_b = DecimalBuilder::new(p, *scale, *sz)?; + let old = std::mem::replace(builder, new_b); + let arr = old.finish(nulls, p, s)?; + Ok(arr) + } + Self::TimeMillis(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::TimeMicros(vals) => { + let arr = flush_primitive::(vals, nulls); + Ok(Arc::new(arr) as Arc) + } + Self::TimestampMillis(is_utc, vals) => { + let arr = flush_primitive::(vals, nulls) + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); + Ok(Arc::new(arr) as Arc) + } + Self::TimestampMicros(is_utc, vals) => { + let arr = flush_primitive::(vals, nulls) + .with_timezone_opt::>(is_utc.then(|| "+00:00".into())); + Ok(Arc::new(arr) as Arc) + } + Self::Interval(ivals) => { + let len = ivals.len(); + let mut b = PrimitiveBuilder::::with_capacity(len); + for v in ivals.drain(..) { + b.append_value(v); + } + let arr = b + .finish() + .with_data_type(DataType::Interval(IntervalUnit::MonthDayNano)); + if let Some(nb) = nulls { + let arr_data = arr.into_data().into_builder().nulls(Some(nb)); + let arr_data = arr_data.build()?; + Ok( + Arc::new(PrimitiveArray::::from(arr_data)) + as Arc, + ) + } else { + Ok(Arc::new(arr) as Arc) + } + } + } } } -#[inline] -fn flush_values(values: &mut Vec) -> Vec { - std::mem::replace(values, Vec::with_capacity(DEFAULT_CAPACITY)) +fn read_array_blocks( + buf: &mut AvroCursor, + decode_item: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + read_blockwise_items(buf, true, decode_item) +} + +fn read_map_blocks( + buf: &mut AvroCursor, + decode_entry: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + read_blockwise_items(buf, true, decode_entry) } -#[inline] -fn flush_offsets(offsets: &mut OffsetBufferBuilder) -> OffsetBuffer { - std::mem::replace(offsets, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +fn read_blockwise_items( + buf: &mut AvroCursor, + read_size_after_negative: bool, + mut decode_fn: impl FnMut(&mut AvroCursor) -> Result<(), ArrowError>, +) -> Result { + let mut total = 0usize; + loop { + let blk = buf.get_long()?; + match blk.cmp(&0) { + Ordering::Equal => break, + Ordering::Less => { + let cnt = (-blk) as usize; + if read_size_after_negative { + let _size_in_bytes = buf.get_long()?; + } + for _ in 0..cnt { + decode_fn(buf)?; + } + total += cnt; + } + Ordering::Greater => { + let cnt = blk as usize; + for _i in 0..cnt { + decode_fn(buf)?; + } + total += cnt; + } + } + } + Ok(total) } -#[inline] fn flush_primitive( - values: &mut Vec, - nulls: Option, + vals: &mut Vec, + nb: Option, ) -> PrimitiveArray { - PrimitiveArray::new(flush_values(values).into(), nulls) + PrimitiveArray::new(std::mem::take(vals).into(), nb) } -const DEFAULT_CAPACITY: usize = 1024; +fn flush_offsets(ob: &mut OffsetBufferBuilder) -> OffsetBuffer { + std::mem::replace(ob, OffsetBufferBuilder::new(DEFAULT_CAPACITY)).finish() +} + +fn flush_values(vec: &mut Vec) -> Vec { + std::mem::replace(vec, Vec::with_capacity(DEFAULT_CAPACITY)) +} + +/// A builder for Avro decimal, either 128-bit or 256-bit. +#[derive(Debug)] +enum DecimalBuilder { + Decimal128(Decimal128Builder), + Decimal256(Decimal256Builder), +} + +impl DecimalBuilder { + fn new( + precision: usize, + scale: Option, + size: Option, + ) -> Result { + let prec = precision as u8; + let scl = scale.unwrap_or(0) as i8; + if let Some(s) = size { + if s <= 16 { + return Ok(Self::Decimal128( + Decimal128Builder::new().with_precision_and_scale(prec, scl)?, + )); + } + if s <= 32 { + return Ok(Self::Decimal256( + Decimal256Builder::new().with_precision_and_scale(prec, scl)?, + )); + } + return Err(ArrowError::ParseError(format!( + "Unsupported decimal size: {s:?}" + ))); + } + if precision <= DECIMAL128_MAX_PRECISION as usize { + Ok(Self::Decimal128( + Decimal128Builder::new().with_precision_and_scale(prec, scl)?, + )) + } else if precision <= DECIMAL256_MAX_PRECISION as usize { + Ok(Self::Decimal256( + Decimal256Builder::new().with_precision_and_scale(prec, scl)?, + )) + } else { + Err(ArrowError::ParseError(format!( + "Decimal precision {} exceeds maximum supported", + precision + ))) + } + } + + fn append_bytes(&mut self, raw: &[u8]) -> Result<(), ArrowError> { + match self { + Self::Decimal128(b) => { + let ext = sign_extend_to_16(raw)?; + let val = i128::from_be_bytes(ext); + b.append_value(val); + } + Self::Decimal256(b) => { + let ext = sign_extend_to_32(raw)?; + let val = i256::from_be_bytes(ext); + b.append_value(val); + } + } + Ok(()) + } + + fn append_null(&mut self) -> Result<(), ArrowError> { + match self { + Self::Decimal128(b) => { + let zero = [0u8; 16]; + b.append_value(i128::from_be_bytes(zero)); + } + Self::Decimal256(b) => { + let zero = [0u8; 32]; + b.append_value(i256::from_be_bytes(zero)); + } + } + Ok(()) + } + + fn finish( + self, + nb: Option, + precision: usize, + scale: usize, + ) -> Result, ArrowError> { + match self { + Self::Decimal128(mut b) => { + let arr = b.finish(); + let vals = arr.values().clone(); + let dec = Decimal128Array::new(vals, nb) + .with_precision_and_scale(precision as u8, scale as i8)?; + Ok(Arc::new(dec)) + } + Self::Decimal256(mut b) => { + let arr = b.finish(); + let vals = arr.values().clone(); + let dec = Decimal256Array::new(vals, nb) + .with_precision_and_scale(precision as u8, scale as i8)?; + Ok(Arc::new(dec)) + } + } + } +} + +fn sign_extend_to_16(raw: &[u8]) -> Result<[u8; 16], ArrowError> { + let ext = sign_extend(raw, 16); + if ext.len() != 16 { + return Err(ArrowError::ParseError(format!( + "Failed to extend to 16 bytes, got {} bytes", + ext.len() + ))); + } + let mut arr = [0u8; 16]; + arr.copy_from_slice(&ext); + Ok(arr) +} + +fn sign_extend_to_32(raw: &[u8]) -> Result<[u8; 32], ArrowError> { + let ext = sign_extend(raw, 32); + if ext.len() != 32 { + return Err(ArrowError::ParseError(format!( + "Failed to extend to 32 bytes, got {} bytes", + ext.len() + ))); + } + let mut arr = [0u8; 32]; + arr.copy_from_slice(&ext); + Ok(arr) +} + +fn sign_extend(raw: &[u8], target_len: usize) -> Vec { + if raw.is_empty() { + return vec![0; target_len]; + } + let sign_bit = raw[0] & 0x80; + let mut out = Vec::with_capacity(target_len); + if sign_bit != 0 { + out.resize(target_len - raw.len(), 0xFF); + } else { + out.resize(target_len - raw.len(), 0x00); + } + out.extend_from_slice(raw); + out +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::codec::AvroField; + use crate::schema::Schema; + use arrow_array::{cast::AsArray, Array, ListArray, MapArray, StructArray}; + use std::sync::Arc; + + fn encode_avro_int(value: i32) -> Vec { + let mut buf = Vec::new(); + let mut v = (value << 1) ^ (value >> 31); + while v & !0x7F != 0 { + buf.push(((v & 0x7F) | 0x80) as u8); + v >>= 7; + } + buf.push(v as u8); + buf + } + + fn encode_avro_long(value: i64) -> Vec { + let mut buf = Vec::new(); + let mut v = (value << 1) ^ (value >> 63); + while v & !0x7F != 0 { + buf.push(((v & 0x7F) | 0x80) as u8); + v >>= 7; + } + buf.push(v as u8); + buf + } + + fn encode_avro_bytes(bytes: &[u8]) -> Vec { + let mut out = encode_avro_long(bytes.len() as i64); + out.extend_from_slice(bytes); + out + } + + fn encode_union_branch(branch_idx: i32) -> Vec { + encode_avro_int(branch_idx) + } + + fn encode_array(items: &[T], mut encode_item: impl FnMut(&T) -> Vec) -> Vec { + let mut out = Vec::new(); + if !items.is_empty() { + out.extend_from_slice(&encode_avro_long(items.len() as i64)); + for it in items { + out.extend_from_slice(&encode_item(it)); + } + } + out.extend_from_slice(&encode_avro_long(0)); + out + } + + fn encode_map(entries: &[(&str, Vec)]) -> Vec { + let mut out = Vec::new(); + if !entries.is_empty() { + out.extend_from_slice(&encode_avro_long(entries.len() as i64)); + for (k, val) in entries { + out.extend_from_slice(&encode_avro_bytes(k.as_bytes())); + out.extend_from_slice(val); + } + } + out.extend_from_slice(&encode_avro_long(0)); + out + } + + #[test] + fn test_union_primitive_long_null_record_decoder() { + let json_schema = r#" + { + "type": "record", + "name": "topLevelRecord", + "fields": [ + { + "name": "id", + "type": ["long","null"] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_record = AvroField::try_from(&schema).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); + let mut data = Vec::new(); + data.extend_from_slice(&encode_union_branch(0)); + data.extend_from_slice(&encode_avro_long(1)); + data.extend_from_slice(&encode_union_branch(1)); + let used = record_decoder.decode(&data, 2).unwrap(); + assert_eq!(used, data.len()); + let batch = record_decoder.flush().unwrap(); + assert_eq!(batch.num_rows(), 2); + let arr = batch.column(0).as_primitive::(); + assert_eq!(arr.value(0), 1); + assert!(arr.is_null(1)); + } + + #[test] + fn test_union_array_of_int_null_record_decoder() { + let json_schema = r#" + { + "type":"record", + "name":"topLevelRecord", + "fields":[ + { + "name":"int_array", + "type":[ + { + "type":"array", + "items":[ "int", "null" ] + }, + "null" + ] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_record = AvroField::try_from(&schema).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); + let mut data = Vec::new(); + + fn encode_int_or_null(opt_val: &Option) -> Vec { + match opt_val { + Some(v) => { + let mut out = encode_union_branch(0); + out.extend_from_slice(&encode_avro_int(*v)); + out + } + None => encode_union_branch(1), + } + } + + data.extend_from_slice(&encode_union_branch(0)); + let row1_values = vec![Some(1), Some(2), Some(3)]; + data.extend_from_slice(&encode_array(&row1_values, encode_int_or_null)); + data.extend_from_slice(&encode_union_branch(0)); + let row2_values = vec![None, Some(1), Some(2), None, Some(3), None]; + data.extend_from_slice(&encode_array(&row2_values, encode_int_or_null)); + data.extend_from_slice(&encode_union_branch(0)); + data.extend_from_slice(&encode_avro_long(0)); + data.extend_from_slice(&encode_union_branch(1)); + record_decoder.decode(&data, 4).unwrap(); + let batch = record_decoder.flush().unwrap(); + assert_eq!(batch.num_rows(), 4); + let list_arr = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(list_arr.is_null(3)); + { + let start = list_arr.value_offsets()[0] as usize; + let end = list_arr.value_offsets()[1] as usize; + let child = list_arr.values().as_primitive::(); + assert_eq!(end - start, 3); + assert_eq!(child.value(start), 1); + assert_eq!(child.value(start + 1), 2); + assert_eq!(child.value(start + 2), 3); + } + { + let start = list_arr.value_offsets()[1] as usize; + let end = list_arr.value_offsets()[2] as usize; + let child = list_arr.values().as_primitive::(); + assert_eq!(end - start, 6); + // index-by-index + assert!(child.is_null(start)); // None + assert_eq!(child.value(start + 1), 1); // Some(1) + assert_eq!(child.value(start + 2), 2); + assert!(child.is_null(start + 3)); + assert_eq!(child.value(start + 4), 3); + assert!(child.is_null(start + 5)); + } + { + let start = list_arr.value_offsets()[2] as usize; + let end = list_arr.value_offsets()[3] as usize; + assert_eq!(end - start, 0); + } + } + + #[test] + fn test_union_nested_array_of_int_null_record_decoder() { + let json_schema = r#" + { + "type":"record", + "name":"topLevelRecord", + "fields":[ + { + "name":"int_array_Array", + "type":[ + { + "type":"array", + "items":[ + { + "type":"array", + "items":[ + "int", + "null" + ] + }, + "null" + ] + }, + "null" + ] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_record = AvroField::try_from(&schema).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); + let mut data = Vec::new(); + + fn encode_inner(vals: &[Option]) -> Vec { + encode_array(vals, |o| match o { + Some(v) => { + let mut out = encode_union_branch(0); + out.extend_from_slice(&encode_avro_int(*v)); + out + } + None => encode_union_branch(1), + }) + } + + data.extend_from_slice(&encode_union_branch(0)); + { + let outer_vals: Vec>>> = + vec![Some(vec![Some(1), Some(2)]), Some(vec![Some(3), None])]; + data.extend_from_slice(&encode_array(&outer_vals, |maybe_arr| match maybe_arr { + Some(vlist) => { + let mut out = encode_union_branch(0); + out.extend_from_slice(&encode_inner(vlist)); + out + } + None => encode_union_branch(1), + })); + } + data.extend_from_slice(&encode_union_branch(0)); + { + let outer_vals: Vec>>> = vec![None]; + data.extend_from_slice(&encode_array(&outer_vals, |maybe_arr| match maybe_arr { + Some(vlist) => { + let mut out = encode_union_branch(0); + out.extend_from_slice(&encode_inner(vlist)); + out + } + None => encode_union_branch(1), + })); + } + data.extend_from_slice(&encode_union_branch(1)); + record_decoder.decode(&data, 3).unwrap(); + let batch = record_decoder.flush().unwrap(); + assert_eq!(batch.num_rows(), 3); + let outer_list = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(outer_list.is_null(2)); + assert!(!outer_list.is_null(0)); + let start = outer_list.value_offsets()[0] as usize; + let end = outer_list.value_offsets()[1] as usize; + assert_eq!(end - start, 2); + let start2 = outer_list.value_offsets()[1] as usize; + let end2 = outer_list.value_offsets()[2] as usize; + assert_eq!(end2 - start2, 1); + let subitem_arr = outer_list.value(1); + let sub_list = subitem_arr.as_any().downcast_ref::().unwrap(); + assert_eq!(sub_list.len(), 1); + assert!(sub_list.is_null(0)); + } + + #[test] + fn test_union_map_of_int_null_record_decoder() { + let json_schema = r#" + { + "type":"record", + "name":"topLevelRecord", + "fields":[ + { + "name":"int_map", + "type":[ + { + "type":"map", + "values":[ + "int", + "null" + ] + }, + "null" + ] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_record = AvroField::try_from(&schema).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); + let mut data = Vec::new(); + data.extend_from_slice(&encode_union_branch(0)); + let row1_map = vec![ + ("k1", { + let mut out = encode_union_branch(0); + out.extend_from_slice(&encode_avro_int(1)); + out + }), + ("k2", encode_union_branch(1)), + ]; + data.extend_from_slice(&encode_map(&row1_map)); + data.extend_from_slice(&encode_union_branch(0)); + let empty: [(&str, Vec); 0] = []; + data.extend_from_slice(&encode_map(&empty)); + data.extend_from_slice(&encode_union_branch(1)); + record_decoder.decode(&data, 3).unwrap(); + let batch = record_decoder.flush().unwrap(); + assert_eq!(batch.num_rows(), 3); + let map_arr = batch.column(0).as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr.len(), 3); + assert!(map_arr.is_null(2)); + assert_eq!(map_arr.value_length(0), 2); + let binding = map_arr.value(0); + let struct_arr = binding.as_any().downcast_ref::().unwrap(); + let keys = struct_arr.column(0).as_string::(); + let vals = struct_arr.column(1).as_primitive::(); + assert_eq!(keys.value(0), "k1"); + assert_eq!(vals.value(0), 1); + assert_eq!(keys.value(1), "k2"); + assert!(vals.is_null(1)); + assert_eq!(map_arr.value_length(1), 0); + } + + #[test] + fn test_union_map_array_of_int_null_record_decoder() { + let json_schema = r#" + { + "type": "record", + "name": "topLevelRecord", + "fields": [ + { + "name": "int_Map_Array", + "type": [ + { + "type": "array", + "items": [ + { + "type": "map", + "values": [ + "int", + "null" + ] + }, + "null" + ] + }, + "null" + ] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_record = AvroField::try_from(&schema).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); + let mut data = Vec::new(); + fn encode_map_int_null(entries: &[(&str, Option)]) -> Vec { + let items: Vec<(&str, Vec)> = entries + .iter() + .map(|(k, v)| { + let val = match v { + Some(x) => { + let mut out = encode_union_branch(0); + out.extend_from_slice(&encode_avro_int(*x)); + out + } + None => encode_union_branch(1), + }; + (*k, val) + }) + .collect(); + encode_map(&items) + } + data.extend_from_slice(&encode_union_branch(0)); + { + let mut arr_buf = encode_avro_long(1); + { + let mut item_buf = encode_union_branch(0); + item_buf.extend_from_slice(&encode_map_int_null(&[("k1", Some(1))])); + arr_buf.extend_from_slice(&item_buf); + } + arr_buf.extend_from_slice(&encode_avro_long(0)); + data.extend_from_slice(&arr_buf); + } + data.extend_from_slice(&encode_union_branch(0)); + { + let mut arr_buf = encode_avro_long(2); // 2 items + arr_buf.extend_from_slice(&encode_union_branch(1)); + { + let mut item1 = encode_union_branch(0); + item1.extend_from_slice(&encode_map_int_null(&[("k2", None)])); + arr_buf.extend_from_slice(&item1); + } + arr_buf.extend_from_slice(&encode_avro_long(0)); // end + data.extend_from_slice(&arr_buf); + } + data.extend_from_slice(&encode_union_branch(1)); + record_decoder.decode(&data, 3).unwrap(); + let batch = record_decoder.flush().unwrap(); + assert_eq!(batch.num_rows(), 3); + let outer_list = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(outer_list.is_null(2)); + { + let start = outer_list.value_offsets()[0] as usize; + let end = outer_list.value_offsets()[1] as usize; + assert_eq!(end - start, 1); + let subarr = outer_list.value(0); + let sublist = subarr.as_any().downcast_ref::().unwrap(); + assert_eq!(sublist.len(), 1); + assert!(!sublist.is_null(0)); + let sub_value_0 = sublist.value(0); + let struct_arr = sub_value_0.as_any().downcast_ref::().unwrap(); + let keys = struct_arr.column(0).as_string::(); + let vals = struct_arr.column(1).as_primitive::(); + assert_eq!(keys.value(0), "k1"); + assert_eq!(vals.value(0), 1); + } + } + + #[test] + fn test_union_nested_struct_out_of_spec_record_decoder() { + let json_schema = r#" + { + "type":"record", + "name":"topLevelRecord", + "fields":[ + { + "name":"nested_struct", + "type":[ + { + "type":"record", + "name":"nested_struct", + "namespace":"topLevelRecord", + "fields":[ + { + "name":"A", + "type":[ + "int", + "null" + ] + }, + { + "name":"b", + "type":[ + { + "type":"array", + "items":[ + "int", + "null" + ] + }, + "null" + ] + } + ] + }, + "null" + ] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_record = AvroField::try_from(&schema).unwrap(); + let mut record_decoder = RecordDecoder::try_new(avro_record.data_type(), false).unwrap(); + let mut data = Vec::new(); + data.extend_from_slice(&encode_union_branch(0)); + data.extend_from_slice(&encode_union_branch(0)); + data.extend_from_slice(&encode_avro_int(7)); + data.extend_from_slice(&encode_union_branch(0)); + let row1_b = [Some(1), Some(2)]; + data.extend_from_slice(&encode_array(&row1_b, |val| match val { + Some(x) => { + let mut out = encode_union_branch(0); + out.extend_from_slice(&encode_avro_int(*x)); + out + } + None => encode_union_branch(1), + })); + data.extend_from_slice(&encode_union_branch(0)); + data.extend_from_slice(&encode_union_branch(1)); + data.extend_from_slice(&encode_union_branch(1)); + data.extend_from_slice(&encode_union_branch(1)); + record_decoder.decode(&data, 3).unwrap(); + let batch = record_decoder.flush().unwrap(); + assert_eq!(batch.num_rows(), 3); + let col = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(col.is_null(2)); + let field_a = col.column(0).as_primitive::(); + let field_b = col.column(1).as_any().downcast_ref::().unwrap(); + assert_eq!(field_a.value(0), 7); + { + let start = field_b.value_offsets()[0] as usize; + let end = field_b.value_offsets()[1] as usize; + let values = field_b.values().as_primitive::(); + assert_eq!(end - start, 2); + assert_eq!(values.value(start), 1); + assert_eq!(values.value(start + 1), 2); + } + assert!(field_a.is_null(1)); + assert!(field_b.is_null(1)); + } + + #[test] + fn test_record_decoder_default_metadata() { + use crate::codec::AvroField; + use crate::schema::Schema; + let json_schema = r#" + { + "type": "record", + "name": "TestRecord", + "fields": [ + {"name": "default_int", "type": "int", "default": 42} + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + let avro_record = AvroField::try_from(&schema).unwrap(); + let record_decoder = RecordDecoder::try_new(avro_record.data_type(), true).unwrap(); + let arrow_schema = record_decoder.schema(); + assert_eq!(arrow_schema.fields().len(), 1); + let field = arrow_schema.field(0); + let metadata = field.metadata(); + assert_eq!(metadata.get("avro.default").unwrap(), "42"); + } + + #[test] + fn test_fixed_decoding() { + let dt = AvroDataType::from_codec(Codec::Fixed(4)); + let mut dec = Decoder::try_new(&dt, true).unwrap(); + let row1 = [0xDE, 0xAD, 0xBE, 0xEF]; + let row2 = [0x01, 0x23, 0x45, 0x67]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); + dec.decode(&mut cursor).unwrap(); + let arr = dec.flush(None).unwrap(); + let fsb = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(fsb.len(), 2); + assert_eq!(fsb.value_length(), 4); + assert_eq!(fsb.value(0), row1); + assert_eq!(fsb.value(1), row2); + } + + #[test] + fn test_fixed_with_nulls() { + let dt = AvroDataType::from_codec(Codec::Fixed(2)); + let child = Decoder::try_new(&dt, true).unwrap(); + let mut dec = Decoder::Nullable( + UnionOrder::NullSecond, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(child), + ); + let row1 = [0x11, 0x22]; + let row3 = [0x55, 0x66]; + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row1); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row3); + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); // Row1 + dec.decode(&mut cursor).unwrap(); // Row2 (null) + dec.decode(&mut cursor).unwrap(); // Row3 + let arr = dec.flush(None).unwrap(); + let fsb = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(fsb.len(), 3); + assert!(fsb.is_valid(0)); + assert!(!fsb.is_valid(1)); + assert!(fsb.is_valid(2)); + assert_eq!(fsb.value_length(), 2); + assert_eq!(fsb.value(0), row1); + assert_eq!(fsb.value(2), row3); + } + + #[test] + fn test_interval_decoding() { + let dt = AvroDataType::from_codec(Codec::Duration); + let mut dec = Decoder::try_new(&dt, true).unwrap(); + let row1 = [ + 0x01, 0x00, 0x00, 0x00, // months=1 + 0x02, 0x00, 0x00, 0x00, // days=2 + 0x64, 0x00, 0x00, 0x00, // ms=100 + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, // months=-1 + 0x0A, 0x00, 0x00, 0x00, // days=10 + 0x0F, 0x27, 0x00, 0x00, // ms=9999 + ]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); + dec.decode(&mut cursor).unwrap(); + let arr = dec.flush(None).unwrap(); + let intervals = arr + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(intervals.len(), 2); + let val0 = intervals.value(0); + assert_eq!(val0.months, 1); + assert_eq!(val0.days, 2); + assert_eq!(val0.nanoseconds, 100_000_000); + let val1 = intervals.value(1); + assert_eq!(val1.months, -1); + assert_eq!(val1.days, 10); + assert_eq!(val1.nanoseconds, 9_999_000_000); + } + + #[test] + fn test_interval_decoding_with_nulls() { + let dt = AvroDataType::from_codec(Codec::Duration); + let child = Decoder::try_new(&dt, true).unwrap(); + let mut dec = Decoder::Nullable( + UnionOrder::NullSecond, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(child), + ); + let row1 = [ + 0x02, 0x00, 0x00, 0x00, // months=2 + 0x03, 0x00, 0x00, 0x00, // days=3 + 0xF4, 0x01, 0x00, 0x00, // ms=500 + ]; + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row1); + data.extend_from_slice(&encode_avro_int(1)); + let mut cursor = AvroCursor::new(&data); + dec.decode(&mut cursor).unwrap(); // Row1 + dec.decode(&mut cursor).unwrap(); // Row2 (null) + let arr = dec.flush(None).unwrap(); + let intervals = arr + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(intervals.len(), 2); + assert!(intervals.is_valid(0)); + assert!(!intervals.is_valid(1)); + let val0 = intervals.value(0); + assert_eq!(val0.months, 2); + assert_eq!(val0.days, 3); + assert_eq!(val0.nanoseconds, 500_000_000); + } + + #[test] + fn test_enum_decoding() { + let symbols = Arc::new(["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]); + let enum_dt = AvroDataType::from_codec(Codec::Enum(symbols, Arc::new([]))); + let mut decoder = Decoder::try_new(&enum_dt, true).unwrap(); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_int(2)); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let dict_arr = array + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(dict_arr.len(), 3); + let keys = dict_arr.keys(); + assert_eq!(keys.value(0), 1); + assert_eq!(keys.value(1), 0); + assert_eq!(keys.value(2), 2); + let dict_values = dict_arr.values().as_string::(); + assert_eq!(dict_values.value(0), "RED"); + assert_eq!(dict_values.value(1), "GREEN"); + assert_eq!(dict_values.value(2), "BLUE"); + } + + #[test] + fn test_enum_decoding_with_nulls() { + let symbols = ["RED".to_string(), "GREEN".to_string(), "BLUE".to_string()]; + let enum_dt = AvroDataType::from_codec(Codec::Enum(Arc::new(symbols), Arc::new([]))); + let inner_decoder = Decoder::try_new(&enum_dt, true).unwrap(); + let mut nullable_decoder = Decoder::Nullable( + UnionOrder::NullSecond, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(inner_decoder), + ); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_int(0)); + let mut cursor = AvroCursor::new(&data); + nullable_decoder.decode(&mut cursor).unwrap(); + nullable_decoder.decode(&mut cursor).unwrap(); + nullable_decoder.decode(&mut cursor).unwrap(); + let array = nullable_decoder.flush(None).unwrap(); + let dict_arr = array + .as_any() + .downcast_ref::>() + .unwrap(); + assert_eq!(dict_arr.len(), 3); + assert!(dict_arr.is_valid(0)); + assert!(!dict_arr.is_valid(1)); + assert!(dict_arr.is_valid(2)); + let dict_values = dict_arr.values().as_string::(); + assert_eq!(dict_values.value(0), "RED"); + assert_eq!(dict_values.value(1), "GREEN"); + assert_eq!(dict_values.value(2), "BLUE"); + } + + #[test] + fn test_map_decoding_one_entry() { + let value_type = AvroDataType::from_codec(Codec::String); + let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); + let mut decoder = Decoder::try_new(&map_type, true).unwrap(); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_long(1)); + data.extend_from_slice(&encode_avro_bytes(b"hello")); + data.extend_from_slice(&encode_avro_bytes(b"world")); + data.extend_from_slice(&encode_avro_long(0)); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let map_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr.len(), 1); + assert_eq!(map_arr.value_length(0), 1); + let struct_arr = map_arr.value(0); + assert_eq!(struct_arr.len(), 1); + let keys = struct_arr.column(0).as_string::(); + let vals = struct_arr.column(1).as_string::(); + assert_eq!(keys.value(0), "hello"); + assert_eq!(vals.value(0), "world"); + } + + #[test] + fn test_map_decoding_empty() { + let value_type = AvroDataType::from_codec(Codec::String); + let map_type = AvroDataType::from_codec(Codec::Map(Arc::new(value_type))); + let mut decoder = Decoder::try_new(&map_type, true).unwrap(); + let data = encode_avro_long(0); + decoder.decode(&mut AvroCursor::new(&data)).unwrap(); + let array = decoder.flush(None).unwrap(); + let map_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(map_arr.len(), 1); + assert_eq!(map_arr.value_length(0), 0); + } + + #[test] + fn test_decimal_decoding_fixed128() { + let dt = AvroDataType::from_codec(Codec::Decimal(5, Some(2), Some(16))); + let mut decoder = Decoder::try_new(&dt, true).unwrap(); + let row1 = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x30, 0x39, + ]; + let row2 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, + 0xFF, 0x85, + ]; + let mut data = Vec::new(); + data.extend_from_slice(&row1); + data.extend_from_slice(&row2); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec.len(), 2); + assert_eq!(dec.value_as_string(0), "123.45"); + assert_eq!(dec.value_as_string(1), "-1.23"); + } + + #[test] + fn test_decimal_decoding_bytes_with_nulls() { + let dt = AvroDataType::from_codec(Codec::Decimal(4, Some(1), None)); + let inner = Decoder::try_new(&dt, true).unwrap(); + let mut decoder = Decoder::Nullable( + UnionOrder::NullSecond, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(inner), + ); + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_bytes(&[0x04, 0xD2])); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&encode_avro_bytes(&[0xFB, 0x2E])); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); // row1 + decoder.decode(&mut cursor).unwrap(); // row2 + decoder.decode(&mut cursor).unwrap(); // row3 + let arr = decoder.flush(None).unwrap(); + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "123.4"); + assert_eq!(dec_arr.value_as_string(2), "-123.4"); + } + + #[test] + fn test_decimal_decoding_bytes_with_nulls_fixed_size() { + let dt = AvroDataType::from_codec(Codec::Decimal(6, Some(2), Some(16))); + let inner = Decoder::try_new(&dt, true).unwrap(); + let mut decoder = Decoder::Nullable( + UnionOrder::NullSecond, + NullBufferBuilder::new(DEFAULT_CAPACITY), + Box::new(inner), + ); + let row1 = [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0xE2, 0x40, + ]; + let row3 = [ + 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFE, + 0x1D, 0xC0, + ]; + let mut data = Vec::new(); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row1); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(0)); + data.extend_from_slice(&row3); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + decoder.decode(&mut cursor).unwrap(); + let arr = decoder.flush(None).unwrap(); + let dec_arr = arr.as_any().downcast_ref::().unwrap(); + assert_eq!(dec_arr.len(), 3); + assert!(dec_arr.is_valid(0)); + assert!(!dec_arr.is_valid(1)); + assert!(dec_arr.is_valid(2)); + assert_eq!(dec_arr.value_as_string(0), "1234.56"); + assert_eq!(dec_arr.value_as_string(2), "-1234.56"); + } + + #[test] + fn test_list_decoding() { + let item_dt = AvroDataType::from_codec(Codec::Int32); + let list_dt = AvroDataType::from_codec(Codec::Array(Arc::new(item_dt))); + let mut decoder = Decoder::try_new(&list_dt, true).unwrap(); + let mut row1 = Vec::new(); + row1.extend_from_slice(&encode_avro_long(2)); + row1.extend_from_slice(&encode_avro_int(10)); + row1.extend_from_slice(&encode_avro_int(20)); + row1.extend_from_slice(&encode_avro_long(0)); + let row2 = encode_avro_long(0); + let mut cursor = AvroCursor::new(&row1); + decoder.decode(&mut cursor).unwrap(); + let mut cursor2 = AvroCursor::new(&row2); + decoder.decode(&mut cursor2).unwrap(); + let array = decoder.flush(None).unwrap(); + let list_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(list_arr.len(), 2); + let offsets = list_arr.value_offsets(); + assert_eq!(offsets, &[0, 2, 2]); + let values = list_arr.values(); + let int_arr = values.as_primitive::(); + assert_eq!(int_arr.len(), 2); + assert_eq!(int_arr.value(0), 10); + assert_eq!(int_arr.value(1), 20); + } + + #[test] + fn test_list_decoding_with_negative_block_count() { + let item_dt = AvroDataType::from_codec(Codec::Int32); + let list_dt = AvroDataType::from_codec(Codec::Array(Arc::new(item_dt))); + let mut decoder = Decoder::try_new(&list_dt, true).unwrap(); + let mut data = encode_avro_long(-3); + data.extend_from_slice(&encode_avro_long(12)); + data.extend_from_slice(&encode_avro_int(1)); + data.extend_from_slice(&encode_avro_int(2)); + data.extend_from_slice(&encode_avro_int(3)); + data.extend_from_slice(&encode_avro_long(0)); + let mut cursor = AvroCursor::new(&data); + decoder.decode(&mut cursor).unwrap(); + let array = decoder.flush(None).unwrap(); + let list_arr = array.as_any().downcast_ref::().unwrap(); + assert_eq!(list_arr.len(), 1); + assert_eq!(list_arr.value_length(0), 3); + let values = list_arr.values().as_primitive::(); + assert_eq!(values.len(), 3); + assert_eq!(values.value(0), 1); + assert_eq!(values.value(1), 2); + assert_eq!(values.value(2), 3); + } +} diff --git a/arrow-avro/src/reader/vlq.rs b/arrow-avro/src/reader/vlq.rs index b198a0d66f24..818c1f53cc0a 100644 --- a/arrow-avro/src/reader/vlq.rs +++ b/arrow-avro/src/reader/vlq.rs @@ -84,7 +84,7 @@ fn read_varint_array(buf: [u8; 10]) -> Option<(u64, usize)> { #[cold] fn read_varint_slow(buf: &[u8]) -> Option<(u64, usize)> { let mut value = 0; - for (count, byte) in buf.iter().take(10).enumerate() { + for (count, _) in buf.iter().take(10).enumerate() { let byte = buf[count]; value |= u64::from(byte & 0x7F) << (count * 7); if byte <= 0x7F { diff --git a/arrow-avro/src/schema.rs b/arrow-avro/src/schema.rs index a9d91e47948b..843545fda79b 100644 --- a/arrow-avro/src/schema.rs +++ b/arrow-avro/src/schema.rs @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -use serde::{Deserialize, Serialize}; +use serde::{Deserialize, Deserializer, Serialize}; use std::collections::HashMap; /// The metadata key used for storing the JSON encoded [`Schema`] @@ -123,29 +123,48 @@ pub enum ComplexType<'a> { pub struct Record<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub namespace: Option<&'a str>, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub doc: Option<&'a str>, #[serde(borrow, default)] pub aliases: Vec<&'a str>, #[serde(borrow)] - pub fields: Vec>, + pub fields: Vec>, #[serde(flatten)] pub attributes: Attributes<'a>, } /// A field within a [`Record`] +/// #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] -pub struct Field<'a> { +pub struct RecordField<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub doc: Option<&'a str>, + #[serde(borrow, default)] + pub aliases: Vec<&'a str>, #[serde(borrow)] pub r#type: Schema<'a>, - #[serde(borrow, default)] - pub default: Option<&'a str>, + #[serde( + default, + skip_serializing_if = "Option::is_none", + deserialize_with = "allow_out_of_spec_default" + )] + pub default: Option, +} + +/// Custom parse logic that stores *any* default as raw JSON +/// (including "null" for non-null-first unions). +fn allow_out_of_spec_default<'de, D>(deserializer: D) -> Result, D::Error> +where + D: Deserializer<'de>, +{ + match serde_json::Value::deserialize(deserializer) { + Ok(v) => Ok(Some(v)), + Err(_) => Ok(None), + } } /// An enumeration @@ -155,16 +174,16 @@ pub struct Field<'a> { pub struct Enum<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub namespace: Option<&'a str>, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub doc: Option<&'a str>, #[serde(borrow, default)] pub aliases: Vec<&'a str>, #[serde(borrow)] pub symbols: Vec<&'a str>, - #[serde(borrow, default)] - pub default: Option<&'a str>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub default: Option, #[serde(flatten)] pub attributes: Attributes<'a>, } @@ -198,7 +217,7 @@ pub struct Map<'a> { pub struct Fixed<'a> { #[serde(borrow)] pub name: &'a str, - #[serde(borrow, default)] + #[serde(borrow, default, skip_serializing_if = "Option::is_none")] pub namespace: Option<&'a str>, #[serde(borrow, default)] pub aliases: Vec<&'a str>, @@ -210,7 +229,7 @@ pub struct Fixed<'a> { #[cfg(test)] mod tests { use super::*; - use crate::codec::{AvroDataType, AvroField}; + use crate::codec::AvroField; use arrow_schema::{DataType, Fields, TimeUnit}; use serde_json::json; @@ -254,6 +273,7 @@ mod tests { "type":"fixed", "name":"fixed", "namespace":"topLevelRecord.value", + "aliases":[], "size":11, "logicalType":"decimal", "precision":25, @@ -309,9 +329,10 @@ mod tests { namespace: None, doc: None, aliases: vec![], - fields: vec![Field { + fields: vec![RecordField { name: "value", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::Complex(decimal), Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), @@ -343,15 +364,17 @@ mod tests { doc: None, aliases: vec!["LinkedLongs"], fields: vec![ - Field { + RecordField { name: "value", doc: None, + aliases: vec![], r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Long)), default: None, }, - Field { + RecordField { name: "next", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), Schema::TypeName(TypeName::Ref("LongList")), @@ -359,7 +382,7 @@ mod tests { default: None, } ], - attributes: Attributes::default(), + attributes: Default::default(), })) ); @@ -402,18 +425,20 @@ mod tests { doc: None, aliases: vec![], fields: vec![ - Field { + RecordField { name: "id", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), ]), default: None, }, - Field { + RecordField { name: "timestamp_col", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::Type(timestamp), Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), @@ -463,9 +488,10 @@ mod tests { doc: None, aliases: vec![], fields: vec![ - Field { + RecordField { name: "clientHash", doc: None, + aliases: vec![], r#type: Schema::Complex(ComplexType::Fixed(Fixed { name: "MD5", namespace: None, @@ -475,27 +501,30 @@ mod tests { })), default: None, }, - Field { + RecordField { name: "clientProtocol", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), Schema::TypeName(TypeName::Primitive(PrimitiveType::String)), ]), default: None, }, - Field { + RecordField { name: "serverHash", doc: None, + aliases: vec![], r#type: Schema::TypeName(TypeName::Ref("MD5")), default: None, }, - Field { + RecordField { name: "meta", doc: None, + aliases: vec![], r#type: Schema::Union(vec![ Schema::TypeName(TypeName::Primitive(PrimitiveType::Null)), - Schema::Complex(ComplexType::Map(Map { + Schema::Complex(ComplexType::Map(crate::schema::Map { values: Box::new(Schema::TypeName(TypeName::Primitive( PrimitiveType::Bytes ))), @@ -508,5 +537,230 @@ mod tests { attributes: Default::default(), })) ); + + let t: Type = serde_json::from_str( + r#"{ + "type":"string", + "logicalType":"uuid" + }"#, + ) + .unwrap(); + + let uuid = Type { + r#type: TypeName::Primitive(PrimitiveType::String), + attributes: Attributes { + logical_type: Some("uuid"), + additional: Default::default(), + }, + }; + + assert_eq!(t, uuid); + + // Ensure aliases are parsed + let schema: Schema = serde_json::from_str( + r#"{ + "type": "record", + "name": "Foo", + "aliases": ["Bar"], + "fields" : [ + {"name":"id","aliases":["uid"],"type":"int"} + ] + }"#, + ) + .unwrap(); + + let with_aliases = Schema::Complex(ComplexType::Record(Record { + name: "Foo", + namespace: None, + doc: None, + aliases: vec!["Bar"], + fields: vec![RecordField { + name: "id", + aliases: vec!["uid"], + doc: None, + r#type: Schema::TypeName(TypeName::Primitive(PrimitiveType::Int)), + default: None, + }], + attributes: Default::default(), + })); + + assert_eq!(schema, with_aliases); + } + + #[test] + fn test_default_parsing() { + // Test that a default value is correctly parsed for a record field. + let json_schema = r#" + { + "type": "record", + "name": "TestRecord", + "fields": [ + {"name": "a", "type": "int", "default": 10}, + {"name": "b", "type": "string", "default": "default_str"}, + {"name": "c", "type": "boolean"} + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + if let Schema::Complex(ComplexType::Record(rec)) = schema { + assert_eq!(rec.fields.len(), 3); + assert_eq!(rec.fields[0].default, Some(json!(10))); + assert_eq!(rec.fields[1].default, Some(json!("default_str"))); + assert_eq!(rec.fields[2].default, None); + } else { + panic!("Expected record schema"); + } + } + + #[test] + fn test_union_int_null_with_default_null() { + let json_schema = r#" + { + "type": "record", + "name": "ImpalaNullableRecord", + "fields": [ + {"name": "i", "type": ["int","null"], "default": null} + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + if let Schema::Complex(ComplexType::Record(rec)) = schema { + assert_eq!(rec.fields.len(), 1); + assert_eq!(rec.fields[0].name, "i"); + assert_eq!(rec.fields[0].default, Some(json!(null))); + let field_codec = + AvroField::try_from(&Schema::Complex(ComplexType::Record(rec))).unwrap(); + use arrow_schema::{DataType, Field, Fields}; + assert_eq!( + field_codec.field(), + Field::new( + "ImpalaNullableRecord", + DataType::Struct(Fields::from(vec![Field::new("i", DataType::Int32, true),])), + false + ) + ); + } else { + panic!("Expected record schema with union int|null, default null"); + } + } + + #[test] + fn test_union_impala_null_with_default_null() { + let json_schema = r#" + { + "type":"record","name":"topLevelRecord","fields":[ + {"name":"id","type":["long","null"]}, + {"name":"int_array","type":[{"type":"array","items":["int","null"]},"null"]}, + {"name":"int_array_Array","type":[{"type":"array","items":[{"type":"array","items":["int","null"]},"null"]},"null"]}, + {"name":"int_map","type":[{"type":"map","values":["int","null"]},"null"]}, + {"name":"int_Map_Array","type":[{"type":"array","items":[{"type":"map","values":["int","null"]},"null"]},"null"]}, + { + "name":"nested_struct", + "type":[ + { + "type":"record", + "name":"nested_struct", + "namespace":"topLevelRecord", + "fields":[ + {"name":"A","type":["int","null"]}, + {"name":"b","type":[{"type":"array","items":["int","null"]},"null"]}, + { + "name":"C", + "type":[ + { + "type":"record", + "name":"C", + "namespace":"topLevelRecord.nested_struct", + "fields":[ + { + "name":"d", + "type":[ + { + "type":"array", + "items":[ + { + "type":"array", + "items":[ + { + "type":"record", + "name":"d", + "namespace":"topLevelRecord.nested_struct.C", + "fields":[ + {"name":"E","type":["int","null"]}, + {"name":"F","type":["string","null"]} + ] + }, + "null" + ] + }, + "null" + ] + }, + "null" + ] + } + ] + }, + "null" + ] + }, + { + "name":"g", + "type":[ + { + "type":"map", + "values":[ + { + "type":"record", + "name":"g", + "namespace":"topLevelRecord.nested_struct", + "fields":[ + { + "name":"H", + "type":[ + { + "type":"record", + "name":"H", + "namespace":"topLevelRecord.nested_struct.g", + "fields":[ + { + "name":"i", + "type":[ + { + "type":"array", + "items":["double","null"] + }, + "null" + ] + } + ] + }, + "null" + ] + } + ] + }, + "null" + ] + }, + "null" + ] + } + ] + }, + "null" + ] + } + ] + } + "#; + let schema: Schema = serde_json::from_str(json_schema).unwrap(); + if let Schema::Complex(ComplexType::Record(rec)) = &schema { + assert_eq!(rec.name, "topLevelRecord"); + assert_eq!(rec.fields.len(), 6); + let _field_codec = AvroField::try_from(&schema).unwrap(); + } else { + panic!("Expected top-level record schema"); + } } } diff --git a/arrow-avro/src/writer/block.rs b/arrow-avro/src/writer/block.rs new file mode 100644 index 000000000000..9a9a542f1ce7 --- /dev/null +++ b/arrow-avro/src/writer/block.rs @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Manages Avro data blocks (accumulating row data, compressing, flushing to disk). + +use std::io::Write; + +use arrow_schema::ArrowError; + +use crate::compression::CompressionCodec; +use crate::writer::utils::{to_arrow_io_err}; +use crate::writer::zigzag::write_zigzag_long; + +/// Handles buffering Avro-encoded rows, compressing them if needed, +/// and writing them in Avro block format. +#[derive(Debug)] +pub struct BlockEncoder { + block_buf: Vec, + block_count: usize, + max_block_size: usize, + compression: Option, + sync_marker: [u8; 16], +} + +impl BlockEncoder { + /// Create a new `BlockEncoder`. + /// + /// * `compression`: Optional compression codec + /// * `sync_marker`: The 16-byte sync marker used in the Avro container file + /// * `max_block_size`: Threshold in bytes; once `block_buf` grows past this, + /// the block is flushed to the sink. + pub fn new( + compression: Option, + sync_marker: [u8; 16], + max_block_size: usize, + ) -> Self { + Self { + block_buf: Vec::new(), + block_count: 0, + max_block_size, + compression, + sync_marker, + } + } + + /// Appends encoded bytes for a row (or partial row chunk) into the internal buffer. + pub fn append_encoded(&mut self, data: &[u8]) { + self.block_buf.extend_from_slice(data); + } + + /// Increments the row count for the next flush. + pub fn inc_count(&mut self) { + self.block_count += 1; + } + + /// Flush the current buffer if it exceeds `max_block_size`. + pub fn maybe_flush(&mut self, sink: &mut dyn Write) -> Result<(), ArrowError> { + if self.block_buf.len() >= self.max_block_size { + self.flush_block(sink)?; + } + Ok(()) + } + + /// Force a flush of the current block, if any rows are present. + pub fn flush_block(&mut self, sink: &mut dyn Write) -> Result<(), ArrowError> { + if self.block_count == 0 { + return Ok(()); + } + write_zigzag_long(self.block_count as i64, sink)?; + let payload = if let Some(codec) = self.compression { + codec.compress_block(&self.block_buf)? + } else { + self.block_buf.clone() + }; + write_zigzag_long(payload.len() as i64, sink)?; + sink.write_all(&payload) + .map_err(|e| to_arrow_io_err(e, "Writing block data"))?; + sink.write_all(&self.sync_marker) + .map_err(|e| to_arrow_io_err(e, "Writing sync marker"))?; + self.block_buf.clear(); + self.block_count = 0; + Ok(()) + } + + /// Flush any leftover data. Called when finishing or closing the file. + pub fn close(&mut self, sink: &mut dyn Write) -> Result<(), ArrowError> { + self.flush_block(sink)?; + Ok(()) + } +} diff --git a/arrow-avro/src/writer/encoder.rs b/arrow-avro/src/writer/encoder.rs new file mode 100644 index 000000000000..841a4a49fb91 --- /dev/null +++ b/arrow-avro/src/writer/encoder.rs @@ -0,0 +1,1495 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Row-based encoder logic for Avro container files. + +use arrow_array::{ + Array, BinaryArray, BooleanArray, Decimal128Array, Decimal256Array, DictionaryArray, + FixedSizeBinaryArray, FixedSizeListArray, Float32Array, Float64Array, Int32Array, + Int64Array, LargeListArray, ListArray, MapArray, PrimitiveArray, StringArray, + StructArray, TimestampMicrosecondArray, TimestampMillisecondArray, +}; +use arrow_array::builder::{Decimal128Builder, Decimal256Builder}; +use arrow_array::types::{ + Int16Type, Int32Type, Int64Type, Int8Type, IntervalMonthDayNanoType, + Time32MillisecondType, Time64MicrosecondType, +}; +use arrow_buffer::{i256, IntervalMonthDayNano}; +use arrow_schema::{ArrowError, DataType, Field, IntervalUnit, TimeUnit}; + +use crate::codec::Nullability; +use crate::writer::zigzag::write_zigzag_long; + +/// A `RecordEncoder` converts entire Arrow rows into Avro bytes by +/// calling `encode_one` on each column row. +#[derive(Debug)] +pub struct RecordEncoder { + fields: Vec, +} + +impl RecordEncoder { + /// Create a new `RecordEncoder` from an Arrow schema, + /// specifying `impala_mode=true` if we want `[ T, "null" ]` ordering + /// for nullable fields. + pub fn try_new( + schema: &arrow_schema::Schema, + impala_mode: bool, + ) -> Result { + let mut fields = Vec::with_capacity(schema.fields().len()); + for f in schema.fields() { + fields.push(Encoder::try_new(f, impala_mode)?); + } + Ok(Self { fields }) + } + + /// Encode one row from a `RecordBatch` into `out`. + pub fn encode_row( + &mut self, + batch: &arrow_array::RecordBatch, + row_idx: usize, + out: &mut Vec, + ) -> Result<(), ArrowError> { + for (col_idx, field_enc) in self.fields.iter_mut().enumerate() { + let col = batch.column(col_idx); + field_enc.encode_one(col.as_ref(), row_idx, out)?; + } + Ok(()) + } + + /// Convenience to encode a single row into a new `Vec`. + pub fn encode_row_to_vec( + &mut self, + batch: &arrow_array::RecordBatch, + row_idx: usize, + ) -> Result, ArrowError> { + let mut buf = Vec::new(); + self.encode_row(batch, row_idx, &mut buf)?; + Ok(buf) + } +} + +/// A `Encoder` is responsible for writing a single Arrow column +/// to Avro bytes. +#[derive(Debug)] +enum Encoder { + /// Primitives + Null, + Boolean, + Int32, + Int64, + Float32, + Float64, + Binary, + Utf8, + /// Complex + Record(Vec), + Enum(DictKeyEnc), + Array(Box), + Map(Box), + Fixed(usize), + /// Logical + Decimal128(usize, usize, Option), + Decimal256(usize, usize, Option), + Date32, + TimeMillis, + TimeMicros, + TimestampMillis(bool), + TimestampMicros(bool), + Uuid, + Duration, + /// For union-encoded columns. The second param is the union ordering. + /// + /// - `NullFirst` => `[ "null", T ]` => branch=0 => null, branch=1 => T + /// - `NullSecond` => `[ T, "null" ]` => branch=0 => T, branch=1 => null + Nullable(Box, Nullability), +} + +/// The integral type used for dictionary keys in an Avro enum. +#[derive(Debug, Clone, Copy)] +enum DictKeyEnc { + Int8, + Int16, + Int32, + Int64, +} + +impl Encoder { + /// Build an `Encoder` for the given Arrow `Field`. + /// + /// If `impala` is true, and the field is nullable, we produce a union + /// that uses `[ T, "null" ]` ordering (`Nullability::NullSecond`). + pub fn try_new(field: &Field, impala: bool) -> Result { + let dt = field.data_type(); + let enc = match dt { + DataType::Null => Self::Null, + DataType::Boolean => Self::Boolean, + DataType::Int8 | DataType::Int16 | DataType::Int32 => Self::Int32, + DataType::Int64 => Self::Int64, + DataType::Float32 => Self::Float32, + DataType::Float64 => Self::Float64, + DataType::Binary | DataType::LargeBinary => Self::Binary, + DataType::Utf8 | DataType::LargeUtf8 => Self::Utf8, + DataType::Struct(fields) => { + let mut child_encoders = Vec::with_capacity(fields.len()); + for child_field in fields { + child_encoders.push(Self::try_new(child_field.as_ref(), impala)?); + } + Self::Record(child_encoders) + } + DataType::Dictionary(key_dt, _value_dt) => { + let valid_key = matches!( + key_dt.as_ref(), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 + ); + if !valid_key { + Self::Utf8 + } else if let Some(sym_json_str) = field.metadata().get("avro.enum.symbols") { + let parsed: serde_json::Value = serde_json::from_str(sym_json_str) + .map_err(|e| { + ArrowError::ParseError(format!( + "Invalid JSON in avro.enum.symbols: {e}" + )) + })?; + if !parsed.is_array() { + Self::Utf8 + } else { + let key_enc = match key_dt.as_ref() { + DataType::Int8 => DictKeyEnc::Int8, + DataType::Int16 => DictKeyEnc::Int16, + DataType::Int32 => DictKeyEnc::Int32, + DataType::Int64 => DictKeyEnc::Int64, + _ => DictKeyEnc::Int32, + }; + Self::Enum(key_enc) + } + } else { + Self::Utf8 + } + } + DataType::List(child_field) | DataType::LargeList(child_field) => { + let child_enc = Self::try_new(child_field.as_ref(), impala)?; + Self::Array(Box::new(child_enc)) + } + DataType::FixedSizeList(child_field, _sz) => { + let child_enc = Self::try_new(child_field.as_ref(), impala)?; + Self::Array(Box::new(child_enc)) + } + DataType::Map(entry_field, _keys_sorted) => match entry_field.data_type() { + DataType::Struct(fs) if fs.len() == 2 => { + let val_field = &fs[1]; + let val_enc = Self::try_new(val_field, impala)?; + Self::Map(Box::new(val_enc)) + } + _ => Self::Null, + }, + DataType::FixedSizeBinary(n) => { + let md = field.metadata(); + match md.get("logicalType").map(|s| s.as_str()) { + Some("uuid") if *n == 16 => Self::Uuid, + Some("duration") if *n == 12 => Self::Duration, + _ => Self::Fixed(*n as usize), + } + } + DataType::Decimal128(p, s) => { + Self::Decimal128(*p as usize, *s as usize, Some(16)) + } + DataType::Decimal256(p, s) => { + Self::Decimal256(*p as usize, *s as usize, Some(32)) + } + DataType::Date32 => Self::Date32, + DataType::Time32(TimeUnit::Millisecond) => Self::TimeMillis, + DataType::Time64(TimeUnit::Microsecond) => Self::TimeMicros, + DataType::Timestamp(TimeUnit::Millisecond, tz_opt) => { + let is_utc = tz_opt.as_deref() == Some("+00:00"); + Self::TimestampMillis(is_utc) + } + DataType::Timestamp(TimeUnit::Microsecond, tz_opt) => { + let is_utc = tz_opt.as_deref() == Some("+00:00"); + Self::TimestampMicros(is_utc) + } + DataType::Interval(IntervalUnit::MonthDayNano) => Self::Duration, + other => { + eprintln!("WARN: unhandled Arrow type {other:?}, encoding as Null"); + Self::Null + } + }; + if field.is_nullable() && !matches!(enc, Self::Null) { + let nullability = if impala { + Nullability::NullSecond + } else { + Nullability::NullFirst + }; + Ok(Self::Nullable(Box::new(enc), nullability)) + } else { + Ok(enc) + } + } + + /// Encode a row from an Arrow array into Avro bytes. + pub fn encode_one( + &self, + array: &dyn Array, + row_idx: usize, + out: &mut Vec, + ) -> Result<(), ArrowError> { + match self { + Self::Null => Ok(()), + Self::Boolean => { + let bool_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError("Not a Boolean array".to_string()) + })?; + let val = bool_arr.value(row_idx); + out.push(val as u8); + Ok(()) + } + Self::Int32 => { + let int_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError("Not an Int32 array".to_string()) + })?; + let val = int_arr.value(row_idx); + write_zigzag_long(val as i64, out)?; + Ok(()) + } + Self::Int64 => { + let int_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError("Not an Int64 array".to_string()) + })?; + let val = int_arr.value(row_idx); + write_zigzag_long(val, out)?; + Ok(()) + } + Self::Float32 => { + let float_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError("Not a Float32 array".to_string()) + })?; + let val = float_arr.value(row_idx).to_le_bytes(); + out.extend_from_slice(&val); + Ok(()) + } + Self::Float64 => { + let float_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError("Not a Float64 array".to_string()) + })?; + let val = float_arr.value(row_idx).to_le_bytes(); + out.extend_from_slice(&val); + Ok(()) + } + Self::Binary => { + let bin_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError("Not a Binary array".to_string()) + })?; + let val = bin_arr.value(row_idx); + write_zigzag_long(val.len() as i64, out)?; + out.extend_from_slice(val); + Ok(()) + } + Self::Utf8 => { + let str_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError("Not a String array".to_string()) + })?; + let val = str_arr.value(row_idx); + write_zigzag_long(val.len() as i64, out)?; + out.extend_from_slice(val.as_bytes()); + Ok(()) + } + Self::Record(child_encoders) => { + let struct_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError("Not a Struct array".to_string()) + })?; + for (i, child_enc) in child_encoders.iter().enumerate() { + let col = struct_arr.column(i); + child_enc.encode_one(col.as_ref(), row_idx, out)?; + } + Ok(()) + } + Self::Enum(key_enc) => { + match key_enc { + DictKeyEnc::Int8 => { + let dict_arr = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ParseError( + "Not a Dictionary array".to_string(), + ) + })?; + let key_usize = dict_arr.key(row_idx).unwrap_or(0); + write_zigzag_long(key_usize as i64, out)?; + } + DictKeyEnc::Int16 => { + let dict_arr = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ParseError( + "Not a Dictionary array".to_string(), + ) + })?; + let key_usize = dict_arr.key(row_idx).unwrap_or(0); + write_zigzag_long(key_usize as i64, out)?; + } + DictKeyEnc::Int32 => { + let dict_arr = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ParseError( + "Not a Dictionary array".to_string(), + ) + })?; + let key_usize = dict_arr.key(row_idx).unwrap_or(0); + write_zigzag_long(key_usize as i64, out)?; + } + DictKeyEnc::Int64 => { + let dict_arr = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ParseError( + "Not a Dictionary array".to_string(), + ) + })?; + let key_usize = dict_arr.key(row_idx).unwrap_or(0); + write_zigzag_long(key_usize as i64, out)?; + } + } + Ok(()) + } + Self::Array(child_enc) => { + match array.data_type() { + DataType::List(_) => { + let list_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError("Not a List array".to_string()) + })?; + + let offset = list_arr.value_offsets()[row_idx] as usize; + let offset_next = list_arr.value_offsets()[row_idx + 1] as usize; + let length = offset_next - offset; + write_zigzag_long(length as i64, out)?; + let values = list_arr.values(); + for i in offset..offset_next { + child_enc.encode_one(values.as_ref(), i, out)?; + } + if length > 0 { + write_zigzag_long(0, out)?; + } + } + DataType::LargeList(_) => { + let ll_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError("Not a LargeList array".to_string()) + })?; + let offset = ll_arr.value_offsets()[row_idx] as usize; + let offset_next = ll_arr.value_offsets()[row_idx + 1] as usize; + let length = offset_next - offset; + write_zigzag_long(length as i64, out)?; + let values = ll_arr.values(); + for i in offset..offset_next { + child_enc.encode_one(values.as_ref(), i, out)?; + } + if length > 0 { + write_zigzag_long(0, out)?; + } + } + DataType::FixedSizeList(_, size) => { + let fsl_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError( + "Not a FixedSizeList array".to_string(), + ) + })?; + let length = *size; + let start = row_idx * *size as usize; + let end = start + *size as usize; + write_zigzag_long(*size as i64, out)?; + let values = fsl_arr.values(); + for i in start..end { + child_enc.encode_one(values.as_ref(), i, out)?; + } + // Avro array termination + if length > 0 { + write_zigzag_long(0, out)?; + } + } + dt => { + return Err(ArrowError::NotYetImplemented(format!( + "Array writer for arrow type {dt:?} not supported" + ))); + } + } + Ok(()) + } + Self::Map(val_enc) => { + let map_array = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError("Not a Map array".to_string()) + })?; + let offset = map_array.value_offsets()[row_idx] as usize; + let offset_next = map_array.value_offsets()[row_idx + 1] as usize; + let length = offset_next - offset; + write_zigzag_long(length as i64, out)?; + let entries_struct = map_array.entries(); + let key_arr = entries_struct + .column(0) + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError("Map keys not a String array".to_string()) + })?; + let val_arr = entries_struct.column(1); + for i in offset..offset_next { + let key_val = key_arr.value(i); + write_zigzag_long(key_val.len() as i64, out)?; + out.extend_from_slice(key_val.as_bytes()); + val_enc.encode_one(val_arr.as_ref(), i, out)?; + } + if length > 0 { + write_zigzag_long(0, out)?; + } + Ok(()) + } + Self::Fixed(n) => { + let fsb_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError( + "Not a FixedSizeBinary array".to_string(), + ) + })?; + let val = fsb_arr.value(row_idx); + if val.len() != *n { + return Err(ArrowError::InvalidArgumentError(format!( + "FixedSizeBinary length mismatch: expected {n}, got {}", + val.len() + ))); + } + out.extend_from_slice(val); + Ok(()) + } + Self::Decimal128(_p, _s, size_opt) => { + let dec_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError( + "Not a Decimal128 array".to_string(), + ) + })?; + let value = dec_arr.value(row_idx); + let be_bytes = value.to_be_bytes(); + let sign_byte = if value >= 0 { 0x00 } else { 0xFF }; + // Trim sign extension + let mut first_non_extend = 0usize; + while first_non_extend + 1 < be_bytes.len() + && be_bytes[first_non_extend] == sign_byte + && (be_bytes[first_non_extend + 1] & 0x80) == (sign_byte & 0x80) + { + first_non_extend += 1; + } + let trimmed = &be_bytes[first_non_extend..]; + if let Some(sz) = size_opt { + // fixed-size decimal + if trimmed.len() > *sz { + return Err(ArrowError::InvalidArgumentError( + "Decimal128 value doesn't fit fixed size".to_string(), + )); + } + let mut buf = vec![sign_byte; *sz]; + let start = sz - trimmed.len(); + buf[start..].copy_from_slice(trimmed); + out.extend_from_slice(&buf); + } else { + // variable-size decimal => length-prefix + write_zigzag_long(trimmed.len() as i64, out)?; + out.extend_from_slice(trimmed); + } + Ok(()) + } + Self::Decimal256(_p, _s, size_opt) => { + let dec_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError( + "Not a Decimal256 array".to_string(), + ) + })?; + let val_i256 = dec_arr.value(row_idx); + // Convert to big-endian + let mut be_bytes = val_i256.to_le_bytes(); + be_bytes.reverse(); + let sign_byte = if val_i256.is_negative() { 0xFF } else { 0x00 }; + // Trim sign extension + let mut first_non_extend = 0usize; + while first_non_extend + 1 < be_bytes.len() + && be_bytes[first_non_extend] == sign_byte + && (be_bytes[first_non_extend + 1] & 0x80) + == (sign_byte & 0x80) + { + first_non_extend += 1; + } + let trimmed = &be_bytes[first_non_extend..]; + if let Some(sz) = size_opt { + // fixed-size + if trimmed.len() > *sz { + return Err(ArrowError::InvalidArgumentError( + "Decimal256 value doesn't fit fixed size".to_string(), + )); + } + let mut buf = vec![sign_byte; *sz]; + let start = sz - trimmed.len(); + buf[start..].copy_from_slice(trimmed); + out.extend_from_slice(&buf); + } else { + // variable-size => length + data + write_zigzag_long(trimmed.len() as i64, out)?; + out.extend_from_slice(trimmed); + } + Ok(()) + } + Self::Date32 => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError("Not a Date32 array".to_string()) + })?; + let val = arr.value(row_idx); + write_zigzag_long(val as i64, out)?; + Ok(()) + } + Self::TimeMillis => { + let arr = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ParseError( + "Not a Time32(Millis) array".to_string(), + ) + })?; + let val = arr.value(row_idx); + write_zigzag_long(val as i64, out)?; + Ok(()) + } + Self::TimeMicros => { + let arr = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ParseError( + "Not a Time64(Micros) array".to_string(), + ) + })?; + let val = arr.value(row_idx); + write_zigzag_long(val, out)?; + Ok(()) + } + Self::TimestampMillis(_is_utc) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError( + "Not a Timestamp(Millis) array".to_string(), + ) + })?; + let val = arr.value(row_idx); + write_zigzag_long(val, out)?; + Ok(()) + } + Self::TimestampMicros(_is_utc) => { + let arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError( + "Not a Timestamp(Micros) array".to_string(), + ) + })?; + let val = arr.value(row_idx); + write_zigzag_long(val, out)?; + Ok(()) + } + Self::Uuid => { + let fsb_arr = array + .as_any() + .downcast_ref::() + .ok_or_else(|| { + ArrowError::ParseError( + "Not a FixedSizeBinary(16) array".to_string(), + ) + })?; + let val = fsb_arr.value(row_idx); + if val.len() != 16 { + return Err(ArrowError::InvalidArgumentError(format!( + "UUID field must be 16 bytes, got {}", + val.len() + ))); + } + out.extend_from_slice(val); + Ok(()) + } + Self::Duration => { + let arr = array + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + ArrowError::ParseError( + "Not IntervalMonthDayNano array".to_string(), + ) + })?; + let val: IntervalMonthDayNano = arr.value(row_idx); + let months = val.months; + let days = val.days; + // Convert total nanoseconds to milliseconds + let ms = (val.nanoseconds / 1_000_000) as i32; + out.extend_from_slice(&months.to_le_bytes()); + out.extend_from_slice(&days.to_le_bytes()); + out.extend_from_slice(&ms.to_le_bytes()); + Ok(()) + } + Self::Nullable(inner, nb) => { + match nb { + Nullability::NullFirst => { + if array.is_null(row_idx) { + write_zigzag_long(0, out)?; + } else { + write_zigzag_long(1, out)?; + inner.encode_one(array, row_idx, out)?; + } + } + Nullability::NullSecond => { + if array.is_null(row_idx) { + write_zigzag_long(1, out)?; + } else { + write_zigzag_long(0, out)?; + inner.encode_one(array, row_idx, out)?; + } + } + } + Ok(()) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{BinaryArray, BooleanArray, Date32Array, FixedSizeBinaryArray, Float32Array, Float64Array, Int32Array, Int64Array, Int8Array, RecordBatch, StringArray, Time32MillisecondArray, Time64MicrosecondArray}; + use arrow_schema::{ArrowError, DataType, Field, FieldRef, Fields, Schema as ArrowSchema}; + use std::sync::Arc; + use arrow_array::builder::{Int32Builder, MapBuilder, StringBuilder}; + use arrow_buffer::{OffsetBuffer, ScalarBuffer, Buffer as ArrowBuffer}; + use arrow_data::ArrayData; + + fn decode_avro_boolean(data: &[u8]) -> bool { + assert_eq!(data.len(), 1, "boolean => exactly 1 byte"); + data[0] != 0 + } + + fn decode_varint_u64(data: &[u8]) -> (u64, usize) { + let mut val: u64 = 0; + let mut shift = 0; + let mut i = 0; + for b in data { + let lower = (b & 0x7F) as u64; + val |= lower << shift; + shift += 7; + i += 1; + if b & 0x80 == 0 { + break; + } + } + (val, i) + } + + fn decode_zigzag_long(data: &[u8]) -> (i64, usize) { + let (zz, consumed) = decode_varint_u64(data); + let val = ((zz >> 1) as i64) ^ -((zz & 1) as i64); + (val, consumed) + } + + fn decode_avro_f32(data: &[u8]) -> f32 { + assert_eq!(data.len(), 4); + f32::from_le_bytes(data.try_into().unwrap()) + } + + fn decode_avro_f64(data: &[u8]) -> f64 { + assert_eq!(data.len(), 8); + f64::from_le_bytes(data.try_into().unwrap()) + } + + fn decode_avro_bytes(data: &[u8]) -> Vec { + let (len, consumed) = decode_zigzag_long(data); + let len = len as usize; + data[consumed..consumed + len].to_vec() + } + + fn decode_avro_string(data: &[u8]) -> (String, usize) { + let (len, used) = decode_zigzag_long(data); + let len_usize = len as usize; + let start = used; + let end = used + len_usize; + let bytes = &data[start..end]; + let s = std::str::from_utf8(bytes).unwrap().to_string(); + (s, end) + } + + #[test] + fn test_encode_date32() -> Result<(), ArrowError> { + let date_arr = Date32Array::from(vec![Some(10), Some(-3)]); + let field = Field::new("date_col", DataType::Date32, false); + let schema = ArrowSchema::new(vec![field.clone()]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(date_arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + encoder.encode_row(&batch, 1, &mut out)?; + let mut offset = 0; + let (val0, used0) = decode_zigzag_long(&out[offset..]); + offset += used0; + assert_eq!(val0, 10, "Expected day=10 in row0"); + let (val1, used1) = decode_zigzag_long(&out[offset..]); + offset += used1; + assert_eq!(val1, -3, "Expected day=-3 in row1"); + assert_eq!(offset, out.len(), "Consumed all bytes"); + Ok(()) + } + + #[test] + fn test_encode_time_millis() -> Result<(), ArrowError> { + let arr = Time32MillisecondArray::from(vec![Some(1234), Some(99999)]); + let field = Field::new("time_ms", DataType::Time32(TimeUnit::Millisecond), false); + let schema = ArrowSchema::new(vec![field.clone()]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + encoder.encode_row(&batch, 1, &mut out)?; + let mut offset = 0; + let (row0_val, used0) = decode_zigzag_long(&out[offset..]); + offset += used0; + assert_eq!(row0_val, 1234); + let (row1_val, used1) = decode_zigzag_long(&out[offset..]); + offset += used1; + assert_eq!(row1_val, 99999); + assert_eq!(offset, out.len()); + Ok(()) + } + + #[test] + fn test_encode_time_micros() -> Result<(), ArrowError> { + let arr = Time64MicrosecondArray::from(vec![Some(50_000), Some(1_000_000)]); + let field = Field::new("time_us", DataType::Time64(TimeUnit::Microsecond), false); + let schema = ArrowSchema::new(vec![field.clone()]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + encoder.encode_row(&batch, 1, &mut out)?; + let mut offset = 0; + let (row0_val, used0) = decode_zigzag_long(&out[offset..]); + offset += used0; + assert_eq!(row0_val, 50_000); + let (row1_val, used1) = decode_zigzag_long(&out[offset..]); + offset += used1; + assert_eq!(row1_val, 1_000_000); + assert_eq!(offset, out.len()); + Ok(()) + } + + #[test] + fn test_encode_timestamp_millis_utc() -> Result<(), ArrowError> { + let data_type = DataType::Timestamp(TimeUnit::Millisecond, Some(Arc::from("+00:00".to_string()))); + let values = [1000i64, 86400000i64]; + let buf = ArrowBuffer::from_slice_ref(&values); + let array_data = ArrayData::builder(data_type.clone()) + .len(2) + .add_buffer(buf) + .build()?; + let arr = TimestampMillisecondArray::from(array_data); + let field = Field::new("ts_ms_utc", data_type, false); + let schema = ArrowSchema::new(vec![field.clone()]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + encoder.encode_row(&batch, 1, &mut out)?; + let mut offset = 0; + let (row0_val, used0) = decode_zigzag_long(&out[offset..]); + offset += used0; + assert_eq!(row0_val, 1000); + let (row1_val, used1) = decode_zigzag_long(&out[offset..]); + offset += used1; + assert_eq!(row1_val, 86400000); + assert_eq!(offset, out.len()); + Ok(()) + } + + #[test] + fn test_encode_timestamp_millis_local() -> Result<(), ArrowError> { + let arr = TimestampMillisecondArray::from(vec![Some(5000), Some(1577836800000)]); + let field = Field::new( + "ts_ms_local", + DataType::Timestamp(TimeUnit::Millisecond, None), + false, + ); + let schema = ArrowSchema::new(vec![field.clone()]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + encoder.encode_row(&batch, 1, &mut out)?; + let mut offset = 0; + let (row0_val, used0) = decode_zigzag_long(&out[offset..]); + offset += used0; + assert_eq!(row0_val, 5000); + let (row1_val, used1) = decode_zigzag_long(&out[offset..]); + offset += used1; + assert_eq!(row1_val, 1_577_836_800_000); + assert_eq!(offset, out.len()); + Ok(()) + } + + #[test] + fn test_encode_timestamp_micros_utc() -> Result<(), ArrowError> { + let data_type = DataType::Timestamp(TimeUnit::Microsecond, Some("+00:00".into())); + let values = [123456i64, 2_000_000i64]; // 2 rows + let buf = ArrowBuffer::from_slice_ref(&values); + let array_data = ArrayData::builder(data_type.clone()) + .len(2) + .add_buffer(buf) + .build()?; + let arr = TimestampMicrosecondArray::from(array_data); + let field = Field::new("ts_us_utc", data_type, false); + let schema = ArrowSchema::new(vec![field]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + encoder.encode_row(&batch, 1, &mut out)?; + let mut offset = 0; + let (row0_val, used0) = decode_zigzag_long(&out[offset..]); + offset += used0; + assert_eq!(row0_val, 123456, "Expected microseconds=123456 for row0"); + let (row1_val, used1) = decode_zigzag_long(&out[offset..]); + offset += used1; + assert_eq!(row1_val, 2_000_000, "Expected microseconds=2_000_000 for row1"); + assert_eq!(offset, out.len()); + Ok(()) + } + + #[test] + fn test_encode_timestamp_micros_local() -> Result<(), ArrowError> { + let arr = TimestampMicrosecondArray::from(vec![Some(42), Some(9999999999999)]); + let field = Field::new( + "ts_us_local", + DataType::Timestamp(TimeUnit::Microsecond, None), + false, + ); + let schema = ArrowSchema::new(vec![field.clone()]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + encoder.encode_row(&batch, 1, &mut out)?; + let mut offset = 0; + let (row0_val, used0) = decode_zigzag_long(&out[offset..]); + offset += used0; + assert_eq!(row0_val, 42); + let (row1_val, used1) = decode_zigzag_long(&out[offset..]); + offset += used1; + assert_eq!(row1_val, 9_999_999_999_999); + assert_eq!(offset, out.len()); + Ok(()) + } + + #[test] + fn test_encode_uuid() -> Result<(), ArrowError> { + let val0 = b"1234567890ABCDEF"; + let val1 = b"abcdefghijklmnop"; + let arr = FixedSizeBinaryArray::from(vec![Some(&val0[..]), Some(&val1[..])]); + let field = Field::new("uuid_col", DataType::FixedSizeBinary(16), false) + .with_metadata(std::collections::HashMap::from([ + ("logicalType".to_string(), "uuid".to_string()), + ])); + let schema = ArrowSchema::new(vec![field.clone()]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + encoder.encode_row(&batch, 1, &mut out)?; + assert_eq!(out.len(), 32, "2 rows * 16 bytes each = 32 total"); + assert_eq!(&out[0..16], b"1234567890ABCDEF"); + assert_eq!(&out[16..32], b"abcdefghijklmnop"); + Ok(()) + } + + #[test] + fn test_encode_duration_interval() { + let data = vec![Some(IntervalMonthDayNano::new(2, 3, 1_000_000))]; + let arr = PrimitiveArray::::from(data); + let field = Field::new( + "duration_test", + DataType::Interval(IntervalUnit::MonthDayNano), + false, + ); + let schema = ArrowSchema::new(vec![field.clone()]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(arr)]) + .unwrap(); + let mut encoder = RecordEncoder::try_new(&schema, false).unwrap(); + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out).unwrap(); + assert_eq!( + &out, + &[ + 0x02, 0x00, 0x00, 0x00, + 0x03, 0x00, 0x00, 0x00, + 0x01, 0x00, 0x00, 0x00 + ] + ); + } + + #[test] + fn test_encode_enum_dictionary_int8_nullable() -> Result<(), ArrowError> { + use std::collections::HashMap; + let keys = Int8Array::from(vec![Some(1i8), Some(0), None, Some(2), Some(2)]); + let values = StringArray::from(vec!["GREEN", "RED", "BLUE"]); + let dict_array = DictionaryArray::try_new(keys, Arc::new(values))?; + let mut md = HashMap::new(); + md.insert( + "avro.enum.symbols".to_string(), + "[\"GREEN\",\"RED\",\"BLUE\"]".to_string() + ); + let field = Field::new("enum_col", dict_array.data_type().clone(), true) + .with_metadata(md); + let schema = ArrowSchema::new(vec![field]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(dict_array)], + )?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + for row_idx in 0..5 { + encoder.encode_row(&batch, row_idx, &mut out)?; + } + let mut offset = 0; + let (branch0, used0) = decode_zigzag_long(&out[offset..]); + offset += used0; + assert_eq!(branch0, 1, "Expected branch=1 => non-null"); + let (val0, used_val0) = decode_zigzag_long(&out[offset..]); + offset += used_val0; + assert_eq!(val0, 1); + let (branch1, used1) = decode_zigzag_long(&out[offset..]); + offset += used1; + assert_eq!(branch1, 1); + let (val1, used_val1) = decode_zigzag_long(&out[offset..]); + offset += used_val1; + assert_eq!(val1, 0); + let (branch2, used2) = decode_zigzag_long(&out[offset..]); + offset += used2; + assert_eq!(branch2, 0, "Expected branch=0 => null row"); + let (branch3, used3) = decode_zigzag_long(&out[offset..]); + offset += used3; + assert_eq!(branch3, 1); + let (val3, used_val3) = decode_zigzag_long(&out[offset..]); + offset += used_val3; + assert_eq!(val3, 2); + let (branch4, used4) = decode_zigzag_long(&out[offset..]); + offset += used4; + assert_eq!(branch4, 1); + let (val4, used_val4) = decode_zigzag_long(&out[offset..]); + offset += used_val4; + assert_eq!(val4, 2); + assert_eq!(offset, out.len(), "All encoded data consumed"); + Ok(()) + } + + #[test] + fn test_encode_enum_dictionary_int64_in_range() -> Result<(), ArrowError> { + use std::collections::HashMap; + let keys = Int64Array::from(vec![Some(0), Some(2)]); + let values = StringArray::from(vec!["FISH", "DOG", "CAT"]); + let dict_array = DictionaryArray::try_new(keys, Arc::new(values))?; + let mut md = HashMap::new(); + md.insert( + "avro.enum.symbols".to_string(), + "[\"FISH\",\"DOG\",\"CAT\"]".to_string() + ); + let field = Field::new("enum_col", dict_array.data_type().clone(), false) + .with_metadata(md); + let schema = ArrowSchema::new(vec![field]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(dict_array)], + )?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + encoder.encode_row(&batch, 1, &mut out)?; + let mut offset = 0; + let (row0_val, used0) = decode_zigzag_long(&out[offset..]); + offset += used0; + assert_eq!(row0_val, 0, "Expected ordinal=0 in row 0"); + let (row1_val, used1) = decode_zigzag_long(&out[offset..]); + offset += used1; + assert_eq!(row1_val, 2, "Expected ordinal=2 in row 1"); + assert_eq!(offset, out.len(), "All encoded data should be consumed"); + Ok(()) + } + + #[test] + fn test_map_field_encoder() { + let key_builder = StringBuilder::new(); + let value_builder = Int32Builder::new(); + let mut map_builder = MapBuilder::new(None, key_builder, value_builder); + map_builder.keys().append_value("apple"); + map_builder.values().append_value(10); + map_builder.keys().append_value("banana"); + map_builder.values().append_value(20); + let _ = map_builder.append(true); + map_builder.keys().append_value("hello"); + map_builder.values().append_value(42); + let _ = map_builder.append(true); + let map_array = map_builder.finish(); + assert_eq!(map_array.len(), 2); + let field = Field::new("my_map", map_array.data_type().clone(), true); + let map_encoder = Encoder::try_new(&field, false).expect("Failed to build FieldEncoder"); + let mut encoded = Vec::new(); + map_encoder.encode_one(&map_array, 0, &mut encoded).unwrap(); + map_encoder.encode_one(&map_array, 1, &mut encoded).unwrap(); + assert!(!encoded.is_empty()); + } + + #[test] + fn test_map_encoder_null() { + let key_builder = StringBuilder::new(); + let value_builder = Int32Builder::new(); + let mut map_builder = MapBuilder::new(None, key_builder, value_builder); + let _ = map_builder.append(false); + let map_array = map_builder.finish(); + assert_eq!(map_array.len(), 1, "Expected 1 row"); + assert!(map_array.is_null(0), "Row 0 should be null"); + let field = Field::new("nullable_map", map_array.data_type().clone(), true); + let enc = Encoder::try_new(&field, false).unwrap(); + let mut buf = Vec::new(); + enc.encode_one(&map_array, 0, &mut buf).unwrap(); + assert_eq!(buf, vec![0x00], "Expected union=0 => null for a null map"); + } + + #[test] + fn test_encode_largelist_of_strings() -> Result<(), ArrowError> { + let child_field = Field::new("str_item", DataType::Utf8, false); + let ll_type = DataType::LargeList(Arc::new(child_field.clone())); + let offsets = vec![0i64, 2, 3]; + let child_vals = StringArray::from(vec!["hello", "arrow", "avro"]); + let ll_arr = LargeListArray::new( + FieldRef::from(child_field), + OffsetBuffer::new(ScalarBuffer::from(offsets)), + Arc::new(child_vals), + None, + ); + let schema = ArrowSchema::new(vec![Field::new("ll_col", ll_type, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(ll_arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + encoder.encode_row(&batch, 1, &mut out)?; + let mut offset = 0; + let (length0, used0) = decode_zigzag_long(&out[offset..]); + offset += used0; + assert_eq!(length0, 2); + let (s0, used_s0) = decode_avro_string(&out[offset..]); + offset += used_s0; + assert_eq!(s0, "hello"); + let (s1, used_s1) = decode_avro_string(&out[offset..]); + offset += used_s1; + assert_eq!(s1, "arrow"); + let (block_term0, used_bt0) = decode_zigzag_long(&out[offset..]); + offset += used_bt0; + assert_eq!(block_term0, 0); + let (length1, used1) = decode_zigzag_long(&out[offset..]); + offset += used1; + assert_eq!(length1, 1); + let (s2, used_s2) = decode_avro_string(&out[offset..]); + offset += used_s2; + assert_eq!(s2, "avro"); + let (block_term1, used_bt1) = decode_zigzag_long(&out[offset..]); + offset += used_bt1; + assert_eq!(block_term1, 0); + assert_eq!(offset, out.len()); + Ok(()) + } + + #[test] + fn test_encode_list_of_int32() -> Result<(), ArrowError> { + let child_field = Field::new("items", DataType::Int32, false); + let list_type = DataType::List(Arc::new(child_field.clone())); + let offsets = vec![0i32, 2, 2]; + let values = Int32Array::from(vec![10, 20]); + let list_arr = ListArray::new( + Arc::new(child_field), + OffsetBuffer::new(ScalarBuffer::from(offsets)), + Arc::new(values), + None, + ); + let schema = ArrowSchema::new(vec![Field::new("list_col", list_type, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(list_arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + encoder.encode_row(&batch, 1, &mut out)?; + let mut offset = 0; + let (length0, used0) = decode_zigzag_long(&out[offset..]); + offset += used0; + assert_eq!(length0, 2); + let (val0, used_v0) = decode_zigzag_long(&out[offset..]); + offset += used_v0; + assert_eq!(val0, 10); + let (val1, used_v1) = decode_zigzag_long(&out[offset..]); + offset += used_v1; + assert_eq!(val1, 20); + let (blk_term0, used_term0) = decode_zigzag_long(&out[offset..]); + offset += used_term0; + assert_eq!(blk_term0, 0); + let (length1, used1) = decode_zigzag_long(&out[offset..]); + offset += used1; + assert_eq!(length1, 0); + let (blk_term1, used_term1) = decode_zigzag_long(&out[offset..]); + offset += used_term1; + assert_eq!(blk_term1, 0); + assert_eq!(offset, out.len()); + Ok(()) + } + + #[test] + fn test_encode_fixedsizelist_of_bools() -> Result<(), ArrowError> { + let size = 3; + let child_data = BooleanArray::from(vec![ + Some(true), Some(false), Some(true), + Some(false), Some(false), Some(false), + ]); + let child_field = Arc::new(Field::new("fsl_item", DataType::Boolean, false)); + let fsl_arr = FixedSizeListArray::new( + child_field.clone(), + size, + Arc::new(child_data), + None, + ); + let top_level_field = Field::new( + "fsl_col", + DataType::FixedSizeList(child_field.clone(), size), + false, + ); + let schema = ArrowSchema::new(vec![top_level_field]); + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(fsl_arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + encoder.encode_row(&batch, 1, &mut out)?; + let mut offset = 0; + let (arr_len0, used0) = decode_zigzag_long(&out[offset..]); + offset += used0; + assert_eq!(arr_len0, 3); + assert_eq!(out[offset], 1); + offset += 1; + assert_eq!(out[offset], 0); + offset += 1; + assert_eq!(out[offset], 1); + offset += 1; + let (blk_term0, used_bt0) = decode_zigzag_long(&out[offset..]); + offset += used_bt0; + assert_eq!(blk_term0, 0); + let (arr_len1, used1) = decode_zigzag_long(&out[offset..]); + offset += used1; + assert_eq!(arr_len1, 3); + for _ in 0..3 { + assert_eq!(out[offset], 0); + offset += 1; + } + let (blk_term1, used_bt1) = decode_zigzag_long(&out[offset..]); + offset += used_bt1; + assert_eq!(blk_term1, 0); + assert_eq!(offset, out.len(), "Consumed all bytes"); + Ok(()) + } + + #[test] + fn test_encode_nested_struct() -> Result<(), ArrowError> { + let child_a = Field::new("a", DataType::Int32, false); + let child_b = Field::new("b", DataType::Boolean, true); + let struct_type = DataType::Struct(vec![ + Field::new("a", DataType::Int32, false), + Field::new("b", DataType::Boolean, true), + ].into()); + let a_data = Arc::new(Int32Array::from(vec![100, 200])) as Arc; + let b_data = Arc::new(BooleanArray::from(vec![Some(true), None])) as Arc; + let struct_array = StructArray::new( + Fields::from(vec![child_a, child_b]), + vec![a_data, b_data], + None, + ); + let schema = ArrowSchema::new(vec![Field::new("nested_rec", struct_type, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(struct_array)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + encoder.encode_row(&batch, 1, &mut out)?; + let mut offset = 0; + let (val_a_0, consumed) = decode_zigzag_long(&out[offset..]); + offset += consumed; + assert_eq!(val_a_0, 100); + let (branch, used) = decode_zigzag_long(&out[offset..]); + offset += used; + assert_eq!(branch, 1); + let b0 = decode_avro_boolean(&out[offset..offset + 1]); + offset += 1; + assert_eq!(b0, true); + let (val_a_1, consumed) = decode_zigzag_long(&out[offset..]); + offset += consumed; + assert_eq!(val_a_1, 200); + let (branch1, used) = decode_zigzag_long(&out[offset..]); + offset += used; + assert_eq!(branch1, 0); + assert_eq!(offset, out.len(), "All bytes should be consumed"); + Ok(()) + } + + #[test] + fn test_encode_decimal128_fixed() -> Result<(), ArrowError> { + let mut builder = Decimal128Builder::new() + .with_precision_and_scale(10, 2)?; + builder.append_value(12345); + let decimal_arr = builder.finish(); + let field = Field::new("dec128", DataType::Decimal128(10, 2), false); + let schema = ArrowSchema::new(vec![field]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(decimal_arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + assert_eq!(out.len(), 16); + Ok(()) + } + + #[test] + fn test_encode_decimal256_fixed() -> Result<(), ArrowError> { + let mut builder = Decimal256Builder::new() + .with_precision_and_scale(12, 2)?; + builder.append_value(i256::from_i128(99900)); + let decimal_arr = builder.finish(); + let field = Field::new("dec256", DataType::Decimal256(12, 2), false); + let schema = ArrowSchema::new(vec![field]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(decimal_arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + assert_eq!(out.len(), 32); + Ok(()) + } + + #[test] + fn test_encode_decimal128_as_bytes() -> Result<(), ArrowError> { + let mut builder = Decimal128Builder::new() + .with_precision_and_scale(10, 2)?; + builder.append_value(-250); + let decimal_arr = builder.finish(); + let field = Field::new("dec_col", DataType::Decimal128(10, 2), false); + let schema = ArrowSchema::new(vec![field]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(decimal_arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + if let Encoder::Decimal128(_, _, ref mut size_opt) = encoder.fields[0] { + *size_opt = None; + } + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out)?; + let (llen, used) = decode_zigzag_long(&out); + let data = &out[used..]; + assert_eq!(llen as usize, data.len()); + Ok(()) + } + + #[test] + fn test_encode_boolean() { + let bool_arr = BooleanArray::from(vec![true]); + let schema = ArrowSchema::new(vec![Field::new("bool_col", DataType::Boolean, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(bool_arr)]) + .unwrap(); + let mut encoder = RecordEncoder::try_new(&schema, false).unwrap(); + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out).unwrap(); + let got = decode_avro_boolean(&out); + assert_eq!(got, true); + } + + #[test] + fn test_encode_int32() { + let int_arr = Int32Array::from(vec![42]); + let schema = ArrowSchema::new(vec![Field::new("int_col", DataType::Int32, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(int_arr)]) + .unwrap(); + let mut encoder = RecordEncoder::try_new(&schema, false).unwrap(); + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out).unwrap(); + let (decoded, consumed) = decode_zigzag_long(&out); + assert_eq!(decoded, 42); + assert_eq!(consumed, out.len()); + } + + #[test] + fn test_encode_int64() { + let long_arr = Int64Array::from(vec![-1_i64]); + let schema = ArrowSchema::new(vec![Field::new("long_col", DataType::Int64, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(long_arr)]) + .unwrap(); + let mut encoder = RecordEncoder::try_new(&schema, false).unwrap(); + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out).unwrap(); + let (decoded, consumed) = decode_zigzag_long(&out); + assert_eq!(decoded, -1); + assert_eq!(consumed, out.len()); + } + + #[test] + fn test_encode_float32() { + let float_arr = Float32Array::from(vec![3.14_f32]); + let schema = ArrowSchema::new(vec![Field::new("float_col", DataType::Float32, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(float_arr)]) + .unwrap(); + let mut encoder = RecordEncoder::try_new(&schema, false).unwrap(); + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out).unwrap(); + let got = decode_avro_f32(&out); + assert!((got - 3.14).abs() < 1e-7); + } + + #[test] + fn test_encode_float64() { + let double_arr = Float64Array::from(vec![std::f64::consts::E]); + let schema = ArrowSchema::new(vec![Field::new("double_col", DataType::Float64, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(double_arr)]) + .unwrap(); + let mut encoder = RecordEncoder::try_new(&schema, false).unwrap(); + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out).unwrap(); + let got = decode_avro_f64(&out); + assert!((got - std::f64::consts::E).abs() < 1e-14); + } + + #[test] + fn test_encode_binary() { + let bin_arr = BinaryArray::from(vec![Some(&b"hello"[..])]); + let schema = ArrowSchema::new(vec![Field::new("bin_col", DataType::Binary, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(bin_arr)]) + .unwrap(); + let mut encoder = RecordEncoder::try_new(&schema, false).unwrap(); + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out).unwrap(); + let got = decode_avro_bytes(&out); + assert_eq!(got, b"hello"); + } + + #[test] + fn test_encode_utf8() { + let str_arr = StringArray::from(vec![Some("Avro!")]); + let schema = ArrowSchema::new(vec![Field::new("str_col", DataType::Utf8, false)]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(str_arr)]) + .unwrap(); + let mut encoder = RecordEncoder::try_new(&schema, false).unwrap(); + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out).unwrap(); + let got = decode_avro_string(&out); + assert_eq!(got.0, "Avro!"); + } + + #[test] + fn test_encode_fixed() { + let arr = FixedSizeBinaryArray::from(vec![ + Some(&b"ABCDE"[..]), + ]); + let schema = ArrowSchema::new(vec![Field::new( + "fixed_col", + DataType::FixedSizeBinary(5), + false, + )]); + let batch = RecordBatch::try_new( + Arc::new(schema.clone()), + vec![Arc::new(arr)], + ).unwrap(); + let mut encoder = RecordEncoder::try_new(&schema, false).unwrap(); + let mut out = Vec::new(); + encoder.encode_row(&batch, 0, &mut out).unwrap(); + assert_eq!(out.len(), 5); + assert_eq!(&out, b"ABCDE"); + } + + #[test] + fn test_encode_nullable() -> Result<(), ArrowError> { + let str_arr = StringArray::from(vec![None, Some("non-null here")]); + let field = Field::new("maybe_str", DataType::Utf8, true); + let schema = ArrowSchema::new(vec![field]); + let batch = RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(str_arr)])?; + let mut encoder = RecordEncoder::try_new(&schema, false)?; + let mut out0 = Vec::new(); + encoder.encode_row(&batch, 0, &mut out0)?; + let (branch, consumed) = decode_zigzag_long(&out0); + assert_eq!(branch, 0, "Expected union branch=0 => null"); + assert_eq!(consumed, out0.len(), "No payload after branch=0"); + let mut out1 = Vec::new(); + encoder.encode_row(&batch, 1, &mut out1)?; + let (branch, used) = decode_zigzag_long(&out1); + assert_eq!(branch, 1, "Expected branch=1 => string"); + let got_str = decode_avro_string(&out1[used..]); + assert_eq!(got_str.0, "non-null here"); + Ok(()) + } +} diff --git a/arrow-avro/src/writer/header.rs b/arrow-avro/src/writer/header.rs new file mode 100644 index 000000000000..773d0c519255 --- /dev/null +++ b/arrow-avro/src/writer/header.rs @@ -0,0 +1,94 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Avro file header logic, including magic bytes, metadata map, sync marker. + +use std::io::Write; + +use arrow_schema::{ArrowError, Schema}; +use crate::codec::{SchemaBuilder}; +use crate::compression::CompressionCodec; +use crate::writer::utils::{ + to_arrow_io_err, write_bytes, write_string, +}; +use crate::writer::zigzag::write_zigzag_long; + +/// Holds information needed to write the Avro container-file header. +#[derive(Debug)] +pub struct AvroHeader { + /// The Arrow schema (JSON-serialized into metadata) + pub arrow_schema: Schema, + /// The Avro schema (JSON-serialized into metadata) + pub impala_mode: bool, + /// Optional compression codec + pub compression: Option, + /// Additional metadata key-value pairs + pub extra_meta: Vec<(String, Vec)>, + /// The 16-byte sync marker used to separate file blocks + pub sync_marker: [u8; 16], +} + +impl AvroHeader { + /// Writes the Avro container file header + pub fn write_header(&self, sink: &mut dyn Write) -> Result<(), ArrowError> { + sink.write_all(b"Obj\x01") + .map_err(|e| to_arrow_io_err(e, "Writing Avro magic"))?; + let mut meta_entries = self.extra_meta.len() + 1; + if self.compression.is_some() { + meta_entries += 1; + } + write_zigzag_long(meta_entries as i64, sink)?; + //let schema = make_schema(&self.arrow_schema, &self.impala_mode)?; + let record_name = self.arrow_schema + .metadata() + .get("avro.record.name") + .cloned() + .unwrap_or_else(|| "topLevelRecord".to_string()); + let record_namespace = self.arrow_schema + .metadata() + .get("avro.record.namespace").map(|s| s.as_str()); + + let schema = SchemaBuilder::from(&self.arrow_schema) + .with_impala_mode(&self.impala_mode) + .with_name(record_name.as_str()) + .with_namespace(record_namespace) + .finish()?; + let schema_json = serde_json::to_vec(&schema) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + write_string("avro.schema", sink)?; + write_bytes(&schema_json, sink)?; + if let Some(codec) = self.compression { + write_string(crate::compression::CODEC_METADATA_KEY, sink)?; + let codec_str: &[u8] = match codec { + CompressionCodec::Snappy => b"snappy", + CompressionCodec::Deflate => b"deflate", + CompressionCodec::ZStandard => b"zstandard", + CompressionCodec::Bzip2 => b"bzip2", + CompressionCodec::Xz => b"xz", + }; + write_bytes(codec_str, sink)?; + } + for (k, v) in &self.extra_meta { + write_string(k, sink)?; + write_bytes(v, sink)?; + } + write_zigzag_long(0, sink)?; + sink.write_all(&self.sync_marker) + .map_err(|e| to_arrow_io_err(e, "Writing sync marker"))?; + Ok(()) + } +} diff --git a/arrow-avro/src/writer/mod.rs b/arrow-avro/src/writer/mod.rs new file mode 100644 index 000000000000..b58190266fa2 --- /dev/null +++ b/arrow-avro/src/writer/mod.rs @@ -0,0 +1,777 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This module provides writing capabilities for Avro container files, + +mod block; +mod header; +pub mod encoder; +pub mod zigzag; +mod utils; + +use std::io::Write; +use std::sync::Arc; + +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_schema::{ArrowError, SchemaRef}; + +use crate::compression::CompressionCodec; +use crate::schema::Schema as AvroSchema; + +use block::BlockEncoder; +use encoder::RecordEncoder; +use header::AvroHeader; + +/// A builder for creating an Avro [`Writer`] for Arrow data. +/// +/// Use this builder to configure options like compression, max block size, +/// custom Avro schema, out-of-spec Impala union ordering, etc., before creating the writer. +pub struct WriterBuilder { + /// The underlying output [`Write`] to which Avro data is written. + writer: W, + + /// The Arrow schema used for converting columns to Avro. + arrow_schema: SchemaRef, + + /// Optional user-supplied Avro schema that overrides the Arrow-generated schema. + avro_schema: Option>, + + /// If present, enables block-level compression (e.g. Snappy). + compression: Option, + + /// The threshold in bytes at which the writer flushes a data block. + max_block_size: usize, + + /// Arbitrary additional metadata key-value pairs stored in the Avro file header. + extra_meta: Vec<(String, Vec)>, + + /// If provided, the 16-byte sync marker used in the file; otherwise generated. + sync_marker: Option<[u8; 16]>, + + /// If `true`, produce unions with `[ T, "null" ]` ordering (Impala style) + /// rather than the standard `[ "null", T ]`. + impala_mode: bool, +} + +impl WriterBuilder { + /// Create a new `WriterBuilder` with the specified `writer` and Arrow schema. + pub fn new(writer: W, arrow_schema: SchemaRef) -> Self { + Self { + writer, + arrow_schema, + avro_schema: None, + compression: None, + max_block_size: 16 * 1024 * 1024, + extra_meta: vec![], + sync_marker: None, + impala_mode: false, + } + } + + /// Provide a custom Avro schema, overriding the derived Arrow-to-Avro conversion. + pub fn with_avro_schema(mut self, avro_schema: AvroSchema<'static>) -> Self { + self.avro_schema = Some(avro_schema); + self + } + + /// Enable block-level compression (e.g. Snappy, Deflate, etc.). + pub fn with_compression(mut self, codec: CompressionCodec) -> Self { + self.compression = Some(codec); + self + } + + /// Set the maximum in-memory block size (in bytes). Once the block buffer + /// exceeds this size, the writer flushes the block to disk. + pub fn with_max_block_size(mut self, size: usize) -> Self { + self.max_block_size = size; + self + } + + /// Add a key-value pair to the Avro file header's metadata map. + pub fn with_metadata(mut self, key: &str, value: &[u8]) -> Self { + self.extra_meta.push((key.to_string(), value.to_vec())); + self + } + + /// Specify a sync marker to be used in the file. If not set, a default + /// or random marker is used instead. + pub fn with_sync_marker(mut self, marker: [u8; 16]) -> Self { + self.sync_marker = Some(marker); + self + } + + /// **New:** If `impala_mode` is true, produce `[ T, "null" ]` union ordering for nullable fields + /// (matching Impala's out-of-spec union ordering), instead of the typical `[ "null", T ]`. + pub fn with_impala_mode(mut self, impala: bool) -> Self { + self.impala_mode = impala; + self + } + + /// Finalize the configuration and construct the [`Writer`], immediately + /// writing the Avro file header to the underlying output. + /// + /// # Errors + /// + /// Returns an error if the Arrow schema cannot be converted to Avro + /// or if writing the header fails. + pub fn build(mut self) -> Result, ArrowError> { + let sync_marker = self.sync_marker.unwrap_or([0xAA; 16]); + let header = AvroHeader { + arrow_schema: self.arrow_schema.as_ref().clone(), + impala_mode: self.impala_mode, + compression: self.compression, + extra_meta: self.extra_meta, + sync_marker, + }; + header.write_header(&mut self.writer)?; + let block_encoder = BlockEncoder::new( + self.compression, + sync_marker, + self.max_block_size, + ); + let record_encoder = RecordEncoder::try_new(self.arrow_schema.as_ref(), self.impala_mode)?; + Ok(Writer { + sink: self.writer, + header_written: true, + header, + block_encoder, + record_encoder, + arrow_schema: self.arrow_schema, + finished: false, + }) + } +} + +/// An Avro writer that produces container files from Arrow batches or single rows. +/// +/// Use [`WriterBuilder`] to construct a writer and then call: +/// * [`write`](Self::write) or [`write_batches`](Self::write_batches) to write entire [`RecordBatch`]es. +/// * [`finish`](Self::finish) to finalize the file. +pub struct Writer { + /// The output sink for all Avro bytes. + pub(crate) sink: W, + + #[allow(dead_code)] + /// Indicates we wrote the Avro file header already (unused in logic). + pub(crate) header_written: bool, + + #[allow(dead_code)] + /// The Avro file header data, including schema and compression. + pub(crate) header: AvroHeader, + + /// Handles buffering rows into blocks, compression, etc. + pub(crate) block_encoder: BlockEncoder, + + /// Encodes a single row from an Arrow `RecordBatch` into Avro bytes. + pub(crate) record_encoder: RecordEncoder, + + /// The Arrow schema used by this writer. + pub(crate) arrow_schema: SchemaRef, + + /// Whether this writer has been finished (no more data allowed). + pub(crate) finished: bool, +} + +impl Writer { + /// Write all rows in `batch` to the Avro file. + /// + /// # Errors + /// + /// Returns an error if the columns do not match the schema or if an + /// underlying I/O or encoding error occurs. + pub fn write(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { + if batch.num_columns() != self.arrow_schema.fields().len() { + return Err(ArrowError::InvalidArgumentError( + "Number of columns mismatch".into(), + )); + } + let row_count = batch.num_rows(); + for row_idx in 0..row_count { + let encoded_row = self + .record_encoder + .encode_row_to_vec(batch, row_idx)?; + self.block_encoder.append_encoded(&encoded_row); + self.block_encoder.inc_count(); + self.block_encoder.maybe_flush(&mut self.sink)?; + } + Ok(()) + } + + /// Writes multiple [`RecordBatch`] in succession. + /// + /// # Errors + /// + /// Returns an error if any batch fails to encode/write. + pub fn write_batches(&mut self, batches: &[&RecordBatch]) -> Result<(), ArrowError> { + for b in batches { + self.write(b)?; + } + Ok(()) + } + + /// Finish the file, flushing the final block if needed. + /// + /// After calling `finish()`, no further data can be written. + /// + /// # Errors + /// + /// If flushing fails, returns an error. + pub fn finish(&mut self) -> Result<(), ArrowError> { + if self.finished { + return Ok(()); + } + self.block_encoder.close(&mut self.sink)?; + self.finished = true; + Ok(()) + } + + /// Consume this writer and return the underlying output. + pub fn into_inner(self) -> W { + self.sink + } +} + +/// A convenience trait that some Arrow writers implement for a standard +/// interface to write [`RecordBatch`] or close the writer. +pub trait BatchWriter { + /// Write a single [`RecordBatch`] to the writer. + fn write_batch(&mut self, batch: &RecordBatch) -> Result<(), ArrowError>; + + /// Close or finalize the writer. + fn close(&mut self) -> Result<(), ArrowError>; +} + +impl BatchWriter for Writer { + fn write_batch(&mut self, batch: &RecordBatch) -> Result<(), ArrowError> { + self.write(batch) + } + + fn close(&mut self) -> Result<(), ArrowError> { + self.finish() + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + use std::fs::File; + use super::*; + use arrow_array::{Array, Decimal128Array, Decimal256Array, DictionaryArray, Int32Array, Int8Array, ListArray, MapArray, PrimitiveArray, StringArray, StructArray}; + use arrow_schema::{DataType, Field, Fields, IntervalUnit, Schema}; + use crate::reader::{ReaderBuilder}; + use std::io::{BufReader, Cursor}; + use arrow_array::builder::{Decimal128Builder, Decimal256Builder, Int32Builder, MapBuilder, StringBuilder}; + use arrow_array::types::IntervalMonthDayNanoType; + use arrow_buffer::{i256, Buffer, IntervalMonthDayNano}; + use arrow_data::ArrayData; + use crate::test_util::arrow_test_data; + + #[test] + fn test_round_trip_files() -> Result<(), ArrowError> { + let files = [ + ("avro/alltypes_plain.avro", false), + ("avro/alltypes_plain.snappy.avro", false), + ("avro/alltypes_plain.zstandard.avro", false), + ("avro/alltypes_plain.bzip2.avro", false), + ("avro/alltypes_plain.xz.avro", false), + ("avro/alltypes_dictionary.avro", false), + ("avro/alltypes_nulls_plain.avro", false), + ("avro/binary.avro", false), + ("avro/fixed_length_decimal.avro", false), + ("avro/fixed_length_decimal_legacy.avro", false), + ("avro/int32_decimal.avro", false), + ("avro/int64_decimal.avro", false), + ("avro/datapage_v2.snappy.avro", false), + ("avro/dict-page-offset-zero.avro", false), + ("avro/list_columns.avro", false), + ("avro/nested_lists.snappy.avro", false), + ("avro/nested_records.avro", false), + ("avro/nonnullable.impala.avro", true), + ("avro/nullable.impala.avro", true), + ("avro/nulls.snappy.avro", false), + ("avro/repeated_no_annotation.avro", false), + ("avro/simple_enum.avro", false), + ("avro/simple_fixed.avro", false), + ("avro/single_nan.avro", false), + ]; + for (file, mode) in files { + let file_path = arrow_test_data(file); + let mut original_reader = { + let f = File::open(&file_path).unwrap(); + let bf = BufReader::new(f); + ReaderBuilder::new().with_batch_size(64).build(bf)? + }; + let mut original_batches = Vec::new(); + while let Some(batch) = original_reader.next() { + original_batches.push(batch?); + } + let mut buffer = Vec::new(); + if !original_batches.is_empty() { + let schema = original_batches[0].schema(); + let mut writer = WriterBuilder::new(&mut buffer, schema.clone()) + .with_impala_mode(mode) + .build()?; + for batch in &original_batches { + writer.write(batch)?; + } + writer.finish()?; + } + let mut roundtrip_reader = ReaderBuilder::new().build(Cursor::new(&buffer))?; + let mut roundtrip_batches = Vec::new(); + while let Some(batch) = roundtrip_reader.next() { + roundtrip_batches.push(batch?); + } + assert_eq!(original_batches.len(), roundtrip_batches.len(), + "Mismatch in number of batches for file '{}'", file); + for (i, (original, roundtrip)) in + original_batches.iter().zip(roundtrip_batches.iter()).enumerate() + { + assert_eq!( + original.num_rows(), roundtrip.num_rows(), + "Row count mismatch in file '{}' batch {}", + file, i + ); + assert_eq!( + original.num_columns(), roundtrip.num_columns(), + "Column count mismatch in file '{}' batch {}", + file, i + ); + assert_eq!( + original, roundtrip, + "Mismatch in file '{}' batch {} after round-trip", + file, i + ); + } + } + Ok(()) + } + + fn round_trip( + batches: &[RecordBatch], + compression: Option, + ) -> Result, ArrowError> { + if batches.is_empty() { + return Ok(vec![]); + } + let schema = batches[0].schema(); + let mut buffer = Vec::new(); + { + let mut writer = WriterBuilder::new(&mut buffer, schema.clone()); + if let Some(codec) = compression { + writer = writer.with_compression(codec); + } + let mut writer = writer.build()?; + for b in batches { + writer.write(b)?; + } + writer.finish()?; + } + let mut reader = ReaderBuilder::new() + .with_batch_size(64) + .build(Cursor::new(buffer))?; + let mut out = Vec::new(); + while let Some(batch) = reader.next() { + let batch = batch?; + out.push(batch); + } + Ok(out) + } + + #[test] + fn test_round_trip_duration() -> Result<(), ArrowError> { + let row0 = IntervalMonthDayNano::new(0, 0, 0); + let row1 = IntervalMonthDayNano::new(0, 1, 500_000_000); + let data = vec![Some(row0), Some(row1)]; + let interval_arr = PrimitiveArray::::from(data); + let field = Field::new( + "duration_col", + DataType::Interval(IntervalUnit::MonthDayNano), + true, + ); + let schema = Arc::new(Schema::new(vec![field])); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(interval_arr) as ArrayRef])?; + let out_batches = round_trip(&[batch], None)?; + assert_eq!(out_batches.len(), 1); + let out_batch = &out_batches[0]; + assert_eq!(out_batch.num_rows(), 2); + assert_eq!(out_batch.num_columns(), 1); + let out_arr = out_batch.column(0); + assert_eq!( + out_arr.data_type(), + &DataType::Interval(IntervalUnit::MonthDayNano), + ); + let out_mdn = out_arr + .as_any() + .downcast_ref::>() + .unwrap(); + let row0_val: IntervalMonthDayNano = out_mdn.value(0); + assert_eq!(row0_val.months, 0); + assert_eq!(row0_val.days, 0); + assert_eq!(row0_val.nanoseconds, 0); + let row1_val: IntervalMonthDayNano = out_mdn.value(1); + assert_eq!(row1_val.months, 0); + assert_eq!(row1_val.days, 1); + assert_eq!(row1_val.nanoseconds, 500_000_000); + Ok(()) + } + + #[test] + fn test_round_trip_nested_array_non_nullable() -> Result<(), ArrowError> { + let int_values = Int32Array::from(vec![1, 2, 3]); + let int_data = int_values.into_data(); + let offsets_child = [0i32, 2, 3]; + let child_list_data_type = DataType::List(Arc::new(Field::new( + "item", + DataType::Int32, + false, + ))); + let child_list_data = ArrayData::builder(child_list_data_type.clone()) + .len(2) + .add_buffer(Buffer::from_slice_ref(&offsets_child)) + .add_child_data(int_data) + .build()?; + let child_list_arr = ListArray::from(child_list_data); + let offsets_outer = [0i32, 1, 2]; + let outer_list_data_type = DataType::List(Arc::new(Field::new( + "item", + child_list_data_type, + false + ))); + let outer_list_data = ArrayData::builder(outer_list_data_type.clone()) + .len(2) + .add_buffer(Buffer::from_slice_ref(&offsets_outer)) + .add_child_data(child_list_arr.into_data()) + .build()?; + let outer_list_arr = ListArray::from(outer_list_data); + let field = Field::new("outer_list_col", outer_list_data_type.clone(), false); + let schema = Arc::new(Schema::new(vec![field])); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(outer_list_arr) as ArrayRef])?; + let result_batches = round_trip(&[batch.clone()], None)?; + assert_eq!(result_batches.len(), 1, "Expected 1 roundtrip batch"); + let out_batch = &result_batches[0]; + assert_eq!(out_batch.num_rows(), 2); + assert_eq!(out_batch.num_columns(), 1); + let out_arr = out_batch + .column(0) + .as_any() + .downcast_ref::() + .expect("outer list array"); + assert_eq!(out_arr.len(), 2); + let row0_list = out_arr.value(0); + let row0_list = row0_list.as_any().downcast_ref::().unwrap(); + assert_eq!(row0_list.len(), 1, "row0 has 1 sub-list in this example"); + let sublist0 = row0_list.value(0); + let sublist0_ints = sublist0.as_any().downcast_ref::().unwrap(); + assert_eq!(sublist0_ints.len(), 2); + assert_eq!(sublist0_ints.value(0), 1); + assert_eq!(sublist0_ints.value(1), 2); + let row1_list = out_arr.value(1); + let row1_list = row1_list.as_any().downcast_ref::().unwrap(); + assert_eq!(row1_list.len(), 1, "row1 has 1 sub-list"); + let sublist1 = row1_list.value(0); + let sublist1_ints = sublist1.as_any().downcast_ref::().unwrap(); + assert_eq!(sublist1_ints.len(), 1); + assert_eq!(sublist1_ints.value(0), 3); + Ok(()) + } + + #[test] + fn test_round_trip_nested_record() -> Result<(), ArrowError> { + let address_fields = Fields::from(vec![ + Field::new("street", DataType::Utf8, false), + Field::new("city", DataType::Utf8, false), + ]); + let street_array = StringArray::from(vec![Some("Sunset Blvd"), Some("2nd Street")]); + let city_array = StringArray::from(vec![Some("LA"), Some("NYC")]); + let address_struct_array = StructArray::try_new( + address_fields.clone(), + vec![ + Arc::new(street_array) as ArrayRef, + Arc::new(city_array) as ArrayRef, + ], + None, + )?; + let person_fields = Fields::from(vec![ + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::Int32, false), + Field::new( + "address", + DataType::Struct(address_fields.clone()), + false, + ), + ]); + let name_array = StringArray::from(vec![Some("Alice"), Some("Bob")]); + let age_array = Int32Array::from(vec![Some(30), Some(40)]); + let person_struct_array = StructArray::try_new( + person_fields.clone(), + vec![ + Arc::new(name_array) as ArrayRef, + Arc::new(age_array) as ArrayRef, + Arc::new(address_struct_array) as ArrayRef, + ], + None, + )?; + let schema = Arc::new(Schema::new(vec![Field::new( + "person", + DataType::Struct(person_fields), + true, + )])); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(person_struct_array)])?; + let result_batches = round_trip(&[batch.clone()], None)?; + assert_eq!(result_batches.len(), 1, "Expected 1 output batch"); + let out_batch = &result_batches[0]; + assert_eq!(out_batch.num_rows(), batch.num_rows()); + assert_eq!(out_batch.num_columns(), 1); + let person_col = out_batch.column(0); + let person_struct = person_col + .as_any() + .downcast_ref::() + .expect("person column should be StructArray"); + assert_eq!(person_struct.len(), 2); + let name_arr = person_struct + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let age_arr = person_struct + .column_by_name("age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let address_arr = person_struct + .column_by_name("address") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let street_arr = address_arr + .column_by_name("street") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let city_arr = address_arr + .column_by_name("city") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(name_arr.value(0), "Alice"); + assert_eq!(age_arr.value(0), 30); + assert_eq!(street_arr.value(0), "Sunset Blvd"); + assert_eq!(city_arr.value(0), "LA"); + assert_eq!(name_arr.value(1), "Bob"); + assert_eq!(age_arr.value(1), 40); + assert_eq!(street_arr.value(1), "2nd Street"); + assert_eq!(city_arr.value(1), "NYC"); + Ok(()) + } + + #[test] + fn test_round_trip_enum_dictionary() -> Result<(), ArrowError> { + let dictionary_values = StringArray::from(vec!["RED", "GREEN", "BLUE"]); + let keys = Int8Array::from(vec![Some(0), Some(1), Some(2), Some(1), None, Some(2)]); + let dict_array = DictionaryArray::try_new(keys, Arc::new(dictionary_values))?; + let mut md = HashMap::new(); + md.insert( + "avro.enum.symbols".to_string(), + "[\"RED\",\"GREEN\",\"BLUE\"]".to_string(), + ); + let field = arrow_schema::Field::new("enum_col", dict_array.data_type().clone(), true) + .with_metadata(md); + let schema = Arc::new(arrow_schema::Schema::new(vec![field])); + let batch = arrow_array::RecordBatch::try_new(schema.clone(), vec![Arc::new(dict_array)])?; + let result_batches = round_trip(&[batch.clone()], None)?; + assert_eq!(result_batches.len(), 1, "Expected 1 batch after roundtrip"); + let out_batch = &result_batches[0]; + assert_eq!(out_batch.num_rows(), batch.num_rows()); + assert_eq!(out_batch.num_columns(), 1); + let out_col = out_batch.column(0); + let out_dict_arr = out_col + .as_any() + .downcast_ref::>() + .expect("Expected a Dictionary after reading Avro enum"); + let out_values = out_dict_arr + .values() + .as_any() + .downcast_ref::() + .expect("Dictionary values should be a StringArray"); + assert_eq!(out_values.len(), 3, "Should have 3 enum symbols"); + assert_eq!(out_values.value(0), "RED"); + assert_eq!(out_values.value(1), "GREEN"); + assert_eq!(out_values.value(2), "BLUE"); + let out_keys = out_dict_arr.keys(); + assert_eq!(out_keys.len(), 6); + assert_eq!(out_keys.is_null(4), true, "Row 4 was null in the original data"); + let out_keys_i64: Vec> = out_keys + .iter() + .map(|k| k.map(|x| x as i64)) + .collect(); + assert_eq!(out_keys_i64[0], Some(0)); + assert_eq!(out_keys_i64[1], Some(1)); + assert_eq!(out_keys_i64[2], Some(2)); + assert_eq!(out_keys_i64[3], Some(1)); + assert_eq!(out_keys_i64[4], None); + assert_eq!(out_keys_i64[5], Some(2)); + Ok(()) + } + + #[test] + fn test_round_trip_map() -> Result<(), ArrowError> { + let key_builder = StringBuilder::new(); + let value_builder = Int32Builder::new(); + let mut map_builder = MapBuilder::new(None, key_builder, value_builder); + map_builder.keys().append_value("apple"); + map_builder.values().append_value(10); + map_builder.keys().append_value("banana"); + map_builder.values().append_value(20); + map_builder.append(true)?; + map_builder.keys().append_value("hello"); + map_builder.values().append_value(42); + map_builder.append(true)?; + let map_array = map_builder.finish(); + let field = Field::new("my_map", map_array.data_type().clone(), true); + let schema = Arc::new(Schema::new(vec![field])); + let batch = RecordBatch::try_new(schema.clone(), vec![Arc::new(map_array) as ArrayRef])?; + let result = round_trip(&[batch.clone()], None)?; + assert_eq!(result.len(), 1, "Expected 1 batch after round trip"); + let out_batch = &result[0]; + assert_eq!(out_batch.num_rows(), batch.num_rows()); + assert_eq!(out_batch.num_columns(), 1); + let out_map = out_batch + .column(0) + .as_any() + .downcast_ref::() + .expect("Failed to downcast to MapArray"); + assert_eq!(out_map.len(), 2, "Should still have 2 rows"); + assert_eq!(out_map.value_length(0), 2); + assert_eq!(out_map.value_length(1), 1); + let entries_struct = out_map.entries(); + let key_arr = entries_struct + .column(0) + .as_any() + .downcast_ref::() + .expect("Map keys should be a StringArray"); + let val_arr = entries_struct + .column(1) + .as_any() + .downcast_ref::() + .expect("Map values should be Int32Array"); + assert_eq!(key_arr.value(0), "apple"); + assert_eq!(val_arr.value(0), 10); + assert_eq!(key_arr.value(1), "banana"); + assert_eq!(val_arr.value(1), 20); + assert_eq!(key_arr.value(2), "hello"); + assert_eq!(val_arr.value(2), 42); + Ok(()) + } + + #[test] + fn test_round_trip_decimal128() -> Result<(), ArrowError> { + let mut builder = Decimal128Builder::new() + .with_precision_and_scale(4, 2)?; + builder.append_value(12345); + builder.append_value(-9999); + builder.append_null(); + builder.append_value(5000); + let decimal_arr = builder.finish(); + let field = Field::new("decimal_col", DataType::Decimal128(4, 2), true); + let schema = Arc::new(Schema::new(vec![field])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(decimal_arr) as ArrayRef], + )?; + let result = round_trip(&[batch.clone()], None)?; + assert_eq!(result.len(), 1, "Expected exactly 1 batch returned"); + let out_batch = &result[0]; + assert_eq!(out_batch.num_rows(), batch.num_rows()); + assert_eq!(out_batch.num_columns(), 1); + let out_decimal = out_batch + .column(0) + .as_any() + .downcast_ref::() + .expect("Failed to downcast to Decimal128Array"); + assert_eq!(out_decimal.len(), 4); + assert_eq!(out_decimal.null_count(), 1); + assert_eq!(out_decimal.value(0), 12345); + assert_eq!(out_decimal.value(1), -9999); + assert!(out_decimal.is_null(2)); + assert_eq!(out_decimal.value(3), 5000); + Ok(()) + } + + #[test] + fn test_round_trip_decimal256() -> Result<(), ArrowError> { + let mut builder = Decimal256Builder::new() + .with_precision_and_scale(38, 6)?; + builder.append_value(i256::from_i128(123_456_789)); + builder.append_value(i256::from_i128(-1_000_000)); + builder.append_null(); + builder.append_value(i256::from_i128(999_999_999_999)); + let decimal_arr = builder.finish(); + let field = Field::new("decimal256_col", DataType::Decimal256(38, 6), true); + let schema = Arc::new(Schema::new(vec![field])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![Arc::new(decimal_arr) as ArrayRef], + )?; + let result = round_trip(&[batch.clone()], None)?; + assert_eq!(result.len(), 1, "Expected exactly 1 batch returned"); + let out_batch = &result[0]; + assert_eq!(out_batch.num_rows(), batch.num_rows()); + assert_eq!(out_batch.num_columns(), 1); + let out_decimal = out_batch + .column(0) + .as_any() + .downcast_ref::() + .expect("Failed to downcast to Decimal256Array"); + assert_eq!(out_decimal.len(), 4); + assert_eq!(out_decimal.null_count(), 1); + assert_eq!(out_decimal.value(0), i256::from_i128(123_456_789)); + assert_eq!(out_decimal.value(1), i256::from_i128(-1_000_000)); + assert!(out_decimal.is_null(2)); + assert_eq!(out_decimal.value(3), i256::from_i128(999_999_999_999)); + Ok(()) + } + + #[test] + fn test_round_trip_simple() -> Result<(), ArrowError> { + let schema = Arc::new(Schema::new(vec![ + Field::new("int_field", DataType::Int32, true), + Field::new("str_field", DataType::Utf8, true), + ])); + let ints = Arc::new(Int32Array::from(vec![Some(1), None, Some(3)])) as ArrayRef; + let strs = Arc::new(StringArray::from(vec![None, Some("x"), Some("y")])) as ArrayRef; + let batch = RecordBatch::try_new(schema.clone(), vec![ints, strs])?; + let result = round_trip(&[batch.clone()], None)?; + assert_eq!(result.len(), 1); + let out_batch = &result[0]; + assert_eq!(out_batch.num_rows(), 3); + assert_eq!(out_batch.num_columns(), 2); + let out_ints = out_batch.column(0).as_any().downcast_ref::().unwrap(); + let out_strs = out_batch.column(1).as_any().downcast_ref::().unwrap(); + assert_eq!(out_ints.is_null(1), true); + assert_eq!(out_ints.value(0), 1); + assert_eq!(out_ints.value(2), 3); + assert_eq!(out_strs.is_null(0), true); + assert_eq!(out_strs.value(1), "x"); + assert_eq!(out_strs.value(2), "y"); + Ok(()) + } +} \ No newline at end of file diff --git a/arrow-avro/src/writer/utils.rs b/arrow-avro/src/writer/utils.rs new file mode 100644 index 000000000000..907a5a34332e --- /dev/null +++ b/arrow-avro/src/writer/utils.rs @@ -0,0 +1,41 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::io::{Error as IoError, Write}; + +use arrow_schema::ArrowError; + +use crate::writer::zigzag::write_zigzag_long; + +/// Wraps an `std::io::Error` with a helpful context, returning an `ArrowError`. +pub fn to_arrow_io_err(e: IoError, context: &str) -> ArrowError { + let msg = format!("{context}: {e}"); + ArrowError::ExternalError(Box::new(std::io::Error::new(e.kind(), msg))) +} + +/// Write a UTF-8 string in Avro format: zigzag-encoded length + raw bytes +pub fn write_string(s: &str, w: &mut dyn Write) -> Result<(), ArrowError> { + write_bytes(s.as_bytes(), w) +} + +/// Write raw bytes in Avro format: zigzag-encoded length + raw bytes +pub fn write_bytes(b: &[u8], w: &mut dyn Write) -> Result<(), ArrowError> { + write_zigzag_long(b.len() as i64, w)?; + w.write_all(b) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + Ok(()) +} diff --git a/arrow-avro/src/writer/zigzag.rs b/arrow-avro/src/writer/zigzag.rs new file mode 100644 index 000000000000..acf292340152 --- /dev/null +++ b/arrow-avro/src/writer/zigzag.rs @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Zigzag variable-length encoding for Avro "long" and "int" types + +use std::io::Write; +use arrow_schema::ArrowError; + +/// Write an Avro "long" (i64) in **zigzag** variable-length format into `writer`. +/// +pub fn write_zigzag_long(n: i64, writer: &mut dyn Write) -> Result<(), ArrowError> { + let zz = ((n << 1) ^ (n >> 63)) as u64; + write_varint_u64(zz, writer) +} + +/// Write an Avro "long" (i32) in **zigzag** variable-length format into `writer`. +/// +pub fn write_zigzag_int(n: i32, writer: &mut dyn Write) -> Result<(), ArrowError> { + let i64_val = n as i64; + write_zigzag_long(i64_val, writer) +} + +fn write_varint_u64(mut val: u64, writer: &mut dyn Write) -> Result<(), ArrowError> { + let mut buf = [0u8; 10]; + let mut i = 0; + loop { + let b = (val & 0x7F) as u8; + val >>= 7; + if val != 0 { + buf[i] = b | 0x80; + i += 1; + } else { + buf[i] = b; + i += 1; + break; + } + if i >= buf.len() { + return Err(ArrowError::ParseError( + "Varint exceeded 10 bytes, not a valid Avro long".to_string(), + )); + } + } + writer + .write_all(&buf[..i]) + .map_err(|e| ArrowError::ExternalError(Box::new(e)))?; + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::Cursor; + + #[test] + fn test_zigzag_long_positive() { + let mut buffer = Cursor::new(Vec::new()); + write_zigzag_long(5, &mut buffer).unwrap(); + assert_eq!(buffer.into_inner(), vec![0x0A]); + } + + #[test] + fn test_zigzag_long_negative() { + let mut buffer = Cursor::new(Vec::new()); + write_zigzag_long(-1, &mut buffer).unwrap(); + assert_eq!(buffer.into_inner(), vec![0x01]); + } + + #[test] + fn test_zigzag_long_zero() { + let mut buffer = Cursor::new(Vec::new()); + write_zigzag_long(0, &mut buffer).unwrap(); + assert_eq!(buffer.into_inner(), vec![0x00]); + } + + #[test] + fn test_zigzag_int_positive() { + let mut buffer = Cursor::new(Vec::new()); + write_zigzag_int(3, &mut buffer).unwrap(); + assert_eq!(buffer.into_inner(), vec![0x06]); + } + + #[test] + fn test_zigzag_int_negative() { + let mut buffer = Cursor::new(Vec::new()); + write_zigzag_int(-3, &mut buffer).unwrap(); + assert_eq!(buffer.into_inner(), vec![0x05]); + } +} diff --git a/arrow-avro/test/data/nested_lists.snappy.avro b/arrow-avro/test/data/nested_lists.snappy.avro new file mode 100644 index 000000000000..6cbff89610a7 Binary files /dev/null and b/arrow-avro/test/data/nested_lists.snappy.avro differ diff --git a/arrow-avro/test/data/simple_enum.avro b/arrow-avro/test/data/simple_enum.avro new file mode 100644 index 000000000000..dbf0a42baae4 Binary files /dev/null and b/arrow-avro/test/data/simple_enum.avro differ diff --git a/arrow-schema/src/error.rs b/arrow-schema/src/error.rs index 982dd026a04d..76e8a2efbc03 100644 --- a/arrow-schema/src/error.rs +++ b/arrow-schema/src/error.rs @@ -54,6 +54,8 @@ pub enum ArrowError { InvalidArgumentError(String), /// Error during Parquet operations. ParquetError(String), + /// Error during Avro operations. + AvroError(String), /// Error during import or export to/from the C Data Interface CDataInterface(String), /// Error when a dictionary key is bigger than the key type @@ -117,6 +119,9 @@ impl Display for ArrowError { ArrowError::ParquetError(desc) => { write!(f, "Parquet argument error: {desc}") } + ArrowError::AvroError(desc) => { + write!(f, "Avro argument error: {desc}") + } ArrowError::CDataInterface(desc) => { write!(f, "C Data interface error: {desc}") }