Skip to content

Commit d011e6a

Browse files
askoaask
and
ask
authored
perf: take_run improvements (#3705)
* take_run improvements * doc fix * test case update per pr comment --------- Co-authored-by: ask <ask@local>
1 parent e37e379 commit d011e6a

File tree

1 file changed

+66
-58
lines changed

1 file changed

+66
-58
lines changed

arrow-select/src/take.rs

Lines changed: 66 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -19,15 +19,14 @@
1919
2020
use std::sync::Arc;
2121

22+
use arrow_array::builder::BufferBuilder;
23+
use arrow_array::types::*;
2224
use arrow_array::*;
23-
use arrow_array::{builder::PrimitiveRunBuilder, types::*};
2425
use arrow_buffer::{bit_util, ArrowNativeType, Buffer, MutableBuffer};
2526
use arrow_data::{ArrayData, ArrayDataBuilder};
2627
use arrow_schema::{ArrowError, DataType, Field};
2728

28-
use arrow_array::cast::{
29-
as_generic_binary_array, as_largestring_array, as_primitive_array, as_string_array,
30-
};
29+
use arrow_array::cast::{as_generic_binary_array, as_largestring_array, as_string_array};
3130
use num::{ToPrimitive, Zero};
3231

3332
/// Take elements by index from [Array], creating a new [Array] from those indexes.
@@ -816,22 +815,14 @@ where
816815
Ok(DictionaryArray::<T>::from(data))
817816
}
818817

819-
macro_rules! primitive_run_take {
820-
($t:ty, $o:ty, $indices:ident, $value:ident) => {
821-
take_primitive_run_values::<$o, $t>(
822-
$indices,
823-
as_primitive_array::<$t>($value.values()),
824-
)
825-
};
826-
}
827-
828818
/// `take` implementation for run arrays
829819
///
830820
/// Finds physical indices for the given logical indices and builds output run array
831-
/// by taking values in the input run array at the physical indices.
832-
/// for e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `indices=[2,7]`
833-
/// would be converted to `physical_indices=[1,3]` which will be used to build
834-
/// output `RunArray{ run_ends=[2], values=[2] }`
821+
/// by taking values in the input run_array.values at the physical indices.
822+
/// The output run array will be run encoded on the physical indices and not on output values.
823+
/// For e.g. an input `RunArray{ run_ends = [2,4,6,8], values=[1,2,1,2] }` and `logical_indices=[2,3,6,7]`
824+
/// would be converted to `physical_indices=[1,1,3,3]` which will be used to build
825+
/// output `RunArray{ run_ends=[2,4], values=[2,2] }`.
835826
fn take_run<T, I>(
836827
run_array: &RunArray<T>,
837828
logical_indices: &PrimitiveArray<I>,
@@ -842,43 +833,60 @@ where
842833
I: ArrowPrimitiveType,
843834
I::Native: ToPrimitive,
844835
{
845-
match run_array.data_type() {
846-
DataType::RunEndEncoded(_, fl) => {
847-
let physical_indices =
848-
run_array.get_physical_indices(logical_indices.values())?;
849-
850-
downcast_primitive! {
851-
fl.data_type() => (primitive_run_take, T, physical_indices, run_array),
852-
dt => Err(ArrowError::NotYetImplemented(format!("take_run is not implemented for {dt:?}")))
853-
}
836+
// get physical indices for the input logical indices
837+
let physical_indices = run_array.get_physical_indices(logical_indices.values())?;
838+
839+
// Run encode the physical indices into new_run_ends_builder
840+
// Keep track of the physical indices to take in take_value_indices
841+
// `unwrap` is used in this function because the unwrapped values are bounded by the corresponding `::Native`.
842+
let mut new_run_ends_builder = BufferBuilder::<T::Native>::new(1);
843+
let mut take_value_indices = BufferBuilder::<I::Native>::new(1);
844+
let mut new_physical_len = 1;
845+
for ix in 1..physical_indices.len() {
846+
if physical_indices[ix] != physical_indices[ix - 1] {
847+
take_value_indices
848+
.append(I::Native::from_usize(physical_indices[ix - 1]).unwrap());
849+
new_run_ends_builder.append(T::Native::from_usize(ix).unwrap());
850+
new_physical_len += 1;
854851
}
855-
dt => Err(ArrowError::InvalidArgumentError(format!(
856-
"Expected DataType::RunEndEncoded found {dt:?}"
857-
))),
858852
}
859-
}
853+
take_value_indices.append(
854+
I::Native::from_usize(physical_indices[physical_indices.len() - 1]).unwrap(),
855+
);
856+
new_run_ends_builder.append(T::Native::from_usize(physical_indices.len()).unwrap());
857+
let new_run_ends = unsafe {
858+
// Safety:
859+
// The function builds a valid run_ends array and hence need not be validated.
860+
ArrayDataBuilder::new(T::DATA_TYPE)
861+
.len(new_physical_len)
862+
.null_count(0)
863+
.add_buffer(new_run_ends_builder.finish())
864+
.build_unchecked()
865+
};
860866

861-
// Builds a `RunArray` by taking values from given array for the given indices.
862-
fn take_primitive_run_values<R, V>(
863-
physical_indices: Vec<usize>,
864-
values: &PrimitiveArray<V>,
865-
) -> Result<RunArray<R>, ArrowError>
866-
where
867-
R: RunEndIndexType,
868-
V: ArrowPrimitiveType,
869-
{
870-
let mut builder = PrimitiveRunBuilder::<R, V>::new();
871-
let values_len = values.len();
872-
for ix in physical_indices {
873-
if ix >= values_len {
874-
return Err(ArrowError::InvalidArgumentError("The requested index {ix} is out of bounds for values array with length {values_len}".to_string()));
875-
} else if values.is_null(ix) {
876-
builder.append_null()
877-
} else {
878-
builder.append_value(values.value(ix))
879-
}
880-
}
881-
Ok(builder.finish())
867+
let take_value_indices: PrimitiveArray<I> = unsafe {
868+
// Safety:
869+
// The function builds a valid take_value_indices array and hence need not be validated.
870+
ArrayDataBuilder::new(I::DATA_TYPE)
871+
.len(new_physical_len)
872+
.null_count(0)
873+
.add_buffer(take_value_indices.finish())
874+
.build_unchecked()
875+
.into()
876+
};
877+
878+
let new_values = take(run_array.values(), &take_value_indices, None)?;
879+
880+
let builder = ArrayDataBuilder::new(run_array.data_type().clone())
881+
.len(physical_indices.len())
882+
.add_child_data(new_run_ends)
883+
.add_child_data(new_values.into_data());
884+
let array_data = unsafe {
885+
// Safety:
886+
// This function builds a valid run array and hence can skip validation.
887+
builder.build_unchecked()
888+
};
889+
Ok(array_data.into())
882890
}
883891

884892
/// Takes/filters a list array's inner data using the offsets of the list array.
@@ -983,7 +991,7 @@ where
983991
#[cfg(test)]
984992
mod tests {
985993
use super::*;
986-
use arrow_array::builder::*;
994+
use arrow_array::{builder::*, cast::as_primitive_array};
987995
use arrow_schema::TimeUnit;
988996

989997
fn test_take_decimal_arrays(
@@ -2159,24 +2167,24 @@ mod tests {
21592167

21602168
#[test]
21612169
fn test_take_runs() {
2162-
let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 2, 2, 1, 1, 2, 2];
2170+
let logical_array: Vec<i32> = vec![1_i32, 1, 2, 2, 1, 1, 1, 2, 2, 1, 1, 2, 2];
21632171

21642172
let mut builder = PrimitiveRunBuilder::<Int32Type, Int32Type>::new();
21652173
builder.extend(logical_array.into_iter().map(Some));
21662174
let run_array = builder.finish();
21672175

21682176
let take_indices: PrimitiveArray<Int32Type> =
2169-
vec![2, 7, 10].into_iter().collect();
2177+
vec![7, 2, 3, 7, 11, 4, 6].into_iter().collect();
21702178

21712179
let take_out = take_run(&run_array, &take_indices).unwrap();
21722180

2173-
assert_eq!(take_out.len(), 3);
2181+
assert_eq!(take_out.len(), 7);
21742182

2175-
assert_eq!(take_out.run_ends().len(), 1);
2176-
assert_eq!(take_out.run_ends().value(0), 3);
2183+
assert_eq!(take_out.run_ends().len(), 5);
2184+
assert_eq!(take_out.run_ends().values(), &[1_i32, 3, 4, 5, 7]);
21772185

21782186
let take_out_values = as_primitive_array::<Int32Type>(take_out.values());
2179-
assert_eq!(take_out_values.value(0), 2);
2187+
assert_eq!(take_out_values.values(), &[2, 2, 2, 2, 1]);
21802188
}
21812189

21822190
#[test]

0 commit comments

Comments
 (0)