diff --git a/arrow-array/src/array/extension_array.rs b/arrow-array/src/array/extension_array.rs new file mode 100644 index 000000000000..f1cdaf5b6d48 --- /dev/null +++ b/arrow-array/src/array/extension_array.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow_data::ArrayData; +use arrow_schema::{extension::DynExtensionType, ArrowError, DataType}; + +use super::{make_array, Array, ArrayRef}; + +/// Array type for DataType::Extension +#[derive(Debug)] +pub struct ExtensionArray { + data_type: DataType, + storage: ArrayRef, +} + +impl ExtensionArray { + /// Try to create a new ExtensionArray + pub fn try_new( + extension: Arc, + storage: ArrayRef, + ) -> Result { + Ok(Self { + data_type: DataType::Extension(extension), + storage, + }) + } + + /// Create a new ExtensionArray + pub fn new(extension: Arc, storage: ArrayRef) -> Self { + Self::try_new(extension, storage).unwrap() + } + + /// Return the underlying storage array + pub fn storage(&self) -> &ArrayRef { + &self.storage + } + + /// Return a new array with new storage of the same type + pub fn with_storage(&self, new_storage: ArrayRef) -> Self { + assert_eq!(new_storage.data_type(), new_storage.data_type()); + Self { + data_type: self.data_type.clone(), + storage: new_storage, + } + } +} + +impl From for ExtensionArray { + fn from(data: ArrayData) -> Self { + if let DataType::Extension(extension) = data.data_type() { + let storage_data = ArrayData::try_new( + extension.storage_type().clone(), + data.len(), + data.nulls().map(|b| b.buffer()).cloned(), + data.offset(), + data.buffers().to_vec(), + data.child_data().to_vec(), + ) + .unwrap(); + + Self { + data_type: data.data_type().clone(), + storage: Arc::new(make_array(storage_data)) as ArrayRef, + } + } else { + panic!("{} is not Extension", data.data_type()) + } + } +} + +impl Array for ExtensionArray { + fn as_any(&self) -> &dyn Any { + self + } + + fn to_data(&self) -> ArrayData { + let storage_data = self.storage.to_data(); + ArrayData::try_new( + self.data_type.clone(), + storage_data.len(), + storage_data.nulls().map(|b| b.buffer()).cloned(), + storage_data.offset(), + storage_data.buffers().to_vec(), + storage_data.child_data().to_vec(), + ) + .unwrap() + } + + fn into_data(self) -> ArrayData { + self.to_data() + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn slice(&self, offset: usize, length: usize) -> ArrayRef { + Arc::new(Self { + data_type: self.data_type.clone(), + storage: self.storage.slice(offset, length), + }) + } + + fn len(&self) -> usize { + self.storage.len() + } + + fn is_empty(&self) -> bool { + self.storage.is_empty() + } + + fn offset(&self) -> usize { + self.storage.offset() + } + + fn nulls(&self) -> Option<&arrow_buffer::NullBuffer> { + self.storage.nulls() + } + + fn get_buffer_memory_size(&self) -> usize { + self.storage.get_buffer_memory_size() + } + + fn get_array_memory_size(&self) -> usize { + self.storage.get_array_memory_size() + } +} diff --git a/arrow-array/src/array/mod.rs b/arrow-array/src/array/mod.rs index e41a3a1d719a..a5894f47e401 100644 --- a/arrow-array/src/array/mod.rs +++ b/arrow-array/src/array/mod.rs @@ -76,6 +76,9 @@ mod list_view_array; pub use list_view_array::*; +mod extension_array; +pub use extension_array::*; + use crate::iterator::ArrayIter; /// An array in the [arrow columnar format](https://arrow.apache.org/docs/format/Columnar.html) @@ -829,6 +832,7 @@ pub fn make_array(data: ArrayData) -> ArrayRef { DataType::Null => Arc::new(NullArray::from(data)) as ArrayRef, DataType::Decimal128(_, _) => Arc::new(Decimal128Array::from(data)) as ArrayRef, DataType::Decimal256(_, _) => Arc::new(Decimal256Array::from(data)) as ArrayRef, + DataType::Extension(_) => Arc::new(ExtensionArray::from(data)) as ArrayRef, dt => panic!("Unexpected data type {dt:?}"), } } diff --git a/arrow-data/src/data.rs b/arrow-data/src/data.rs index 4c117184de79..7d01a27a6131 100644 --- a/arrow-data/src/data.rs +++ b/arrow-data/src/data.rs @@ -151,6 +151,7 @@ pub(crate) fn new_buffers(data_type: &DataType, capacity: usize) -> [MutableBuff } } } + DataType::Extension(extension) => new_buffers(extension.storage_type(), capacity), } } @@ -590,6 +591,12 @@ impl ArrayData { /// Returns a new [`ArrayData`] valid for `data_type` containing `len` null values pub fn new_null(data_type: &DataType, len: usize) -> Self { + if let DataType::Extension(extension) = data_type { + let mut storage_data = Self::new_null(extension.storage_type(), len); + storage_data.data_type = data_type.clone(); + return storage_data; + } + let bit_len = bit_util::ceil(len, 8); let zeroed = |len: usize| Buffer::from(MutableBuffer::from_len_zeroed(len)); @@ -1664,6 +1671,7 @@ pub fn layout(data_type: &DataType) -> DataTypeLayout { } } DataType::Dictionary(key_type, _value_type) => layout(key_type), + DataType::Extension(extension) => layout(extension.storage_type()), } } @@ -2119,7 +2127,7 @@ impl From for ArrayDataBuilder { #[cfg(test)] mod tests { use super::*; - use arrow_schema::{Field, Fields}; + use arrow_schema::{extension::TestExtension, Field, Fields}; // See arrow/tests/array_data_validation.rs for test of array validation @@ -2448,4 +2456,15 @@ mod tests { assert!(array.is_null(i)); } } + + #[test] + fn test_data_extension() { + let data_type = DataType::Extension(Arc::new(TestExtension { + storage_type: DataType::Utf8, + })); + let array_null = ArrayData::new_null(&data_type, 3); + assert_eq!(array_null.len(), 3); + assert_eq!(array_null.data_type(), &data_type); + assert_eq!(array_null.null_count(), 3); + } } diff --git a/arrow-data/src/equal/mod.rs b/arrow-data/src/equal/mod.rs index f24179b61700..e35ac30b62f9 100644 --- a/arrow-data/src/equal/mod.rs +++ b/arrow-data/src/equal/mod.rs @@ -123,6 +123,7 @@ fn equal_values( DataType::Float16 => primitive_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::Map(_, _) => list_equal::(lhs, rhs, lhs_start, rhs_start, len), DataType::RunEndEncoded(_, _) => run_equal(lhs, rhs, lhs_start, rhs_start, len), + DataType::Extension(_) => unimplemented!("Extension not implemented"), } } diff --git a/arrow-data/src/transform/mod.rs b/arrow-data/src/transform/mod.rs index 93b79e6a5eb8..5616e9efea4e 100644 --- a/arrow-data/src/transform/mod.rs +++ b/arrow-data/src/transform/mod.rs @@ -276,6 +276,7 @@ fn build_extend(array: &ArrayData) -> Extend { UnionMode::Dense => union::build_extend_dense(array), }, DataType::RunEndEncoded(_, _) => todo!(), + DataType::Extension(_) => unimplemented!("Extension not implemented"), } } @@ -332,6 +333,7 @@ fn build_extend_nulls(data_type: &DataType) -> ExtendNulls { UnionMode::Dense => union::extend_nulls_dense, }, DataType::RunEndEncoded(_, _) => todo!(), + DataType::Extension(_) => unimplemented!("ListView/LargeListView not implemented"), }) } @@ -590,6 +592,7 @@ impl<'a> MutableArrayData<'a> { MutableArrayData::new(child_arrays, use_nulls, array_capacity) }) .collect::>(), + DataType::Extension(_) => unimplemented!("Extension not implemented"), }; // Get the dictionary if any, and if it is a concatenation of multiple diff --git a/arrow-integration-test/src/datatype.rs b/arrow-integration-test/src/datatype.rs index 24e02c8430c7..aa41f24ff3a6 100644 --- a/arrow-integration-test/src/datatype.rs +++ b/arrow-integration-test/src/datatype.rs @@ -345,6 +345,7 @@ pub fn data_type_to_json(data_type: &DataType) -> serde_json::Value { json!({"name": "map", "keysSorted": keys_sorted}) } DataType::RunEndEncoded(_, _) => todo!(), + DataType::Extension(extension) => data_type_to_json(extension.storage_type()), } } diff --git a/arrow-ipc/src/convert.rs b/arrow-ipc/src/convert.rs index 79dd1726ed70..40e10362111c 100644 --- a/arrow-ipc/src/convert.rs +++ b/arrow-ipc/src/convert.rs @@ -18,6 +18,7 @@ //! Utilities for converting between IPC types and native Arrow types use arrow_buffer::Buffer; +use arrow_schema::extension::DynExtensionTypeFactory; use arrow_schema::*; use flatbuffers::{ FlatBufferBuilder, ForwardsUOffset, UnionWIPOffset, Vector, Verifiable, Verifier, @@ -194,8 +195,16 @@ impl From> for Field { } } -/// Deserialize an ipc [crate::Schema`] from flat buffers to an arrow [Schema]. +/// Deserialize an ipc [crate::Schema`] from flat buffers to an arrow [Schema] pub fn fb_to_schema(fb: crate::Schema) -> Schema { + fb_to_schema_with_extension_factory(fb, None).unwrap() +} + +/// Deserialize an ipc [crate::Schema`] from flat buffers to an arrow [Schema] with extension support +pub fn fb_to_schema_with_extension_factory( + fb: crate::Schema, + extension_factory: Option<&dyn DynExtensionTypeFactory>, +) -> Result { let mut fields: Vec = vec![]; let c_fields = fb.fields().unwrap(); let len = c_fields.len(); @@ -207,7 +216,15 @@ pub fn fb_to_schema(fb: crate::Schema) -> Schema { } _ => (), }; - fields.push(c_field.into()); + let field: Field = c_field.into(); + if let Some(factory) = extension_factory { + if let Some(extension) = factory.make_from_field(&field)? { + fields.push(field.clone().with_data_type(DataType::Extension(extension))); + continue; + } + } + + fields.push(field); } let mut metadata: HashMap = HashMap::default(); @@ -224,7 +241,8 @@ pub fn fb_to_schema(fb: crate::Schema) -> Schema { } } } - Schema::new_with_metadata(fields, metadata) + + Ok(Schema::new_with_metadata(fields, metadata)) } /// Try deserialize flat buffer format bytes into a schema @@ -514,7 +532,24 @@ pub(crate) fn build_field<'a>( ) -> WIPOffset> { // Optional custom metadata. let mut fb_metadata = None; - if !field.metadata().is_empty() { + + // Handle extension type metadata if applicable + if let DataType::Extension(extension) = field.data_type() { + let mut field_metadata = HashMap::from([ + ( + "ARROW:extension:name".to_string(), + extension.extension_name().to_string(), + ), + ( + "ARROW:extension:metadata".to_string(), + extension.serialized_metadata(), + ), + ]); + + for (k, v) in field.metadata() { + field_metadata.insert(k.clone(), v.clone()); + } + } else if !field.metadata().is_empty() { fb_metadata = Some(metadata_to_fb(fbb, field.metadata())); }; @@ -883,6 +918,9 @@ pub(crate) fn get_fb_field_type<'a>( children: Some(fbb.create_vector(&children[..])), } } + DataType::Extension(extension) => { + get_fb_field_type(extension.storage_type(), dictionary_tracker, fbb) + } } } diff --git a/arrow-ipc/src/reader.rs b/arrow-ipc/src/reader.rs index 83dc5702dc94..6e084f3d5e60 100644 --- a/arrow-ipc/src/reader.rs +++ b/arrow-ipc/src/reader.rs @@ -26,6 +26,7 @@ mod stream; +use arrow_schema::extension::DynExtensionTypeFactory; pub use stream::*; use flatbuffers::{VectorIter, VerifierOptions}; @@ -229,6 +230,12 @@ impl RecordBatchDecoder<'_> { .offset(0); self.create_array_from_builder(builder) } + Extension(extension) => self.create_array( + &field + .clone() + .with_data_type(extension.storage_type().clone()), + variadic_counts, + ), _ => { let field_node = self.next_node(field)?; let buffers = [self.next_buffer()?, self.next_buffer()?]; @@ -1173,7 +1180,7 @@ impl FileReader { /// Try to create a new file reader. /// /// There is no internal buffering. If buffered reads are needed you likely want to use - /// [`FileReader::try_new_buffered`] instead. + /// [`FileReader::try_new_buffered`] instead. /// /// # Errors /// @@ -1364,8 +1371,17 @@ impl StreamReader { /// An ['Err'](Result::Err) may be returned if the reader does not encounter a schema /// as the first message in the stream. pub fn try_new( + reader: R, + projection: Option>, + ) -> Result, ArrowError> { + Self::try_new_with_extension_factory(reader, projection, None) + } + + /// Create a stream reader with an extension factory + pub fn try_new_with_extension_factory( mut reader: R, projection: Option>, + extension_factory: Option<&dyn DynExtensionTypeFactory>, ) -> Result, ArrowError> { // determine metadata length let mut meta_size: [u8; 4] = [0; 4]; @@ -1389,7 +1405,9 @@ impl StreamReader { let ipc_schema: crate::Schema = message.header_as_schema().ok_or_else(|| { ArrowError::ParseError("Unable to read IPC message as schema".to_string()) })?; - let schema = crate::convert::fb_to_schema(ipc_schema); + + let schema = + crate::convert::fb_to_schema_with_extension_factory(ipc_schema, extension_factory)?; // Create an array of optional dictionary value arrays, one per field. let dictionaries_by_id = HashMap::new(); diff --git a/arrow-schema/src/datatype.rs b/arrow-schema/src/datatype.rs index 5c9073c4eeb6..87cdd6056540 100644 --- a/arrow-schema/src/datatype.rs +++ b/arrow-schema/src/datatype.rs @@ -19,6 +19,7 @@ use std::fmt; use std::str::FromStr; use std::sync::Arc; +use crate::extension::DynExtensionType; use crate::{ArrowError, Field, FieldRef, Fields, UnionFields}; /// Datatypes supported by this implementation of Apache Arrow. @@ -411,6 +412,9 @@ pub enum DataType { /// These child arrays are prescribed the standard names of "run_ends" and "values" /// respectively. RunEndEncoded(FieldRef, FieldRef), + /// An ExtensionType + #[cfg_attr(feature = "serde", serde(skip))] + Extension(Arc), } /// An absolute length of time in seconds, milliseconds, microseconds or nanoseconds. @@ -689,6 +693,7 @@ impl DataType { DataType::Union(_, _) => None, DataType::Dictionary(_, _) => None, DataType::RunEndEncoded(_, _) => None, + DataType::Extension(_) => None, } } @@ -740,6 +745,7 @@ impl DataType { run_ends.size() - std::mem::size_of_val(run_ends) + values.size() - std::mem::size_of_val(values) } + DataType::Extension(extension) => extension.size(), } } diff --git a/arrow-schema/src/extension/mod.rs b/arrow-schema/src/extension/mod.rs index c5119873af0c..1d29bfc89a21 100644 --- a/arrow-schema/src/extension/mod.rs +++ b/arrow-schema/src/extension/mod.rs @@ -24,7 +24,13 @@ mod canonical; #[cfg(feature = "canonical_extension_types")] pub use canonical::*; -use crate::{ArrowError, DataType}; +use crate::{ArrowError, DataType, Field}; +use std::any::Any; +use std::cmp::Ordering; +use std::fmt::Debug; +use std::hash::{Hash, Hasher}; +use std::panic::RefUnwindSafe; +use std::sync::Arc; /// The metadata key for the string name identifying an [`ExtensionType`]. pub const EXTENSION_TYPE_NAME_KEY: &str = "ARROW:extension:name"; @@ -258,3 +264,193 @@ pub trait ExtensionType: Sized { /// this extension type. fn try_new(data_type: &DataType, metadata: Self::Metadata) -> Result; } + +/// dyn-compatible ExtensionType +pub trait DynExtensionType: Debug + RefUnwindSafe { + /// For dyn-compatible comparison methods + fn as_any(&self) -> &dyn Any; + + /// Because DataType implements sized + fn size(&self) -> usize; + + /// Concrete storage type for this extension + fn storage_type(&self) -> &DataType; + + /// Name of the extension + fn extension_name(&self) -> &'static str; + + /// Extension metadata + fn serialized_metadata(&self) -> String; + + /// Because DataType implement Eq + fn extension_equals(&self, other: &dyn Any) -> bool; + + /// Because DataType implements Hash + fn extension_hash(&self, hasher: &mut dyn Hasher); + + /// Because DataType implements Ord + fn exension_cmp(&self, other: &dyn Any) -> Ordering; +} + +impl PartialEq for dyn DynExtensionType + Send + Sync { + fn eq(&self, other: &Self) -> bool { + self.extension_equals(other.as_any()) + } +} + +impl Eq for dyn DynExtensionType + Send + Sync {} + +impl Hash for dyn DynExtensionType + Send + Sync { + fn hash(&self, state: &mut H) { + self.extension_hash(state); + } +} + +impl PartialOrd for dyn DynExtensionType + Send + Sync { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for dyn DynExtensionType + Send + Sync { + fn cmp(&self, other: &Self) -> Ordering { + self.exension_cmp(other.as_any()) + } +} + +/// A way to create extension types for places where they might be imported +pub trait DynExtensionTypeFactory { + /// Create an extension type from name, storage type, and metadata + fn make_extension_type( + &self, + extension_name: &str, + storage_type: &DataType, + extension_metadata: Option<&String>, + ) -> Result>, ArrowError>; + + /// Create an extension type from a field + fn make_from_field( + &self, + field: &Field, + ) -> Result>, ArrowError> { + if let Some(extension_name) = field.metadata().get("ARROW:extension:name") { + self.make_extension_type( + extension_name, + field.data_type(), + field.metadata().get("ARROW:extension:metadata"), + ) + } else { + Ok(None) + } + } +} + +/// Simple factory with registered types +pub struct CanonicalExtensionTypeFactory {} + +#[cfg(feature = "canonical_extension_types")] +impl DynExtensionType for Uuid { + fn as_any(&self) -> &dyn Any { + self + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn storage_type(&self) -> &DataType { + &DataType::FixedSizeBinary(16) + } + + fn extension_name(&self) -> &'static str { + Self::NAME + } + + fn serialized_metadata(&self) -> String { + "".to_string() + } + + fn extension_equals(&self, other: &dyn Any) -> bool { + other.downcast_ref::().is_some() + } + + fn extension_hash(&self, hasher: &mut dyn Hasher) { + hasher.write("arrow.uuid".as_bytes()); + } + + fn exension_cmp(&self, other: &dyn Any) -> Ordering { + if self.extension_equals(other) { + Ordering::Equal + } else { + // Fishy... + Ordering::Less + } + } +} + +#[cfg(feature = "canonical_extension_types")] +impl DynExtensionTypeFactory for CanonicalExtensionTypeFactory { + fn make_extension_type( + &self, + extension_name: &str, + storage_type: &DataType, + extension_metadata: Option<&String>, + ) -> Result>, ArrowError> { + match extension_name { + "arrow.uuid" => { + let uuid = Uuid::try_new( + storage_type, + Uuid::deserialize_metadata(extension_metadata.map(|s| s.as_str()))?, + )?; + Ok(Some(Arc::new(uuid))) + } + _ => Ok(None), + } + } +} + +/// Extension for tests +#[derive(Debug)] +pub struct TestExtension { + /// Arbitrary storage type + pub storage_type: DataType, +} + +impl DynExtensionType for TestExtension { + fn as_any(&self) -> &dyn Any { + self + } + + fn size(&self) -> usize { + size_of_val(self) + } + + fn storage_type(&self) -> &DataType { + &self.storage_type + } + + fn extension_name(&self) -> &'static str { + "arrow.rs.test" + } + + fn serialized_metadata(&self) -> String { + "".to_string() + } + + fn extension_equals(&self, other: &dyn Any) -> bool { + other.downcast_ref::().is_some() + } + + fn extension_hash(&self, hasher: &mut dyn Hasher) { + hasher.write("arrow.rs.test".as_bytes()); + } + + fn exension_cmp(&self, other: &dyn Any) -> Ordering { + if self.extension_equals(other) { + Ordering::Equal + } else { + // Fishy... + Ordering::Less + } + } +} diff --git a/arrow-schema/src/field.rs b/arrow-schema/src/field.rs index dbd671a62a3a..cb385a59a0e5 100644 --- a/arrow-schema/src/field.rs +++ b/arrow-schema/src/field.rs @@ -727,7 +727,7 @@ impl Field { DataType::Null => { self.nullable = true; self.data_type = from.data_type.clone(); - } + }, | DataType::Boolean | DataType::Int8 | DataType::Int16 @@ -761,7 +761,8 @@ impl Field { | DataType::LargeUtf8 | DataType::Utf8View | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) => { + | DataType::Decimal256(_, _) + | DataType::Extension(_) => { if from.data_type == DataType::Null { self.nullable = true; } else if self.data_type != from.data_type { diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index b48998478442..bc62548e5739 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -297,6 +297,22 @@ pub fn concat(arrays: &[&dyn Array]) -> Result { return Err(ArrowError::InvalidArgumentError(error_message)); } + if let DataType::Extension(extension) = d { + let storage: Vec<_> = arrays + .iter() + .map(|array| { + let extension_array: ExtensionArray = array.to_data().into(); + extension_array.storage().clone() + }) + .collect(); + let storage_ref: Vec<_> = storage.iter().map(|array| array.as_ref()).collect(); + let storage_result = concat(&storage_ref)?; + return Ok(Arc::new(ExtensionArray::new( + extension.clone(), + storage_result, + ))); + } + downcast_primitive! { d => (primitive_concat, arrays), DataType::Boolean => concat_boolean(arrays), @@ -374,7 +390,7 @@ pub fn concat_batches<'a>( mod tests { use super::*; use arrow_array::builder::{GenericListBuilder, StringDictionaryBuilder}; - use arrow_schema::{Field, Schema}; + use arrow_schema::{extension::TestExtension, Field, Schema}; use std::fmt::Debug; #[test] @@ -1267,4 +1283,27 @@ mod tests { "There are duplicates in the value list (the value list here is sorted which is only for the assertion)" ); } + + #[test] + fn test_concat_extension() { + let storage = Arc::new(StringArray::from(vec!["one banana", "two banana"])); + + let array = ExtensionArray::new( + Arc::new(TestExtension { + storage_type: DataType::Utf8, + }), + storage.clone(), + ); + + let result_ref = concat(&[&array, &array]).unwrap(); + assert_eq!(result_ref.data_type(), array.data_type()); + assert_eq!(result_ref.len(), 4); + + let result_array: ExtensionArray = result_ref.to_data().into(); + let expected_storage = create_array!( + Utf8, + ["one banana", "two banana", "one banana", "two banana"] + ); + assert_eq!(**result_array.storage(), *expected_storage); + } } diff --git a/arrow-select/src/filter.rs b/arrow-select/src/filter.rs index 7bb140d37f51..91180b516020 100644 --- a/arrow-select/src/filter.rs +++ b/arrow-select/src/filter.rs @@ -393,6 +393,11 @@ fn filter_array(values: &dyn Array, predicate: &FilterPredicate) -> Result { Ok(Arc::new(filter_sparse_union(values.as_union(), predicate)?)) } + DataType::Extension(_) => { + let extension_array: ExtensionArray = values.to_data().into(); + let storage_result = filter_array(extension_array.storage(), predicate)?; + Ok(Arc::new(extension_array.with_storage(storage_result))) + } _ => { let data = values.to_data(); // fallback to using MutableArrayData @@ -864,6 +869,7 @@ mod tests { use arrow_array::builder::*; use arrow_array::cast::as_run_array; use arrow_array::types::*; + use arrow_schema::extension::TestExtension; use rand::distr::uniform::{UniformSampler, UniformUsize}; use rand::distr::{Alphanumeric, StandardUniform}; use rand::prelude::*; @@ -2045,4 +2051,28 @@ mod tests { assert_eq!(result.to_data(), expected.to_data()); } + + #[test] + fn test_filter_extension() { + let predicate = BooleanArray::from(vec![true, false, true, false]); + let storage = Arc::new(StringArray::from(vec![ + "one banana", + "two banana", + "three banana", + "four", + ])); + let array = ExtensionArray::new( + Arc::new(TestExtension { + storage_type: DataType::Utf8, + }), + storage.clone(), + ); + let result_ref = filter(&array, &predicate).unwrap(); + assert_eq!(result_ref.data_type(), array.data_type()); + assert_eq!(result_ref.len(), 2); + + let result_array: ExtensionArray = result_ref.to_data().into(); + let expected = create_array!(Utf8, ["one banana", "three banana"]); + assert_eq!(result_array.storage().to_data(), expected.to_data()); + } } diff --git a/arrow-select/src/interleave.rs b/arrow-select/src/interleave.rs index 5fc019da78f1..39c317c18e25 100644 --- a/arrow-select/src/interleave.rs +++ b/arrow-select/src/interleave.rs @@ -93,6 +93,22 @@ pub fn interleave( return Ok(new_empty_array(data_type)); } + if let DataType::Extension(extension) = data_type { + let storage: Vec<_> = values + .iter() + .map(|array| { + let extension_array: ExtensionArray = array.to_data().into(); + extension_array.storage().clone() + }) + .collect(); + let storage_ref: Vec<_> = storage.iter().map(|array| array.as_ref()).collect(); + let storage_result = interleave(&storage_ref, indices)?; + return Ok(Arc::new(ExtensionArray::new( + extension.clone(), + storage_result, + ))); + } + downcast_primitive! { data_type => (primitive_helper, values, indices, data_type), DataType::Utf8 => interleave_bytes::(values, indices), @@ -369,6 +385,7 @@ pub fn interleave_record_batch( mod tests { use super::*; use arrow_array::builder::{Int32Builder, ListBuilder}; + use arrow_schema::extension::TestExtension; #[test] fn test_primitive() { @@ -729,4 +746,30 @@ mod tests { ] ); } + + #[test] + fn test_interleave_extension() { + let indices = [(0, 0), (1, 3), (0, 2)]; + let storage = Arc::new(StringArray::from(vec![ + "one banana", + "two banana", + "three banana", + "four", + ])); + + let array = ExtensionArray::new( + Arc::new(TestExtension { + storage_type: DataType::Utf8, + }), + storage.clone(), + ); + + let result_ref = interleave(&[&array, &array], &indices).unwrap(); + assert_eq!(result_ref.data_type(), array.data_type()); + assert_eq!(result_ref.len(), 3); + + let result_array: ExtensionArray = result_ref.to_data().into(); + let expected_storage = create_array!(Utf8, ["one banana", "four", "three banana"]); + assert_eq!(**result_array.storage(), *expected_storage); + } } diff --git a/arrow-select/src/nullif.rs b/arrow-select/src/nullif.rs index dc729da7e6c3..caf4c6d6d919 100644 --- a/arrow-select/src/nullif.rs +++ b/arrow-select/src/nullif.rs @@ -113,12 +113,17 @@ pub fn nullif(left: &dyn Array, right: &BooleanArray) -> Result( values: &dyn Array, indices: &PrimitiveArray, ) -> Result { + if let DataType::Extension(_) = values.data_type() { + let extension_array: ExtensionArray = values.to_data().into(); + let storage_result = take_impl(extension_array.storage(), indices)?; + return Ok(Arc::new(extension_array.with_storage(storage_result))); + } + downcast_primitive_array! { values => Ok(Arc::new(take_primitive(values, indices)?)), DataType::Boolean => { @@ -949,7 +955,7 @@ mod tests { use super::*; use arrow_array::builder::*; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; - use arrow_schema::{Field, Fields, TimeUnit, UnionFields}; + use arrow_schema::{extension::TestExtension, Field, Fields, TimeUnit, UnionFields}; fn test_take_decimal_arrays( data: Vec>, @@ -2400,4 +2406,28 @@ mod tests { let array = take(&array, &indicies, None).unwrap(); assert_eq!(array.len(), 3); } + + #[test] + fn test_take_extension() { + let indices = Int32Array::from(vec![1, 3]); + let storage = Arc::new(StringArray::from(vec![ + "one banana", + "two banana", + "three banana", + "four", + ])); + let array = ExtensionArray::new( + Arc::new(TestExtension { + storage_type: DataType::Utf8, + }), + storage.clone(), + ); + let result_ref = take(&array, &indices, None).unwrap(); + assert_eq!(result_ref.data_type(), array.data_type()); + assert_eq!(result_ref.len(), 2); + + let result_array: ExtensionArray = result_ref.to_data().into(); + let expected = create_array!(Utf8, ["two banana", "four"]); + assert_eq!(result_array.storage().to_data(), expected.to_data()); + } } diff --git a/arrow-select/src/zip.rs b/arrow-select/src/zip.rs index 2efd2e749921..747f59baf30d 100644 --- a/arrow-select/src/zip.rs +++ b/arrow-select/src/zip.rs @@ -17,10 +17,12 @@ //! [`zip`]: Combine values from two arrays based on boolean mask +use std::sync::Arc; + use crate::filter::SlicesIterator; use arrow_array::*; use arrow_data::transform::MutableArrayData; -use arrow_schema::ArrowError; +use arrow_schema::{ArrowError, DataType}; /// Zip two arrays by some boolean mask. /// @@ -116,6 +118,16 @@ pub fn zip( )); } + if let DataType::Extension(extension) = truthy.data_type() { + let truthy_extension: ExtensionArray = truthy.to_data().into(); + let falsy_extension: ExtensionArray = falsy.to_data().into(); + let storage_result = zip(mask, truthy_extension.storage(), falsy_extension.storage())?; + return Ok(Arc::new(ExtensionArray::new( + extension.clone(), + storage_result, + ))); + } + let falsy = falsy.to_data(); let truthy = truthy.to_data(); @@ -168,6 +180,10 @@ pub fn zip( #[cfg(test)] mod test { + use std::sync::Arc; + + use arrow_schema::{extension::TestExtension, DataType}; + use super::*; #[test] @@ -279,4 +295,33 @@ mod test { let expected = Int32Array::from(vec![None, None, Some(42), Some(42), None]); assert_eq!(actual, &expected); } + + #[test] + fn test_zip_extension() { + let mask = BooleanArray::from(vec![true, false, true, false]); + let truthy_storage = Arc::new(StringArray::from(vec![ + "one banana", + "two banana", + "three banana", + "four", + ])); + let falsy_storage = Arc::new(StringArray::from(vec![ + "five banana", + "six banana", + "seven banana", + "more", + ])); + let extension = Arc::new(TestExtension { + storage_type: DataType::Utf8, + }); + let truthy = ExtensionArray::new(extension.clone(), truthy_storage.clone()); + let falsy = ExtensionArray::new(extension.clone(), falsy_storage.clone()); + let result_ref = zip(&mask, &truthy, &falsy).unwrap(); + assert_eq!(result_ref.data_type(), truthy.data_type()); + assert_eq!(result_ref.len(), 4); + + let result_array: ExtensionArray = result_ref.to_data().into(); + let expected = create_array!(Utf8, ["one banana", "six banana", "three banana", "more"]); + assert_eq!(result_array.storage().to_data(), expected.to_data()); + } } diff --git a/parquet/src/arrow/arrow_reader/statistics.rs b/parquet/src/arrow/arrow_reader/statistics.rs index 09f8ec7cc274..26c17afad0e6 100644 --- a/parquet/src/arrow/arrow_reader/statistics.rs +++ b/parquet/src/arrow/arrow_reader/statistics.rs @@ -535,7 +535,8 @@ macro_rules! get_statistics { DataType::LargeListView(_) | DataType::Struct(_) | DataType::Union(_, _) | - DataType::RunEndEncoded(_, _) => { + DataType::RunEndEncoded(_, _) | + DataType::Extension(_) => { let len = $iterator.count(); // don't know how to extract statistics, so return a null array Ok(new_null_array($data_type, len)) @@ -1056,7 +1057,8 @@ macro_rules! get_data_page_statistics { DataType::Struct(_) | DataType::Union(_, _) | DataType::Map(_, _) | - DataType::RunEndEncoded(_, _) => { + DataType::RunEndEncoded(_, _) | + DataType::Extension(_) => { let len = $iterator.count(); // don't know how to extract statistics, so return a null array Ok(new_null_array($data_type, len)) diff --git a/parquet/src/arrow/schema/mod.rs b/parquet/src/arrow/schema/mod.rs index 89c42f5eaf92..2519145c9a66 100644 --- a/parquet/src/arrow/schema/mod.rs +++ b/parquet/src/arrow/schema/mod.rs @@ -767,6 +767,12 @@ fn arrow_to_parquet_type(field: &Field, coerce_types: bool) -> Result { DataType::RunEndEncoded(_, _) => Err(arrow_err!( "Converting RunEndEncodedType to parquet not supported", )), + DataType::Extension(extension) => arrow_to_parquet_type( + &field + .clone() + .with_data_type(extension.storage_type().clone()), + coerce_types, + ), } }