Skip to content

Add take_in kernel #7325

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
56 changes: 55 additions & 1 deletion arrow-array/src/builder/generic_bytes_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@

use crate::builder::{ArrayBuilder, BufferBuilder, UInt8BufferBuilder};
use crate::types::{ByteArrayType, GenericBinaryType, GenericStringType};
use crate::{ArrayRef, GenericByteArray, OffsetSizeTrait};
use crate::{
Array, ArrayRef, ArrowPrimitiveType, GenericByteArray, OffsetSizeTrait, PrimitiveArray,
};
use arrow_buffer::NullBufferBuilder;
use arrow_buffer::{ArrowNativeType, Buffer, MutableBuffer};
use arrow_data::ArrayDataBuilder;
use std::any::Any;
use std::sync::Arc;

use super::take_in_utils;

/// Builder for [`GenericByteArray`]
///
/// For building strings, see docs on [`GenericStringBuilder`].
Expand Down Expand Up @@ -129,6 +133,56 @@ impl<T: ByteArrayType> GenericByteBuilder<T> {
self.offsets_builder.append(self.next_offset());
}

/// Take values at indices from array into the builder.
pub fn take_in<I>(&mut self, array: &GenericByteArray<T>, indices: &PrimitiveArray<I>)
where
I: ArrowPrimitiveType,
{
take_in_utils::take_in_nulls(&mut self.null_buffer_builder, array.nulls(), &indices);

let indices_has_nulls = indices.null_count() > 0;
let array_has_nulls = indices.null_count() > 0;

match (indices_has_nulls, array_has_nulls) {
(true, true) => {
for idx in indices.iter() {
match idx {
Some(idx) => {
if array.is_valid(idx.as_usize()) {
self.append_value(&array.value(idx.as_usize()))
} else {
self.append_null();
}
}
None => self.append_null(),
};
}
}
(true, false) => {
for idx in indices.iter() {
match idx {
Some(idx) => self.append_value(&array.value(idx.as_usize())),
None => self.append_null(),
};
}
}
(false, true) => {
for idx in indices.values().iter() {
if array.is_valid(idx.as_usize()) {
self.append_value(&array.value(idx.as_usize()));
} else {
self.append_null();
}
}
}
(false, false) => {
for idx in indices.values().iter() {
self.append_value(&array.value(idx.as_usize()));
}
}
}
}

/// Builds the [`GenericByteArray`] and reset this builder.
pub fn finish(&mut self) -> GenericByteArray<T> {
let array_type = T::DATA_TYPE;
Expand Down
76 changes: 74 additions & 2 deletions arrow-array/src/builder/generic_bytes_view_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use std::any::Any;
use std::marker::PhantomData;
use std::sync::Arc;

use arrow_buffer::{Buffer, BufferBuilder, NullBufferBuilder, ScalarBuffer};
use arrow_buffer::{ArrowNativeType, Buffer, BufferBuilder, NullBufferBuilder, ScalarBuffer};
use arrow_data::ByteView;
use arrow_schema::ArrowError;
use hashbrown::hash_table::Entry;
Expand All @@ -28,7 +28,9 @@ use hashbrown::HashTable;
use crate::builder::ArrayBuilder;
use crate::types::bytes::ByteArrayNativeType;
use crate::types::{BinaryViewType, ByteViewType, StringViewType};
use crate::{ArrayRef, GenericByteViewArray};
use crate::{Array, ArrayRef, ArrowPrimitiveType, GenericByteViewArray, PrimitiveArray};

use super::take_in_utils;

const STARTING_BLOCK_SIZE: u32 = 8 * 1024; // 8KiB
const MAX_BLOCK_SIZE: u32 = 2 * 1024 * 1024; // 2MiB
Expand Down Expand Up @@ -361,6 +363,76 @@ impl<T: ByteViewType + ?Sized> GenericByteViewBuilder<T> {
self.views_builder.append(0);
}

/// Take values at indices from array into the builder.
pub fn take_in<I>(&mut self, array: &GenericByteViewArray<T>, indices: &PrimitiveArray<I>)
where
I: ArrowPrimitiveType,
{
// lots of todos:
// - Maybe: Decide on GC based on indices to array length ratio
// - Maybe: GC only fully-empty buffers
// - Maybe: Use take_in_nulls / take_in_native in GC path

let capacity_needed: u32 = indices
.iter()
.map(|idx| match idx {
Some(idx) => {
let len = if array.is_valid(idx.as_usize()) {
let view = array.views().get(idx.as_usize()).unwrap();
*view as u32
} else {
0
};
if len <= 12 {
0
} else {
len
}
}
None => 0,
})
.sum();

if array.get_buffer_memory_size() > 2 * capacity_needed as usize {
for index in indices {
match index {
Some(index) => {
if array.is_valid(index.as_usize()) {
self.append_value(array.value(index.as_usize()));
} else {
self.append_null();
}
}
None => {
self.append_null();
}
}
}
} else {
self.flush_in_progress();

let start = self.len();
let start_buffers = self.completed.len();

take_in_utils::take_in_nulls(&mut self.null_buffer_builder, array.nulls(), indices);

take_in_utils::take_in_native::<u128, I>(
&mut self.views_builder,
&array.views(),
indices,
);

self.completed
.extend(array.data_buffers().into_iter().cloned());
for i in start..self.len() {
let mut view = ByteView::from(self.views_builder.as_slice_mut()[i]);
if view.length > 12 {
view.buffer_index += start_buffers as u32; // overflow check needed?
self.views_builder.as_slice_mut()[i] = view.as_u128();
}
}
}
}
/// Builds the [`GenericByteViewArray`] and reset this builder
pub fn finish(&mut self) -> GenericByteViewArray<T> {
self.flush_in_progress();
Expand Down
113 changes: 113 additions & 0 deletions arrow-array/src/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,15 @@
//! assert_eq!(buffer.iter().collect::<Vec<_>>(), vec![true, true, true, true, true, true, true, false]);
//! ```

mod take_in_utils;

pub use arrow_buffer::BooleanBufferBuilder;
pub use arrow_buffer::NullBufferBuilder;

mod boolean_builder;
use arrow_schema::DataType;
use arrow_schema::IntervalUnit;
use arrow_schema::TimeUnit;
pub use boolean_builder::*;
mod buffer_builder;
pub use buffer_builder::*;
Expand Down Expand Up @@ -271,6 +276,8 @@ mod union_builder;

pub use union_builder::*;

use crate::types::BinaryViewType;
use crate::types::StringViewType;
use crate::ArrayRef;
use std::any::Any;

Expand Down Expand Up @@ -416,3 +423,109 @@ pub type StringBuilder = GenericStringBuilder<i32>;
///
/// See examples on [`GenericStringBuilder`]
pub type LargeStringBuilder = GenericStringBuilder<i64>;

/// Create an empty builder for `data_type` with capacity `capacity`
pub fn new_empty_builder(data_type: &DataType, capacity: usize) -> Box<dyn ArrayBuilder> {
match data_type {
DataType::Null => todo!(),
DataType::Boolean => todo!(),
DataType::Int8 => Box::new(Int8Builder::with_capacity(capacity)) as _,
DataType::Int16 => Box::new(Int16Builder::with_capacity(capacity)) as _,
DataType::Int32 => Box::new(Int32Builder::with_capacity(capacity)) as _,
DataType::Int64 => Box::new(Int64Builder::with_capacity(capacity)) as _,
DataType::UInt8 => Box::new(UInt8Builder::with_capacity(capacity)) as _,
DataType::UInt16 => Box::new(UInt16Builder::with_capacity(capacity)) as _,
DataType::UInt32 => Box::new(UInt32Builder::with_capacity(capacity)) as _,
DataType::UInt64 => Box::new(UInt64Builder::with_capacity(capacity)) as _,
DataType::Float16 => Box::new(Float16Builder::with_capacity(capacity)) as _,
DataType::Float32 => Box::new(Float32Builder::with_capacity(capacity)) as _,
DataType::Float64 => Box::new(Float64Builder::with_capacity(capacity)) as _,
DataType::Timestamp(TimeUnit::Second, _) => {
Box::new(TimestampSecondBuilder::with_capacity(capacity)) as _
}
DataType::Timestamp(TimeUnit::Millisecond, _) => {
Box::new(TimestampMillisecondBuilder::with_capacity(capacity)) as _
}
DataType::Timestamp(TimeUnit::Microsecond, _) => {
Box::new(TimestampMicrosecondBuilder::with_capacity(capacity)) as _
}
DataType::Timestamp(TimeUnit::Nanosecond, _) => {
Box::new(TimestampNanosecondBuilder::with_capacity(capacity)) as _
}
DataType::Date32 => Box::new(Date32Builder::with_capacity(capacity)) as _,
DataType::Date64 => Box::new(Date64Builder::with_capacity(capacity)) as _,
DataType::Time32(TimeUnit::Second) => {
Box::new(Time32SecondBuilder::with_capacity(capacity)) as _
}
DataType::Time32(TimeUnit::Millisecond) => {
Box::new(Time32MillisecondBuilder::with_capacity(capacity)) as _
}
DataType::Time64(TimeUnit::Microsecond) => {
Box::new(Time64MicrosecondBuilder::with_capacity(capacity)) as _
}
DataType::Time64(TimeUnit::Nanosecond) => {
Box::new(Time64NanosecondBuilder::with_capacity(capacity)) as _
}
DataType::Duration(TimeUnit::Second) => {
Box::new(DurationSecondBuilder::with_capacity(capacity)) as _
}
DataType::Duration(TimeUnit::Millisecond) => {
Box::new(DurationMillisecondBuilder::with_capacity(capacity)) as _
}
DataType::Duration(TimeUnit::Microsecond) => {
Box::new(DurationMicrosecondBuilder::with_capacity(capacity)) as _
}
DataType::Duration(TimeUnit::Nanosecond) => {
Box::new(DurationNanosecondBuilder::with_capacity(capacity)) as _
}
DataType::Interval(IntervalUnit::YearMonth) => {
Box::new(IntervalYearMonthBuilder::with_capacity(capacity)) as _
}
DataType::Interval(IntervalUnit::DayTime) => {
Box::new(IntervalDayTimeBuilder::with_capacity(capacity)) as _
}
DataType::Interval(IntervalUnit::MonthDayNano) => {
Box::new(IntervalMonthDayNanoBuilder::with_capacity(capacity)) as _
}
DataType::Binary => Box::new(GenericBinaryBuilder::<i32>::with_capacity(
capacity, capacity,
)) as _,
DataType::FixedSizeBinary(_) => todo!(),
DataType::LargeBinary => Box::new(GenericBinaryBuilder::<i64>::with_capacity(
capacity, capacity,
)) as _,
DataType::BinaryView => Box::new(GenericByteViewBuilder::<BinaryViewType>::with_capacity(
capacity,
)) as _,
DataType::Utf8 => Box::new(GenericStringBuilder::<i32>::with_capacity(
capacity, capacity,
)) as _,
DataType::LargeUtf8 => Box::new(GenericStringBuilder::<i64>::with_capacity(
capacity, capacity,
)) as _,
DataType::Utf8View => Box::new(GenericByteViewBuilder::<StringViewType>::with_capacity(
capacity,
)) as _,
DataType::List(_field) => todo!(),
DataType::ListView(_field) => todo!(),
DataType::FixedSizeList(_field, _) => todo!(),
DataType::LargeList(_field) => todo!(),
DataType::LargeListView(_field) => todo!(),
DataType::Struct(_fields) => todo!(),
DataType::Union(_union_fields, _union_mode) => todo!(),
DataType::Dictionary(_data_type, _data_type1) => todo!(),
DataType::Decimal128(precision, scale) => Box::new(
Decimal128Builder::with_capacity(capacity)
.with_precision_and_scale(*precision, *scale)
.expect("Invalid precision / scale for Decimal128"),
) as _,
DataType::Decimal256(precision, scale) => Box::new(
Decimal256Builder::with_capacity(capacity)
.with_precision_and_scale(*precision, *scale)
.expect("Invalid precision / scale for Decimal256"),
) as _,
DataType::Map(_field, _) => todo!(),
DataType::RunEndEncoded(_field, _field1) => todo!(),
dt => panic!("Unexpected data type {dt:?}"),
}
}
21 changes: 20 additions & 1 deletion arrow-array/src/builder/primitive_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

use crate::builder::{ArrayBuilder, BufferBuilder};
use crate::types::*;
use crate::{types::*, Array};
use crate::{ArrayRef, PrimitiveArray};
use arrow_buffer::NullBufferBuilder;
use arrow_buffer::{Buffer, MutableBuffer};
Expand All @@ -25,6 +25,8 @@ use arrow_schema::{ArrowError, DataType};
use std::any::Any;
use std::sync::Arc;

use super::take_in_utils;

/// A signed 8-bit integer array builder.
pub type Int8Builder = PrimitiveBuilder<Int8Type>;
/// A signed 16-bit integer array builder.
Expand Down Expand Up @@ -272,6 +274,23 @@ impl<T: ArrowPrimitiveType> PrimitiveBuilder<T> {
self.values_builder.append_trusted_len_iter(iter);
}

/// Takes values at indices from array into this array
///
/// Each index in indices needs to be in-bounds for array or null - Otherwise an incorrect
/// result may be returned
#[inline]
pub fn take_in<I>(&mut self, array: &PrimitiveArray<T>, indices: &PrimitiveArray<I>)
where
I: ArrowPrimitiveType,
{
take_in_utils::take_in_nulls(&mut self.null_buffer_builder, array.nulls(), &indices);

take_in_utils::take_in_native::<T::Native, I>(
&mut self.values_builder,
&array.values(),
&indices,
);
}
/// Builds the [`PrimitiveArray`] and reset this builder.
pub fn finish(&mut self) -> PrimitiveArray<T> {
let len = self.len();
Expand Down
Loading