Skip to content

Commit

Permalink
perf: Optimize array and list gather (#19327)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Oct 21, 2024
1 parent ab5e464 commit 01e801f
Show file tree
Hide file tree
Showing 14 changed files with 375 additions and 14 deletions.
108 changes: 107 additions & 1 deletion crates/polars-arrow/src/array/fixed_size_list/mod.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{new_empty_array, new_null_array, Array, Splitable};
use super::{new_empty_array, new_null_array, Array, ArrayRef, Splitable};
use crate::bitmap::Bitmap;
use crate::datatypes::{ArrowDataType, Field};

Expand All @@ -9,8 +9,11 @@ mod iterator;
mod mutable;
pub use mutable::*;
use polars_error::{polars_bail, polars_ensure, PolarsResult};
use polars_utils::format_tuple;
use polars_utils::pl_str::PlSmallStr;

use crate::datatypes::reshape::{Dimension, ReshapeDimension};

/// The Arrow's equivalent to an immutable `Vec<Option<[T; size]>>` where `T` is an Arrow type.
/// Cloning and slicing this struct is `O(1)`.
#[derive(Clone)]
Expand Down Expand Up @@ -120,6 +123,108 @@ impl FixedSizeListArray {
let values = new_null_array(field.dtype().clone(), length * size);
Self::new(dtype, length, values, Some(Bitmap::new_zeroed(length)))
}

pub fn from_shape(
leaf_array: ArrayRef,
dimensions: &[ReshapeDimension],
) -> PolarsResult<ArrayRef> {
polars_ensure!(
!dimensions.is_empty(),
InvalidOperation: "at least one dimension must be specified"
);
let size = leaf_array.len();

let mut total_dim_size = 1;
let mut num_infers = 0;
for &dim in dimensions {
match dim {
ReshapeDimension::Infer => num_infers += 1,
ReshapeDimension::Specified(dim) => total_dim_size *= dim.get() as usize,
}
}

polars_ensure!(num_infers <= 1, InvalidOperation: "can only specify one inferred dimension");

if size == 0 {
polars_ensure!(
num_infers > 0 || total_dim_size == 0,
InvalidOperation: "cannot reshape empty array into shape without zero dimension: {}",
format_tuple!(dimensions),
);

let mut prev_arrow_dtype = leaf_array.dtype().clone();
let mut prev_array = leaf_array;

// @NOTE: We need to collect the iterator here because it is lazily processed.
let mut current_length = dimensions[0].get_or_infer(0);
let len_iter = dimensions[1..]
.iter()
.map(|d| {
let length = current_length as usize;
current_length *= d.get_or_infer(0);
length
})
.collect::<Vec<_>>();

// We pop the outer dimension as that is the height of the series.
for (dim, length) in dimensions[1..].iter().zip(len_iter).rev() {
// Infer dimension if needed
let dim = dim.get_or_infer(0);
prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true);

prev_array =
FixedSizeListArray::new(prev_arrow_dtype.clone(), length, prev_array, None)
.boxed();
}

return Ok(prev_array);
}

polars_ensure!(
total_dim_size > 0,
InvalidOperation: "cannot reshape non-empty array into shape containing a zero dimension: {}",
format_tuple!(dimensions)
);

polars_ensure!(
size % total_dim_size == 0,
InvalidOperation: "cannot reshape array of size {} into shape {}", size, format_tuple!(dimensions)
);

let mut prev_arrow_dtype = leaf_array.dtype().clone();
let mut prev_array = leaf_array;

// We pop the outer dimension as that is the height of the series.
for dim in dimensions[1..].iter().rev() {
// Infer dimension if needed
let dim = dim.get_or_infer((size / total_dim_size) as u64);
prev_arrow_dtype = prev_arrow_dtype.to_fixed_size_list(dim as usize, true);

prev_array = FixedSizeListArray::new(
prev_arrow_dtype.clone(),
prev_array.len() / dim as usize,
prev_array,
None,
)
.boxed();
}
Ok(prev_array)
}

pub fn get_dims(&self) -> Vec<Dimension> {
let mut dims = vec![
Dimension::new(self.length as _),
Dimension::new(self.size as _),
];

let mut prev_array = &self.values;

while let Some(a) = prev_array.as_any().downcast_ref::<FixedSizeListArray>() {
dims.push(Dimension::new(a.size as _));
prev_array = &a.values;
}
dims
}
}

// must use
Expand All @@ -144,6 +249,7 @@ impl FixedSizeListArray {
/// # Safety
/// The caller must ensure that `offset + length <= self.len()`.
pub unsafe fn slice_unchecked(&mut self, offset: usize, length: usize) {
debug_assert!(offset + length <= self.len());
self.validity = self
.validity
.take()
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/array/growable/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ unsafe fn extend_offset_values<O: Offset>(
start: usize,
len: usize,
) {
let array = growable.arrays[index];
let array = growable.arrays.get_unchecked_release(index);
let offsets = array.offsets();

growable
Expand Down
179 changes: 175 additions & 4 deletions crates/polars-arrow/src/compute/take/fixed_size_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,18 @@
// specific language governing permissions and limitations
// under the License.

use polars_utils::itertools::Itertools;

use super::Index;
use crate::array::growable::{Growable, GrowableFixedSizeList};
use crate::array::{FixedSizeListArray, PrimitiveArray};
use crate::array::{Array, ArrayRef, FixedSizeListArray, PrimitiveArray};
use crate::bitmap::MutableBitmap;
use crate::datatypes::reshape::{Dimension, ReshapeDimension};
use crate::datatypes::{ArrowDataType, PhysicalType};
use crate::legacy::prelude::FromData;
use crate::with_match_primitive_type;

/// `take` implementation for FixedSizeListArrays
pub(super) unsafe fn take_unchecked<O: Index>(
pub(super) unsafe fn take_unchecked_slow<O: Index>(
values: &FixedSizeListArray,
indices: &PrimitiveArray<O>,
) -> FixedSizeListArray {
Expand All @@ -31,7 +37,7 @@ pub(super) unsafe fn take_unchecked<O: Index>(
.iter()
.map(|index| {
let index = index.to_usize();
let slice = values.clone().sliced(index, take_len);
let slice = values.clone().sliced_unchecked(index, take_len);
capacity += slice.len();
slice
})
Expand Down Expand Up @@ -62,3 +68,168 @@ pub(super) unsafe fn take_unchecked<O: Index>(
growable.into()
}
}

fn get_stride_and_leaf_type(dtype: &ArrowDataType, size: usize) -> (usize, &ArrowDataType) {
if let ArrowDataType::FixedSizeList(inner, size_inner) = dtype {
get_stride_and_leaf_type(inner.dtype(), *size_inner * size)
} else {
(size, dtype)
}
}

fn get_leaves(array: &FixedSizeListArray) -> &dyn Array {
if let Some(array) = array.values().as_any().downcast_ref::<FixedSizeListArray>() {
get_leaves(array)
} else {
&**array.values()
}
}

fn get_buffer_and_size(array: &dyn Array) -> (&[u8], usize) {
match array.dtype().to_physical_type() {
PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {

let arr = array.as_any().downcast_ref::<PrimitiveArray<$T>>().unwrap();
let values = arr.values();
(bytemuck::cast_slice(values), size_of::<$T>())

}),
_ => {
unimplemented!()
},
}
}

unsafe fn from_buffer(mut buf: Vec<u8>, dtype: &ArrowDataType) -> ArrayRef {
assert_eq!(buf.as_ptr().align_offset(256), 0);

match dtype.to_physical_type() {
PhysicalType::Primitive(primitive) => with_match_primitive_type!(primitive, |$T| {

let ptr = buf.as_mut_ptr();
let len_units = buf.len();
let cap_units = buf.capacity();

std::mem::forget(buf);

let buf = Vec::from_raw_parts(
ptr as *mut $T,
len_units / size_of::<$T>(),
cap_units / size_of::<$T>(),
);

PrimitiveArray::from_data_default(buf.into(), None).boxed()

}),
_ => {
unimplemented!()
},
}
}

// Use an alignedvec so the alignment always fits the actual type
// That way we can operate on bytes and reduce monomorphization.
#[repr(C, align(256))]
struct Align256([u8; 256]);

unsafe fn aligned_vec(n_bytes: usize) -> Vec<u8> {
// Lazy math to ensure we always have enough.
let n_units = (n_bytes / size_of::<Align256>()) + 1;

let mut aligned: Vec<Align256> = Vec::with_capacity(n_units);

let ptr = aligned.as_mut_ptr();
let len_units = aligned.len();
let cap_units = aligned.capacity();

std::mem::forget(aligned);

Vec::from_raw_parts(
ptr as *mut u8,
len_units * size_of::<Align256>(),
cap_units * size_of::<Align256>(),
)
}

fn no_inner_validities(values: &ArrayRef) -> bool {
if let Some(arr) = values.as_any().downcast_ref::<FixedSizeListArray>() {
arr.validity().is_none() && no_inner_validities(arr.values())
} else {
values.validity().is_none()
}
}

/// `take` implementation for FixedSizeListArrays
pub(super) unsafe fn take_unchecked<O: Index>(
values: &FixedSizeListArray,
indices: &PrimitiveArray<O>,
) -> ArrayRef {
let (stride, leaf_type) = get_stride_and_leaf_type(values.dtype(), 1);
if leaf_type.to_physical_type().is_primitive() && no_inner_validities(values.values()) {
let leaves = get_leaves(values);

let (leaves_buf, leave_size) = get_buffer_and_size(leaves);
let bytes_per_element = leave_size * stride;

let n_idx = indices.len();
let total_bytes = bytes_per_element * n_idx;

let mut buf = aligned_vec(total_bytes);
let dst = buf.spare_capacity_mut();

let mut count = 0;
let validity = if indices.null_count() == 0 {
for i in indices.values().iter() {
let i = i.to_usize();

std::ptr::copy_nonoverlapping(
leaves_buf.as_ptr().add(i * bytes_per_element),
dst.as_mut_ptr().add(count * bytes_per_element) as *mut _,
bytes_per_element,
);
count += 1;
}
None
} else {
let mut new_validity = MutableBitmap::with_capacity(indices.len());
new_validity.extend_constant(indices.len(), true);
for i in indices.iter() {
if let Some(i) = i {
let i = i.to_usize();
std::ptr::copy_nonoverlapping(
leaves_buf.as_ptr().add(i * bytes_per_element),
dst.as_mut_ptr().add(count * bytes_per_element) as *mut _,
bytes_per_element,
);
} else {
new_validity.set_unchecked(count, false);
std::ptr::write_bytes(
dst.as_mut_ptr().add(count * bytes_per_element) as *mut _,
0,
bytes_per_element,
);
}

count += 1;
}
Some(new_validity.freeze())
};
assert_eq!(count * bytes_per_element, total_bytes);

buf.set_len(total_bytes);

let leaves = from_buffer(buf, leaves.dtype());
let mut shape = values.get_dims();
shape[0] = Dimension::new(indices.len() as _);
let shape = shape
.into_iter()
.map(ReshapeDimension::Specified)
.collect_vec();

FixedSizeListArray::from_shape(leaves.clone(), &shape)
.unwrap()
.with_validity(validity)
} else {
take_unchecked_slow(values, indices).boxed()
}
}
2 changes: 1 addition & 1 deletion crates/polars-arrow/src/compute/take/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ pub unsafe fn take_unchecked(values: &dyn Array, indices: &IdxArr) -> Box<dyn Ar
},
FixedSizeList => {
let array = values.as_any().downcast_ref().unwrap();
Box::new(fixed_size_list::take_unchecked(array, indices))
fixed_size_list::take_unchecked(array, indices)
},
BinaryView => {
take_binview_unchecked(values.as_any().downcast_ref().unwrap(), indices).boxed()
Expand Down
20 changes: 20 additions & 0 deletions crates/polars-arrow/src/datatypes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
mod field;
mod physical_type;
pub mod reshape;
mod schema;

use std::collections::BTreeMap;
Expand Down Expand Up @@ -365,6 +366,25 @@ impl ArrowDataType {
matches!(self, ArrowDataType::Utf8View | ArrowDataType::BinaryView)
}

pub fn is_numeric(&self) -> bool {
use ArrowDataType as D;
matches!(
self,
D::Int8
| D::Int16
| D::Int32
| D::Int64
| D::UInt8
| D::UInt16
| D::UInt32
| D::UInt64
| D::Float32
| D::Float64
| D::Decimal(_, _)
| D::Decimal256(_, _)
)
}

pub fn to_fixed_size_list(self, size: usize, is_nullable: bool) -> ArrowDataType {
ArrowDataType::FixedSizeList(
Box::new(Field::new(
Expand Down
4 changes: 4 additions & 0 deletions crates/polars-arrow/src/datatypes/physical_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ impl PhysicalType {
false
}
}

pub fn is_primitive(&self) -> bool {
matches!(self, Self::Primitive(_))
}
}

/// the set of valid indices types of a dictionary-encoded Array.
Expand Down
File renamed without changes.
Loading

0 comments on commit 01e801f

Please sign in to comment.