diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index df03b85ff186..4feb8c27b938 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -24,9 +24,10 @@ use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::{ - bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, ScalarBuffer, + bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, OffsetBuffer, + ScalarBuffer, }; -use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_data::ArrayDataBuilder; use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode}; use num::{One, Zero}; @@ -465,67 +466,79 @@ fn take_bytes( array: &GenericByteArray, indices: &PrimitiveArray, ) -> Result, ArrowError> { - let data_len = indices.len(); - - let bytes_offset = (data_len + 1) * std::mem::size_of::(); - let mut offsets = MutableBuffer::new(bytes_offset); + let mut offsets = Vec::with_capacity(indices.len() + 1); offsets.push(T::Offset::default()); - let mut values = MutableBuffer::new(0); + let input_offsets = array.value_offsets(); + let mut capacity = 0; + let nulls = take_nulls(array.nulls(), indices); - let nulls; - if array.null_count() == 0 && indices.null_count() == 0 { + let (offsets, values) = if array.null_count() == 0 && indices.null_count() == 0 { offsets.extend(indices.values().iter().map(|index| { - let s: &[u8] = array.value(index.as_usize()).as_ref(); - values.extend_from_slice(s); - T::Offset::usize_as(values.len()) + let index = index.as_usize(); + capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize(); + T::Offset::from_usize(capacity).expect("overflow") })); - nulls = None - } else if indices.null_count() == 0 { - let num_bytes = bit_util::ceil(data_len, 8); + let mut values = Vec::with_capacity(capacity); - let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); - let null_slice = null_buf.as_slice_mut(); - offsets.extend(indices.values().iter().enumerate().map(|(i, index)| { + for index in indices.values() { + values.extend_from_slice(array.value(index.as_usize()).as_ref()); + } + (offsets, values) + } else if indices.null_count() == 0 { + offsets.extend(indices.values().iter().map(|index| { let index = index.as_usize(); if array.is_valid(index) { - let s: &[u8] = array.value(index).as_ref(); - values.extend_from_slice(s.as_ref()); - } else { - bit_util::unset_bit(null_slice, i); + capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize(); } - T::Offset::usize_as(values.len()) + T::Offset::from_usize(capacity).expect("overflow") })); - nulls = Some(null_buf.into()); + let mut values = Vec::with_capacity(capacity); + + for index in indices.values() { + let index = index.as_usize(); + if array.is_valid(index) { + values.extend_from_slice(array.value(index).as_ref()); + } + } + (offsets, values) } else if array.null_count() == 0 { offsets.extend(indices.values().iter().enumerate().map(|(i, index)| { + let index = index.as_usize(); if indices.is_valid(i) { - let s: &[u8] = array.value(index.as_usize()).as_ref(); - values.extend_from_slice(s); + capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize(); } - T::Offset::usize_as(values.len()) + T::Offset::from_usize(capacity).expect("overflow") })); - nulls = indices.nulls().map(|b| b.inner().sliced()); - } else { - let num_bytes = bit_util::ceil(data_len, 8); + let mut values = Vec::with_capacity(capacity); - let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); - let null_slice = null_buf.as_slice_mut(); + for (i, index) in indices.values().iter().enumerate() { + if indices.is_valid(i) { + values.extend_from_slice(array.value(index.as_usize()).as_ref()); + } + } + (offsets, values) + } else { + let nulls = nulls.as_ref().unwrap(); offsets.extend(indices.values().iter().enumerate().map(|(i, index)| { + let index = index.as_usize(); + if nulls.is_valid(i) { + capacity += input_offsets[index + 1].as_usize() - input_offsets[index].as_usize(); + } + T::Offset::from_usize(capacity).expect("overflow") + })); + let mut values = Vec::with_capacity(capacity); + + for (i, index) in indices.values().iter().enumerate() { // check index is valid before using index. The value in // NULL index slots may not be within bounds of array let index = index.as_usize(); - if indices.is_valid(i) && array.is_valid(index) { - let s: &[u8] = array.value(index).as_ref(); - values.extend_from_slice(s); - } else { - // set null bit - bit_util::unset_bit(null_slice, i); + if nulls.is_valid(i) { + values.extend_from_slice(array.value(index).as_ref()); } - T::Offset::usize_as(values.len()) - })); - nulls = Some(null_buf.into()) - } + } + (offsets, values) + }; T::Offset::from_usize(values.len()).ok_or(ArrowError::ComputeError(format!( "Offset overflow for {}BinaryArray: {}", @@ -533,15 +546,12 @@ fn take_bytes( values.len() )))?; - let array_data = ArrayData::builder(T::DATA_TYPE) - .len(data_len) - .add_buffer(offsets.into()) - .add_buffer(values.into()) - .null_bit_buffer(nulls); - - let array_data = unsafe { array_data.build_unchecked() }; + let array = unsafe { + let offsets = OffsetBuffer::new_unchecked(offsets.into()); + GenericByteArray::::new_unchecked(offsets, values.into(), nulls) + }; - Ok(GenericByteArray::from(array_data)) + Ok(array) } /// `take` implementation for byte view arrays @@ -949,6 +959,7 @@ mod tests { use super::*; use arrow_array::builder::*; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; + use arrow_data::ArrayData; use arrow_schema::{Field, Fields, TimeUnit, UnionFields}; fn test_take_decimal_arrays(