From 1e2b56e0cbca3e6a840f80eff16abbf610a09895 Mon Sep 17 00:00:00 2001 From: ritchie Date: Tue, 11 Jun 2024 11:02:40 +0200 Subject: [PATCH] perf: Optimize string/binary sort --- crates/polars-arrow/src/array/binview/mod.rs | 22 +++ .../src/chunked_array/ops/downcast.rs | 6 + .../src/chunked_array/ops/sort/categorical.rs | 7 +- .../src/chunked_array/ops/sort/mod.rs | 129 +++++++++++------- py-polars/tests/unit/operations/test_sort.py | 46 +++++++ 5 files changed, 155 insertions(+), 55 deletions(-) diff --git a/crates/polars-arrow/src/array/binview/mod.rs b/crates/polars-arrow/src/array/binview/mod.rs index 521044368caa..deeda0df6c08 100644 --- a/crates/polars-arrow/src/array/binview/mod.rs +++ b/crates/polars-arrow/src/array/binview/mod.rs @@ -209,6 +209,28 @@ impl BinaryViewArrayGeneric { self.views.make_mut() } + pub fn into_inner( + self, + ) -> ( + Buffer, + Arc<[Buffer]>, + Option, + usize, + usize, + ) { + let views = self.views; + let buffers = self.buffers; + let validity = self.validity; + + ( + views, + buffers, + validity, + self.total_bytes_len.load(Ordering::Relaxed) as usize, + self.total_buffer_len, + ) + } + pub fn try_new( data_type: ArrowDataType, views: Buffer, diff --git a/crates/polars-core/src/chunked_array/ops/downcast.rs b/crates/polars-core/src/chunked_array/ops/downcast.rs index e94984e3f894..a029f7f05cfb 100644 --- a/crates/polars-core/src/chunked_array/ops/downcast.rs +++ b/crates/polars-core/src/chunked_array/ops/downcast.rs @@ -106,6 +106,12 @@ impl ChunkedArray { unsafe { Some(&*(arr as *const dyn Array as *const T::Array)) } } + #[inline] + pub fn downcast_into_array(self) -> T::Array { + assert_eq!(self.chunks.len(), 1); + self.downcast_get(0).unwrap().clone() + } + #[inline] /// # Safety /// It is up to the caller to ensure the chunk idx is in-bounds diff --git a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs index 3b2c67db8eb2..ff8133dcef7b 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs @@ -15,12 +15,7 @@ impl CategoricalChunked { .zip(self.iter_str()) .collect_trusted::>(); - sort_unstable_by_branch( - vals.as_mut_slice(), - options.descending, - |a, b| a.1.cmp(&b.1), - options.multithreaded, - ); + sort_unstable_by_branch(vals.as_mut_slice(), options, |a, b| a.1.cmp(&b.1)); let cats: UInt32Chunked = vals .into_iter() .map(|(idx, _v)| idx) diff --git a/crates/polars-core/src/chunked_array/ops/sort/mod.rs b/crates/polars-core/src/chunked_array/ops/sort/mod.rs index 857a65c1f257..f8689335b13f 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/mod.rs @@ -11,7 +11,7 @@ mod categorical; use std::cmp::Ordering; pub(crate) use arg_sort_multiple::argsort_multiple_row_fmt; -use arrow::bitmap::MutableBitmap; +use arrow::bitmap::{Bitmap, MutableBitmap}; use arrow::buffer::Buffer; use arrow::legacy::trusted_len::TrustedLenPush; use compare_inner::NonNull; @@ -43,18 +43,18 @@ where } } -fn sort_unstable_by_branch(slice: &mut [T], descending: bool, cmp: C, parallel: bool) +fn sort_unstable_by_branch(slice: &mut [T], options: SortOptions, cmp: C) where T: Send, C: Send + Sync + Fn(&T, &T) -> Ordering, { - if parallel { - POOL.install(|| match descending { + if options.multithreaded { + POOL.install(|| match options.descending { true => slice.par_sort_unstable_by(|a, b| cmp(b, a)), false => slice.par_sort_unstable_by(cmp), }) } else { - match descending { + match options.descending { true => slice.sort_unstable_by(|a, b| cmp(b, a)), false => slice.sort_unstable_by(cmp), } @@ -66,12 +66,19 @@ fn sort_impl_unstable(vals: &mut [T], options: SortOptions) where T: TotalOrd + Send + Sync, { - sort_unstable_by_branch( - vals, - options.descending, - TotalOrd::tot_cmp, - options.multithreaded, - ); + sort_unstable_by_branch(vals, options, TotalOrd::tot_cmp); +} + +fn create_validity(len: usize, null_count: usize, nulls_last: bool) -> Bitmap { + let mut validity = MutableBitmap::with_capacity(len); + if nulls_last { + validity.extend_constant(len - null_count, true); + validity.extend_constant(null_count, false); + } else { + validity.extend_constant(null_count, false); + validity.extend_constant(len - null_count, true); + } + validity.into() } macro_rules! sort_with_fast_path { @@ -148,20 +155,14 @@ where sort_impl_unstable(mut_slice, options); - let mut validity = MutableBitmap::with_capacity(len); if options.nulls_last { vals.extend(std::iter::repeat(T::Native::default()).take(ca.null_count())); - validity.extend_constant(len - null_count, true); - validity.extend_constant(null_count, false); - } else { - validity.extend_constant(null_count, false); - validity.extend_constant(len - null_count, true); - }; + } let arr = PrimitiveArray::new( T::get_dtype().to_arrow(true), vals.into(), - Some(validity.into()), + Some(create_validity(len, null_count, options.nulls_last)), ); let mut new_ca = ChunkedArray::with_chunk(ca.name(), arr); let s = if options.descending { @@ -314,37 +315,73 @@ impl ChunkSort for StringChunked { impl ChunkSort for BinaryChunked { fn sort_with(&self, options: SortOptions) -> ChunkedArray { sort_with_fast_path!(self, options); + // We will sort by the views and reconstruct with sorted views. We leave the buffers as is. + // We must rechunk to ensure that all views point into the proper buffers. + let ca = self.rechunk(); + let arr = ca.downcast_into_array(); + + let (views, buffers, mut validity, total_bytes_len, total_buffer_len) = arr.into_inner(); + let mut views = views.make_mut(); + + let partitioned_part = if let Some(bitmap) = &validity { + // Partition null last first + let mut out_len = 0; + for idx in bitmap.true_idx_iter() { + unsafe { *views.get_unchecked_mut(out_len) = *views.get_unchecked(idx) }; + out_len += 1; + } + let valid_count = out_len; + let null_count = views.len() - valid_count; + validity = Some(create_validity( + bitmap.len(), + bitmap.unset_bits(), + options.nulls_last, + )); + + // Views are correctly partitioned. + if options.nulls_last { + &mut views[..valid_count] + } + // We need to swap the ends. + else { + // swap nulls with end + let mut end = views.len() - 1; + + for i in 0..null_count { + unsafe { *views.get_unchecked_mut(end) = *views.get_unchecked(i) }; + end -= 1; + } + &mut views[null_count..] + } + } else { + views.as_mut_slice() + }; - let mut v: Vec<&[u8]> = Vec::with_capacity(self.len()); - for arr in self.downcast_iter() { - v.extend(arr.non_null_values_iter()); - } - sort_impl_unstable(v.as_mut_slice(), options); + sort_unstable_by_branch(partitioned_part, options, |a, b| unsafe { + a.get_slice_unchecked(&buffers) + .tot_cmp(&b.get_slice_unchecked(&buffers)) + }); - let len = self.len(); - let null_count = self.null_count(); - let mut mutable = MutableBinaryViewArray::with_capacity(len); + let array = unsafe { + BinaryViewArray::new_unchecked( + ArrowDataType::BinaryView, + views.into(), + buffers, + validity, + total_bytes_len, + total_buffer_len, + ) + }; - if options.nulls_last { - for row in v { - mutable.push_value_ignore_validity(row) - } - mutable.extend_null(null_count); - } else { - mutable.extend_null(null_count); - for row in v { - mutable.push_value(row) - } - } - let mut ca = ChunkedArray::with_chunk(self.name(), mutable.into()); + let mut out = Self::with_chunk_like(self, array); let s = if options.descending { IsSorted::Descending } else { IsSorted::Ascending }; - ca.set_sorted_flag(s); - ca + out.set_sorted_flag(s); + out } fn sort(&self, descending: bool) -> ChunkedArray { @@ -434,9 +471,6 @@ impl ChunkSort for BinaryOffsetChunked { length_so_far = values.len() as i64; offsets.push(length_so_far); } - let mut validity = MutableBitmap::with_capacity(len); - validity.extend_constant(len - null_count, true); - validity.extend_constant(null_count, false); offsets.extend(std::iter::repeat(length_so_far).take(null_count)); // SAFETY: offsets are correctly created. @@ -444,15 +478,12 @@ impl ChunkSort for BinaryOffsetChunked { BinaryArray::from_data_unchecked_default( offsets.into(), values.into(), - Some(validity.into()), + Some(create_validity(len, null_count, true)), ) }; ChunkedArray::with_chunk(self.name(), arr) }, (_, false) => { - let mut validity = MutableBitmap::with_capacity(len); - validity.extend_constant(null_count, false); - validity.extend_constant(len - null_count, true); offsets.extend(std::iter::repeat(length_so_far).take(null_count)); for val in v { @@ -466,7 +497,7 @@ impl ChunkSort for BinaryOffsetChunked { BinaryArray::from_data_unchecked_default( offsets.into(), values.into(), - Some(validity.into()), + Some(create_validity(len, null_count, false)), ) }; ChunkedArray::with_chunk(self.name(), arr) diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index 0ff9c240b6b5..3586c968255e 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -1012,3 +1012,49 @@ def test_sort_chunked_no_nulls() -> None: 0, 2, ] + + +def test_sort_string_nulls() -> None: + str_series = pl.Series( + "b", ["a", None, "c", None, "x", "z", "y", None], dtype=pl.String + ) + assert str_series.sort(descending=False, nulls_last=False).to_list() == [ + None, + None, + None, + "a", + "c", + "x", + "y", + "z", + ] + assert str_series.sort(descending=True, nulls_last=False).to_list() == [ + None, + None, + None, + "z", + "y", + "x", + "c", + "a", + ] + assert str_series.sort(descending=True, nulls_last=True).to_list() == [ + "z", + "y", + "x", + "c", + "a", + None, + None, + None, + ] + assert str_series.sort(descending=False, nulls_last=True).to_list() == [ + "a", + "c", + "x", + "y", + "z", + None, + None, + None, + ]