diff --git a/arrow-array/src/builder/generic_bytes_builder.rs b/arrow-array/src/builder/generic_bytes_builder.rs index e2be96615b61..e07031456590 100644 --- a/arrow-array/src/builder/generic_bytes_builder.rs +++ b/arrow-array/src/builder/generic_bytes_builder.rs @@ -17,13 +17,17 @@ use crate::builder::{ArrayBuilder, BufferBuilder, UInt8BufferBuilder}; use crate::types::{ByteArrayType, GenericBinaryType, GenericStringType}; -use crate::{ArrayRef, GenericByteArray, OffsetSizeTrait}; +use crate::{ + Array, ArrayRef, ArrowPrimitiveType, GenericByteArray, OffsetSizeTrait, PrimitiveArray, +}; use arrow_buffer::NullBufferBuilder; use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer}; use arrow_data::ArrayDataBuilder; use std::any::Any; use std::sync::Arc; +use super::take_in_utils; + /// Builder for [`GenericByteArray`] /// /// For building strings, see docs on [`GenericStringBuilder`]. @@ -129,6 +133,56 @@ impl GenericByteBuilder { self.offsets_builder.append(self.next_offset()); } + /// Take values at indices from array into the builder. + pub fn take_in(&mut self, array: &GenericByteArray, indices: &PrimitiveArray) + where + I: ArrowPrimitiveType, + { + take_in_utils::take_in_nulls(&mut self.null_buffer_builder, array.nulls(), &indices); + + let indices_has_nulls = indices.null_count() > 0; + let array_has_nulls = indices.null_count() > 0; + + match (indices_has_nulls, array_has_nulls) { + (true, true) => { + for idx in indices.iter() { + match idx { + Some(idx) => { + if array.is_valid(idx.as_usize()) { + self.append_value(&array.value(idx.as_usize())) + } else { + self.append_null(); + } + } + None => self.append_null(), + }; + } + } + (true, false) => { + for idx in indices.iter() { + match idx { + Some(idx) => self.append_value(&array.value(idx.as_usize())), + None => self.append_null(), + }; + } + } + (false, true) => { + for idx in indices.values().iter() { + if array.is_valid(idx.as_usize()) { + self.append_value(&array.value(idx.as_usize())); + } else { + self.append_null(); + } + } + } + (false, false) => { + for idx in indices.values().iter() { + self.append_value(&array.value(idx.as_usize())); + } + } + } + } + /// Builds the [`GenericByteArray`] and reset this builder. pub fn finish(&mut self) -> GenericByteArray { let array_type = T::DATA_TYPE; diff --git a/arrow-array/src/builder/generic_bytes_view_builder.rs b/arrow-array/src/builder/generic_bytes_view_builder.rs index 7268e751b149..e2add136ffc8 100644 --- a/arrow-array/src/builder/generic_bytes_view_builder.rs +++ b/arrow-array/src/builder/generic_bytes_view_builder.rs @@ -19,7 +19,7 @@ use std::any::Any; use std::marker::PhantomData; use std::sync::Arc; -use arrow_buffer::{Buffer, BufferBuilder, NullBufferBuilder, ScalarBuffer}; +use arrow_buffer::{ArrowNativeType, Buffer, BufferBuilder, NullBufferBuilder, ScalarBuffer}; use arrow_data::ByteView; use arrow_schema::ArrowError; use hashbrown::hash_table::Entry; @@ -28,7 +28,9 @@ use hashbrown::HashTable; use crate::builder::ArrayBuilder; use crate::types::bytes::ByteArrayNativeType; use crate::types::{BinaryViewType, ByteViewType, StringViewType}; -use crate::{ArrayRef, GenericByteViewArray}; +use crate::{Array, ArrayRef, ArrowPrimitiveType, GenericByteViewArray, PrimitiveArray}; + +use super::take_in_utils; const STARTING_BLOCK_SIZE: u32 = 8 * 1024; // 8KiB const MAX_BLOCK_SIZE: u32 = 2 * 1024 * 1024; // 2MiB @@ -361,6 +363,76 @@ impl GenericByteViewBuilder { self.views_builder.append(0); } + /// Take values at indices from array into the builder. + pub fn take_in(&mut self, array: &GenericByteViewArray, indices: &PrimitiveArray) + where + I: ArrowPrimitiveType, + { + // lots of todos: + // - Maybe: Decide on GC based on indices to array length ratio + // - Maybe: GC only fully-empty buffers + // - Maybe: Use take_in_nulls / take_in_native in GC path + + let capacity_needed: u32 = indices + .iter() + .map(|idx| match idx { + Some(idx) => { + let len = if array.is_valid(idx.as_usize()) { + let view = array.views().get(idx.as_usize()).unwrap(); + *view as u32 + } else { + 0 + }; + if len <= 12 { + 0 + } else { + len + } + } + None => 0, + }) + .sum(); + + if array.get_buffer_memory_size() > 2 * capacity_needed as usize { + for index in indices { + match index { + Some(index) => { + if array.is_valid(index.as_usize()) { + self.append_value(array.value(index.as_usize())); + } else { + self.append_null(); + } + } + None => { + self.append_null(); + } + } + } + } else { + self.flush_in_progress(); + + let start = self.len(); + let start_buffers = self.completed.len(); + + take_in_utils::take_in_nulls(&mut self.null_buffer_builder, array.nulls(), indices); + + take_in_utils::take_in_native::( + &mut self.views_builder, + &array.views(), + indices, + ); + + self.completed + .extend(array.data_buffers().into_iter().cloned()); + for i in start..self.len() { + let mut view = ByteView::from(self.views_builder.as_slice_mut()[i]); + if view.length > 12 { + view.buffer_index += start_buffers as u32; // overflow check needed? + self.views_builder.as_slice_mut()[i] = view.as_u128(); + } + } + } + } /// Builds the [`GenericByteViewArray`] and reset this builder pub fn finish(&mut self) -> GenericByteViewArray { self.flush_in_progress(); diff --git a/arrow-array/src/builder/mod.rs b/arrow-array/src/builder/mod.rs index e859f3794ad4..31a383b475cb 100644 --- a/arrow-array/src/builder/mod.rs +++ b/arrow-array/src/builder/mod.rs @@ -232,10 +232,15 @@ //! assert_eq!(buffer.iter().collect::>(), vec![true, true, true, true, true, true, true, false]); //! ``` +mod take_in_utils; + pub use arrow_buffer::BooleanBufferBuilder; pub use arrow_buffer::NullBufferBuilder; mod boolean_builder; +use arrow_schema::DataType; +use arrow_schema::IntervalUnit; +use arrow_schema::TimeUnit; pub use boolean_builder::*; mod buffer_builder; pub use buffer_builder::*; @@ -271,6 +276,8 @@ mod union_builder; pub use union_builder::*; +use crate::types::BinaryViewType; +use crate::types::StringViewType; use crate::ArrayRef; use std::any::Any; @@ -416,3 +423,109 @@ pub type StringBuilder = GenericStringBuilder; /// /// See examples on [`GenericStringBuilder`] pub type LargeStringBuilder = GenericStringBuilder; + +/// Create an empty builder for `data_type` with capacity `capacity` +pub fn new_empty_builder(data_type: &DataType, capacity: usize) -> Box { + match data_type { + DataType::Null => todo!(), + DataType::Boolean => todo!(), + DataType::Int8 => Box::new(Int8Builder::with_capacity(capacity)) as _, + DataType::Int16 => Box::new(Int16Builder::with_capacity(capacity)) as _, + DataType::Int32 => Box::new(Int32Builder::with_capacity(capacity)) as _, + DataType::Int64 => Box::new(Int64Builder::with_capacity(capacity)) as _, + DataType::UInt8 => Box::new(UInt8Builder::with_capacity(capacity)) as _, + DataType::UInt16 => Box::new(UInt16Builder::with_capacity(capacity)) as _, + DataType::UInt32 => Box::new(UInt32Builder::with_capacity(capacity)) as _, + DataType::UInt64 => Box::new(UInt64Builder::with_capacity(capacity)) as _, + DataType::Float16 => Box::new(Float16Builder::with_capacity(capacity)) as _, + DataType::Float32 => Box::new(Float32Builder::with_capacity(capacity)) as _, + DataType::Float64 => Box::new(Float64Builder::with_capacity(capacity)) as _, + DataType::Timestamp(TimeUnit::Second, _) => { + Box::new(TimestampSecondBuilder::with_capacity(capacity)) as _ + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + Box::new(TimestampMillisecondBuilder::with_capacity(capacity)) as _ + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + Box::new(TimestampMicrosecondBuilder::with_capacity(capacity)) as _ + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + Box::new(TimestampNanosecondBuilder::with_capacity(capacity)) as _ + } + DataType::Date32 => Box::new(Date32Builder::with_capacity(capacity)) as _, + DataType::Date64 => Box::new(Date64Builder::with_capacity(capacity)) as _, + DataType::Time32(TimeUnit::Second) => { + Box::new(Time32SecondBuilder::with_capacity(capacity)) as _ + } + DataType::Time32(TimeUnit::Millisecond) => { + Box::new(Time32MillisecondBuilder::with_capacity(capacity)) as _ + } + DataType::Time64(TimeUnit::Microsecond) => { + Box::new(Time64MicrosecondBuilder::with_capacity(capacity)) as _ + } + DataType::Time64(TimeUnit::Nanosecond) => { + Box::new(Time64NanosecondBuilder::with_capacity(capacity)) as _ + } + DataType::Duration(TimeUnit::Second) => { + Box::new(DurationSecondBuilder::with_capacity(capacity)) as _ + } + DataType::Duration(TimeUnit::Millisecond) => { + Box::new(DurationMillisecondBuilder::with_capacity(capacity)) as _ + } + DataType::Duration(TimeUnit::Microsecond) => { + Box::new(DurationMicrosecondBuilder::with_capacity(capacity)) as _ + } + DataType::Duration(TimeUnit::Nanosecond) => { + Box::new(DurationNanosecondBuilder::with_capacity(capacity)) as _ + } + DataType::Interval(IntervalUnit::YearMonth) => { + Box::new(IntervalYearMonthBuilder::with_capacity(capacity)) as _ + } + DataType::Interval(IntervalUnit::DayTime) => { + Box::new(IntervalDayTimeBuilder::with_capacity(capacity)) as _ + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + Box::new(IntervalMonthDayNanoBuilder::with_capacity(capacity)) as _ + } + DataType::Binary => Box::new(GenericBinaryBuilder::::with_capacity( + capacity, capacity, + )) as _, + DataType::FixedSizeBinary(_) => todo!(), + DataType::LargeBinary => Box::new(GenericBinaryBuilder::::with_capacity( + capacity, capacity, + )) as _, + DataType::BinaryView => Box::new(GenericByteViewBuilder::::with_capacity( + capacity, + )) as _, + DataType::Utf8 => Box::new(GenericStringBuilder::::with_capacity( + capacity, capacity, + )) as _, + DataType::LargeUtf8 => Box::new(GenericStringBuilder::::with_capacity( + capacity, capacity, + )) as _, + DataType::Utf8View => Box::new(GenericByteViewBuilder::::with_capacity( + capacity, + )) as _, + DataType::List(_field) => todo!(), + DataType::ListView(_field) => todo!(), + DataType::FixedSizeList(_field, _) => todo!(), + DataType::LargeList(_field) => todo!(), + DataType::LargeListView(_field) => todo!(), + DataType::Struct(_fields) => todo!(), + DataType::Union(_union_fields, _union_mode) => todo!(), + DataType::Dictionary(_data_type, _data_type1) => todo!(), + DataType::Decimal128(precision, scale) => Box::new( + Decimal128Builder::with_capacity(capacity) + .with_precision_and_scale(*precision, *scale) + .expect("Invalid precision / scale for Decimal128"), + ) as _, + DataType::Decimal256(precision, scale) => Box::new( + Decimal256Builder::with_capacity(capacity) + .with_precision_and_scale(*precision, *scale) + .expect("Invalid precision / scale for Decimal256"), + ) as _, + DataType::Map(_field, _) => todo!(), + DataType::RunEndEncoded(_field, _field1) => todo!(), + dt => panic!("Unexpected data type {dt:?}"), + } +} diff --git a/arrow-array/src/builder/primitive_builder.rs b/arrow-array/src/builder/primitive_builder.rs index 3191fea6e407..3189957ea29a 100644 --- a/arrow-array/src/builder/primitive_builder.rs +++ b/arrow-array/src/builder/primitive_builder.rs @@ -16,7 +16,7 @@ // under the License. use crate::builder::{ArrayBuilder, BufferBuilder}; -use crate::types::*; +use crate::{types::*, Array}; use crate::{ArrayRef, PrimitiveArray}; use arrow_buffer::NullBufferBuilder; use arrow_buffer::{Buffer, MutableBuffer}; @@ -25,6 +25,8 @@ use arrow_schema::{ArrowError, DataType}; use std::any::Any; use std::sync::Arc; +use super::take_in_utils; + /// A signed 8-bit integer array builder. pub type Int8Builder = PrimitiveBuilder; /// A signed 16-bit integer array builder. @@ -272,6 +274,23 @@ impl PrimitiveBuilder { self.values_builder.append_trusted_len_iter(iter); } + /// Takes values at indices from array into this array + /// + /// Each index in indices needs to be in-bounds for array or null - Otherwise an incorrect + /// result may be returned + #[inline] + pub fn take_in(&mut self, array: &PrimitiveArray, indices: &PrimitiveArray) + where + I: ArrowPrimitiveType, + { + take_in_utils::take_in_nulls(&mut self.null_buffer_builder, array.nulls(), &indices); + + take_in_utils::take_in_native::( + &mut self.values_builder, + &array.values(), + &indices, + ); + } /// Builds the [`PrimitiveArray`] and reset this builder. pub fn finish(&mut self) -> PrimitiveArray { let len = self.len(); diff --git a/arrow-array/src/builder/take_in_utils.rs b/arrow-array/src/builder/take_in_utils.rs new file mode 100644 index 000000000000..5c438f6dec5a --- /dev/null +++ b/arrow-array/src/builder/take_in_utils.rs @@ -0,0 +1,54 @@ +use arrow_buffer::{ArrowNativeType, BufferBuilder, NullBuffer, NullBufferBuilder, ScalarBuffer}; + +use crate::{Array, ArrowPrimitiveType, PrimitiveArray}; + +pub(crate) fn take_in_nulls( + null_buffer_builder: &mut NullBufferBuilder, + array_nulls: Option<&NullBuffer>, + indices: &PrimitiveArray, +) where + I: ArrowPrimitiveType, +{ + let array_nulls = array_nulls.filter(|n| n.null_count() > 0); + let indices_nulls = indices.nulls().filter(|n| n.null_count() > 0); + + match (array_nulls, indices_nulls) { + (None, None) => null_buffer_builder.append_n_non_nulls(indices.len()), + (None, Some(indices_nulls)) => null_buffer_builder.append_buffer(&indices_nulls), + (Some(array_nulls), None) => { + let iter = indices + .values() + .iter() + .map(|idx| array_nulls.is_valid(idx.as_usize())); + null_buffer_builder.append_iter(iter); + } + (Some(array_nulls), Some(_indices_nulls)) => { + let iter = indices.iter().map(|idx| { + idx.map(|idx| array_nulls.is_valid(idx.as_usize())) + .unwrap_or(false) + }); + null_buffer_builder.append_iter(iter); + } + } +} + +pub(crate) fn take_in_native( + values_builder: &mut BufferBuilder, + array: &ScalarBuffer, + indices: &PrimitiveArray, +) where + T: ArrowNativeType, + I: ArrowPrimitiveType, +{ + values_builder.reserve(indices.len()); + if indices.null_count() > 0 { + values_builder.extend( + indices + .values() + .iter() + .map(|index| array.get(index.as_usize()).cloned().unwrap_or(T::default())), + ) + } else { + values_builder.extend(indices.values().iter().map(|index| array[index.as_usize()])); + } +} diff --git a/arrow-array/src/iterator.rs b/arrow-array/src/iterator.rs index 6708da3d5dd6..eae0a4bf9033 100644 --- a/arrow-array/src/iterator.rs +++ b/arrow-array/src/iterator.rs @@ -44,7 +44,7 @@ use arrow_buffer::NullBuffer; /// [`PrimitiveArray`]: crate::PrimitiveArray /// [`compute::unary`]: https://docs.rs/arrow/latest/arrow/compute/fn.unary.html /// [`compute::try_unary`]: https://docs.rs/arrow/latest/arrow/compute/fn.try_unary.html -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct ArrayIter { array: T, logical_nulls: Option, diff --git a/arrow-array/src/lib.rs b/arrow-array/src/lib.rs index 0fc9d30ab6e3..8661db7c741d 100644 --- a/arrow-array/src/lib.rs +++ b/arrow-array/src/lib.rs @@ -232,6 +232,9 @@ pub use record_batch::{ RecordBatch, RecordBatchIterator, RecordBatchOptions, RecordBatchReader, RecordBatchWriter, }; +mod record_batch_builder; +pub use record_batch_builder::RecordBatchBuilder; + mod arithmetic; pub use arithmetic::ArrowNativeTypeOp; diff --git a/arrow-array/src/record_batch_builder.rs b/arrow-array/src/record_batch_builder.rs new file mode 100644 index 000000000000..7a4ba723eeb0 --- /dev/null +++ b/arrow-array/src/record_batch_builder.rs @@ -0,0 +1,54 @@ +use arrow_schema::{ArrowError, SchemaRef}; + +use crate::{ + builder::{new_empty_builder, ArrayBuilder}, + RecordBatch, +}; + +/// Builder for an entire reocrd batch with a schema +pub struct RecordBatchBuilder { + schema: SchemaRef, + builders: Vec>, +} + +impl RecordBatchBuilder { + /// Creates a RecordBatchBuilder with the given schema. + pub fn new(schema: SchemaRef) -> Self { + Self::with_capacity(schema, 0) + } + + /// Creates a RecordBatchBuilder with the given schema and initial capacity. + pub fn with_capacity(schema: SchemaRef, capacity: usize) -> Self { + let builders = schema + .fields() + .iter() + .map(|field| new_empty_builder(field.data_type(), capacity)) + .collect(); + + Self { schema, builders } + } + + /// Iterate (mutable) over the builders + pub fn builders_mut(&mut self) -> &mut [Box] { + &mut self.builders + } + + /// Returns the number of rows in this RecordBatchBuilder. + pub fn len(&mut self) -> usize { + if self.builders.is_empty() { + 0 + } else { + self.builders[0].len() + } + } + + /// Produce a new RecordBatch from the data in this builder. + pub fn finish(&mut self) -> Result { + let columns = self + .builders + .iter_mut() + .map(|builder| builder.finish()) + .collect(); + RecordBatch::try_new(self.schema.clone(), columns) + } +} diff --git a/arrow-buffer/src/builder/boolean.rs b/arrow-buffer/src/builder/boolean.rs index 83d64ab8d8b3..e89524113340 100644 --- a/arrow-buffer/src/builder/boolean.rs +++ b/arrow-buffer/src/builder/boolean.rs @@ -197,12 +197,18 @@ impl BooleanBufferBuilder { /// Appends a slice of booleans into the buffer #[inline] pub fn append_slice(&mut self, slice: &[bool]) { - let additional = slice.len(); + self.append_iter(slice.iter().copied()) + } + + /// Appends a iter of booleans into the buffer + #[inline] + pub fn append_iter(&mut self, iter: impl ExactSizeIterator) { + let additional = iter.len(); self.advance(additional); let offset = self.len() - additional; - for (i, v) in slice.iter().enumerate() { - if *v { + for (i, v) in iter.enumerate() { + if v { unsafe { bit_util::set_bit_raw(self.buffer.as_mut_ptr(), offset + i) } } } diff --git a/arrow-buffer/src/builder/null.rs b/arrow-buffer/src/builder/null.rs index 3c762bdebcf3..ce125dd5e187 100644 --- a/arrow-buffer/src/builder/null.rs +++ b/arrow-buffer/src/builder/null.rs @@ -169,13 +169,32 @@ impl NullBufferBuilder { /// Appends a boolean slice into the builder /// to indicate the validations of these items. pub fn append_slice(&mut self, slice: &[bool]) { - if slice.iter().any(|v| !v) { - self.materialize_if_needed() + self.append_iter(slice.iter().copied()); + } + + /// Append booleans from an `iter` into hte builder + pub fn append_iter(&mut self, iter: impl ExactSizeIterator + Clone) { + if iter.clone().any(|v| !v) { + self.materialize_if_needed(); } + if let Some(buf) = self.bitmap_builder.as_mut() { - buf.append_slice(slice) + buf.append_iter(iter); + } else { + self.len += iter.count(); + } + } + + /// Appends nulls from a NullBuffer into the builder + pub fn append_buffer(&mut self, buffer: &NullBuffer) { + if buffer.null_count() == 0 { + self.append_n_non_nulls(buffer.null_count()); } else { - self.len += slice.len(); + self.materialize(); + self.bitmap_builder + .as_mut() + .unwrap() + .append_buffer(buffer.inner()); } } diff --git a/arrow-select/src/lib.rs b/arrow-select/src/lib.rs index e8f45441c481..f29d99a4ef31 100644 --- a/arrow-select/src/lib.rs +++ b/arrow-select/src/lib.rs @@ -24,6 +24,7 @@ pub mod filter; pub mod interleave; pub mod nullif; pub mod take; +pub mod take_in; pub mod union_extract; pub mod window; pub mod zip; diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index 71a7c77a8f92..06939fdb7119 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -158,7 +158,7 @@ pub fn take_arrays( } /// Verifies that the non-null values of `indices` are all `< len` -fn check_bounds( +pub(crate) fn check_bounds( len: usize, indices: &PrimitiveArray, ) -> Result<(), ArrowError> { @@ -836,7 +836,7 @@ where /// To avoid generating take implementations for every index type, instead we /// only generate for UInt32 and UInt64 and coerce inputs to these types -trait ToIndices { +pub(crate) trait ToIndices { type T: ArrowPrimitiveType; fn to_indices(&self) -> PrimitiveArray; diff --git a/arrow-select/src/take_in.rs b/arrow-select/src/take_in.rs new file mode 100644 index 000000000000..a74b73fc7413 --- /dev/null +++ b/arrow-select/src/take_in.rs @@ -0,0 +1,149 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Defines take_in kernel for [Array] [ArrayBuilder] +use arrow_array::cast::AsArray; +use arrow_array::types::*; +use arrow_array::*; +use arrow_schema::{ArrowError, DataType}; +use builder::{ + ArrayBuilder, GenericBinaryBuilder, GenericByteViewBuilder, GenericStringBuilder, + PrimitiveBuilder, +}; + +use crate::take::check_bounds; +use crate::take::TakeOptions; +use crate::take::ToIndices; + +/// Take values at `indices` from `values` and append them to `builder`. +pub fn take_in( + values: &dyn Array, + builder: &mut dyn ArrayBuilder, + indices: &dyn Array, + options: Option, +) -> Result<(), ArrowError> { + let options = options.unwrap_or_default(); + macro_rules! helper { + ($t:ty, $values:expr, $builder:expr, $indices:expr, $options:expr) => {{ + let indices = indices.as_primitive::<$t>(); + if $options.check_bounds { + check_bounds($values.len(), indices)?; + } + let indices = indices.to_indices(); + Ok(take_in_impl($values, $builder, &indices)) + }}; + } + downcast_integer! { + indices.data_type() => (helper, values, builder, indices, options), + d => Err(ArrowError::InvalidArgumentError(format!("Take only supported for integers, got {d:?}"))) + } +} + +/// Take values at `indices` from `batch` and append them to `builder`. +pub fn take_in_batch( + batch: &RecordBatch, + builder: &mut RecordBatchBuilder, + indices: &dyn Array, + options: Option, +) -> Result<(), ArrowError> { + let columns = batch.columns(); + let builders = builder.builders_mut(); + + for (column, builder) in columns.into_iter().zip(builders.into_iter()) { + take_in(column, builder.as_mut(), indices, options.clone())? + } + + Ok(()) +} + +fn take_in_primitive( + values: &dyn Array, + builder: &mut dyn ArrayBuilder, + indices: &PrimitiveArray, +) { + macro_rules! helper { + ($t: ty) => {{ + builder + .as_any_mut() + .downcast_mut::>() + .expect("builder does not match array type") + .take_in(values.as_primitive::<$t>(), indices); + }}; + } + + downcast_primitive! { + values.data_type() => (helper), + _ => unreachable!() + } +} + +fn take_in_impl( + values: &dyn Array, + builder: &mut dyn ArrayBuilder, + indices: &PrimitiveArray, +) where + IndexType: ArrowPrimitiveType, +{ + downcast_primitive_array! { + values => { + take_in_primitive(values, builder, indices) + }, + DataType::Utf8 => { + let builder = builder + .as_any_mut() + .downcast_mut::>() + .expect("builder does not match array type"); + builder.take_in(values.as_string::(), indices); + } + DataType::LargeUtf8 => { + let builder = builder + .as_any_mut() + .downcast_mut::>() + .expect("builder does not match array type"); + builder.take_in(values.as_string::(), indices); + } + DataType::Utf8View => { + let builder = builder + .as_any_mut() + .downcast_mut::>() + .expect("builder does not match array type"); + builder.take_in(values.as_string_view(), indices); + } + DataType::Binary => { + let builder = builder + .as_any_mut() + .downcast_mut::>() + .expect("builder does not match array type"); + builder.take_in(values.as_binary::(), indices); + } + DataType::LargeBinary => { + let builder = builder + .as_any_mut() + .downcast_mut::>() + .expect("builder does not match array type"); + builder.take_in(values.as_binary::(), indices); + } + DataType::BinaryView => { + let builder = builder + .as_any_mut() + .downcast_mut::>() + .expect("builder does not match array type"); + builder.take_in(values.as_binary_view(), indices); + } + t => unimplemented!("TakeIn not supported for data type {:?}", t) + } +} diff --git a/arrow/src/compute/kernels.rs b/arrow/src/compute/kernels.rs index 86fdbe66c8ae..32b854ffd05b 100644 --- a/arrow/src/compute/kernels.rs +++ b/arrow/src/compute/kernels.rs @@ -21,7 +21,7 @@ pub use arrow_arith::{aggregate, arithmetic, arity, bitwise, boolean, numeric, t pub use arrow_cast::cast; pub use arrow_cast::parse as cast_utils; pub use arrow_ord::{cmp, partition, rank, sort}; -pub use arrow_select::{concat, filter, interleave, nullif, take, union_extract, window, zip}; +pub use arrow_select::{concat, filter, interleave, nullif, take, take_in, union_extract, window, zip}; pub use arrow_string::{concat_elements, length, regexp, substring}; /// Comparison kernels for `Array`s. diff --git a/arrow/src/compute/mod.rs b/arrow/src/compute/mod.rs index bff7214718fc..7119c2730049 100644 --- a/arrow/src/compute/mod.rs +++ b/arrow/src/compute/mod.rs @@ -34,6 +34,7 @@ pub use self::kernels::rank::*; pub use self::kernels::regexp::*; pub use self::kernels::sort::*; pub use self::kernels::take::*; +pub use self::kernels::take_in::*; pub use self::kernels::temporal::*; pub use self::kernels::union_extract::*; pub use self::kernels::window::*;