Skip to content

Commit

Permalink
perf: Optimize string/binary sort
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jun 11, 2024
1 parent 13d68ae commit 1e2b56e
Show file tree
Hide file tree
Showing 5 changed files with 155 additions and 55 deletions.
22 changes: 22 additions & 0 deletions crates/polars-arrow/src/array/binview/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,28 @@ impl<T: ViewType + ?Sized> BinaryViewArrayGeneric<T> {
self.views.make_mut()
}

pub fn into_inner(
self,
) -> (
Buffer<View>,
Arc<[Buffer<u8>]>,
Option<Bitmap>,
usize,
usize,
) {
let views = self.views;
let buffers = self.buffers;
let validity = self.validity;

(
views,
buffers,
validity,
self.total_bytes_len.load(Ordering::Relaxed) as usize,
self.total_buffer_len,
)
}

pub fn try_new(
data_type: ArrowDataType,
views: Buffer<View>,
Expand Down
6 changes: 6 additions & 0 deletions crates/polars-core/src/chunked_array/ops/downcast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ impl<T: PolarsDataType> ChunkedArray<T> {
unsafe { Some(&*(arr as *const dyn Array as *const T::Array)) }
}

#[inline]
pub fn downcast_into_array(self) -> T::Array {
assert_eq!(self.chunks.len(), 1);
self.downcast_get(0).unwrap().clone()
}

#[inline]
/// # Safety
/// It is up to the caller to ensure the chunk idx is in-bounds
Expand Down
7 changes: 1 addition & 6 deletions crates/polars-core/src/chunked_array/ops/sort/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@ impl CategoricalChunked {
.zip(self.iter_str())
.collect_trusted::<Vec<_>>();

sort_unstable_by_branch(
vals.as_mut_slice(),
options.descending,
|a, b| a.1.cmp(&b.1),
options.multithreaded,
);
sort_unstable_by_branch(vals.as_mut_slice(), options, |a, b| a.1.cmp(&b.1));
let cats: UInt32Chunked = vals
.into_iter()
.map(|(idx, _v)| idx)
Expand Down
129 changes: 80 additions & 49 deletions crates/polars-core/src/chunked_array/ops/sort/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ mod categorical;
use std::cmp::Ordering;

pub(crate) use arg_sort_multiple::argsort_multiple_row_fmt;
use arrow::bitmap::MutableBitmap;
use arrow::bitmap::{Bitmap, MutableBitmap};
use arrow::buffer::Buffer;
use arrow::legacy::trusted_len::TrustedLenPush;
use compare_inner::NonNull;
Expand Down Expand Up @@ -43,18 +43,18 @@ where
}
}

fn sort_unstable_by_branch<T, C>(slice: &mut [T], descending: bool, cmp: C, parallel: bool)
fn sort_unstable_by_branch<T, C>(slice: &mut [T], options: SortOptions, cmp: C)
where
T: Send,
C: Send + Sync + Fn(&T, &T) -> Ordering,
{
if parallel {
POOL.install(|| match descending {
if options.multithreaded {
POOL.install(|| match options.descending {
true => slice.par_sort_unstable_by(|a, b| cmp(b, a)),
false => slice.par_sort_unstable_by(cmp),
})
} else {
match descending {
match options.descending {
true => slice.sort_unstable_by(|a, b| cmp(b, a)),
false => slice.sort_unstable_by(cmp),
}
Expand All @@ -66,12 +66,19 @@ fn sort_impl_unstable<T>(vals: &mut [T], options: SortOptions)
where
T: TotalOrd + Send + Sync,
{
sort_unstable_by_branch(
vals,
options.descending,
TotalOrd::tot_cmp,
options.multithreaded,
);
sort_unstable_by_branch(vals, options, TotalOrd::tot_cmp);
}

fn create_validity(len: usize, null_count: usize, nulls_last: bool) -> Bitmap {
let mut validity = MutableBitmap::with_capacity(len);
if nulls_last {
validity.extend_constant(len - null_count, true);
validity.extend_constant(null_count, false);
} else {
validity.extend_constant(null_count, false);
validity.extend_constant(len - null_count, true);
}
validity.into()
}

macro_rules! sort_with_fast_path {
Expand Down Expand Up @@ -148,20 +155,14 @@ where

sort_impl_unstable(mut_slice, options);

let mut validity = MutableBitmap::with_capacity(len);
if options.nulls_last {
vals.extend(std::iter::repeat(T::Native::default()).take(ca.null_count()));
validity.extend_constant(len - null_count, true);
validity.extend_constant(null_count, false);
} else {
validity.extend_constant(null_count, false);
validity.extend_constant(len - null_count, true);
};
}

let arr = PrimitiveArray::new(
T::get_dtype().to_arrow(true),
vals.into(),
Some(validity.into()),
Some(create_validity(len, null_count, options.nulls_last)),
);
let mut new_ca = ChunkedArray::with_chunk(ca.name(), arr);
let s = if options.descending {
Expand Down Expand Up @@ -314,37 +315,73 @@ impl ChunkSort<StringType> for StringChunked {
impl ChunkSort<BinaryType> for BinaryChunked {
fn sort_with(&self, options: SortOptions) -> ChunkedArray<BinaryType> {
sort_with_fast_path!(self, options);
// We will sort by the views and reconstruct with sorted views. We leave the buffers as is.
// We must rechunk to ensure that all views point into the proper buffers.
let ca = self.rechunk();
let arr = ca.downcast_into_array();

let (views, buffers, mut validity, total_bytes_len, total_buffer_len) = arr.into_inner();
let mut views = views.make_mut();

let partitioned_part = if let Some(bitmap) = &validity {
// Partition null last first
let mut out_len = 0;
for idx in bitmap.true_idx_iter() {
unsafe { *views.get_unchecked_mut(out_len) = *views.get_unchecked(idx) };
out_len += 1;
}
let valid_count = out_len;
let null_count = views.len() - valid_count;
validity = Some(create_validity(
bitmap.len(),
bitmap.unset_bits(),
options.nulls_last,
));

// Views are correctly partitioned.
if options.nulls_last {
&mut views[..valid_count]
}
// We need to swap the ends.
else {
// swap nulls with end
let mut end = views.len() - 1;

for i in 0..null_count {
unsafe { *views.get_unchecked_mut(end) = *views.get_unchecked(i) };
end -= 1;
}
&mut views[null_count..]
}
} else {
views.as_mut_slice()
};

let mut v: Vec<&[u8]> = Vec::with_capacity(self.len());
for arr in self.downcast_iter() {
v.extend(arr.non_null_values_iter());
}
sort_impl_unstable(v.as_mut_slice(), options);
sort_unstable_by_branch(partitioned_part, options, |a, b| unsafe {
a.get_slice_unchecked(&buffers)
.tot_cmp(&b.get_slice_unchecked(&buffers))
});

let len = self.len();
let null_count = self.null_count();
let mut mutable = MutableBinaryViewArray::with_capacity(len);
let array = unsafe {
BinaryViewArray::new_unchecked(
ArrowDataType::BinaryView,
views.into(),
buffers,
validity,
total_bytes_len,
total_buffer_len,
)
};

if options.nulls_last {
for row in v {
mutable.push_value_ignore_validity(row)
}
mutable.extend_null(null_count);
} else {
mutable.extend_null(null_count);
for row in v {
mutable.push_value(row)
}
}
let mut ca = ChunkedArray::with_chunk(self.name(), mutable.into());
let mut out = Self::with_chunk_like(self, array);

let s = if options.descending {
IsSorted::Descending
} else {
IsSorted::Ascending
};
ca.set_sorted_flag(s);
ca
out.set_sorted_flag(s);
out
}

fn sort(&self, descending: bool) -> ChunkedArray<BinaryType> {
Expand Down Expand Up @@ -434,25 +471,19 @@ impl ChunkSort<BinaryOffsetType> for BinaryOffsetChunked {
length_so_far = values.len() as i64;
offsets.push(length_so_far);
}
let mut validity = MutableBitmap::with_capacity(len);
validity.extend_constant(len - null_count, true);
validity.extend_constant(null_count, false);
offsets.extend(std::iter::repeat(length_so_far).take(null_count));

// SAFETY: offsets are correctly created.
let arr = unsafe {
BinaryArray::from_data_unchecked_default(
offsets.into(),
values.into(),
Some(validity.into()),
Some(create_validity(len, null_count, true)),
)
};
ChunkedArray::with_chunk(self.name(), arr)
},
(_, false) => {
let mut validity = MutableBitmap::with_capacity(len);
validity.extend_constant(null_count, false);
validity.extend_constant(len - null_count, true);
offsets.extend(std::iter::repeat(length_so_far).take(null_count));

for val in v {
Expand All @@ -466,7 +497,7 @@ impl ChunkSort<BinaryOffsetType> for BinaryOffsetChunked {
BinaryArray::from_data_unchecked_default(
offsets.into(),
values.into(),
Some(validity.into()),
Some(create_validity(len, null_count, false)),
)
};
ChunkedArray::with_chunk(self.name(), arr)
Expand Down
46 changes: 46 additions & 0 deletions py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -1012,3 +1012,49 @@ def test_sort_chunked_no_nulls() -> None:
0,
2,
]


def test_sort_string_nulls() -> None:
str_series = pl.Series(
"b", ["a", None, "c", None, "x", "z", "y", None], dtype=pl.String
)
assert str_series.sort(descending=False, nulls_last=False).to_list() == [
None,
None,
None,
"a",
"c",
"x",
"y",
"z",
]
assert str_series.sort(descending=True, nulls_last=False).to_list() == [
None,
None,
None,
"z",
"y",
"x",
"c",
"a",
]
assert str_series.sort(descending=True, nulls_last=True).to_list() == [
"z",
"y",
"x",
"c",
"a",
None,
None,
None,
]
assert str_series.sort(descending=False, nulls_last=True).to_list() == [
"a",
"c",
"x",
"y",
"z",
None,
None,
None,
]

0 comments on commit 1e2b56e

Please sign in to comment.