Skip to content

Commit

Permalink
fix: rewrite top_k/bottom_k, variety of bugs (#16804)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Jun 7, 2024
1 parent 84cc62f commit d26b328
Show file tree
Hide file tree
Showing 21 changed files with 417 additions and 354 deletions.
28 changes: 6 additions & 22 deletions crates/polars-arrow/src/array/binview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,10 @@ impl<T: ViewType + ?Sized> BinaryViewArrayGeneric<T> {
&self.views
}

pub fn into_views(self) -> Vec<View> {
self.views.make_mut()
}

pub fn try_new(
data_type: ArrowDataType,
views: Buffer<View>,
Expand Down Expand Up @@ -265,28 +269,8 @@ impl<T: ViewType + ?Sized> BinaryViewArrayGeneric<T> {
/// Assumes that the `i < self.len`.
#[inline]
pub unsafe fn value_unchecked(&self, i: usize) -> &T {
let v = *self.views.get_unchecked_release(i);
let len = v.length;

// view layout:
// length: 4 bytes
// prefix: 4 bytes
// buffer_index: 4 bytes
// offset: 4 bytes

// inlined layout:
// length: 4 bytes
// data: 12 bytes

let bytes = if len <= 12 {
let ptr = self.views.as_ptr() as *const u8;
std::slice::from_raw_parts(ptr.add(i * 16 + 4), len as usize)
} else {
let data = self.buffers.get_unchecked_release(v.buffer_idx as usize);
let offset = v.offset as usize;
data.get_unchecked_release(offset..offset + len as usize)
};
T::from_bytes_unchecked(bytes)
let v = self.views.get_unchecked_release(i);
T::from_bytes_unchecked(v.get_slice_unchecked(&self.buffers))
}

/// Returns an iterator of `Option<&T>` over every element of this array.
Expand Down
22 changes: 14 additions & 8 deletions crates/polars-arrow/src/array/binview/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,24 +343,30 @@ impl<T: ViewType + ?Sized> MutableBinaryViewArray<T> {
/// Assumes that the `i < self.len`.
#[inline]
pub unsafe fn value_unchecked(&self, i: usize) -> &T {
let v = *self.views.get_unchecked(i);
let len = v.length;
self.value_from_view_unchecked(self.views.get_unchecked(i))
}

// view layout:
/// Returns the element indicated by the given view.
///
/// # Safety
/// Assumes the View belongs to this MutableBinaryViewArray.
pub unsafe fn value_from_view_unchecked<'a>(&'a self, view: &'a View) -> &'a T {
// View layout:
// length: 4 bytes
// prefix: 4 bytes
// buffer_index: 4 bytes
// offset: 4 bytes

// inlined layout:
// Inlined layout:
// length: 4 bytes
// data: 12 bytes
let len = view.length;
let bytes = if len <= 12 {
let ptr = self.views.as_ptr() as *const u8;
std::slice::from_raw_parts(ptr.add(i * 16 + 4), len as usize)
let ptr = view as *const View as *const u8;
std::slice::from_raw_parts(ptr.add(4), len as usize)
} else {
let buffer_idx = v.buffer_idx as usize;
let offset = v.offset;
let buffer_idx = view.buffer_idx as usize;
let offset = view.offset;

let data = if buffer_idx == self.completed_buffers.len() {
self.in_progress_buffer.as_slice()
Expand Down
17 changes: 17 additions & 0 deletions crates/polars-arrow/src/array/binview/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,23 @@ impl View {
}
}
}

/// Constructs a byteslice from this view.
///
/// # Safety
/// Assumes that this view is valid for the given buffers.
pub unsafe fn get_slice_unchecked<'a>(&'a self, buffers: &'a [Buffer<u8>]) -> &'a [u8] {
unsafe {
if self.length <= 12 {
let ptr = self as *const View as *const u8;
std::slice::from_raw_parts(ptr.add(4), self.length as usize)
} else {
let data = buffers.get_unchecked_release(self.buffer_idx as usize);
let offset = self.offset as usize;
data.get_unchecked_release(offset..offset + self.length as usize)
}
}
}
}

impl IsNull for View {
Expand Down
45 changes: 21 additions & 24 deletions crates/polars-compute/src/filter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,53 +9,50 @@ mod avx512;
use arrow::array::growable::make_growable;
use arrow::array::{new_empty_array, Array, BinaryViewArray, BooleanArray, PrimitiveArray};
use arrow::bitmap::utils::SlicesIterator;
use arrow::datatypes::ArrowDataType;
use arrow::bitmap::Bitmap;
use arrow::with_match_primitive_type_full;
use polars_error::PolarsResult;

pub fn filter(array: &dyn Array, mask: &BooleanArray) -> PolarsResult<Box<dyn Array>> {
pub fn filter(array: &dyn Array, mask: &BooleanArray) -> Box<dyn Array> {
assert_eq!(array.len(), mask.len());

// Treat null mask values as false.
if let Some(validities) = mask.validity() {
let values = mask.values();
let new_values = values & validities;
let mask = BooleanArray::new(ArrowDataType::Boolean, new_values, None);
return filter(array, &mask);
let combined_mask = mask.values() & validities;
filter_with_bitmap(array, &combined_mask)
} else {
filter_with_bitmap(array, mask.values())
}
}

pub fn filter_with_bitmap(array: &dyn Array, mask: &Bitmap) -> Box<dyn Array> {
// Fast-path: completely empty or completely full mask.
let false_count = mask.values().unset_bits();
let false_count = mask.unset_bits();
if false_count == mask.len() {
return Ok(new_empty_array(array.data_type().clone()));
return new_empty_array(array.data_type().clone());
}
if false_count == 0 {
return Ok(array.to_boxed());
return array.to_boxed();
}

use arrow::datatypes::PhysicalType::*;
match array.data_type().to_physical_type() {
Primitive(primitive) => with_match_primitive_type_full!(primitive, |$T| {
let array: &PrimitiveArray<$T> = array.as_any().downcast_ref().unwrap();
let (values, validity) = primitive::filter_values_and_validity::<$T>(array.values(), array.validity(), mask.values());
Ok(Box::new(PrimitiveArray::from_vec(values).with_validity(validity)))
let (values, validity) = primitive::filter_values_and_validity::<$T>(array.values(), array.validity(), mask);
Box::new(PrimitiveArray::from_vec(values).with_validity(validity))
}),
Boolean => {
let array = array.as_any().downcast_ref::<BooleanArray>().unwrap();
let (values, validity) = boolean::filter_bitmap_and_validity(
array.values(),
array.validity(),
mask.values(),
);
Ok(BooleanArray::new(array.data_type().clone(), values, validity).boxed())
let (values, validity) =
boolean::filter_bitmap_and_validity(array.values(), array.validity(), mask);
BooleanArray::new(array.data_type().clone(), values, validity).boxed()
},
BinaryView => {
let array = array.as_any().downcast_ref::<BinaryViewArray>().unwrap();
let views = array.views();
let validity = array.validity();
let (views, validity) =
primitive::filter_values_and_validity(views, validity, mask.values());
Ok(unsafe {
let (views, validity) = primitive::filter_values_and_validity(views, validity, mask);
unsafe {
BinaryViewArray::new_unchecked_unknown_md(
array.data_type().clone(),
views.into(),
Expand All @@ -64,19 +61,19 @@ pub fn filter(array: &dyn Array, mask: &BooleanArray) -> PolarsResult<Box<dyn Ar
Some(array.total_buffer_len()),
)
}
.boxed())
.boxed()
},
// Should go via BinaryView
Utf8View => {
unreachable!()
},
_ => {
let iter = SlicesIterator::new(mask.values());
let iter = SlicesIterator::new(mask);
let mut mutable = make_growable(&[array], false, iter.slots());
// SAFETY:
// we are in bounds
iter.for_each(|(start, len)| unsafe { mutable.extend(0, start, len) });
Ok(mutable.as_box())
mutable.as_box()
},
}
}
22 changes: 16 additions & 6 deletions crates/polars-core/src/chunked_array/from.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,14 @@ where
unsafe { Self::from_chunks(name, vec![Box::new(arr)]) }
}

pub fn with_chunk_like<A>(ca: &Self, arr: A) -> Self
where
A: Array,
T: PolarsDataType<Array = A>,
{
Self::from_chunk_iter_like(ca, std::iter::once(arr))
}

pub fn from_chunk_iter<I>(name: &str, iter: I) -> Self
where
I: IntoIterator,
Expand Down Expand Up @@ -165,12 +173,14 @@ where
})
.collect();

ChunkedArray::new_with_dims(
field,
chunks,
length.try_into().expect(LENGTH_LIMIT_MSG),
null_count as IdxSize,
)
unsafe {
ChunkedArray::new_with_dims(
field,
chunks,
length.try_into().expect(LENGTH_LIMIT_MSG),
null_count as IdxSize,
)
}
}

/// Create a new [`ChunkedArray`] from existing chunks.
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-core/src/chunked_array/logical/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ impl<K: PolarsDataType, T: PolarsDataType> DerefMut for Logical<K, T> {
}

impl<K: PolarsDataType, T: PolarsDataType> Logical<K, T> {
pub(crate) fn new_logical<J: PolarsDataType>(ca: ChunkedArray<T>) -> Logical<J, T> {
pub fn new_logical<J: PolarsDataType>(ca: ChunkedArray<T>) -> Logical<J, T> {
Logical(ca, PhantomData, None)
}
}
Expand Down
39 changes: 35 additions & 4 deletions crates/polars-core/src/chunked_array/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;

use arrow::array::*;
use arrow::bitmap::Bitmap;
use polars_compute::filter::filter_with_bitmap;

use crate::prelude::*;

Expand Down Expand Up @@ -148,16 +149,21 @@ impl<T: PolarsDataType> ChunkedArray<T> {
/// If you want to explicitly the `length` and `null_count`, look at
/// [`ChunkedArray::new_with_dims`]
pub fn new_with_compute_len(field: Arc<Field>, chunks: Vec<ArrayRef>) -> Self {
let mut chunked_arr = Self::new_with_dims(field, chunks, 0, 0);
chunked_arr.compute_len();
chunked_arr
unsafe {
let mut chunked_arr = Self::new_with_dims(field, chunks, 0, 0);
chunked_arr.compute_len();
chunked_arr
}
}

/// Create a new [`ChunkedArray`] and explicitly set its `length` and `null_count`.
///
/// If you want to compute the `length` and `null_count`, look at
/// [`ChunkedArray::new_with_compute_len`]
pub fn new_with_dims(
///
/// # Safety
/// The length and null_count must be correct.
pub unsafe fn new_with_dims(
field: Arc<Field>,
chunks: Vec<ArrayRef>,
length: IdxSize,
Expand Down Expand Up @@ -424,6 +430,31 @@ impl<T: PolarsDataType> ChunkedArray<T> {
}
}

pub fn drop_nulls(&self) -> Self {
if self.null_count() == 0 {
self.clone()
} else {
let chunks = self
.downcast_iter()
.map(|arr| {
if arr.null_count() == 0 {
arr.to_boxed()
} else {
filter_with_bitmap(arr, arr.validity().unwrap())
}
})
.collect();
unsafe {
Self::new_with_dims(
self.field.clone(),
chunks,
(self.len() - self.null_count()) as IdxSize,
0,
)
}
}
}

/// Get the buffer of bits representing null values
#[inline]
#[allow(clippy::type_complexity)]
Expand Down
10 changes: 7 additions & 3 deletions crates/polars-core/src/chunked_array/object/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ where

self.field.dtype = get_object_type::<T>();

ChunkedArray::new_with_dims(Arc::new(self.field), vec![arr], len as IdxSize, null_count)
unsafe {
ChunkedArray::new_with_dims(Arc::new(self.field), vec![arr], len as IdxSize, null_count)
}
}
}

Expand Down Expand Up @@ -141,7 +143,7 @@ where
len,
});

ObjectChunked::new_with_dims(field, vec![arr], len as IdxSize, 0)
unsafe { ObjectChunked::new_with_dims(field, vec![arr], len as IdxSize, 0) }
}

pub fn new_from_vec_and_validity(name: &str, v: Vec<T>, validity: Bitmap) -> Self {
Expand All @@ -155,7 +157,9 @@ where
len,
});

ObjectChunked::new_with_dims(field, vec![arr], len as IdxSize, null_count as IdxSize)
unsafe {
ObjectChunked::new_with_dims(field, vec![arr], len as IdxSize, null_count as IdxSize)
}
}

pub fn new_empty(name: &str) -> Self {
Expand Down
12 changes: 6 additions & 6 deletions crates/polars-core/src/chunked_array/ops/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ where
arity::binary_unchecked_same_type(
self,
filter,
|left, mask| filter_fn(left, mask).unwrap(),
|left, mask| filter_fn(left, mask),
true,
true,
)
Expand All @@ -53,7 +53,7 @@ impl ChunkFilter<BooleanType> for BooleanChunked {
arity::binary_unchecked_same_type(
self,
filter,
|left, mask| filter_fn(left, mask).unwrap(),
|left, mask| filter_fn(left, mask),
true,
true,
)
Expand Down Expand Up @@ -82,7 +82,7 @@ impl ChunkFilter<BinaryType> for BinaryChunked {
arity::binary_unchecked_same_type(
self,
filter,
|left, mask| filter_fn(left, mask).unwrap(),
|left, mask| filter_fn(left, mask),
true,
true,
)
Expand All @@ -104,7 +104,7 @@ impl ChunkFilter<BinaryOffsetType> for BinaryOffsetChunked {
arity::binary_unchecked_same_type(
self,
filter,
|left, mask| filter_fn(left, mask).unwrap(),
|left, mask| filter_fn(left, mask),
true,
true,
)
Expand All @@ -129,7 +129,7 @@ impl ChunkFilter<ListType> for ListChunked {
arity::binary_unchecked_same_type(
self,
filter,
|left, mask| filter_fn(left, mask).unwrap(),
|left, mask| filter_fn(left, mask),
true,
true,
)
Expand All @@ -155,7 +155,7 @@ impl ChunkFilter<FixedSizeListType> for ArrayChunked {
arity::binary_unchecked_same_type(
self,
filter,
|left, mask| filter_fn(left, mask).unwrap(),
|left, mask| filter_fn(left, mask),
true,
true,
)
Expand Down
Loading

0 comments on commit d26b328

Please sign in to comment.