diff --git a/arrow-select/src/concat.rs b/arrow-select/src/concat.rs index b4899847844..a7a6d072844 100644 --- a/arrow-select/src/concat.rs +++ b/arrow-select/src/concat.rs @@ -30,8 +30,12 @@ //! assert_eq!(arr.len(), 3); //! ``` -use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values}; -use arrow_array::builder::{BooleanBuilder, GenericByteBuilder, PrimitiveBuilder}; +use crate::dictionary::{ + merge_dictionary_values, should_merge_dictionary_values, ShouldMergeValues, +}; +use arrow_array::builder::{ + BooleanBuilder, GenericByteBuilder, PrimitiveBuilder, PrimitiveDictionaryBuilder, +}; use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; @@ -84,6 +88,7 @@ fn fixed_size_list_capacity(arrays: &[&dyn Array], data_type: &DataType) -> Capa } fn concat_dictionaries( + value_type: &DataType, arrays: &[&dyn Array], ) -> Result { let mut output_len = 0; @@ -93,11 +98,41 @@ fn concat_dictionaries( .inspect(|d| output_len += d.len()) .collect(); - if !should_merge_dictionary_values::(&dictionaries, output_len) { - return concat_fallback(arrays, Capacities::Array(output_len)); + let is_overflow = match should_merge_dictionary_values::(&dictionaries, output_len) { + ShouldMergeValues::ConcatWillOverflow => true, + ShouldMergeValues::Yes => false, + ShouldMergeValues::No => { + return concat_fallback(arrays, Capacities::Array(output_len)); + } + }; + + macro_rules! primitive_dict_helper { + ($t:ty) => { + merge_concat_primitive_dictionaries::(&dictionaries, output_len) + }; } - let merged = merge_dictionary_values(&dictionaries, None)?; + downcast_primitive! { + value_type => (primitive_dict_helper), + DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => { + merge_concat_byte_dictionaries(&dictionaries, output_len) + }, + // merge not yet implemented for this type and it's not going to overflow, so fall back + // to concatenating values + _ if !is_overflow => concat_fallback(arrays, Capacities::Array(output_len)), + other => Err(ArrowError::NotYetImplemented(format!( + "concat of dictionaries would overflow key type {key_type:?} and \ + value type {other:?} not yet supported for merging", + key_type = K::DATA_TYPE, + ))) + } +} + +fn merge_concat_byte_dictionaries( + dictionaries: &[&DictionaryArray], + output_len: usize, +) -> Result { + let merged = merge_dictionary_values(dictionaries, None)?; // Recompute keys let mut key_values = Vec::with_capacity(output_len); @@ -113,7 +148,7 @@ fn concat_dictionaries( let nulls = has_nulls.then(|| { let mut nulls = BooleanBufferBuilder::new(output_len); - for d in &dictionaries { + for d in dictionaries { match d.nulls() { Some(n) => nulls.append_buffer(n.inner()), None => nulls.append_n(d.len(), true), @@ -130,6 +165,19 @@ fn concat_dictionaries( Ok(Arc::new(array)) } +fn merge_concat_primitive_dictionaries( + dictionaries: &[&DictionaryArray], + output_len: usize, +) -> Result { + let mut builder = PrimitiveDictionaryBuilder::::with_capacity(output_len, 0); + for dict in dictionaries { + for value in dict.downcast_dict::>().unwrap() { + builder.append_option(value); + } + } + Ok(Arc::new(builder.finish())) +} + fn concat_lists( arrays: &[&dyn Array], field: &FieldRef, @@ -231,8 +279,8 @@ fn concat_bytes(arrays: &[&dyn Array]) -> Result { - return Ok(Arc::new(concat_dictionaries::<$t>($arrays)?) as _) + ($t:ty, $value_type:expr, $arrays:expr) => { + concat_dictionaries::<$t>($value_type.as_ref(), $arrays) }; } @@ -300,9 +348,9 @@ pub fn concat(arrays: &[&dyn Array]) -> Result { downcast_primitive! { d => (primitive_concat, arrays), DataType::Boolean => concat_boolean(arrays), - DataType::Dictionary(k, _) => { + DataType::Dictionary(k, v) => { downcast_integer! { - k.as_ref() => (dict_helper, arrays), + k.as_ref() => (dict_helper, v, arrays), _ => unreachable!("illegal dictionary key type {k}") } } @@ -938,6 +986,69 @@ mod tests { assert!((30..40).contains(&values_len), "{values_len}") } + #[test] + fn test_concat_dictionary_overflows() { + // each array has length equal to the full dictionary key space + let len: usize = usize::try_from(i8::MAX).unwrap(); + + let a = DictionaryArray::::new( + Int8Array::from_value(0, len), + Arc::new(Int8Array::from_value(0, len)), + ); + let b = DictionaryArray::::new( + Int8Array::from_value(0, len), + Arc::new(Int8Array::from_value(1, len)), + ); + + // Case 1: with a single input array, should _never_ overflow + let values = concat(&[&a]).unwrap(); + let v = values.as_dictionary::(); + let vc = v.downcast_dict::().unwrap(); + let collected: Vec<_> = vc.into_iter().map(Option::unwrap).collect(); + assert_eq!(&collected, &vec![0; len]); + + // Case 2: two arrays + // Should still not overflow, there are only two values + let values = concat(&[&a, &b]).unwrap(); + let v = values.as_dictionary::(); + let vc = v.downcast_dict::().unwrap(); + let collected: Vec<_> = vc.into_iter().map(Option::unwrap).collect(); + assert_eq!( + &collected, + &vec![0; len] + .into_iter() + .chain(vec![1; len]) + .collect::>() + ); + } + + #[test] + fn test_unsupported_concat_dictionary_overflow() { + // each array has length equal to the full dictionary key space + let len: usize = usize::try_from(i8::MAX).unwrap(); + + let a = DictionaryArray::::new( + Int8Array::from_value(0, len), + Arc::new(NullArray::new(len)), + ); + let b = DictionaryArray::::new( + Int8Array::from_value(0, len), + Arc::new(NullArray::new(len)), + ); + + // Case 1: with a single input array, should _never_ overflow + concat(&[&a]).unwrap(); + + // Case 2: two arrays + // Will fail to merge values on unsupported datatype + let values = concat(&[&a, &b]).unwrap_err(); + assert_eq!( + values.to_string(), + "Not yet implemented: concat of dictionaries would overflow key type Int8 and \ + value type Null not yet supported for merging" + ); + } + #[test] fn test_concat_string_sizes() { let a: LargeStringArray = ((0..150).map(|_| Some("foo"))).collect(); diff --git a/arrow-select/src/dictionary.rs b/arrow-select/src/dictionary.rs index 57aed644fe0..bf1b0809c63 100644 --- a/arrow-select/src/dictionary.rs +++ b/arrow-select/src/dictionary.rs @@ -101,8 +101,19 @@ fn bytes_ptr_eq(a: &dyn Array, b: &dyn Array) -> bool { } } +/// Whether selection kernels should attempt to merge dictionary values +pub enum ShouldMergeValues { + /// Concatenation of the dictionary values will lead to overflowing + /// the key space; it's necessary to attempt to merge + ConcatWillOverflow, + /// The heuristic suggests that merging will be beneficial + Yes, + /// The heuristic suggests that merging is not necessary + No, +} + /// A type-erased function that compares two array for pointer equality -type PtrEq = dyn Fn(&dyn Array, &dyn Array) -> bool; +type PtrEq = fn(&dyn Array, &dyn Array) -> bool; /// A weak heuristic of whether to merge dictionary values that aims to only /// perform the expensive merge computation when it is likely to yield at least @@ -112,15 +123,15 @@ type PtrEq = dyn Fn(&dyn Array, &dyn Array) -> bool; pub fn should_merge_dictionary_values( dictionaries: &[&DictionaryArray], len: usize, -) -> bool { +) -> ShouldMergeValues { use DataType::*; let first_values = dictionaries[0].values().as_ref(); - let ptr_eq: Box = match first_values.data_type() { - Utf8 => Box::new(bytes_ptr_eq::), - LargeUtf8 => Box::new(bytes_ptr_eq::), - Binary => Box::new(bytes_ptr_eq::), - LargeBinary => Box::new(bytes_ptr_eq::), - _ => return false, + let ptr_eq: PtrEq = match first_values.data_type() { + Utf8 => bytes_ptr_eq::, + LargeUtf8 => bytes_ptr_eq::, + Binary => bytes_ptr_eq::, + LargeBinary => bytes_ptr_eq::, + _ => |_, _| false, }; let mut single_dictionary = true; @@ -136,7 +147,15 @@ pub fn should_merge_dictionary_values( let overflow = K::Native::from_usize(total_values).is_none(); let values_exceed_length = total_values >= len; - !single_dictionary && (overflow || values_exceed_length) + if single_dictionary { + ShouldMergeValues::No + } else if overflow { + ShouldMergeValues::ConcatWillOverflow + } else if values_exceed_length { + ShouldMergeValues::Yes + } else { + ShouldMergeValues::No + } } /// Given an array of dictionaries and an optional key mask compute a values array diff --git a/arrow-select/src/interleave.rs b/arrow-select/src/interleave.rs index b09de13fee6..4d1eaa0b15e 100644 --- a/arrow-select/src/interleave.rs +++ b/arrow-select/src/interleave.rs @@ -17,8 +17,12 @@ //! Interleave elements from multiple arrays -use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values}; -use arrow_array::builder::{BooleanBufferBuilder, BufferBuilder, PrimitiveBuilder}; +use crate::dictionary::{ + merge_dictionary_values, should_merge_dictionary_values, ShouldMergeValues, +}; +use arrow_array::builder::{ + BooleanBufferBuilder, BufferBuilder, PrimitiveBuilder, PrimitiveDictionaryBuilder, +}; use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; @@ -36,8 +40,8 @@ macro_rules! primitive_helper { } macro_rules! dict_helper { - ($t:ty, $values:expr, $indices:expr) => { - Ok(Arc::new(interleave_dictionaries::<$t>($values, $indices)?) as _) + ($t:ty, $value_type:expr, $values:expr, $indices:expr) => { + interleave_dictionaries::<$t>($value_type.as_ref(), $values, $indices) }; } @@ -101,8 +105,8 @@ pub fn interleave( DataType::LargeBinary => interleave_bytes::(values, indices), DataType::BinaryView => interleave_views::(values, indices), DataType::Utf8View => interleave_views::(values, indices), - DataType::Dictionary(k, _) => downcast_integer! { - k.as_ref() => (dict_helper, values, indices), + DataType::Dictionary(k, v) => downcast_integer! { + k.as_ref() => (dict_helper, v, values, indices), _ => unreachable!("illegal dictionary key type {k}") }, _ => interleave_fallback(values, indices) @@ -191,14 +195,45 @@ fn interleave_bytes( } fn interleave_dictionaries( + value_type: &DataType, arrays: &[&dyn Array], indices: &[(usize, usize)], ) -> Result { let dictionaries: Vec<_> = arrays.iter().map(|x| x.as_dictionary::()).collect(); - if !should_merge_dictionary_values::(&dictionaries, indices.len()) { - return interleave_fallback(arrays, indices); + let is_overflow = match should_merge_dictionary_values::(&dictionaries, indices.len()) { + ShouldMergeValues::ConcatWillOverflow => true, + ShouldMergeValues::Yes => false, + ShouldMergeValues::No => { + return interleave_fallback(arrays, indices); + } + }; + + macro_rules! primitive_dict_helper { + ($t:ty) => { + merge_interleave_primitive_dictionaries::(&dictionaries, indices) + }; } + downcast_primitive! { + value_type => (primitive_dict_helper), + DataType::Utf8 | DataType::LargeUtf8 | DataType::Binary | DataType::LargeBinary => { + merge_interleave_byte_dictionaries(&dictionaries, indices) + }, + // merge not yet implemented for this type and it's not going to overflow, so fall back + // to concatenating values + _ if !is_overflow => interleave_fallback(arrays, indices), + other => Err(ArrowError::NotYetImplemented(format!( + "interleave of dictionaries would overflow key type {key_type:?} and \ + value type {other:?} not yet supported for merging", + key_type = K::DATA_TYPE, + ))) + } +} + +fn merge_interleave_byte_dictionaries( + dictionaries: &[&DictionaryArray], + indices: &[(usize, usize)], +) -> Result { let masks: Vec<_> = dictionaries .iter() .enumerate() @@ -215,7 +250,7 @@ fn interleave_dictionaries( }) .collect(); - let merged = merge_dictionary_values(&dictionaries, Some(&masks))?; + let merged = merge_dictionary_values(dictionaries, Some(&masks))?; // Recompute keys let mut keys = PrimitiveBuilder::::with_capacity(indices.len()); @@ -233,6 +268,26 @@ fn interleave_dictionaries( Ok(Arc::new(array)) } +fn merge_interleave_primitive_dictionaries( + dictionaries: &[&DictionaryArray], + indices: &[(usize, usize)], +) -> Result { + let dict_accessors: Vec<_> = dictionaries + .iter() + .map(|d| d.downcast_dict::>().unwrap()) + .collect(); + let mut builder = PrimitiveDictionaryBuilder::::with_capacity(indices.len(), 0); + for (a, b) in indices { + let dict = dict_accessors[*a]; + if dict.is_valid(*b) { + builder.append_value(dict.value(*b)); + } else { + builder.append_null(); + } + } + Ok(Arc::new(builder.finish())) +} + fn interleave_views( values: &[&dyn Array], indices: &[(usize, usize)], @@ -463,6 +518,63 @@ mod tests { assert_eq!(actual, expected); } + #[test] + fn test_interleave_dictionary_overflows() { + // each array has length equal to the full dictionary key space + let len: usize = usize::try_from(i8::MAX).unwrap(); + + let a = DictionaryArray::::new( + Int8Array::from_value(0, len), + Arc::new(Int8Array::from_value(0, len)), + ); + let b = DictionaryArray::::new( + Int8Array::from_value(0, len), + Arc::new(Int8Array::from_value(1, len)), + ); + + // Case 1: with a single input array, should _never_ overflow + let values = interleave(&[&a], &[(0, 2), (0, 2)]).unwrap(); + let v = values.as_dictionary::(); + let vc = v.downcast_dict::().unwrap(); + let collected: Vec<_> = vc.into_iter().map(Option::unwrap).collect(); + assert_eq!(&collected, &[0, 0]); + + // Case 2: two arrays + // Should still not overflow, there are only two values + let values = interleave(&[&a, &b], &[(0, 2), (0, 2), (1, 1)]).unwrap(); + let v = values.as_dictionary::(); + let vc = v.downcast_dict::().unwrap(); + let collected: Vec<_> = vc.into_iter().map(Option::unwrap).collect(); + assert_eq!(&collected, &[0, 0, 1]); + } + + #[test] + fn test_unsupported_interleave_dictionary_overflow() { + // each array has length equal to the full dictionary key space + let len: usize = usize::try_from(i8::MAX).unwrap(); + + let a = DictionaryArray::::new( + Int8Array::from_value(0, len), + Arc::new(NullArray::new(len)), + ); + let b = DictionaryArray::::new( + Int8Array::from_value(0, len), + Arc::new(NullArray::new(len)), + ); + + // Case 1: with a single input array, should _never_ overflow + interleave(&[&a], &[(0, 2), (0, 2)]).unwrap(); + + // Case 2: two arrays + // Will fail to merge values on unsupported datatype + let values = interleave(&[&a, &b], &[(0, 2), (0, 2), (1, 1)]).unwrap_err(); + assert_eq!( + values.to_string(), + "Not yet implemented: qinterleave of dictionaries would overflow key type Int8 and \ + value type Null not yet supported for merging" + ); + } + #[test] fn test_lists() { // [[1, 2], null, [3]]