diff --git a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs index a5477fccb376..21a4bfb96a6a 100644 --- a/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs +++ b/crates/polars-core/src/chunked_array/logical/categorical/ops/unique.rs @@ -66,6 +66,9 @@ impl CategoricalChunked { counts.rename("counts"); let cols = vec![values.into_series(), counts.into_series()]; let df = unsafe { DataFrame::new_no_checks(cols) }; - df.sort(["counts"], true, false) + df.sort( + ["counts"], + SortMultipleOptions::default().with_order_descending(true), + ) } } diff --git a/crates/polars-core/src/chunked_array/ops/mod.rs b/crates/polars-core/src/chunked_array/ops/mod.rs index 9c272c8629c9..f4f81f4b7f42 100644 --- a/crates/polars-core/src/chunked_array/ops/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/mod.rs @@ -42,6 +42,7 @@ pub mod zip; #[cfg(feature = "serde-lazy")] use serde::{Deserialize, Serialize}; +pub use sort::options::*; use crate::series::IsSorted; @@ -357,35 +358,6 @@ pub trait ChunkUnique { } } -#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)] -#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] -pub struct SortOptions { - pub descending: bool, - pub nulls_last: bool, - pub multithreaded: bool, - pub maintain_order: bool, -} - -#[derive(Clone)] -#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] -pub struct SortMultipleOptions { - pub other: Vec, - pub descending: Vec, - pub nulls_last: bool, - pub multithreaded: bool, -} - -impl Default for SortOptions { - fn default() -> Self { - Self { - descending: false, - nulls_last: false, - multithreaded: true, - maintain_order: false, - } - } -} - /// Sort operations on `ChunkedArray`. pub trait ChunkSort { #[allow(unused_variables)] @@ -398,7 +370,12 @@ pub trait ChunkSort { fn arg_sort(&self, options: SortOptions) -> IdxCa; /// Retrieve the indexes need to sort this and the other arrays. - fn arg_sort_multiple(&self, _options: &SortMultipleOptions) -> PolarsResult { + #[allow(unused_variables)] + fn arg_sort_multiple( + &self, + by: &[Series], + _options: &SortMultipleOptions, + ) -> PolarsResult { polars_bail!(opq = arg_sort_multiple, T::get_dtype()); } } diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs new file mode 100644 index 000000000000..ab4b82e4b78a --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs @@ -0,0 +1,81 @@ +use polars_utils::iter::EnumerateIdxTrait; + +use super::*; + +#[derive(Eq)] +struct CompareRow<'a> { + idx: IdxSize, + bytes: &'a [u8], +} + +impl PartialEq for CompareRow<'_> { + fn eq(&self, other: &Self) -> bool { + self.bytes == other.bytes + } +} + +impl Ord for CompareRow<'_> { + fn cmp(&self, other: &Self) -> Ordering { + self.bytes.cmp(other.bytes) + } +} + +impl PartialOrd for CompareRow<'_> { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +pub fn _arg_bottom_k( + k: usize, + by_column: &[Series], + sort_options: &mut SortMultipleOptions, +) -> PolarsResult> { + let from_n_rows = by_column[0].len(); + _broadcast_descending(by_column.len(), &mut sort_options.descending); + let encoded = _get_rows_encoded(by_column, &sort_options.descending, sort_options.nulls_last)?; + let arr = encoded.into_array(); + let mut rows = arr + .values_iter() + .enumerate_idx() + .map(|(idx, bytes)| CompareRow { idx, bytes }) + .collect::>(); + + let sorted = if k >= from_n_rows { + match (sort_options.multithreaded, sort_options.maintain_order) { + (true, true) => POOL.install(|| { + rows.par_sort(); + }), + (true, false) => POOL.install(|| { + rows.par_sort_unstable(); + }), + (false, true) => rows.sort(), + (false, false) => rows.sort_unstable(), + } + &rows + } else if sort_options.maintain_order { + // todo: maybe there is some more efficient method, comparable to select_nth_unstable + if sort_options.multithreaded { + POOL.install(|| { + rows.par_sort(); + }) + } else { + rows.sort(); + } + &rows[..k] + } else { + // todo: possible multi threaded `select_nth_unstable`? + let (lower, _el, _upper) = rows.select_nth_unstable(k); + if sort_options.multithreaded { + POOL.install(|| { + lower.par_sort_unstable(); + }) + } else { + lower.sort_unstable(); + } + &*lower + }; + + let idx: NoNull = sorted.iter().map(|cmp_row| cmp_row.idx).collect(); + Ok(idx) +} diff --git a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs index 170639b364be..ca6651282259 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/arg_sort_multiple.rs @@ -24,46 +24,57 @@ pub(crate) fn args_validate( pub(crate) fn arg_sort_multiple_impl( mut vals: Vec<(IdxSize, T)>, + by: &[Series], options: &SortMultipleOptions, ) -> PolarsResult { let descending = &options.descending; - debug_assert_eq!(descending.len() - 1, options.other.len()); - let compare_inner: Vec<_> = options - .other + debug_assert_eq!(descending.len() - 1, by.len()); + let compare_inner: Vec<_> = by .iter() .map(|s| s.into_total_ord_inner()) .collect_trusted(); let first_descending = descending[0]; - POOL.install(|| { - vals.par_sort_by(|tpl_a, tpl_b| { - match ( - first_descending, - tpl_a - .1 - .null_order_cmp(&tpl_b.1, options.nulls_last ^ first_descending), - ) { - // if ordering is equal, we check the other arrays until we find a non-equal ordering - // if we have exhausted all arrays, we keep the equal ordering. - (_, Ordering::Equal) => { - let idx_a = tpl_a.0 as usize; - let idx_b = tpl_b.0 as usize; - unsafe { - ordering_other_columns( - &compare_inner, - descending.get_unchecked(1..), - options.nulls_last, - idx_a, - idx_b, - ) - } - }, - (true, Ordering::Less) => Ordering::Greater, - (true, Ordering::Greater) => Ordering::Less, - (_, ord) => ord, - } - }); - }); + + let compare = |tpl_a: &(_, T), tpl_b: &(_, T)| -> Ordering { + match ( + first_descending, + tpl_a + .1 + .null_order_cmp(&tpl_b.1, options.nulls_last ^ first_descending), + ) { + // if ordering is equal, we check the other arrays until we find a non-equal ordering + // if we have exhausted all arrays, we keep the equal ordering. + (_, Ordering::Equal) => { + let idx_a = tpl_a.0 as usize; + let idx_b = tpl_b.0 as usize; + unsafe { + ordering_other_columns( + &compare_inner, + descending.get_unchecked(1..), + options.nulls_last, + idx_a, + idx_b, + ) + } + }, + (true, Ordering::Less) => Ordering::Greater, + (true, Ordering::Greater) => Ordering::Less, + (_, ord) => ord, + } + }; + + match (options.multithreaded, options.maintain_order) { + (true, true) => POOL.install(|| { + vals.par_sort_by(compare); + }), + (true, false) => POOL.install(|| { + vals.par_sort_unstable_by(compare); + }), + (false, true) => vals.sort_by(compare), + (false, false) => vals.sort_unstable_by(compare), + } + let ca: NoNull = vals.into_iter().map(|(idx, _v)| idx).collect_trusted(); // Don't set to sorted. Argsort indices are not sorted. Ok(ca.into_inner()) diff --git a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs index 28c5c2962616..0594421f62d3 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/categorical.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/categorical.rs @@ -79,9 +79,13 @@ impl CategoricalChunked { /// Retrieve the indexes need to sort this and the other arrays. - pub(crate) fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { + pub(crate) fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { if self.uses_lexical_ordering() { - args_validate(self.physical(), &options.other, &options.descending)?; + args_validate(self.physical(), by, &options.descending)?; let mut count: IdxSize = 0; // we use bytes to save a monomorphisized str impl @@ -95,9 +99,9 @@ impl CategoricalChunked { }) .collect_trusted(); - arg_sort_multiple_impl(vals, options) + arg_sort_multiple_impl(vals, by, options) } else { - self.physical().arg_sort_multiple(options) + self.physical().arg_sort_multiple(by, options) } } } @@ -171,12 +175,18 @@ mod test { "vals" => [1, 1, 2, 2] ]?; - let out = df.sort(["cat", "vals"], vec![false, false], false)?; + let out = df.sort( + ["cat", "vals"], + SortMultipleOptions::default().with_order_descendings([false, false]), + )?; let out = out.column("cat")?; let cat = out.categorical()?; assert_order(cat, &["a", "a", "b", "c"]); - let out = df.sort(["vals", "cat"], vec![false, false], false)?; + let out = df.sort( + ["vals", "cat"], + SortMultipleOptions::default().with_order_descendings([false, false]), + )?; let out = out.column("cat")?; let cat = out.categorical()?; assert_order(cat, &["b", "c", "a", "a"]); diff --git a/crates/polars-core/src/chunked_array/ops/sort/mod.rs b/crates/polars-core/src/chunked_array/ops/sort/mod.rs index 2611227031f7..e79292b3ac33 100644 --- a/crates/polars-core/src/chunked_array/ops/sort/mod.rs +++ b/crates/polars-core/src/chunked_array/ops/sort/mod.rs @@ -1,6 +1,10 @@ mod arg_sort; pub mod arg_sort_multiple; + +pub mod arg_bottom_k; +pub mod options; + #[cfg(feature = "dtype-categorical")] mod categorical; @@ -207,9 +211,10 @@ where fn arg_sort_multiple_numeric( ca: &ChunkedArray, + by: &[Series], options: &SortMultipleOptions, ) -> PolarsResult { - args_validate(ca, &options.other, &options.descending)?; + args_validate(ca, by, &options.descending)?; let mut count: IdxSize = 0; let no_nulls = ca.null_count() == 0; @@ -223,7 +228,7 @@ fn arg_sort_multiple_numeric( (i, NonNull(*v)) })) } - arg_sort_multiple_impl(vals, options) + arg_sort_multiple_impl(vals, by, options) } else { let mut vals = Vec::with_capacity(ca.len()); for arr in ca.downcast_iter() { @@ -233,7 +238,7 @@ fn arg_sort_multiple_numeric( (i, v.copied()) })); } - arg_sort_multiple_impl(vals, options) + arg_sort_multiple_impl(vals, by, options) } } @@ -260,8 +265,12 @@ where /// /// This function is very opinionated. /// We assume that all numeric `Series` are of the same type, if not it will panic - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { - arg_sort_multiple_numeric(self, options) + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { + arg_sort_multiple_numeric(self, by, options) } } @@ -312,8 +321,12 @@ impl ChunkSort for StringChunked { /// In this case we assume that all numeric `Series` are `f64` types. The caller needs to /// uphold this contract. If not, it will panic. /// - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { - self.as_binary().arg_sort_multiple(options) + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.as_binary().arg_sort_multiple(by, options) } } @@ -377,8 +390,12 @@ impl ChunkSort for BinaryChunked { ) } - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { - args_validate(self, &options.other, &options.descending)?; + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { + args_validate(self, by, &options.descending)?; let mut count: IdxSize = 0; @@ -391,7 +408,7 @@ impl ChunkSort for BinaryChunked { } } - arg_sort_multiple_impl(vals, options) + arg_sort_multiple_impl(vals, by, options) } } @@ -511,8 +528,12 @@ impl ChunkSort for BinaryOffsetChunked { /// /// In this case we assume that all numeric `Series` are `f64` types. The caller needs to /// uphold this contract. If not, it will panic. - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { - args_validate(self, &options.other, &options.descending)?; + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { + args_validate(self, by, &options.descending)?; let mut count: IdxSize = 0; @@ -525,7 +546,7 @@ impl ChunkSort for BinaryOffsetChunked { } } - arg_sort_multiple_impl(vals, options) + arg_sort_multiple_impl(vals, by, options) } } @@ -597,7 +618,11 @@ impl ChunkSort for BooleanChunked { self.len(), ) } - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { let mut vals = Vec::with_capacity(self.len()); let mut count: IdxSize = 0; for arr in self.downcast_iter() { @@ -607,7 +632,7 @@ impl ChunkSort for BooleanChunked { (i, v.map(|v| v as u8)) })); } - arg_sort_multiple_impl(vals, options) + arg_sort_multiple_impl(vals, by, options) } } @@ -655,8 +680,8 @@ pub fn _broadcast_descending(n_cols: usize, descending: &mut Vec) { pub(crate) fn prepare_arg_sort( columns: Vec, - mut descending: Vec, -) -> PolarsResult<(Series, Vec, Vec)> { + sort_options: &mut SortMultipleOptions, +) -> PolarsResult<(Series, Vec)> { let n_cols = columns.len(); let mut columns = columns @@ -667,8 +692,8 @@ pub(crate) fn prepare_arg_sort( let first = columns.remove(0); // broadcast ordering - _broadcast_descending(n_cols, &mut descending); - Ok((first, columns, descending)) + _broadcast_descending(n_cols, &mut sort_options.descending); + Ok((first, columns)) } #[cfg(test)] @@ -772,7 +797,7 @@ mod test { let c = StringChunked::new("c", &["a", "b", "c", "d", "e", "f", "g", "h"]); let df = DataFrame::new(vec![a.into_series(), b.into_series(), c.into_series()])?; - let out = df.sort(["a", "b", "c"], false, false)?; + let out = df.sort(["a", "b", "c"], SortMultipleOptions::default())?; assert_eq!( Vec::from(out.column("b")?.i64()?), &[ @@ -792,7 +817,7 @@ mod test { let b = Int32Chunked::new("b", &[5, 4, 2, 3, 4, 5]).into_series(); let df = DataFrame::new(vec![a, b])?; - let out = df.sort(["a", "b"], false, false)?; + let out = df.sort(["a", "b"], SortMultipleOptions::default())?; let expected = df!( "a" => ["a", "a", "b", "b", "c", "c"], "b" => [3, 5, 4, 4, 2, 5] @@ -804,14 +829,20 @@ mod test { "values" => ["a", "a", "b"] )?; - let out = df.sort(["groups", "values"], vec![true, false], false)?; + let out = df.sort( + ["groups", "values"], + SortMultipleOptions::default().with_order_descendings([true, false]), + )?; let expected = df!( "groups" => [3, 2, 1], "values" => ["b", "a", "a"] )?; assert!(out.equals(&expected)); - let out = df.sort(["values", "groups"], vec![false, true], false)?; + let out = df.sort( + ["values", "groups"], + SortMultipleOptions::default().with_order_descendings([false, true]), + )?; let expected = df!( "groups" => [2, 1, 3], "values" => ["a", "a", "b"] diff --git a/crates/polars-core/src/chunked_array/ops/sort/options.rs b/crates/polars-core/src/chunked_array/ops/sort/options.rs new file mode 100644 index 000000000000..d9bb5e89a884 --- /dev/null +++ b/crates/polars-core/src/chunked_array/ops/sort/options.rs @@ -0,0 +1,221 @@ +#[cfg(feature = "serde-lazy")] +use serde::{Deserialize, Serialize}; +pub use slice::*; + +use crate::prelude::*; + +/// Options for single series sorting. +/// +/// Indicating the order of sorting, nulls position, multithreading, and maintaining order. +/// +/// # Example +/// +/// ``` +/// # use polars_core::prelude::*; +/// let s = Series::new("a", [Some(5), Some(2), Some(3), Some(4), None].as_ref()); +/// let sorted = s +/// .sort( +/// SortOptions::default() +/// .with_order_descending(true) +/// .with_nulls_last(true) +/// .with_multithreaded(false), +/// ) +/// .unwrap(); +/// assert_eq!( +/// sorted, +/// Series::new("a", [Some(5), Some(4), Some(3), Some(2), None].as_ref()) +/// ); +/// ``` +#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)] +#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] +pub struct SortOptions { + /// If true sort in descending order. + /// Default `false`. + pub descending: bool, + /// Whether place null values last. + /// Default `false`. + pub nulls_last: bool, + /// If true sort in multiple threads. + /// Default `true`. + pub multithreaded: bool, + /// If true maintain the order of equal elements. + /// Default `false`. + pub maintain_order: bool, +} + +/// Sort options for multi-series sorting. +/// +/// Indicating the order of sorting, nulls position, multithreading, and maintaining order. +/// +/// # Example +/// ``` +/// # use polars_core::prelude::*; +/// +/// # fn main() -> PolarsResult<()> { +/// let df = df! { +/// "a" => [Some(1), Some(2), None, Some(4), None], +/// "b" => [Some(5), None, Some(3), Some(2), Some(1)] +/// }?; +/// +/// let out = df +/// .sort( +/// ["a", "b"], +/// SortMultipleOptions::default() +/// .with_maintain_order(true) +/// .with_multithreaded(false) +/// .with_order_descendings([false, true]) +/// .with_nulls_last(true), +/// )?; +/// +/// let expected = df! { +/// "a" => [Some(1), Some(2), Some(4), None, None], +/// "b" => [Some(5), None, Some(2), Some(3), Some(1)] +/// }?; +/// +/// assert_eq!(out, expected); +/// +/// # Ok(()) +/// # } +#[derive(Clone, Debug, Eq, PartialEq, Hash)] +#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))] +pub struct SortMultipleOptions { + /// Order of the columns. Default all `false``. + /// + /// If only one value is given, it will broadcast to all columns. + /// + /// Use [`SortMultipleOptions::with_order_descendings`] + /// or [`SortMultipleOptions::with_order_descending`] to modify. + /// + /// # Safety + /// + /// Len must matches the number of columns or equal to 1. + pub descending: Vec, + /// Whether place null values last. Default `false`. + pub nulls_last: bool, + /// Whether sort in multiple threads. Default `true`. + pub multithreaded: bool, + /// Whether maintain the order of equal elements. Default `false`. + pub maintain_order: bool, +} + +impl Default for SortOptions { + fn default() -> Self { + Self { + descending: false, + nulls_last: false, + multithreaded: true, + maintain_order: false, + } + } +} + +impl Default for SortMultipleOptions { + fn default() -> Self { + Self { + descending: vec![false], + nulls_last: false, + multithreaded: true, + maintain_order: false, + } + } +} + +impl SortMultipleOptions { + /// Create `SortMultipleOptions` with default values. + pub fn new() -> Self { + Self::default() + } + + /// Specify order for each columns. Default all `false`. + /// + /// # Safety + /// + /// Len must matches the number of columns or equal to 1. + pub fn with_order_descendings(mut self, descending: impl IntoIterator) -> Self { + self.descending = descending.into_iter().collect(); + self + } + + /// Implement order for all columns. Default `false`. + pub fn with_order_descending(mut self, descending: bool) -> Self { + self.descending = vec![descending]; + self + } + + /// Whether place null values last. Default `false`. + pub fn with_nulls_last(mut self, enabled: bool) -> Self { + self.nulls_last = enabled; + self + } + + /// Whether sort in multiple threads. Default `true`. + pub fn with_multithreaded(mut self, enabled: bool) -> Self { + self.multithreaded = enabled; + self + } + + /// Whether maintain the order of equal elements. Default `false`. + pub fn with_maintain_order(mut self, enabled: bool) -> Self { + self.maintain_order = enabled; + self + } + + /// Reverse the order of sorting for each column. + pub fn with_order_reversed(mut self) -> Self { + self.descending.iter_mut().for_each(|x| *x = !*x); + self + } +} + +impl SortOptions { + /// Create `SortOptions` with default values. + pub fn new() -> Self { + Self::default() + } + + /// Specify sorting order for the column. Default `false`. + pub fn with_order_descending(mut self, enabled: bool) -> Self { + self.descending = enabled; + self + } + + /// Whether place null values last. Default `false`. + pub fn with_nulls_last(mut self, enabled: bool) -> Self { + self.nulls_last = enabled; + self + } + + /// Whether sort in multiple threads. Default `true`. + pub fn with_multithreaded(mut self, enabled: bool) -> Self { + self.multithreaded = enabled; + self + } + + /// Whether maintain the order of equal elements. Default `false`. + pub fn with_maintain_order(mut self, enabled: bool) -> Self { + self.maintain_order = enabled; + self + } +} + +impl From<&SortOptions> for SortMultipleOptions { + fn from(value: &SortOptions) -> Self { + SortMultipleOptions { + descending: vec![value.descending], + nulls_last: value.nulls_last, + multithreaded: value.multithreaded, + maintain_order: value.maintain_order, + } + } +} + +impl From<&SortMultipleOptions> for SortOptions { + fn from(value: &SortMultipleOptions) -> Self { + SortOptions { + descending: value.descending.first().copied().unwrap_or(false), + nulls_last: value.nulls_last, + multithreaded: value.multithreaded, + maintain_order: value.maintain_order, + } + } +} diff --git a/crates/polars-core/src/frame/group_by/mod.rs b/crates/polars-core/src/frame/group_by/mod.rs index b688e8a54656..f7218e1434d7 100644 --- a/crates/polars-core/src/frame/group_by/mod.rs +++ b/crates/polars-core/src/frame/group_by/mod.rs @@ -1079,7 +1079,7 @@ mod test { // Use of deprecated `sum()` for testing purposes #[allow(deprecated)] let res = df.group_by(["flt"]).unwrap().sum().unwrap(); - let res = res.sort(["flt"], false, false).unwrap(); + let res = res.sort(["flt"], SortMultipleOptions::default()).unwrap(); assert_eq!( Vec::from(res.column("val_sum").unwrap().i32().unwrap()), &[Some(2), Some(2), Some(1)] diff --git a/crates/polars-core/src/frame/mod.rs b/crates/polars-core/src/frame/mod.rs index 83c8292918f4..ede71de1e182 100644 --- a/crates/polars-core/src/frame/mod.rs +++ b/crates/polars-core/src/frame/mod.rs @@ -1759,37 +1759,33 @@ impl DataFrame { Ok(self) } - /// Sort [`DataFrame`] in place by a column. + /// Sort [`DataFrame`] in place. + /// + /// See [`DataFrame::sort`] for more instruction. pub fn sort_in_place( &mut self, - by_column: impl IntoVec, - descending: impl IntoVec, - maintain_order: bool, + by: impl IntoVec, + sort_options: SortMultipleOptions, ) -> PolarsResult<&mut Self> { - let by_column = self.select_series(by_column)?; - let descending = descending.into_vec(); - self.columns = self - .sort_impl(by_column, descending, false, maintain_order, None, true)? - .columns; + let by_column = self.select_series(by)?; + self.columns = self.sort_impl(by_column, sort_options, None)?.columns; Ok(self) } + #[doc(hidden)] /// This is the dispatch of Self::sort, and exists to reduce compile bloat by monomorphization. pub fn sort_impl( &self, by_column: Vec, - descending: Vec, - nulls_last: bool, - maintain_order: bool, + mut sort_options: SortMultipleOptions, slice: Option<(i64, usize)>, - parallel: bool, ) -> PolarsResult { // note that the by_column argument also contains evaluated expression from polars-lazy // that may not even be present in this dataframe. // therefore when we try to set the first columns as sorted, we ignore the error // as expressions are not present (they are renamed to _POLARS_SORT_COLUMN_i. - let first_descending = descending[0]; + let first_descending = sort_options.descending[0]; let first_by_column = by_column[0].name().to_string(); let set_sorted = |df: &mut DataFrame| { @@ -1815,7 +1811,7 @@ impl DataFrame { } if let Some((0, k)) = slice { - return self.top_k_impl(k, descending, by_column, nulls_last, maintain_order); + return self.bottom_k_impl(k, by_column, sort_options); } #[cfg(feature = "dtype-struct")] @@ -1834,10 +1830,10 @@ impl DataFrame { (1, false) => { let s = &by_column[0]; let options = SortOptions { - descending: descending[0], - nulls_last, - multithreaded: parallel, - maintain_order, + descending: sort_options.descending[0], + nulls_last: sort_options.nulls_last, + multithreaded: sort_options.multithreaded, + maintain_order: sort_options.maintain_order, }; // fast path for a frame with a single series // no need to compute the sort indices and then take by these indices @@ -1853,17 +1849,19 @@ impl DataFrame { s.arg_sort(options) }, _ => { - if nulls_last || has_struct || std::env::var("POLARS_ROW_FMT_SORT").is_ok() { - argsort_multiple_row_fmt(&by_column, descending, nulls_last, parallel)? + if sort_options.nulls_last + || has_struct + || std::env::var("POLARS_ROW_FMT_SORT").is_ok() + { + argsort_multiple_row_fmt( + &by_column, + sort_options.descending, + sort_options.nulls_last, + sort_options.multithreaded, + )? } else { - let (first, other, descending) = prepare_arg_sort(by_column, descending)?; - let options = SortMultipleOptions { - other, - descending, - nulls_last, - multithreaded: parallel, - }; - first.arg_sort_multiple(&options)? + let (first, other) = prepare_arg_sort(by_column, &mut sort_options)?; + first.arg_sort_multiple(&other, &sort_options)? } }, }; @@ -1874,7 +1872,7 @@ impl DataFrame { // SAFETY: // the created indices are in bounds - let mut df = unsafe { df.take_unchecked_impl(&take, parallel) }; + let mut df = unsafe { df.take_unchecked_impl(&take, sort_options.multithreaded) }; set_sorted(&mut df); Ok(df) } @@ -1883,42 +1881,45 @@ impl DataFrame { /// /// # Example /// + /// Sort by a single column with default options: /// ``` /// # use polars_core::prelude::*; - /// fn sort_example(df: &DataFrame, descending: bool) -> PolarsResult { - /// df.sort(["a"], descending, false) + /// fn sort_by_a(df: &DataFrame) -> PolarsResult { + /// df.sort(["a"], Default::default()) /// } - /// - /// fn sort_by_multiple_columns_example(df: &DataFrame) -> PolarsResult { - /// df.sort(&["a", "b"], vec![false, true], false) + /// ``` + /// Sort by a single column with specific order: + /// ``` + /// # use polars_core::prelude::*; + /// fn sort_with_specific_order(df: &DataFrame, descending: bool) -> PolarsResult { + /// df.sort( + /// ["a"], + /// SortMultipleOptions::new() + /// .with_order_descending(descending) + /// ) + /// } + /// ``` + /// Sort by multiple columns with specifying order for each column: + /// ``` + /// # use polars_core::prelude::*; + /// fn sort_by_multiple_columns_with_specific_order(df: &DataFrame) -> PolarsResult { + /// df.sort( + /// &["a", "b"], + /// SortMultipleOptions::new() + /// .with_order_descendings([false, true]) + /// ) /// } /// ``` + /// See [`SortMultipleOptions`] for more options. + /// + /// Also see [`DataFrame::sort_in_place`]. pub fn sort( &self, - by_column: impl IntoVec, - descending: impl IntoVec, - maintain_order: bool, + by: impl IntoVec, + sort_options: SortMultipleOptions, ) -> PolarsResult { let mut df = self.clone(); - df.sort_in_place(by_column, descending, maintain_order)?; - Ok(df) - } - - /// Sort the [`DataFrame`] by a single column with extra options. - pub fn sort_with_options(&self, by_column: &str, options: SortOptions) -> PolarsResult { - let mut df = self.clone(); - let by_column = vec![df.column(by_column)?.clone()]; - let descending = vec![options.descending]; - df.columns = df - .sort_impl( - by_column, - descending, - options.nulls_last, - options.maintain_order, - None, - options.multithreaded, - )? - .columns; + df.sort_in_place(by, sort_options)?; Ok(df) } @@ -3178,7 +3179,7 @@ mod test { let df = df .unique_stable(None, UniqueKeepStrategy::First, None) .unwrap() - .sort(["flt"], false, false) + .sort(["flt"], SortMultipleOptions::default()) .unwrap(); let valid = df! { "flt" => [1., 2., 3.], diff --git a/crates/polars-core/src/frame/top_k.rs b/crates/polars-core/src/frame/top_k.rs index e201d1abb40a..961f54704b55 100644 --- a/crates/polars-core/src/frame/top_k.rs +++ b/crates/polars-core/src/frame/top_k.rs @@ -1,91 +1,32 @@ -use std::cmp::Ordering; - -use polars_utils::iter::EnumerateIdxTrait; use smartstring::alias::String as SmartString; -use crate::prelude::sort::_broadcast_descending; -use crate::prelude::sort::arg_sort_multiple::_get_rows_encoded; -use crate::prelude::*; -use crate::series::IsSorted; -use crate::utils::NoNull; - -#[derive(Eq)] -struct CompareRow<'a> { - idx: IdxSize, - bytes: &'a [u8], -} - -impl PartialEq for CompareRow<'_> { - fn eq(&self, other: &Self) -> bool { - self.bytes == other.bytes - } -} - -impl Ord for CompareRow<'_> { - fn cmp(&self, other: &Self) -> Ordering { - self.bytes.cmp(other.bytes) - } -} - -impl PartialOrd for CompareRow<'_> { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} +use super::*; +use crate::prelude::sort::arg_bottom_k::_arg_bottom_k; impl DataFrame { pub fn top_k( &self, k: usize, - descending: impl IntoVec, by_column: impl IntoVec, + sort_options: SortMultipleOptions, ) -> PolarsResult { let by_column = self.select_series(by_column)?; - let descending = descending.into_vec(); - self.top_k_impl(k, descending, by_column, false, false) + self.bottom_k_impl(k, by_column, sort_options.with_order_reversed()) } - pub(crate) fn top_k_impl( + pub(crate) fn bottom_k_impl( &self, k: usize, - mut descending: Vec, by_column: Vec, - nulls_last: bool, - maintain_order: bool, + mut sort_options: SortMultipleOptions, ) -> PolarsResult { - _broadcast_descending(by_column.len(), &mut descending); - let encoded = _get_rows_encoded(&by_column, &descending, nulls_last)?; - let arr = encoded.into_array(); - let mut rows = arr - .values_iter() - .enumerate_idx() - .map(|(idx, bytes)| CompareRow { idx, bytes }) - .collect::>(); - - let sorted = if k >= self.height() { - if maintain_order { - rows.sort(); - } else { - rows.sort_unstable(); - } - &rows - } else if maintain_order { - // todo: maybe there is some more efficient method, comparable to select_nth_unstable - rows.sort(); - &rows[..k] - } else { - let (lower, _el, _upper) = rows.select_nth_unstable(k); - lower.sort_unstable(); - &*lower - }; + let first_descending = sort_options.descending[0]; + let first_by_column = by_column[0].name().to_string(); - let idx: NoNull = sorted.iter().map(|cmp_row| cmp_row.idx).collect(); + let idx = _arg_bottom_k(k, &by_column, &mut sort_options)?; let mut df = unsafe { self.take_unchecked(&idx.into_inner()) }; - let first_descending = descending[0]; - let first_by_column = by_column[0].name().to_string(); - // Mark the first sort column as sorted // if the column did not exists it is ok, because we sorted by an expression // not present in the dataframe diff --git a/crates/polars-core/src/series/implementations/binary.rs b/crates/polars-core/src/series/implementations/binary.rs index b0da78d5141a..b39f95a7203b 100644 --- a/crates/polars-core/src/series/implementations/binary.rs +++ b/crates/polars-core/src/series/implementations/binary.rs @@ -84,8 +84,12 @@ impl private::PrivateSeries for SeriesWrap { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { - self.0.arg_sort_multiple(options) + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) } } diff --git a/crates/polars-core/src/series/implementations/binary_offset.rs b/crates/polars-core/src/series/implementations/binary_offset.rs index 8cea50d76212..66da638dc7ee 100644 --- a/crates/polars-core/src/series/implementations/binary_offset.rs +++ b/crates/polars-core/src/series/implementations/binary_offset.rs @@ -47,8 +47,12 @@ impl private::PrivateSeries for SeriesWrap { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { - self.0.arg_sort_multiple(options) + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) } } diff --git a/crates/polars-core/src/series/implementations/boolean.rs b/crates/polars-core/src/series/implementations/boolean.rs index 06f9f0920243..08401055ff57 100644 --- a/crates/polars-core/src/series/implementations/boolean.rs +++ b/crates/polars-core/src/series/implementations/boolean.rs @@ -88,8 +88,12 @@ impl private::PrivateSeries for SeriesWrap { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { - self.0.arg_sort_multiple(options) + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) } fn add_to(&self, rhs: &Series) -> PolarsResult { NumOpsDispatch::add_to(&self.0, rhs) diff --git a/crates/polars-core/src/series/implementations/categorical.rs b/crates/polars-core/src/series/implementations/categorical.rs index 1efd20432d94..056b6a435e3f 100644 --- a/crates/polars-core/src/series/implementations/categorical.rs +++ b/crates/polars-core/src/series/implementations/categorical.rs @@ -119,8 +119,12 @@ impl private::PrivateSeries for SeriesWrap { } } - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { - self.0.arg_sort_multiple(options) + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) } } diff --git a/crates/polars-core/src/series/implementations/dates_time.rs b/crates/polars-core/src/series/implementations/dates_time.rs index c0516e92ce28..355137112e3a 100644 --- a/crates/polars-core/src/series/implementations/dates_time.rs +++ b/crates/polars-core/src/series/implementations/dates_time.rs @@ -129,8 +129,12 @@ macro_rules! impl_dyn_series { self.0.group_tuples(multithreaded, sorted) } - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { - self.0.deref().arg_sort_multiple(options) + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.deref().arg_sort_multiple(by, options) } } diff --git a/crates/polars-core/src/series/implementations/datetime.rs b/crates/polars-core/src/series/implementations/datetime.rs index 29accf0ac33e..c237c63e1d9c 100644 --- a/crates/polars-core/src/series/implementations/datetime.rs +++ b/crates/polars-core/src/series/implementations/datetime.rs @@ -134,8 +134,12 @@ impl private::PrivateSeries for SeriesWrap { self.0.group_tuples(multithreaded, sorted) } - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { - self.0.deref().arg_sort_multiple(options) + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.deref().arg_sort_multiple(by, options) } } diff --git a/crates/polars-core/src/series/implementations/duration.rs b/crates/polars-core/src/series/implementations/duration.rs index 0dfecd55de9c..7249a3f92950 100644 --- a/crates/polars-core/src/series/implementations/duration.rs +++ b/crates/polars-core/src/series/implementations/duration.rs @@ -194,8 +194,12 @@ impl private::PrivateSeries for SeriesWrap { self.0.group_tuples(multithreaded, sorted) } - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { - self.0.deref().arg_sort_multiple(options) + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.deref().arg_sort_multiple(by, options) } } diff --git a/crates/polars-core/src/series/implementations/floats.rs b/crates/polars-core/src/series/implementations/floats.rs index 3f566463acfd..2c21aec09a63 100644 --- a/crates/polars-core/src/series/implementations/floats.rs +++ b/crates/polars-core/src/series/implementations/floats.rs @@ -116,8 +116,12 @@ macro_rules! impl_dyn_series { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { - self.0.arg_sort_multiple(options) + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) } } diff --git a/crates/polars-core/src/series/implementations/mod.rs b/crates/polars-core/src/series/implementations/mod.rs index 6f589f1b61a1..d05797465617 100644 --- a/crates/polars-core/src/series/implementations/mod.rs +++ b/crates/polars-core/src/series/implementations/mod.rs @@ -189,8 +189,12 @@ macro_rules! impl_dyn_series { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { - self.0.arg_sort_multiple(options) + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) } } diff --git a/crates/polars-core/src/series/implementations/string.rs b/crates/polars-core/src/series/implementations/string.rs index 38a60a1be615..28133cc38d20 100644 --- a/crates/polars-core/src/series/implementations/string.rs +++ b/crates/polars-core/src/series/implementations/string.rs @@ -85,8 +85,12 @@ impl private::PrivateSeries for SeriesWrap { IntoGroupsProxy::group_tuples(&self.0, multithreaded, sorted) } - fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult { - self.0.arg_sort_multiple(options) + fn arg_sort_multiple( + &self, + by: &[Series], + options: &SortMultipleOptions, + ) -> PolarsResult { + self.0.arg_sort_multiple(by, options) } } diff --git a/crates/polars-core/src/series/implementations/struct_.rs b/crates/polars-core/src/series/implementations/struct_.rs index 861599cb91b6..346b65491b14 100644 --- a/crates/polars-core/src/series/implementations/struct_.rs +++ b/crates/polars-core/src/series/implementations/struct_.rs @@ -320,14 +320,10 @@ impl SeriesTrait for SeriesWrap { } else { vec![false; df.width()] }; - let out = df.sort_impl( - df.columns.clone(), - desc, - options.nulls_last, - options.maintain_order, - None, - options.multithreaded, - )?; + + let multi_options = SortMultipleOptions::from(&options).with_order_descendings(desc); + + let out = df.sort_impl(df.columns.clone(), multi_options, None)?; Ok(StructChunked::new_unchecked(self.name(), &out.columns).into_series()) } diff --git a/crates/polars-core/src/series/mod.rs b/crates/polars-core/src/series/mod.rs index 2eb60f08788b..59321706ab7b 100644 --- a/crates/polars-core/src/series/mod.rs +++ b/crates/polars-core/src/series/mod.rs @@ -285,12 +285,23 @@ impl Series { Ok(self) } - pub fn sort(&self, descending: bool, nulls_last: bool) -> PolarsResult { - self.sort_with(SortOptions { - descending, - nulls_last, - ..Default::default() - }) + /// Sort the series with specific options. + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// # fn main() -> PolarsResult<()> { + /// let s = Series::new("foo", [2, 1, 3]); + /// let sorted = s.sort(SortOptions::default())?; + /// assert_eq!(sorted, Series::new("foo", [1, 2, 3])); + /// # Ok(()) + /// } + /// ``` + /// + /// See [`SortOptions`] for more options. + pub fn sort(&self, sort_options: SortOptions) -> PolarsResult { + self.sort_with(sort_options) } /// Only implemented for numeric types diff --git a/crates/polars-core/src/series/series_trait.rs b/crates/polars-core/src/series/series_trait.rs index ff4a7eca65a6..583eeac2db11 100644 --- a/crates/polars-core/src/series/series_trait.rs +++ b/crates/polars-core/src/series/series_trait.rs @@ -163,7 +163,12 @@ pub(crate) mod private { invalid_operation_panic!(zip_with_same_type, self) } - fn arg_sort_multiple(&self, _options: &SortMultipleOptions) -> PolarsResult { + #[allow(unused_variables)] + fn arg_sort_multiple( + &self, + by: &[Series], + _options: &SortMultipleOptions, + ) -> PolarsResult { polars_bail!(opq = arg_sort_multiple, self._dtype()); } } diff --git a/crates/polars-lazy/src/frame/mod.rs b/crates/polars-lazy/src/frame/mod.rs index 8c1ff7baedee..e42d37efec28 100644 --- a/crates/polars-lazy/src/frame/mod.rs +++ b/crates/polars-lazy/src/frame/mod.rs @@ -260,29 +260,46 @@ impl LazyFrame { /// /// # Example /// + /// Sort DataFrame by 'sepal.width' column: /// ```rust - /// use polars_core::prelude::*; - /// use polars_lazy::prelude::*; - /// - /// /// Sort DataFrame by 'sepal.width' column - /// fn example(df: DataFrame) -> LazyFrame { - /// df.lazy() - /// .sort("sepal.width", Default::default()) + /// # use polars_core::prelude::*; + /// # use polars_lazy::prelude::*; + /// fn sort_by_a(df: DataFrame) -> LazyFrame { + /// df.lazy().sort(["a"], Default::default()) /// } /// ``` - pub fn sort(self, by_column: &str, options: SortOptions) -> Self { - let descending = options.descending; - let nulls_last = options.nulls_last; - let maintain_order = options.maintain_order; - + /// Sort by a single column with specific order: + /// ``` + /// # use polars_core::prelude::*; + /// # use polars_lazy::prelude::*; + /// fn sort_with_specific_order(df: DataFrame, descending: bool) -> LazyFrame { + /// df.lazy().sort( + /// ["a"], + /// SortMultipleOptions::new() + /// .with_order_descending(descending) + /// ) + /// } + /// ``` + /// Sort by multiple columns with specifying order for each column: + /// ``` + /// # use polars_core::prelude::*; + /// # use polars_lazy::prelude::*; + /// fn sort_by_multiple_columns_with_specific_order(df: DataFrame) -> LazyFrame { + /// df.lazy().sort( + /// &["a", "b"], + /// SortMultipleOptions::new() + /// .with_order_descendings([false, true]) + /// ) + /// } + /// ``` + /// See [`SortMultipleOptions`] for more options. + pub fn sort(self, by: impl IntoVec, sort_options: SortMultipleOptions) -> Self { let opt_state = self.get_opt_state(); let lp = self .get_plan_builder() .sort( - vec![col(by_column)], - vec![descending], - nulls_last, - maintain_order, + by.into_vec().into_iter().map(|x| col(&x)).collect(), + sort_options, ) .build(); Self::from_logical_plan(lp, opt_state) @@ -291,9 +308,9 @@ impl LazyFrame { /// Add a sort operation to the logical plan. /// /// Sorts the LazyFrame by the provided list of expressions, which will be turned into - /// concrete columns before sorting. `reverse` is a list of `bool`, the same length as - /// `by_exprs`, that specifies whether each corresponding expression will be sorted - /// ascending (`false`) or descending (`true`). + /// concrete columns before sorting. + /// + /// See [`SortMultipleOptions`] for more options. /// /// # Example /// @@ -304,60 +321,43 @@ impl LazyFrame { /// /// Sort DataFrame by 'sepal.width' column /// fn example(df: DataFrame) -> LazyFrame { /// df.lazy() - /// .sort_by_exprs(vec![col("sepal.width")], vec![false], false, false) + /// .sort_by_exprs(vec![col("sepal.width")], Default::default()) /// } /// ``` - pub fn sort_by_exprs, B: AsRef<[bool]>>( + pub fn sort_by_exprs>( self, by_exprs: E, - descending: B, - nulls_last: bool, - maintain_order: bool, + sort_options: SortMultipleOptions, ) -> Self { let by_exprs = by_exprs.as_ref().to_vec(); - let descending = descending.as_ref().to_vec(); if by_exprs.is_empty() { self } else { let opt_state = self.get_opt_state(); - let lp = self - .get_plan_builder() - .sort(by_exprs, descending, nulls_last, maintain_order) - .build(); + let lp = self.get_plan_builder().sort(by_exprs, sort_options).build(); Self::from_logical_plan(lp, opt_state) } } - pub fn top_k, B: AsRef<[bool]>>( + pub fn top_k>( self, k: IdxSize, by_exprs: E, - descending: B, - nulls_last: bool, - maintain_order: bool, + sort_options: SortMultipleOptions, ) -> Self { - let mut descending = descending.as_ref().to_vec(); - // top-k is reverse from sort - for v in &mut descending { - *v = !*v; - } // this will optimize to top-k - self.sort_by_exprs(by_exprs, descending, nulls_last, maintain_order) + self.sort_by_exprs(by_exprs, sort_options.with_order_reversed()) .slice(0, k) } - pub fn bottom_k, B: AsRef<[bool]>>( + pub fn bottom_k>( self, k: IdxSize, by_exprs: E, - descending: B, - nulls_last: bool, - maintain_order: bool, + sort_options: SortMultipleOptions, ) -> Self { - let descending = descending.as_ref().to_vec(); // this will optimize to bottom-k - self.sort_by_exprs(by_exprs, descending, nulls_last, maintain_order) - .slice(0, k) + self.sort_by_exprs(by_exprs, sort_options).slice(0, k) } /// Reverse the `DataFrame` from top to bottom. diff --git a/crates/polars-lazy/src/lib.rs b/crates/polars-lazy/src/lib.rs index 0f846a7380ef..49396a0bb256 100644 --- a/crates/polars-lazy/src/lib.rs +++ b/crates/polars-lazy/src/lib.rs @@ -120,7 +120,7 @@ //! col("rain").sum().alias("sum_rain"), //! col("rain").quantile(lit(0.5), QuantileInterpolOptions::Nearest).alias("median_rain"), //! ]) -//! .sort("date", Default::default()) +//! .sort(["date"], Default::default()) //! .collect() //! } //! ``` diff --git a/crates/polars-lazy/src/physical_plan/executors/sort.rs b/crates/polars-lazy/src/physical_plan/executors/sort.rs index 54c6755fd953..2d32a63d6097 100644 --- a/crates/polars-lazy/src/physical_plan/executors/sort.rs +++ b/crates/polars-lazy/src/physical_plan/executors/sort.rs @@ -3,7 +3,8 @@ use super::*; pub(crate) struct SortExec { pub(crate) input: Box, pub(crate) by_column: Vec>, - pub(crate) args: SortArguments, + pub(crate) slice: Option<(i64, usize)>, + pub(crate) sort_options: SortMultipleOptions, } impl SortExec { @@ -32,14 +33,7 @@ impl SortExec { }) .collect::>>()?; - df.sort_impl( - by_columns, - std::mem::take(&mut self.args.descending), - self.args.nulls_last, - self.args.maintain_order, - self.args.slice, - true, - ) + df.sort_impl(by_columns, self.sort_options.clone(), self.slice) } } diff --git a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs index e1b7e703a310..ae1bcadc0e7c 100644 --- a/crates/polars-lazy/src/physical_plan/expressions/sortby.rs +++ b/crates/polars-lazy/src/physical_plan/expressions/sortby.rs @@ -9,22 +9,22 @@ use crate::prelude::*; pub struct SortByExpr { pub(crate) input: Arc, pub(crate) by: Vec>, - pub(crate) descending: Vec, pub(crate) expr: Expr, + pub(crate) sort_options: SortMultipleOptions, } impl SortByExpr { pub fn new( input: Arc, by: Vec>, - descending: Vec, expr: Expr, + sort_options: SortMultipleOptions, ) -> Self { Self { input, by, - descending, expr, + sort_options, } } } @@ -123,6 +123,8 @@ fn sort_by_groups_multiple_by( indicator: GroupsIndicator, sort_by_s: &[Series], descending: &[bool], + multithreaded: bool, + maintain_order: bool, ) -> PolarsResult<(IdxSize, IdxVec)> { let new_idx = match indicator { GroupsIndicator::Idx((_first, idx)) => { @@ -133,13 +135,13 @@ fn sort_by_groups_multiple_by( .collect::>(); let options = SortMultipleOptions { - other: groups[1..].to_vec(), descending: descending.to_owned(), nulls_last: false, - multithreaded: false, + multithreaded, + maintain_order, }; - let sorted_idx = groups[0].arg_sort_multiple(&options).unwrap(); + let sorted_idx = groups[0].arg_sort_multiple(&groups[1..], &options).unwrap(); map_sorted_indices_to_group_idx(&sorted_idx, idx) }, GroupsIndicator::Slice([first, len]) => { @@ -149,12 +151,12 @@ fn sort_by_groups_multiple_by( .collect::>(); let options = SortMultipleOptions { - other: groups[1..].to_vec(), descending: descending.to_owned(), nulls_last: false, - multithreaded: false, + multithreaded, + maintain_order, }; - let sorted_idx = groups[0].arg_sort_multiple(&options).unwrap(); + let sorted_idx = groups[0].arg_sort_multiple(&groups[1..], &options).unwrap(); map_sorted_indices_to_group_slice(&sorted_idx, first) }, }; @@ -171,15 +173,12 @@ impl PhysicalExpr for SortByExpr { } fn evaluate(&self, df: &DataFrame, state: &ExecutionState) -> PolarsResult { let series_f = || self.input.evaluate(df, state); - let descending = prepare_descending(&self.descending, self.by.len()); + let descending = prepare_descending(&self.sort_options.descending, self.by.len()); let (series, sorted_idx) = if self.by.len() == 1 { let sorted_idx_f = || { let s_sort_by = self.by[0].evaluate(df, state)?; - Ok(s_sort_by.arg_sort(SortOptions { - descending: descending[0], - ..Default::default() - })) + Ok(s_sort_by.arg_sort(SortOptions::from(&self.sort_options))) }; POOL.install(|| rayon::join(series_f, sorted_idx_f)) } else { @@ -196,13 +195,8 @@ impl PhysicalExpr for SortByExpr { }) .collect::>>()?; - let options = SortMultipleOptions { - other: s_sort_by[1..].to_vec(), - descending, - nulls_last: false, - multithreaded: true, - }; - s_sort_by[0].arg_sort_multiple(&options) + let options = self.sort_options.clone().with_order_descendings(descending); + s_sort_by[0].arg_sort_multiple(&s_sort_by[1..], &options) }; POOL.install(|| rayon::join(series_f, sorted_idx_f)) }; @@ -226,7 +220,7 @@ impl PhysicalExpr for SortByExpr { state: &ExecutionState, ) -> PolarsResult> { let mut ac_in = self.input.evaluate_on_groups(df, groups, state)?; - let descending = prepare_descending(&self.descending, self.by.len()); + let descending = prepare_descending(&self.sort_options.descending, self.by.len()); let mut ac_sort_by = self .by @@ -267,7 +261,7 @@ impl PhysicalExpr for SortByExpr { return sort_by_groups_no_match_single( ac_in, ac_sort_by, - self.descending[0], + self.sort_options.descending[0], &self.expr, ); }; @@ -295,7 +289,15 @@ impl PhysicalExpr for SortByExpr { let groups = POOL.install(|| { groups .par_iter() - .map(|indicator| sort_by_groups_multiple_by(indicator, &sort_by_s, &descending)) + .map(|indicator| { + sort_by_groups_multiple_by( + indicator, + &sort_by_s, + &descending, + self.sort_options.multithreaded, + self.sort_options.maintain_order, + ) + }) .collect::>() }); GroupsProxy::Idx(groups?) diff --git a/crates/polars-lazy/src/physical_plan/node_timer.rs b/crates/polars-lazy/src/physical_plan/node_timer.rs index 4926f7df8c59..95084eeb4fcb 100644 --- a/crates/polars-lazy/src/physical_plan/node_timer.rs +++ b/crates/polars-lazy/src/physical_plan/node_timer.rs @@ -59,6 +59,6 @@ impl NodeTimer { let columns = vec![nodes_s, start.into_series(), end.into_series()]; let df = unsafe { DataFrame::new_no_checks(columns) }; - df.sort(vec!["start"], vec![false], false) + df.sort(vec!["start"], SortMultipleOptions::default()) } } diff --git a/crates/polars-lazy/src/physical_plan/planner/expr.rs b/crates/polars-lazy/src/physical_plan/planner/expr.rs index 5d37f4588051..bdd4785df796 100644 --- a/crates/polars-lazy/src/physical_plan/planner/expr.rs +++ b/crates/polars-lazy/src/physical_plan/planner/expr.rs @@ -272,7 +272,7 @@ fn create_physical_expr_inner( SortBy { expr, by, - descending, + sort_options, } => { polars_ensure!(!by.is_empty(), InvalidOperation: "'sort_by' got an empty set"); let phys_expr = create_physical_expr_inner(expr, ctxt, expr_arena, schema, state)?; @@ -281,8 +281,8 @@ fn create_physical_expr_inner( Ok(Arc::new(SortByExpr::new( phys_expr, phys_by, - descending, node_to_expr(expression, expr_arena), + sort_options, ))) }, Filter { input, by } => { diff --git a/crates/polars-lazy/src/physical_plan/planner/lp.rs b/crates/polars-lazy/src/physical_plan/planner/lp.rs index 60722959c3a1..2cd03493f69a 100644 --- a/crates/polars-lazy/src/physical_plan/planner/lp.rs +++ b/crates/polars-lazy/src/physical_plan/planner/lp.rs @@ -330,7 +330,8 @@ pub fn create_physical_plan( Sort { input, by_column, - args, + slice, + sort_options, } => { let input_schema = lp_arena.get(input).schema(lp_arena); let by_column = create_physical_expressions_from_irs( @@ -344,7 +345,8 @@ pub fn create_physical_plan( Ok(Box::new(executors::SortExec { input, by_column, - args, + slice, + sort_options, })) }, Cache { diff --git a/crates/polars-lazy/src/physical_plan/streaming/checks.rs b/crates/polars-lazy/src/physical_plan/streaming/checks.rs index a781682d3366..757149174dfb 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/checks.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/checks.rs @@ -1,19 +1,19 @@ +use polars_core::chunked_array::ops::SortMultipleOptions; use polars_ops::prelude::*; use polars_plan::logical_plan::expr_ir::ExprIR; use polars_plan::prelude::*; -pub(super) fn is_streamable_sort(args: &SortArguments) -> bool { +pub(super) fn is_streamable_sort( + slice: &Option<(i64, usize)>, + sort_options: &SortMultipleOptions, +) -> bool { // check if slice is positive or maintain order is true - match args { - SortArguments { - maintain_order: true, - .. - } => false, - SortArguments { - slice: Some((offset, _)), - .. - } => *offset >= 0, - SortArguments { slice: None, .. } => true, + if sort_options.maintain_order { + false + } else if let Some((offset, _)) = slice { + *offset >= 0 + } else { + true } } diff --git a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs index 6b700a7c74a9..b6852313db57 100644 --- a/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs +++ b/crates/polars-lazy/src/physical_plan/streaming/convert_alp.rs @@ -202,8 +202,9 @@ pub(crate) fn insert_streaming_nodes( Sort { input, by_column, - args, - } if is_streamable_sort(args) && all_column(by_column, expr_arena) => { + slice, + sort_options, + } if is_streamable_sort(slice, sort_options) && all_column(by_column, expr_arena) => { state.streamable = true; state.operators_sinks.push(PipelineNode::Sink(root)); stack.push(StackFrame::new(*input, state, current_idx)) diff --git a/crates/polars-lazy/src/tests/aggregations.rs b/crates/polars-lazy/src/tests/aggregations.rs index 60c19bbd00eb..7df14711c5aa 100644 --- a/crates/polars-lazy/src/tests/aggregations.rs +++ b/crates/polars-lazy/src/tests/aggregations.rs @@ -53,7 +53,11 @@ fn test_agg_unique_first() -> PolarsResult<()> { .group_by_stable([col("g")]) .agg([ col("v").unique().first().alias("v_first"), - col("v").unique().sort(false).first().alias("true_first"), + col("v") + .unique() + .sort(Default::default()) + .first() + .alias("true_first"), col("v").unique().implode(), ]) .collect()?; @@ -87,7 +91,7 @@ fn test_cum_sum_agg_as_key() -> PolarsResult<()> { .cum_sum(false) .alias("key")]) .agg([col("depth").max().name().keep()]) - .sort("depth", SortOptions::default()) + .sort(["depth"], Default::default()) .collect()?; assert_eq!( @@ -183,11 +187,8 @@ fn test_power_in_agg_list1() -> PolarsResult<()> { .alias("foo"), ]) .sort( - "fruits", - SortOptions { - descending: true, - ..Default::default() - }, + ["fruits"], + SortMultipleOptions::default().with_order_descending(true), ) .collect()?; @@ -219,11 +220,8 @@ fn test_power_in_agg_list2() -> PolarsResult<()> { .sum() .alias("foo")]) .sort( - "fruits", - SortOptions { - descending: true, - ..Default::default() - }, + ["fruits"], + SortMultipleOptions::default().with_order_descending(true), ) .collect()?; @@ -402,7 +400,7 @@ fn test_shift_elementwise_issue_2509() -> PolarsResult<()> { // Don't use maintain order here! That hides the bug .group_by([col("x")]) .agg(&[(col("y").shift(lit(-1)) + col("x")).alias("sum")]) - .sort("x", Default::default()) + .sort(["x"], Default::default()) .collect()?; let out = out.explode(["sum"])?; @@ -430,7 +428,7 @@ fn take_aggregations() -> PolarsResult<()> { .lazy() .group_by([col("user")]) .agg([col("book").get(col("count").arg_max()).alias("fav_book")]) - .sort("user", Default::default()) + .sort(["user"], Default::default()) .collect()?; let s = out.column("fav_book")?; @@ -457,7 +455,7 @@ fn take_aggregations() -> PolarsResult<()> { ) .alias("ordered"), ]) - .sort("user", Default::default()) + .sort(["user"], Default::default()) .collect()?; let s = out.column("ordered")?; let flat = s.explode()?; @@ -469,7 +467,7 @@ fn take_aggregations() -> PolarsResult<()> { .lazy() .group_by([col("user")]) .agg([col("book").get(lit(0)).alias("take_lit")]) - .sort("user", Default::default()) + .sort(["user"], Default::default()) .collect()?; let taken = out.column("take_lit")?; @@ -563,7 +561,7 @@ fn test_take_in_groups() -> PolarsResult<()> { let out = df .lazy() - .sort("fruits", Default::default()) + .sort(["fruits"], Default::default()) .select([col("B").get(lit(0u32)).over([col("fruits")]).alias("taken")]) .collect()?; diff --git a/crates/polars-lazy/src/tests/logical.rs b/crates/polars-lazy/src/tests/logical.rs index 6fdef02c8b38..93babe7f2788 100644 --- a/crates/polars-lazy/src/tests/logical.rs +++ b/crates/polars-lazy/src/tests/logical.rs @@ -65,7 +65,7 @@ fn test_lazy_arithmetic() { let lf = df .lazy() .select(&[((col("sepal.width") * lit(100)).alias("super_wide"))]) - .sort("super_wide", SortOptions::default()); + .sort(["super_wide"], SortMultipleOptions::default()); print_plans(&lf); diff --git a/crates/polars-lazy/src/tests/optimization_checks.rs b/crates/polars-lazy/src/tests/optimization_checks.rs index 0eabd17fc824..e43cebae8e54 100644 --- a/crates/polars-lazy/src/tests/optimization_checks.rs +++ b/crates/polars-lazy/src/tests/optimization_checks.rs @@ -262,7 +262,9 @@ pub fn test_slice_pushdown_sort() -> PolarsResult<()> { let _guard = SINGLE_LOCK.lock().unwrap(); let q = scan_foods_parquet(false).limit(100); - let q = q.sort("category", SortOptions::default()).slice(1, 3); + let q = q + .sort(["category"], SortMultipleOptions::default()) + .slice(1, 3); // test if optimization continued beyond the sort node assert!(slice_at_scan(q.clone())); @@ -272,7 +274,7 @@ pub fn test_slice_pushdown_sort() -> PolarsResult<()> { assert!((&lp_arena).iter(lp).all(|(_, lp)| { use IR::*; match lp { - Sort { args, .. } => args.slice == Some((1, 3)), + Sort { slice, .. } => *slice == Some((1, 3)), Slice { .. } => false, _ => true, } diff --git a/crates/polars-lazy/src/tests/queries.rs b/crates/polars-lazy/src/tests/queries.rs index ce80a73c87a8..2525f99aff2b 100644 --- a/crates/polars-lazy/src/tests/queries.rs +++ b/crates/polars-lazy/src/tests/queries.rs @@ -21,7 +21,7 @@ fn test_lazy_exec() { .clone() .lazy() .select([col("sepal.width"), col("variety")]) - .sort("sepal.width", Default::default()) + .sort(["sepal.width"], Default::default()) .collect(); let new = df @@ -387,7 +387,7 @@ fn test_lazy_query_9() -> PolarsResult<()> { ) .group_by([col("Cities.Country")]) .agg([col("Sales.Amount").sum().alias("sum")]) - .sort("sum", Default::default()) + .sort(["sum"], Default::default()) .collect()?; let vals = out .column("sum")? @@ -673,7 +673,7 @@ fn test_lazy_partition_agg() { .lazy() .group_by([col("foo")]) .agg([col("bar").mean()]) - .sort("foo", Default::default()) + .sort(["foo"], Default::default()) .collect() .unwrap(); @@ -685,7 +685,7 @@ fn test_lazy_partition_agg() { let out = scan_foods_csv() .group_by([col("category")]) .agg([col("calories")]) - .sort("category", Default::default()) + .sort(["category"], Default::default()) .collect() .unwrap(); let cat_agg_list = out.select_at_idx(1).unwrap(); @@ -763,7 +763,7 @@ fn test_lazy_group_by() { .lazy() .group_by([col("groups")]) .agg([col("a").mean()]) - .sort("a", Default::default()) + .sort(["a"], Default::default()) .collect() .unwrap(); @@ -793,10 +793,10 @@ fn test_lazy_group_by_sort() { .clone() .lazy() .group_by([col("a")]) - .agg([col("b").sort(false).first()]) + .agg([col("b").sort(Default::default()).first()]) .collect() .unwrap() - .sort(["a"], false, false) + .sort(["a"], Default::default()) .unwrap(); assert_eq!( @@ -807,10 +807,10 @@ fn test_lazy_group_by_sort() { let out = df .lazy() .group_by([col("a")]) - .agg([col("b").sort(false).last()]) + .agg([col("b").sort(Default::default()).last()]) .collect() .unwrap() - .sort(["a"], false, false) + .sort(["a"], Default::default()) .unwrap(); assert_eq!( @@ -831,10 +831,15 @@ fn test_lazy_group_by_sort_by() { let out = df .lazy() .group_by([col("a")]) - .agg([col("b").sort_by([col("c")], [true]).first()]) + .agg([col("b") + .sort_by( + [col("c")], + SortMultipleOptions::default().with_order_descending(true), + ) + .first()]) .collect() .unwrap() - .sort(["a"], false, false) + .sort(["a"], Default::default()) .unwrap(); assert_eq!( @@ -876,7 +881,7 @@ fn test_lazy_group_by_binary_expr() { .lazy() .group_by([col("a")]) .agg([col("b").mean() * lit(2)]) - .sort("a", Default::default()) + .sort(["a"], Default::default()) .collect() .unwrap(); assert_eq!( @@ -913,15 +918,7 @@ fn test_lazy_group_by_filter() -> PolarsResult<()> { .last() .alias("b_last"), ]) - .sort( - "a", - SortOptions { - descending: false, - nulls_last: false, - multithreaded: true, - maintain_order: false, - }, - ) + .sort(["a"], SortMultipleOptions::default()) .collect()?; assert_eq!( @@ -988,22 +985,22 @@ fn test_group_by_sort_slice() -> PolarsResult<()> { .clone() .lazy() .sort( - "vals", - SortOptions { - descending: true, - ..Default::default() - }, + ["vals"], + SortMultipleOptions::default().with_order_descending(true), ) .group_by([col("groups")]) .agg([col("vals").head(Some(2)).alias("foo")]) - .sort("groups", SortOptions::default()) + .sort(["groups"], Default::default()) .collect()?; let out2 = df .lazy() .group_by([col("groups")]) - .agg([col("vals").sort(true).head(Some(2)).alias("foo")]) - .sort("groups", SortOptions::default()) + .agg([col("vals") + .sort(SortOptions::default().with_order_descending(true)) + .head(Some(2)) + .alias("foo")]) + .sort(["groups"], Default::default()) .collect()?; assert!(out1.column("foo")?.equals(out2.column("foo")?)); @@ -1022,7 +1019,7 @@ fn test_group_by_cum_sum() -> PolarsResult<()> { .lazy() .group_by([col("groups")]) .agg([col("vals").cum_sum(false)]) - .sort("groups", Default::default()) + .sort(["groups"], Default::default()) .collect()?; assert_eq!( @@ -1049,7 +1046,10 @@ fn test_arg_sort_multiple() -> PolarsResult<()> { let out = df .clone() .lazy() - .select([arg_sort_by([col("int"), col("flt")], &[true, false])]) + .select([arg_sort_by( + [col("int"), col("flt")], + SortMultipleOptions::default().with_order_descendings([true, false]), + )]) .collect()?; assert_eq!( @@ -1064,7 +1064,10 @@ fn test_arg_sort_multiple() -> PolarsResult<()> { // check if this runs let _out = df .lazy() - .select([arg_sort_by([col("str"), col("flt")], &[true, false])]) + .select([arg_sort_by( + [col("str"), col("flt")], + SortMultipleOptions::default().with_order_descendings([true, false]), + )]) .collect()?; Ok(()) } @@ -1265,7 +1268,7 @@ fn test_sort_by() -> PolarsResult<()> { let out = df .clone() .lazy() - .select([col("a").sort_by([col("b"), col("c")], [false])]) + .select([col("a").sort_by([col("b"), col("c")], SortMultipleOptions::default())]) .collect()?; let a = out.column("a")?; @@ -1279,7 +1282,7 @@ fn test_sort_by() -> PolarsResult<()> { .clone() .lazy() .group_by_stable([col("b")]) - .agg([col("a").sort_by([col("b"), col("c")], [false])]) + .agg([col("a").sort_by([col("b"), col("c")], SortMultipleOptions::default())]) .collect()?; let a = out.column("a")?.explode()?; assert_eq!( @@ -1291,7 +1294,7 @@ fn test_sort_by() -> PolarsResult<()> { let out = df .lazy() .group_by_stable([col("b")]) - .agg([col("a").sort_by([col("b"), col("c")], [false])]) + .agg([col("a").sort_by([col("b"), col("c")], SortMultipleOptions::default())]) .collect()?; let a = out.column("a")?.explode()?; @@ -1426,11 +1429,8 @@ fn test_group_by_small_ints() -> PolarsResult<()> { .group_by([col("id_16"), col("id_32")]) .agg([col("id_16").sum().alias("foo")]) .sort( - "foo", - SortOptions { - descending: true, - ..Default::default() - }, + ["foo"], + SortMultipleOptions::default().with_order_descending(true), ) .collect()?; @@ -1784,7 +1784,7 @@ fn test_partitioned_gb_1() -> PolarsResult<()> { (col("vals").eq(lit("a"))).sum().alias("eq_a"), (col("vals").eq(lit("b"))).sum().alias("eq_b"), ]) - .sort("keys", Default::default()) + .sort(["keys"], Default::default()) .collect()?; assert!(out.equals(&df![ @@ -1914,7 +1914,12 @@ fn test_sort_maintain_order_true() -> PolarsResult<()> { .lazy(); let res = q - .sort_by_exprs([col("A")], [false], false, true) + .sort_by_exprs( + [col("A")], + SortMultipleOptions::default() + .with_maintain_order(true) + .with_nulls_last(true), + ) .slice(0, 3) .collect()?; println!("{:?}", res); diff --git a/crates/polars-lazy/src/tests/streaming.rs b/crates/polars-lazy/src/tests/streaming.rs index a25a015a1e42..c320c162b3e2 100644 --- a/crates/polars-lazy/src/tests/streaming.rs +++ b/crates/polars-lazy/src/tests/streaming.rs @@ -39,7 +39,7 @@ fn test_streaming_parquet() -> PolarsResult<()> { let q = q .group_by([col("sugars_g")]) .agg([((lit(1) - col("fats_g")) + col("calories")).sum()]) - .sort("sugars_g", Default::default()); + .sort(["sugars_g"], Default::default()); assert_streaming_with_default(q, true, false); Ok(()) @@ -53,7 +53,7 @@ fn test_streaming_csv() -> PolarsResult<()> { .select([col("sugars_g"), col("calories")]) .group_by([col("sugars_g")]) .agg([col("calories").sum()]) - .sort("sugars_g", Default::default()); + .sort(["sugars_g"], Default::default()); assert_streaming_with_default(q, true, false); Ok(()) @@ -62,7 +62,7 @@ fn test_streaming_csv() -> PolarsResult<()> { #[test] fn test_streaming_glob() -> PolarsResult<()> { let q = get_csv_glob(); - let q = q.sort("sugars_g", Default::default()); + let q = q.sort(["sugars_g"], Default::default()); assert_streaming_with_default(q, true, false); Ok(()) @@ -102,9 +102,7 @@ fn test_streaming_multiple_keys_aggregate() -> PolarsResult<()> { ]) .sort_by_exprs( [col("sugars_g"), col("calories")], - [false, false], - false, - false, + SortMultipleOptions::default().with_order_descendings([false, false]), ); assert_streaming_with_default(q, true, false); @@ -122,7 +120,7 @@ fn test_streaming_first_sum() -> PolarsResult<()> { col("calories").sum(), col("calories").first().alias("calories_first"), ]) - .sort("sugars_g", Default::default()); + .sort(["sugars_g"], Default::default()); assert_streaming_with_default(q, true, false); Ok(()) @@ -135,7 +133,10 @@ fn test_streaming_unique() -> PolarsResult<()> { let q = q .select([col("sugars_g"), col("calories")]) .unique(None, Default::default()) - .sort_by_exprs([cols(["sugars_g", "calories"])], [false], false, false); + .sort_by_exprs( + [cols(["sugars_g", "calories"])], + SortMultipleOptions::default(), + ); assert_streaming_with_default(q, true, false); Ok(()) @@ -379,7 +380,12 @@ fn test_sort_maintain_order_streaming() -> PolarsResult<()> { .lazy(); let res = q - .sort_by_exprs([col("A")], [false], false, true) + .sort_by_exprs( + [col("A")], + SortMultipleOptions::default() + .with_nulls_last(true) + .with_maintain_order(true), + ) .slice(0, 3) .with_streaming(true) .collect()?; @@ -406,7 +412,7 @@ fn test_streaming_outer_join() -> PolarsResult<()> { let q = lf_left .outer_join(lf_right, col("a"), col("a")) - .sort_by_exprs([all()], [false], false, false); + .sort_by_exprs([all()], SortMultipleOptions::default()); // Toggle so that the join order is swapped. for toggle in [true, true] { diff --git a/crates/polars-lazy/src/tests/tpch.rs b/crates/polars-lazy/src/tests/tpch.rs index f6f4256df0aa..34447a0ef4f7 100644 --- a/crates/polars-lazy/src/tests/tpch.rs +++ b/crates/polars-lazy/src/tests/tpch.rs @@ -78,9 +78,9 @@ fn test_q2() -> PolarsResult<()> { ])]) .sort_by_exprs( [cols(["s_acctbal", "n_name", "s_name", "p_partkey"])], - [true, false, false, false], - false, - false, + SortMultipleOptions::default() + .with_order_descendings([true, false, false, false]) + .with_maintain_order(true), ) .limit(100) .with_comm_subplan_elim(true); diff --git a/crates/polars-ops/src/chunked_array/top_k.rs b/crates/polars-ops/src/chunked_array/top_k.rs index 68a8bf7d9a69..75f103c7c715 100644 --- a/crates/polars-ops/src/chunked_array/top_k.rs +++ b/crates/polars-ops/src/chunked_array/top_k.rs @@ -3,6 +3,7 @@ use std::cmp::Ordering; use arrow::array::{BooleanArray, MutableBooleanArray}; use arrow::bitmap::MutableBitmap; use either::Either; +use polars_core::chunked_array::ops::sort::arg_bottom_k::_arg_bottom_k; use polars_core::downcast_as_macro_arg_physical; use polars_core::prelude::*; use polars_utils::total_ord::TotalOrd; @@ -23,6 +24,23 @@ fn arg_partition Ordering>( } } +fn extract_target_and_k(s: &[Series]) -> PolarsResult<(usize, &Series)> { + let k_s = &s[1]; + + polars_ensure!( + k_s.len() == 1, + ComputeError: "`k` must be a single value for `top_k`." + ); + + let Some(k) = k_s.cast(&IDX_DTYPE)?.idx()?.get(0) else { + polars_bail!(ComputeError: "`k` must be set for `top_k`") + }; + + let src = &s[0]; + + Ok((k as usize, src)) +} + fn top_k_num_impl(ca: &ChunkedArray, k: usize, descending: bool) -> ChunkedArray where T: PolarsNumericType, @@ -144,18 +162,7 @@ fn top_k_binary_impl( } pub fn top_k(s: &[Series], descending: bool) -> PolarsResult { - let k_s = &s[1]; - - polars_ensure!( - k_s.len() == 1, - ComputeError: "`k` must be a single value for `top_k`." - ); - - let Some(k) = k_s.cast(&IDX_DTYPE)?.idx()?.get(0) else { - polars_bail!(ComputeError: "`k` must be set for `top_k`") - }; - - let src = &s[0]; + let (k, src) = extract_target_and_k(s)?; if src.is_empty() { return Ok(src.clone()); @@ -164,12 +171,10 @@ pub fn top_k(s: &[Series], descending: bool) -> PolarsResult { match src.is_sorted_flag() { polars_core::series::IsSorted::Ascending => { // TopK is the k element in the bottom of ascending sorted array - return Ok(src - .slice((src.len() - k as usize) as i64, k as usize) - .reverse()); + return Ok(src.slice((src.len() - k) as i64, k).reverse()); }, polars_core::series::IsSorted::Descending => { - return Ok(src.slice(0, k as usize)); + return Ok(src.slice(0, k)); }, _ => {}, } @@ -179,24 +184,45 @@ pub fn top_k(s: &[Series], descending: bool) -> PolarsResult { let s = src.to_physical_repr(); match s.dtype() { - DataType::Boolean => { - Ok(top_k_bool_impl(s.bool().unwrap(), k as usize, descending).into_series()) - }, + DataType::Boolean => Ok(top_k_bool_impl(s.bool().unwrap(), k, descending).into_series()), DataType::String => { - let ca = top_k_binary_impl(&s.str().unwrap().as_binary(), k as usize, descending); + let ca = top_k_binary_impl(&s.str().unwrap().as_binary(), k, descending); let ca = unsafe { ca.to_string() }; Ok(ca.into_series()) }, - DataType::Binary => { - Ok(top_k_binary_impl(s.binary().unwrap(), k as usize, descending).into_series()) - }, + DataType::Binary => Ok(top_k_binary_impl(s.binary().unwrap(), k, descending).into_series()), _dt => { macro_rules! dispatch { ($ca:expr) => {{ - top_k_num_impl($ca, k as usize, descending).into_series() + top_k_num_impl($ca, k, descending).into_series() }}; } unsafe { downcast_as_macro_arg_physical!(&s, dispatch).cast_unchecked(origin_dtype) } }, } } + +pub fn top_k_by( + s: &[Series], + by: &[Series], + sort_options: SortMultipleOptions, +) -> PolarsResult { + let (k, src) = extract_target_and_k(s)?; + + if src.is_empty() { + return Ok(src.clone()); + } + + let multithreaded = sort_options.multithreaded; + + let idx = _arg_bottom_k(k, by, &mut sort_options.with_order_reversed())?; + + let result = unsafe { + if multithreaded { + src.take_unchecked_threaded(&idx.into_inner(), false) + } else { + src.take_unchecked(&idx.into_inner()) + } + }; + Ok(result) +} diff --git a/crates/polars-ops/src/series/ops/cut.rs b/crates/polars-ops/src/series/ops/cut.rs index 3b7d10dcb0a4..1e01ca08bada 100644 --- a/crates/polars-ops/src/series/ops/cut.rs +++ b/crates/polars-ops/src/series/ops/cut.rs @@ -103,7 +103,7 @@ pub fn qcut( include_breaks: bool, ) -> PolarsResult { let s = s.cast(&DataType::Float64)?; - let s2 = s.sort(false, false)?; + let s2 = s.sort(SortOptions::default())?; let ca = s2.f64()?; if ca.null_count() == ca.len() { diff --git a/crates/polars-ops/src/series/ops/various.rs b/crates/polars-ops/src/series/ops/various.rs index cad413816ced..b47b9e0bfc88 100644 --- a/crates/polars-ops/src/series/ops/various.rs +++ b/crates/polars-ops/src/series/ops/various.rs @@ -21,7 +21,12 @@ pub trait SeriesMethods: SeriesSealed { let cols = vec![values, counts.into_series()]; let df = unsafe { DataFrame::new_no_checks(cols) }; if sort { - df.sort(["count"], true, false) + df.sort( + ["count"], + SortMultipleOptions::default() + .with_order_descending(true) + .with_multithreaded(parallel), + ) } else { Ok(df) } diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink.rs b/crates/polars-pipe/src/executors/sinks/sort/sink.rs index 01138240b6ce..2e19f27753b4 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink.rs @@ -2,12 +2,12 @@ use std::any::Any; use std::sync::{Arc, RwLock}; use std::time::Instant; +use polars_core::chunked_array::ops::SortMultipleOptions; use polars_core::config::verbose; use polars_core::error::PolarsResult; use polars_core::frame::DataFrame; use polars_core::prelude::{AnyValue, SchemaRef, Series, SortOptions}; use polars_core::utils::accumulate_dataframes_vertical_unchecked; -use polars_plan::prelude::SortArguments; use crate::executors::sinks::io::{block_thread_until_io_thread_done, IOThread}; use crate::executors::sinks::memory::MemTracker; @@ -28,7 +28,8 @@ pub struct SortSink { io_thread: Arc>>, // location in the dataframe of the columns to sort by sort_idx: usize, - sort_args: SortArguments, + slice: Option<(i64, usize)>, + sort_options: SortMultipleOptions, // Statistics // sampled values so we can find the distribution. dist_sample: Vec>, @@ -41,7 +42,12 @@ pub struct SortSink { } impl SortSink { - pub(crate) fn new(sort_idx: usize, sort_args: SortArguments, schema: SchemaRef) -> Self { + pub(crate) fn new( + sort_idx: usize, + slice: Option<(i64, usize)>, + sort_options: SortMultipleOptions, + schema: SchemaRef, + ) -> Self { // for testing purposes let ooc = std::env::var(FORCE_OOC).is_ok(); let n_morsels_per_sink = morsels_per_sink(); @@ -53,7 +59,8 @@ impl SortSink { ooc, io_thread: Default::default(), sort_idx, - sort_args, + slice, + sort_options, dist_sample: vec![], current_chunk_rows: 0, current_chunks_size: 0, @@ -167,7 +174,8 @@ impl Sink for SortSink { ooc: self.ooc, io_thread: self.io_thread.clone(), sort_idx: self.sort_idx, - sort_args: self.sort_args.clone(), + slice: self.slice, + sort_options: self.sort_options.clone(), dist_sample: vec![], current_chunk_rows: 0, current_chunks_size: 0, @@ -183,12 +191,7 @@ impl Sink for SortSink { let io_thread = lock.take().unwrap(); let dist = Series::from_any_values("", &self.dist_sample, true).unwrap(); - let dist = dist.sort_with(SortOptions { - descending: self.sort_args.descending[0], - nulls_last: self.sort_args.nulls_last, - multithreaded: true, - maintain_order: self.sort_args.maintain_order, - })?; + let dist = dist.sort_with(SortOptions::from(&self.sort_options))?; let instant = self.ooc_start.unwrap(); if context.verbose { @@ -203,9 +206,9 @@ impl Sink for SortSink { io_thread, dist, self.sort_idx, - self.sort_args.descending[0], - self.sort_args.nulls_last, - self.sort_args.slice, + self.sort_options.descending[0], + self.sort_options.nulls_last, + self.slice, context.verbose, self.mem_track.clone(), instant, @@ -216,9 +219,8 @@ impl Sink for SortSink { let df = sort_accumulated( df, self.sort_idx, - self.sort_args.descending[0], - self.sort_args.slice, - self.sort_args.nulls_last, + self.slice, + SortOptions::from(&self.sort_options), )?; Ok(FinalizedSink::Finished(df)) } @@ -236,19 +238,15 @@ impl Sink for SortSink { pub(super) fn sort_accumulated( mut df: DataFrame, sort_idx: usize, - descending: bool, slice: Option<(i64, usize)>, - nulls_last: bool, + sort_options: SortOptions, ) -> PolarsResult { // This is needed because we can have empty blocks and we require chunks to have single chunks. df.as_single_chunk_par(); let sort_column = df.get_columns()[sort_idx].clone(); df.sort_impl( vec![sort_column], - vec![descending], - nulls_last, - false, + SortMultipleOptions::from(&sort_options), slice, - true, ) } diff --git a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs index d31d6e77e3a8..b24439330b80 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/sink_multiple.rs @@ -5,7 +5,6 @@ use polars_core::prelude::sort::_broadcast_descending; use polars_core::prelude::sort::arg_sort_multiple::_get_rows_encoded_compat_array; use polars_core::prelude::*; use polars_core::series::IsSorted; -use polars_plan::prelude::*; use polars_row::decode::decode_rows_from_binary; use polars_row::EncodingField; @@ -15,12 +14,12 @@ use crate::operators::{ }; const POLARS_SORT_COLUMN: &str = "__POLARS_SORT_COLUMN"; -fn get_sort_fields(sort_idx: &[usize], sort_args: &SortArguments) -> Vec { - let mut descending = sort_args.descending.clone(); +fn get_sort_fields(sort_idx: &[usize], sort_options: &SortMultipleOptions) -> Vec { + let mut descending = sort_options.descending.clone(); _broadcast_descending(sort_idx.len(), &mut descending); descending .into_iter() - .map(|descending| EncodingField::new_sorted(descending, sort_args.nulls_last)) + .map(|descending| EncodingField::new_sorted(descending, sort_options.nulls_last)) .collect() } @@ -54,7 +53,7 @@ fn sort_by_idx(values: &[V], idx: &[usize]) -> Vec { fn finalize_dataframe( df: &mut DataFrame, sort_idx: &[usize], - sort_args: &SortArguments, + sort_options: &SortMultipleOptions, can_decode: bool, sort_dtypes: Option<&[ArrowDataType]>, rows: &mut Vec<&'static [u8]>, @@ -101,7 +100,7 @@ fn finalize_dataframe( } let first_sort_col = &mut cols[sort_idx[0]]; - let flag = if sort_args.descending[0] { + let flag = if sort_options.descending[0] { IsSorted::Descending } else { IsSorted::Ascending @@ -121,7 +120,8 @@ pub struct SortSinkMultiple { output_schema: SchemaRef, sort_idx: Arc<[usize]>, sort_sink: Box, - sort_args: SortArguments, + slice: Option<(i64, usize)>, + sort_options: SortMultipleOptions, // Needed for encoding sort_fields: Arc<[EncodingField]>, sort_dtypes: Option>, @@ -136,7 +136,8 @@ pub struct SortSinkMultiple { impl SortSinkMultiple { pub(crate) fn new( - sort_args: SortArguments, + slice: Option<(i64, usize)>, + sort_options: SortMultipleOptions, output_schema: SchemaRef, sort_idx: Vec, ) -> PolarsResult { @@ -161,25 +162,26 @@ impl SortSinkMultiple { sort_dtypes = Some(dtypes.into()); } schema.with_column(POLARS_SORT_COLUMN.into(), DataType::BinaryOffset); - let sort_fields = get_sort_fields(&sort_idx, &sort_args); + let sort_fields = get_sort_fields(&sort_idx, &sort_options); // don't set descending and nulls last as this // will be solved by the row encoding let sort_sink = Box::new(SortSink::new( // we will set the last column as sort column schema.len() - 1, - SortArguments { - descending: vec![false], - nulls_last: false, - slice: sort_args.slice, - maintain_order: false, - }, + slice, + sort_options + .clone() + .with_order_descending(false) + .with_nulls_last(false) + .with_maintain_order(false), Arc::new(schema), )); Ok(SortSinkMultiple { sort_sink, - sort_args, + slice, + sort_options, sort_idx: Arc::from(sort_idx), sort_fields: Arc::from(sort_fields), sort_dtypes, @@ -254,7 +256,8 @@ impl Sink for SortSinkMultiple { sort_idx: self.sort_idx.clone(), sort_sink, sort_fields: self.sort_fields.clone(), - sort_args: self.sort_args.clone(), + slice: self.slice, + sort_options: self.sort_options.clone(), sort_column: vec![], can_decode: self.can_decode, sort_dtypes: self.sort_dtypes.clone(), @@ -277,7 +280,7 @@ impl Sink for SortSinkMultiple { finalize_dataframe( &mut df, self.sort_idx.as_ref(), - &self.sort_args, + &self.sort_options, self.can_decode, sort_dtypes.as_deref(), &mut vec![], @@ -289,7 +292,7 @@ impl Sink for SortSinkMultiple { FinalizedSink::Source(source) => Ok(FinalizedSink::Source(Box::new(DropEncoded { source, sort_idx: self.sort_idx.clone(), - sort_args: std::mem::take(&mut self.sort_args), + sort_options: self.sort_options.clone(), can_decode: self.can_decode, sort_dtypes, rows: vec![], @@ -313,7 +316,7 @@ impl Sink for SortSinkMultiple { struct DropEncoded { source: Box, sort_idx: Arc<[usize]>, - sort_args: SortArguments, + sort_options: SortMultipleOptions, can_decode: bool, sort_dtypes: Option>, rows: Vec<&'static [u8]>, @@ -329,7 +332,7 @@ impl Source for DropEncoded { finalize_dataframe( &mut chunk.data, self.sort_idx.as_ref(), - &self.sort_args, + &self.sort_options, self.can_decode, self.sort_dtypes.as_deref(), &mut self.rows, diff --git a/crates/polars-pipe/src/executors/sinks/sort/source.rs b/crates/polars-pipe/src/executors/sinks/sort/source.rs index edc49130c636..4cf427c264a5 100644 --- a/crates/polars-pipe/src/executors/sinks/sort/source.rs +++ b/crates/polars-pipe/src/executors/sinks/sort/source.rs @@ -92,7 +92,17 @@ impl SortSource { let current_slice = self.slice; let mut df = match &mut self.slice { - None => sort_accumulated(df, self.sort_idx, self.descending, None, self.nulls_last), + None => sort_accumulated( + df, + self.sort_idx, + None, + SortOptions { + descending: self.descending, + nulls_last: self.nulls_last, + multithreaded: true, + maintain_order: false, + }, + ), Some((offset, len)) => { let df_len = df.height(); debug_assert!(*offset >= 0); @@ -103,9 +113,13 @@ impl SortSource { let out = sort_accumulated( df, self.sort_idx, - self.descending, current_slice, - self.nulls_last, + SortOptions { + descending: self.descending, + nulls_last: self.nulls_last, + multithreaded: true, + maintain_order: false, + }, ); *len = len.saturating_sub(df_len); *offset = 0; diff --git a/crates/polars-pipe/src/pipeline/convert.rs b/crates/polars-pipe/src/pipeline/convert.rs index 09bc3a73d0e1..5cd430989892 100644 --- a/crates/polars-pipe/src/pipeline/convert.rs +++ b/crates/polars-pipe/src/pipeline/convert.rs @@ -341,7 +341,8 @@ where Sort { input, by_column, - args, + slice, + sort_options, } => { let input_schema = lp_arena.get(*input).schema(lp_arena).into_owned(); @@ -351,7 +352,7 @@ where .unwrap(); let index = input_schema.try_index_of(by_column.as_ref())?; - let sort_sink = SortSink::new(index, args.clone(), input_schema); + let sort_sink = SortSink::new(index, *slice, sort_options.clone(), input_schema); Box::new(sort_sink) as Box } else { let sort_idx = by_column @@ -364,7 +365,8 @@ where }) .collect::>>()?; - let sort_sink = SortSinkMultiple::new(args.clone(), input_schema, sort_idx)?; + let sort_sink = + SortSinkMultiple::new(*slice, sort_options.clone(), input_schema, sort_idx)?; Box::new(sort_sink) as Box } }, diff --git a/crates/polars-plan/src/dsl/expr.rs b/crates/polars-plan/src/dsl/expr.rs index 5e8a31dd65f2..9f02bef27f92 100644 --- a/crates/polars-plan/src/dsl/expr.rs +++ b/crates/polars-plan/src/dsl/expr.rs @@ -94,7 +94,7 @@ pub enum Expr { SortBy { expr: Arc, by: Vec, - descending: Vec, + sort_options: SortMultipleOptions, }, Agg(AggExpr), /// A ternary operation @@ -234,11 +234,11 @@ impl Hash for Expr { Expr::SortBy { expr, by, - descending, + sort_options, } => { expr.hash(state); by.hash(state); - descending.hash(state); + sort_options.hash(state); }, Expr::Agg(input) => input.hash(state), Expr::Explode(input) => input.hash(state), diff --git a/crates/polars-plan/src/dsl/functions/index.rs b/crates/polars-plan/src/dsl/functions/index.rs index 20e7245d4021..d125ce571307 100644 --- a/crates/polars-plan/src/dsl/functions/index.rs +++ b/crates/polars-plan/src/dsl/functions/index.rs @@ -5,11 +5,11 @@ use super::*; /// until duplicates are found. Once duplicates are found, the next `Series` will /// be used and so on. #[cfg(feature = "range")] -pub fn arg_sort_by>(by: E, descending: &[bool]) -> Expr { +pub fn arg_sort_by>(by: E, sort_options: SortMultipleOptions) -> Expr { let e = &by.as_ref()[0]; let name = expr_output_name(e).unwrap(); int_range(lit(0 as IdxSize), len().cast(IDX_DTYPE), 1, IDX_DTYPE) - .sort_by(by, descending) + .sort_by(by, sort_options) .alias(name.as_ref()) } diff --git a/crates/polars-plan/src/dsl/mod.rs b/crates/polars-plan/src/dsl/mod.rs index 3852999c9a89..406508bbdfa8 100644 --- a/crates/polars-plan/src/dsl/mod.rs +++ b/crates/polars-plan/src/dsl/mod.rs @@ -408,19 +408,36 @@ impl Expr { } } - /// Sort in increasing order. See [the eager implementation](Series::sort). - pub fn sort(self, descending: bool) -> Self { - Expr::Sort { - expr: Arc::new(self), - options: SortOptions { - descending, - ..Default::default() - }, - } - } - /// Sort with given options. - pub fn sort_with(self, options: SortOptions) -> Self { + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// # use polars_lazy::prelude::*; + /// # fn main() -> PolarsResult<()> { + /// let lf = df! { + /// "a" => [Some(5), Some(4), Some(3), Some(2), None] + /// }? + /// .lazy(); + /// + /// let sorted = lf + /// .select( + /// vec![col("a").sort(SortOptions::default())], + /// ) + /// .collect()?; + /// + /// assert_eq!( + /// sorted, + /// df! { + /// "a" => [None, Some(2), Some(3), Some(4), Some(5)] + /// }? + /// ); + /// # Ok(()) + /// # } + /// ``` + /// See [`SortOptions`] for more options. + pub fn sort(self, options: SortOptions) -> Self { Expr::Sort { expr: Arc::new(self), options, @@ -1071,19 +1088,42 @@ impl Expr { } } - /// Sort this column by the ordering of another column. + /// Sort this column by the ordering of another column evaluated from given expr. /// Can also be used in a group_by context to sort the groups. - pub fn sort_by, IE: Into + Clone, R: AsRef<[bool]>>( + /// + /// # Example + /// + /// ```rust + /// # use polars_core::prelude::*; + /// # use polars_lazy::prelude::*; + /// # fn main() -> PolarsResult<()> { + /// let lf = df! { + /// "a" => [1, 2, 3, 4, 5], + /// "b" => [5, 4, 3, 2, 1] + /// }?.lazy(); + /// + /// let sorted = lf + /// .select( + /// vec![col("a").sort_by(col("b"), SortOptions::default())], + /// ) + /// .collect()?; + /// + /// assert_eq!( + /// sorted, + /// df! { "a" => [5, 4, 3, 2, 1] }? + /// ); + /// # Ok(()) + /// # } + pub fn sort_by, IE: Into + Clone>( self, by: E, - descending: R, + sort_options: SortMultipleOptions, ) -> Expr { let by = by.as_ref().iter().map(|e| e.clone().into()).collect(); - let descending = descending.as_ref().to_vec(); Expr::SortBy { expr: Arc::new(self), by, - descending, + sort_options, } } diff --git a/crates/polars-plan/src/logical_plan/aexpr/hash.rs b/crates/polars-plan/src/logical_plan/aexpr/hash.rs index f9b6bcfcfb49..ec9e297faec6 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/hash.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/hash.rs @@ -25,7 +25,7 @@ impl Hash for AExpr { options.hash(state); }, AExpr::Agg(agg) => agg.hash(state), - AExpr::SortBy { descending, .. } => descending.hash(state), + AExpr::SortBy { sort_options, .. } => sort_options.hash(state), AExpr::Cast { strict, .. } => strict.hash(state), AExpr::Window { options, .. } => options.hash(state), AExpr::BinaryExpr { op, .. } => op.hash(state), diff --git a/crates/polars-plan/src/logical_plan/aexpr/mod.rs b/crates/polars-plan/src/logical_plan/aexpr/mod.rs index ae2245ce3721..ea61a8b25da7 100644 --- a/crates/polars-plan/src/logical_plan/aexpr/mod.rs +++ b/crates/polars-plan/src/logical_plan/aexpr/mod.rs @@ -148,7 +148,7 @@ pub enum AExpr { SortBy { expr: Node, by: Vec, - descending: Vec, + sort_options: SortMultipleOptions, }, Filter { input: Node, diff --git a/crates/polars-plan/src/logical_plan/alp/inputs.rs b/crates/polars-plan/src/logical_plan/alp/inputs.rs index 2fa6d0aaf967..497c26044c3b 100644 --- a/crates/polars-plan/src/logical_plan/alp/inputs.rs +++ b/crates/polars-plan/src/logical_plan/alp/inputs.rs @@ -69,11 +69,15 @@ impl IR { options: options.clone(), }, Sort { - by_column, args, .. + by_column, + slice, + sort_options, + .. } => Sort { input: inputs[0], by_column: by_column.clone(), - args: args.clone(), + slice: *slice, + sort_options: sort_options.clone(), }, Cache { id, cache_hits, .. } => Cache { input: inputs[0], diff --git a/crates/polars-plan/src/logical_plan/alp/mod.rs b/crates/polars-plan/src/logical_plan/alp/mod.rs index 8c2d1170c68e..0bb74eca267c 100644 --- a/crates/polars-plan/src/logical_plan/alp/mod.rs +++ b/crates/polars-plan/src/logical_plan/alp/mod.rs @@ -64,7 +64,8 @@ pub enum IR { Sort { input: Node, by_column: Vec, - args: SortArguments, + slice: Option<(i64, usize)>, + sort_options: SortMultipleOptions, }, Cache { input: Node, diff --git a/crates/polars-plan/src/logical_plan/builder.rs b/crates/polars-plan/src/logical_plan/builder.rs index d71d6d66c76e..7129be8534e5 100644 --- a/crates/polars-plan/src/logical_plan/builder.rs +++ b/crates/polars-plan/src/logical_plan/builder.rs @@ -815,24 +815,14 @@ impl LogicalPlanBuilder { .into() } - pub fn sort( - self, - by_column: Vec, - descending: Vec, - null_last: bool, - maintain_order: bool, - ) -> Self { + pub fn sort(self, by_column: Vec, sort_options: SortMultipleOptions) -> Self { let schema = try_delayed!(self.0.schema(), &self.0, into); let by_column = try_delayed!(rewrite_projections(by_column, &schema, &[]), &self.0, into); LogicalPlan::Sort { input: Arc::new(self.0), by_column, - args: SortArguments { - descending, - nulls_last: null_last, - slice: None, - maintain_order, - }, + slice: None, + sort_options, } .into() } diff --git a/crates/polars-plan/src/logical_plan/conversion.rs b/crates/polars-plan/src/logical_plan/conversion.rs index 40c938a0456d..4b1402459cc7 100644 --- a/crates/polars-plan/src/logical_plan/conversion.rs +++ b/crates/polars-plan/src/logical_plan/conversion.rs @@ -141,14 +141,14 @@ fn to_aexpr_impl(expr: Expr, arena: &mut Arena, state: &mut ConversionSta Expr::SortBy { expr, by, - descending, + sort_options, } => AExpr::SortBy { expr: to_aexpr_impl(owned(expr), arena, state), by: by .into_iter() .map(|e| to_aexpr_impl(e, arena, state)) .collect(), - descending, + sort_options, }, Expr::Filter { input, by } => AExpr::Filter { input: to_aexpr_impl(owned(input), arena, state), @@ -389,14 +389,16 @@ pub fn to_alp( LogicalPlan::Sort { input, by_column, - args, + slice, + sort_options, } => { let input = to_alp(owned(input), expr_arena, lp_arena)?; let by_column = to_expr_irs(by_column, expr_arena); IR::Sort { input, by_column, - args, + slice, + sort_options, } }, LogicalPlan::Cache { @@ -567,7 +569,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { AExpr::SortBy { expr, by, - descending, + sort_options, } => { let expr = node_to_expr(expr, expr_arena); let by = by @@ -577,7 +579,7 @@ pub fn node_to_expr(node: Node, expr_arena: &Arena) -> Expr { Expr::SortBy { expr: Arc::new(expr), by, - descending, + sort_options, } }, AExpr::Filter { input, by } => { @@ -856,14 +858,16 @@ impl IR { IR::Sort { input, by_column, - args, + slice, + sort_options, } => { let input = Arc::new(convert_to_lp(input, lp_arena)); let by_column = expr_irs_to_exprs(by_column, expr_arena); LogicalPlan::Sort { input, by_column, - args, + slice, + sort_options, } }, IR::Cache { diff --git a/crates/polars-plan/src/logical_plan/format.rs b/crates/polars-plan/src/logical_plan/format.rs index 098f3e8b4de7..14947cd70837 100644 --- a/crates/polars-plan/src/logical_plan/format.rs +++ b/crates/polars-plan/src/logical_plan/format.rs @@ -307,9 +307,12 @@ impl Debug for Expr { SortBy { expr, by, - descending, + sort_options, } => { - write!(f, "{expr:?}.sort_by(by={by:?}, descending={descending:?})",) + write!( + f, + "{expr:?}.sort_by(by={by:?}, sort_option={sort_options:?})", + ) }, Filter { input, by } => { write!(f, "{input:?}.filter({by:?})") diff --git a/crates/polars-plan/src/logical_plan/mod.rs b/crates/polars-plan/src/logical_plan/mod.rs index 3de72cc4de4f..e0ffddf24471 100644 --- a/crates/polars-plan/src/logical_plan/mod.rs +++ b/crates/polars-plan/src/logical_plan/mod.rs @@ -204,7 +204,8 @@ pub enum LogicalPlan { Sort { input: Arc, by_column: Vec, - args: SortArguments, + slice: Option<(i64, usize)>, + sort_options: SortMultipleOptions, }, /// Slice the table Slice { @@ -264,7 +265,7 @@ impl Clone for LogicalPlan { Self::Join { input_left, input_right, schema, left_on, right_on, options } => Self::Join { input_left: input_left.clone(), input_right: input_right.clone(), schema: schema.clone(), left_on: left_on.clone(), right_on: right_on.clone(), options: options.clone() }, Self::HStack { input, exprs, schema, options } => Self::HStack { input: input.clone(), exprs: exprs.clone(), schema: schema.clone(), options: options.clone() }, Self::Distinct { input, options } => Self::Distinct { input: input.clone(), options: options.clone() }, - Self::Sort { input, by_column, args } => Self::Sort { input: input.clone(), by_column: by_column.clone(), args: args.clone() }, + Self::Sort {input,by_column, slice, sort_options } => Self::Sort { input: input.clone(), by_column: by_column.clone(), slice: slice.clone(), sort_options: sort_options.clone() }, Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() }, Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() }, Self::Union { inputs, options } => Self::Union { inputs: inputs.clone(), options: options.clone() }, diff --git a/crates/polars-plan/src/logical_plan/optimizer/collapse_and_project.rs b/crates/polars-plan/src/logical_plan/optimizer/collapse_and_project.rs index 88dc1e1dc575..1b7239ab6329 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/collapse_and_project.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/collapse_and_project.rs @@ -128,12 +128,14 @@ impl OptimizationRule for SimpleProjectionAndCollapse { Sort { input, by_column, - args, + slice, + sort_options, } => match lp_arena.get(*input) { Sort { input: inner, .. } => Some(Sort { input: *inner, by_column: by_column.clone(), - args: args.clone(), + slice: *slice, + sort_options: sort_options.clone(), }), _ => None, }, diff --git a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs index 7041346234e7..671f02c843d2 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/projection_pushdown/mod.rs @@ -451,7 +451,8 @@ impl ProjectionPushDown { Sort { input, by_column, - args, + slice, + sort_options, } => { if !acc_projections.is_empty() { // Make sure that the column(s) used for the sort is projected @@ -476,7 +477,8 @@ impl ProjectionPushDown { Ok(Sort { input, by_column, - args, + slice, + sort_options, }) }, Distinct { input, options } => { diff --git a/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs b/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs index 2dcc9847b9ac..47a32b5fb79a 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/simplify_functions.rs @@ -23,12 +23,17 @@ pub(super) fn optimize_functions( AExpr::SortBy { expr, by, - descending, - } => Some(AExpr::SortBy { - expr: *expr, - by: by.clone(), - descending: descending.iter().map(|r| !*r).collect(), - }), + sort_options, + } => { + let mut sort_options = sort_options.clone(); + let reversed_descending = sort_options.descending.iter().map(|x| !*x).collect(); + sort_options.descending = reversed_descending; + Some(AExpr::SortBy { + expr: *expr, + by: by.clone(), + sort_options, + }) + }, // TODO: add support for cum_sum and other operation that allow reversing. _ => None, } diff --git a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs index 82750450e897..18b3c9d85631 100644 --- a/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs +++ b/crates/polars-plan/src/logical_plan/optimizer/slice_pushdown_lp.rs @@ -280,17 +280,19 @@ impl SlicePushDown { options, }) } - (Sort {input, by_column, mut args}, Some(state)) => { + (Sort {input, by_column, mut slice, + sort_options}, Some(state)) => { // first restart optimization in inputs and get the updated LP let input_lp = lp_arena.take(input); let input_lp = self.pushdown(input_lp, None, lp_arena, expr_arena)?; let input= lp_arena.add(input_lp); - args.slice = Some((state.offset, state.len as usize)); + slice = Some((state.offset, state.len as usize)); Ok(Sort { input, by_column, - args + slice, + sort_options }) } (Slice { diff --git a/crates/polars-plan/src/logical_plan/options.rs b/crates/polars-plan/src/logical_plan/options.rs index 76e35467822c..95c4e894e7f5 100644 --- a/crates/polars-plan/src/logical_plan/options.rs +++ b/crates/polars-plan/src/logical_plan/options.rs @@ -295,15 +295,6 @@ pub struct LogicalPlanUdfOptions { pub fmt_str: &'static str, } -#[derive(Clone, PartialEq, Eq, Debug, Default, Hash)] -#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] -pub struct SortArguments { - pub descending: Vec, - pub nulls_last: bool, - pub slice: Option<(i64, usize)>, - pub maintain_order: bool, -} - #[derive(Clone, PartialEq, Eq, Debug, Default)] #[cfg_attr(feature = "serde", derive(Serialize, Deserialize))] #[cfg(feature = "python")] diff --git a/crates/polars-plan/src/logical_plan/tree_format.rs b/crates/polars-plan/src/logical_plan/tree_format.rs index f62fd816dfef..9a12730dd61f 100644 --- a/crates/polars-plan/src/logical_plan/tree_format.rs +++ b/crates/polars-plan/src/logical_plan/tree_format.rs @@ -37,11 +37,16 @@ impl UpperExp for AExpr { ) }, AExpr::Gather { .. } => "gather", - AExpr::SortBy { descending, .. } => { + AExpr::SortBy { sort_options, .. } => { write!(f, "sort_by:")?; - for i in descending { + for i in &sort_options.descending { write!(f, "{}", *i as u8)?; } + write!( + f, + "{}{}", + sort_options.nulls_last as u8, sort_options.multithreaded as u8 + )?; return Ok(()); }, AExpr::Filter { .. } => "filter", diff --git a/crates/polars-plan/src/logical_plan/visitor/expr.rs b/crates/polars-plan/src/logical_plan/visitor/expr.rs index 404d18130cf3..8be5084d1314 100644 --- a/crates/polars-plan/src/logical_plan/visitor/expr.rs +++ b/crates/polars-plan/src/logical_plan/visitor/expr.rs @@ -51,7 +51,7 @@ impl TreeWalker for Expr { Cast { expr, data_type, strict } => Cast { expr: am(expr, f)?, data_type, strict }, Sort { expr, options } => Sort { expr: am(expr, f)?, options }, Gather { expr, idx, returns_scalar } => Gather { expr: am(expr, &mut f)?, idx: am(idx, f)?, returns_scalar }, - SortBy { expr, by, descending } => SortBy { expr: am(expr, &mut f)?, by: by.into_iter().map(f).collect::>()?, descending }, + SortBy { expr, by, sort_options } => SortBy { expr: am(expr, &mut f)?, by: by.into_iter().map(f).collect::>()?, sort_options }, Agg(agg_expr) => Agg(match agg_expr { Min { input, propagate_nans } => Min { input: am(input, f)?, propagate_nans }, Max { input, propagate_nans } => Max { input: am(input, f)?, propagate_nans }, @@ -179,7 +179,16 @@ impl AExpr { | (Len, Len) | (Slice { .. }, Slice { .. }) | (Explode(_), Explode(_)) => true, - (SortBy { descending: l, .. }, SortBy { descending: r, .. }) => l == r, + ( + SortBy { + sort_options: l_sort_options, + .. + }, + SortBy { + sort_options: r_sort_options, + .. + }, + ) => l_sort_options == r_sort_options, (Agg(l), Agg(r)) => l.equal_nodes(r), ( Function { diff --git a/crates/polars-plan/src/logical_plan/visitor/hash.rs b/crates/polars-plan/src/logical_plan/visitor/hash.rs index 0c059f930475..c47ed12e3f11 100644 --- a/crates/polars-plan/src/logical_plan/visitor/hash.rs +++ b/crates/polars-plan/src/logical_plan/visitor/hash.rs @@ -118,10 +118,12 @@ impl Hash for HashableEqLP<'_> { IR::Sort { input: _, by_column, - args, + slice, + sort_options, } => { hash_exprs(by_column, self.expr_arena, state); - args.hash(state); + slice.hash(state); + sort_options.hash(state); }, IR::GroupBy { input: _, @@ -321,14 +323,19 @@ impl HashableEqLP<'_> { IR::Sort { input: _, by_column: cl, - args: al, + slice: l_slice, + sort_options: l_options, }, IR::Sort { input: _, by_column: cr, - args: ar, + slice: r_slice, + sort_options: r_options, }, - ) => al == ar && expr_irs_eq(cl, cr, self.expr_arena), + ) => { + (l_slice == r_slice && l_options == r_options) + && expr_irs_eq(cl, cr, self.expr_arena) + }, ( IR::GroupBy { input: _, diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index c690042f18f5..a066bd91fd13 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -678,7 +678,12 @@ impl SQLContext { ); } - Ok(lf.sort_by_exprs(&by, descending, false, false)) + Ok(lf.sort_by_exprs( + &by, + SortMultipleOptions::default() + .with_order_descendings(descending) + .with_maintain_order(true), + )) } fn process_group_by( diff --git a/crates/polars-sql/src/functions.rs b/crates/polars-sql/src/functions.rs index cca0f7a85101..8ca60b2643d0 100644 --- a/crates/polars-sql/src/functions.rs +++ b/crates/polars-sql/src/functions.rs @@ -1,3 +1,4 @@ +use polars_core::chunked_array::ops::{SortMultipleOptions, SortOptions}; use polars_core::prelude::{polars_bail, polars_err, DataType, PolarsResult}; use polars_lazy::dsl::Expr; #[cfg(feature = "list_eval")] @@ -1137,7 +1138,15 @@ impl SQLFunctionVisitor<'_> { .collect::>>()? .into_iter() .unzip(); - self.visit_unary_no_window(|e| cumulative_f(e.sort_by(&order_by, &desc), false)) + self.visit_unary_no_window(|e| { + cumulative_f( + e.sort_by( + &order_by, + SortMultipleOptions::default().with_order_descendings(desc.clone()), + ), + false, + ) + }) } else { self.visit_unary(f) } @@ -1258,7 +1267,9 @@ impl SQLFunctionVisitor<'_> { .iter() .map(|o| { let e = parse_sql_expr(&o.expr, self.ctx)?; - Ok(o.asc.map_or(e.clone(), |b| e.sort(!b))) + Ok(o.asc.map_or(e.clone(), |b| { + e.sort(SortOptions::default().with_order_descending(!b)) + })) }) .collect::>>()?; expr.over(exprs) diff --git a/crates/polars-sql/src/sql_expr.rs b/crates/polars-sql/src/sql_expr.rs index 4d5c74982b1d..e5bc71f7858f 100644 --- a/crates/polars-sql/src/sql_expr.rs +++ b/crates/polars-sql/src/sql_expr.rs @@ -636,7 +636,10 @@ impl SQLExprVisitor<'_> { let mut base = self.visit_expr(&expr.expr)?; if let Some(order_by) = expr.order_by.as_ref() { let (order_by, descending) = self.visit_order_by(order_by)?; - base = base.sort_by(order_by, descending); + base = base.sort_by( + order_by, + SortMultipleOptions::default().with_order_descendings(descending), + ); } if let Some(limit) = &expr.limit { let limit = match self.visit_expr(limit)? { diff --git a/crates/polars-sql/tests/functions_cumulative.rs b/crates/polars-sql/tests/functions_cumulative.rs index 861ad67b358f..41d0a8aed69d 100644 --- a/crates/polars-sql/tests/functions_cumulative.rs +++ b/crates/polars-sql/tests/functions_cumulative.rs @@ -30,7 +30,7 @@ fn create_expected(expr: Expr, sql: &str) -> (DataFrame, DataFrame) { let expected = df .clone() .select(&[expr.alias(alias)]) - .sort(alias, SortOptions::default()) + .sort([alias], Default::default()) .collect() .unwrap(); let mut ctx = SQLContext::new(); @@ -42,7 +42,9 @@ fn create_expected(expr: Expr, sql: &str) -> (DataFrame, DataFrame) { #[test] fn test_cumulative_sum() { - let expr = col("Sales").sort(true).cum_sum(false); + let expr = col("Sales") + .sort(SortOptions::default().with_order_descending(true)) + .cum_sum(false); let sql_expr = "SUM(Sales) OVER (ORDER BY Sales DESC)"; let (expected, actual) = create_expected(expr, sql_expr); @@ -52,7 +54,9 @@ fn test_cumulative_sum() { #[test] fn test_cumulative_min() { - let expr = col("Sales").sort(true).cum_min(false); + let expr = col("Sales") + .sort(SortOptions::default().with_order_descending(true)) + .cum_min(false); let sql_expr = "MIN(Sales) OVER (ORDER BY Sales DESC)"; let (expected, actual) = create_expected(expr, sql_expr); @@ -62,7 +66,9 @@ fn test_cumulative_min() { #[test] fn test_cumulative_max() { - let expr = col("Sales").sort(true).cum_max(false); + let expr = col("Sales") + .sort(SortOptions::default().with_order_descending(true)) + .cum_max(false); let sql_expr = "MAX(Sales) OVER (ORDER BY Sales DESC)"; let (expected, actual) = create_expected(expr, sql_expr); diff --git a/crates/polars-sql/tests/functions_string.rs b/crates/polars-sql/tests/functions_string.rs index 7e08e48c1e86..a1e56ea55134 100644 --- a/crates/polars-sql/tests/functions_string.rs +++ b/crates/polars-sql/tests/functions_string.rs @@ -118,7 +118,10 @@ fn test_array_to_string() { .group_by([col("b")]) .agg([col("a")]) .select(&[col("b"), col("a").list().join(lit(", "), true).alias("as")]) - .sort_by_exprs(vec![col("b"), col("as")], vec![false, false], false, true) + .sort_by_exprs( + vec![col("b"), col("as")], + SortMultipleOptions::default().with_maintain_order(true), + ) .collect() .unwrap(); diff --git a/crates/polars-sql/tests/iss_7437.rs b/crates/polars-sql/tests/iss_7437.rs index f8fd845e38d5..1db92a33d992 100644 --- a/crates/polars-sql/tests/iss_7437.rs +++ b/crates/polars-sql/tests/iss_7437.rs @@ -24,14 +24,14 @@ fn iss_7437() -> PolarsResult<()> { "#, )? .collect()? - .sort(["category"], vec![false], false)?; + .sort(["category"], SortMultipleOptions::default())?; let expected = LazyCsvReader::new("../../examples/datasets/foods1.csv") .finish()? .group_by(vec![col("category").alias("category")]) .agg(vec![]) .collect()? - .sort(["category"], vec![false], false)?; + .sort(["category"], Default::default())?; assert!(df_sql.equals(&expected)); Ok(()) diff --git a/crates/polars-sql/tests/iss_8395.rs b/crates/polars-sql/tests/iss_8395.rs index 06d74affb33e..24fcaa7de3b3 100644 --- a/crates/polars-sql/tests/iss_8395.rs +++ b/crates/polars-sql/tests/iss_8395.rs @@ -19,7 +19,7 @@ fn iss_8395() -> PolarsResult<()> { let df = res.collect()?; // assert that the df only contains [vegetables, seafood] - let s = df.column("category")?.unique()?.sort(false, false)?; + let s = df.column("category")?.unique()?.sort(Default::default())?; let expected = Series::new("category", &["seafood", "vegetables"]); assert!(s.equals(&expected)); Ok(()) diff --git a/crates/polars-sql/tests/iss_8419.rs b/crates/polars-sql/tests/iss_8419.rs index c3af9868d91f..d967eefbe487 100644 --- a/crates/polars-sql/tests/iss_8419.rs +++ b/crates/polars-sql/tests/iss_8419.rs @@ -18,11 +18,11 @@ fn iss_8419() { col("Country"), col("Sales"), col("Sales") - .sort(true) + .sort(SortOptions::default().with_order_descending(true)) .cum_sum(false) .alias("SalesCumulative"), ]) - .sort("SalesCumulative", SortOptions::default()) + .sort(["SalesCumulative"], Default::default()) .collect() .unwrap(); let mut ctx = SQLContext::new(); diff --git a/crates/polars-sql/tests/ops_distinct_on.rs b/crates/polars-sql/tests/ops_distinct_on.rs index d9016b24a9b0..77bc4652a9f8 100644 --- a/crates/polars-sql/tests/ops_distinct_on.rs +++ b/crates/polars-sql/tests/ops_distinct_on.rs @@ -1,3 +1,4 @@ +use polars_core::chunked_array::ops::SortMultipleOptions; use polars_core::df; use polars_lazy::prelude::*; use polars_sql::*; @@ -29,9 +30,9 @@ fn test_distinct_on() { let expected = df .sort_by_exprs( vec![col("Name"), col("Record Date")], - vec![false, true], - true, - false, + SortMultipleOptions::default() + .with_order_descendings([false, true]) + .with_maintain_order(true), ) .group_by_stable(vec![col("Name")]) .agg(vec![col("*").first()]); diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index 0dabdf1cbbff..a2d0a0a7a9db 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -68,14 +68,7 @@ fn test_group_by_simple() -> PolarsResult<()> { LIMIT 100 "#, )? - .sort( - "a", - SortOptions { - descending: false, - nulls_last: false, - ..Default::default() - }, - ) + .sort(["a"], Default::default()) .collect()?; let df_pl = df @@ -87,14 +80,7 @@ fn test_group_by_simple() -> PolarsResult<()> { col("a").count().alias("total_count"), ]) .limit(100) - .sort( - "a", - SortOptions { - descending: false, - nulls_last: false, - ..Default::default() - }, - ) + .sort(["a"], Default::default()) .collect()?; assert_eq!(df_sql, df_pl); Ok(()) @@ -417,7 +403,7 @@ fn test_arr_agg() { ( "SELECT ARRAY_AGG(a ORDER BY a) AS a FROM df", vec![col("a") - .sort_by(vec![col("a")], vec![false]) + .sort_by(vec![col("a")], SortMultipleOptions::default()) .implode() .alias("a")], ), @@ -432,7 +418,7 @@ fn test_arr_agg() { ( "SELECT ARRAY_AGG(a ORDER BY b LIMIT 2) FROM df", vec![col("a") - .sort_by(vec![col("b")], vec![false]) + .sort_by(vec![col("b")], SortMultipleOptions::default()) .head(Some(2)) .implode()], ), @@ -496,9 +482,7 @@ fn test_group_by_2() -> PolarsResult<()> { ]) .sort_by_exprs( vec![col("count"), col("category")], - vec![false, true], - false, - false, + SortMultipleOptions::default().with_order_descendings([false, true]), ) .limit(2); let expected = expected.collect()?; diff --git a/crates/polars/tests/it/lazy/aggregation.rs b/crates/polars/tests/it/lazy/aggregation.rs index eb8172e3af41..33662c442959 100644 --- a/crates/polars/tests/it/lazy/aggregation.rs +++ b/crates/polars/tests/it/lazy/aggregation.rs @@ -29,7 +29,7 @@ fn test_lazy_agg() { .quantile(lit(0.5), QuantileInterpolOptions::default()) .alias("median_rain"), ]) - .sort("date", Default::default()); + .sort(["date"], Default::default()); let new = lf.collect().unwrap(); let min = new.column("min").unwrap(); diff --git a/crates/polars/tests/it/lazy/expressions/apply.rs b/crates/polars/tests/it/lazy/expressions/apply.rs index e4996920460a..3c1cb8f46d83 100644 --- a/crates/polars/tests/it/lazy/expressions/apply.rs +++ b/crates/polars/tests/it/lazy/expressions/apply.rs @@ -52,7 +52,7 @@ fn test_groups_update_binary_shift_log() -> PolarsResult<()> { .lazy() .group_by([col("b")]) .agg([col("a") - col("a").shift(lit(1)).log(2.0)]) - .sort("b", Default::default()) + .sort(["b"], Default::default()) .explode([col("a")]) .collect()?; assert_eq!( diff --git a/crates/polars/tests/it/lazy/expressions/arity.rs b/crates/polars/tests/it/lazy/expressions/arity.rs index f51164b2862c..9e0acb248acc 100644 --- a/crates/polars/tests/it/lazy/expressions/arity.rs +++ b/crates/polars/tests/it/lazy/expressions/arity.rs @@ -218,7 +218,7 @@ fn test_when_then_otherwise_sum_in_agg() -> PolarsResult<()> { .agg([when(all().exclude(["groups"]).sum().eq(lit(1))) .then(all().exclude(["groups"]).sum()) .otherwise(lit(NULL))]) - .sort("groups", Default::default()); + .sort(["groups"], Default::default()); let expected = df![ "groups" => [1, 2], @@ -274,7 +274,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> { .agg([when(col("value").sum().eq(lit(3))) .then(col("value").rank(Default::default(), None)) .otherwise(lit(Series::new("", &[10 as IdxSize])))]) - .sort("name", Default::default()) + .sort(["name"], Default::default()) .collect()?; let out = out.column("value")?; @@ -294,7 +294,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> { .agg([when(col("value").sum().eq(lit(3))) .then(lit(Series::new("", &[10 as IdxSize])).alias("value")) .otherwise(col("value").rank(Default::default(), None))]) - .sort("name", Default::default()) + .sort(["name"], Default::default()) .collect()?; let out = out.column("value")?; @@ -314,7 +314,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> { .agg([when(col("value").sum().eq(lit(3))) .then(col("value").rank(Default::default(), None)) .otherwise(Null {}.lit())]) - .sort("name", Default::default()) + .sort(["name"], Default::default()) .collect()?; let out = out.column("value")?; @@ -328,7 +328,7 @@ fn test_ternary_aggregation_set_literals() -> PolarsResult<()> { .agg([when(col("value").sum().eq(lit(3))) .then(Null {}.lit().alias("value")) .otherwise(col("value").rank(Default::default(), None))]) - .sort("name", Default::default()) + .sort(["name"], Default::default()) .collect()?; let out = out.column("value")?; @@ -350,7 +350,7 @@ fn test_binary_group_consistency() -> PolarsResult<()> { let out = lf .group_by([col("category")]) .agg([col("name").filter(col("score").eq(col("score").max()))]) - .sort("category", Default::default()) + .sort(["category"], Default::default()) .collect()?; let out = out.column("name")?; diff --git a/crates/polars/tests/it/lazy/expressions/window.rs b/crates/polars/tests/it/lazy/expressions/window.rs index 32a46fe01929..4057e16bfbec 100644 --- a/crates/polars/tests/it/lazy/expressions/window.rs +++ b/crates/polars/tests/it/lazy/expressions/window.rs @@ -76,7 +76,7 @@ fn test_exploded_window_function() -> PolarsResult<()> { let out = df .clone() .lazy() - .sort("fruits", Default::default()) + .sort(["fruits"], Default::default()) .select([ col("fruits"), col("B") @@ -95,7 +95,7 @@ fn test_exploded_window_function() -> PolarsResult<()> { // we implicitly also test that a literal does not upcast a column let out = df .lazy() - .sort("fruits", Default::default()) + .sort(["fruits"], Default::default()) .select([ col("fruits"), col("B") @@ -119,7 +119,7 @@ fn test_reverse_in_groups() -> PolarsResult<()> { let out = df .lazy() - .sort("fruits", Default::default()) + .sort(["fruits"], Default::default()) .select([ col("B"), col("fruits"), @@ -140,12 +140,12 @@ fn test_sort_by_in_groups() -> PolarsResult<()> { let out = df .lazy() - .sort("cars", Default::default()) + .sort(["cars"], Default::default()) .select([ col("fruits"), col("cars"), col("A") - .sort_by([col("B")], [false]) + .sort_by([col("B")], SortMultipleOptions::default()) .implode() .over([col("cars")]) .explode() @@ -247,7 +247,7 @@ fn test_window_mapping() -> PolarsResult<()> { // now sorted // this will trigger a fast path - let df = df.sort(["fruits"], vec![false], false)?; + let df = df.sort(["fruits"], Default::default())?; let out = df .clone() diff --git a/crates/polars/tests/it/lazy/group_by.rs b/crates/polars/tests/it/lazy/group_by.rs index d8e20c804ca0..1ccb481d6ee0 100644 --- a/crates/polars/tests/it/lazy/group_by.rs +++ b/crates/polars/tests/it/lazy/group_by.rs @@ -21,10 +21,10 @@ fn test_filter_sort_diff_2984() -> PolarsResult<()> { .group_by([col("group")]) .agg([col("id") .filter(col("id").lt(lit(3))) - .sort(false) + .sort(Default::default()) .diff(1, Default::default()) .sum()]) - .sort("group", Default::default()) + .sort(["group"], Default::default()) .collect()?; assert_eq!(Vec::from(out.column("id")?.i32()?), &[Some(1), Some(0)]); @@ -72,7 +72,7 @@ fn test_filter_diff_arithmetic() -> PolarsResult<()> { .diff(1, Default::default()) * lit(2)) .alias("diff")]) - .sort("user", Default::default()) + .sort(["user"], Default::default()) .explode([col("diff")]) .collect()?; @@ -113,7 +113,7 @@ fn test_group_by_agg_list_with_not_aggregated() -> PolarsResult<()> { .agg([when(col("value").diff(1, NullBehavior::Ignore).gt_eq(0)) .then(col("value").diff(1, NullBehavior::Ignore)) .otherwise(col("value"))]) - .sort("group", Default::default()) + .sort(["group"], Default::default()) .collect()?; let out = out.column("value")?; @@ -140,7 +140,7 @@ fn test_logical_mean_partitioned_group_by_block() -> PolarsResult<()> { .with_column(col("duration").cast(DataType::Duration(TimeUnit::Microseconds))) .group_by([col("decimal")]) .agg([col("duration").mean()]) - .sort("duration", Default::default()) + .sort(["duration"], Default::default()) .collect()?; let duration = out.column("duration")?; @@ -167,7 +167,7 @@ fn test_filter_aggregated_expression() -> PolarsResult<()> { .lazy() .group_by([col("day")]) .agg([(col("x") - col("x").first()).filter(f)]) - .sort("day", Default::default()) + .sort(["day"], Default::default()) .collect() .unwrap(); let x = df.column("x")?; diff --git a/docs/src/rust/user-guide/concepts/contexts.rs b/docs/src/rust/user-guide/concepts/contexts.rs index ade4e61b5b67..1ff1114d4bf0 100644 --- a/docs/src/rust/user-guide/concepts/contexts.rs +++ b/docs/src/rust/user-guide/concepts/contexts.rs @@ -23,7 +23,7 @@ fn main() -> Result<(), Box> { .lazy() .select([ sum("nrs"), - col("names").sort(false), + col("names").sort(Default::default()), col("names").first().alias("first name"), (mean("nrs") * lit(10)).alias("10xnrs"), ]) diff --git a/docs/src/rust/user-guide/expressions/aggregation.rs b/docs/src/rust/user-guide/expressions/aggregation.rs index 532b89db9482..a0b6f7bf029d 100644 --- a/docs/src/rust/user-guide/expressions/aggregation.rs +++ b/docs/src/rust/user-guide/expressions/aggregation.rs @@ -49,12 +49,10 @@ fn main() -> Result<(), Box> { .group_by(["first_name"]) .agg([len(), col("gender"), col("last_name").first()]) .sort( - "len", - SortOptions { - descending: true, - nulls_last: true, - ..Default::default() - }, + ["len"], + SortMultipleOptions::default() + .with_order_descending(true) + .with_nulls_last(true), ) .limit(5) .collect()?; @@ -76,12 +74,8 @@ fn main() -> Result<(), Box> { .alias("pro"), ]) .sort( - "pro", - SortOptions { - descending: true, - nulls_last: false, - ..Default::default() - }, + ["pro"], + SortMultipleOptions::default().with_order_descending(true), ) .limit(5) .collect()?; @@ -101,12 +95,10 @@ fn main() -> Result<(), Box> { .or(col("party").eq(lit("Pro-Administration"))), ) .sort( - "count", - SortOptions { - descending: true, - nulls_last: true, - ..Default::default() - }, + ["count"], + SortMultipleOptions::default() + .with_order_descending(true) + .with_nulls_last(true), ) .limit(5) .collect()?; @@ -151,12 +143,10 @@ fn main() -> Result<(), Box> { .clone() .lazy() .sort( - "birthday", - SortOptions { - descending: true, - nulls_last: true, - ..Default::default() - }, + ["birthday"], + SortMultipleOptions::default() + .with_order_descending(true) + .with_nulls_last(true), ) .group_by(["state"]) .agg([ @@ -174,18 +164,19 @@ fn main() -> Result<(), Box> { .clone() .lazy() .sort( - "birthday", - SortOptions { - descending: true, - nulls_last: true, - ..Default::default() - }, + ["birthday"], + SortMultipleOptions::default() + .with_order_descending(true) + .with_nulls_last(true), ) .group_by(["state"]) .agg([ get_person().first().alias("youngest"), get_person().last().alias("oldest"), - get_person().sort(false).first().alias("alphabetical_first"), + get_person() + .sort(Default::default()) + .first() + .alias("alphabetical_first"), ]) .limit(5) .collect()?; @@ -198,24 +189,25 @@ fn main() -> Result<(), Box> { .clone() .lazy() .sort( - "birthday", - SortOptions { - descending: true, - nulls_last: true, - ..Default::default() - }, + ["birthday"], + SortMultipleOptions::default() + .with_order_descending(true) + .with_nulls_last(true), ) .group_by(["state"]) .agg([ get_person().first().alias("youngest"), get_person().last().alias("oldest"), - get_person().sort(false).first().alias("alphabetical_first"), + get_person() + .sort(Default::default()) + .first() + .alias("alphabetical_first"), col("gender") - .sort_by(["first_name"], [false]) + .sort_by(["first_name"], SortMultipleOptions::default()) .first() .alias("gender"), ]) - .sort("state", SortOptions::default()) + .sort(["state"], SortMultipleOptions::default()) .limit(5) .collect()?; diff --git a/docs/src/rust/user-guide/expressions/window.rs b/docs/src/rust/user-guide/expressions/window.rs index b64fd7e31406..b73e62b05490 100644 --- a/docs/src/rust/user-guide/expressions/window.rs +++ b/docs/src/rust/user-guide/expressions/window.rs @@ -54,7 +54,10 @@ fn main() -> Result<(), Box> { let out = filtered .lazy() .with_columns([cols(["Name", "Speed"]) - .sort_by(["Speed"], [true]) + .sort_by( + ["Speed"], + SortMultipleOptions::default().with_order_descending(true), + ) .over(["Type 1"])]) .collect()?; println!("{}", out); @@ -94,19 +97,25 @@ fn main() -> Result<(), Box> { .select([ col("Type 1").head(Some(3)).over(["Type 1"]).flatten(), col("Name") - .sort_by(["Speed"], [true]) + .sort_by( + ["Speed"], + SortMultipleOptions::default().with_order_descending(true), + ) .head(Some(3)) .over(["Type 1"]) .flatten() .alias("fastest/group"), col("Name") - .sort_by(["Attack"], [true]) + .sort_by( + ["Attack"], + SortMultipleOptions::default().with_order_descending(true), + ) .head(Some(3)) .over(["Type 1"]) .flatten() .alias("strongest/group"), col("Name") - .sort(false) + .sort(Default::default()) .head(Some(3)) .over(["Type 1"]) .flatten() diff --git a/docs/src/rust/user-guide/transformations/joins.rs b/docs/src/rust/user-guide/transformations/joins.rs index 80647431d0a1..5c0526bba90a 100644 --- a/docs/src/rust/user-guide/transformations/joins.rs +++ b/docs/src/rust/user-guide/transformations/joins.rs @@ -185,8 +185,18 @@ fn main() -> Result<(), Box> { // --8<-- [end:df8] // --8<-- [start:asofpre] - let df_trades = df_trades.sort(["time"], false, true).unwrap(); - let df_quotes = df_quotes.sort(["time"], false, true).unwrap(); + let df_trades = df_trades + .sort( + ["time"], + SortMultipleOptions::default().with_maintain_order(true), + ) + .unwrap(); + let df_quotes = df_quotes + .sort( + ["time"], + SortMultipleOptions::default().with_maintain_order(true), + ) + .unwrap(); // --8<-- [end:asofpre] // --8<-- [start:asof] diff --git a/docs/src/rust/user-guide/transformations/time-series/rolling.rs b/docs/src/rust/user-guide/transformations/time-series/rolling.rs index c9b7e58906cc..f8849ddabe41 100644 --- a/docs/src/rust/user-guide/transformations/time-series/rolling.rs +++ b/docs/src/rust/user-guide/transformations/time-series/rolling.rs @@ -11,7 +11,10 @@ fn main() -> Result<(), Box> { .with_try_parse_dates(true) .finish() .unwrap() - .sort(["Date"], false, true)?; + .sort( + ["Date"], + SortMultipleOptions::default().with_maintain_order(true), + )?; println!("{}", &df); // --8<-- [end:df] diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 4c7bf0d9a6f6..fb2a40504808 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -4021,6 +4021,8 @@ def sort( *more_by: IntoExpr, descending: bool | Sequence[bool] = False, nulls_last: bool = False, + multithreaded: bool = True, + maintain_order: bool = False, ) -> DataFrame: """ Sort the dataframe by the given columns. @@ -4037,6 +4039,10 @@ def sort( per column by passing a sequence of booleans. nulls_last Place null values last. + multithreaded + Sort using multiple threads. + maintain_order + Whether the order should be maintained if elements are equal. Examples -------- @@ -4105,7 +4111,14 @@ def sort( """ return ( self.lazy() - .sort(by, *more_by, descending=descending, nulls_last=nulls_last) + .sort( + by, + *more_by, + descending=descending, + nulls_last=nulls_last, + multithreaded=multithreaded, + maintain_order=maintain_order, + ) .collect(_eager=True) ) diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index 75237213df47..357be09e5044 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -2264,6 +2264,9 @@ def sort_by( by: IntoExpr | Iterable[IntoExpr], *more_by: IntoExpr, descending: bool | Sequence[bool] = False, + nulls_last: bool = False, + multithreaded: bool = True, + maintain_order: bool = False, ) -> Self: """ Sort this column by the ordering of other columns. @@ -2281,6 +2284,12 @@ def sort_by( descending Sort in descending order. When sorting by multiple columns, can be specified per column by passing a sequence of booleans. + nulls_last + Place null values last. + multithreaded + Sort using multiple threads. + maintain_order + Whether the order should be maintained if elements are equal. Examples -------- @@ -2388,7 +2397,11 @@ def sort_by( elif len(by) != len(descending): msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})" raise ValueError(msg) - return self._from_pyexpr(self._pyexpr.sort_by(by, descending)) + return self._from_pyexpr( + self._pyexpr.sort_by( + by, descending, nulls_last, multithreaded, maintain_order + ) + ) def gather( self, indices: int | list[int] | Expr | Series | np.ndarray[Any, Any] diff --git a/py-polars/polars/functions/lazy.py b/py-polars/polars/functions/lazy.py index 9c135062dd6c..1a9059e00324 100644 --- a/py-polars/polars/functions/lazy.py +++ b/py-polars/polars/functions/lazy.py @@ -1537,6 +1537,9 @@ def arg_sort_by( exprs: IntoExpr | Iterable[IntoExpr], *more_exprs: IntoExpr, descending: bool | Sequence[bool] = False, + nulls_last: bool = False, + multithreaded: bool = True, + maintain_order: bool = False, ) -> Expr: """ Return the row indices that would sort the column(s). @@ -1551,6 +1554,12 @@ def arg_sort_by( descending Sort in descending order. When sorting by multiple columns, can be specified per column by passing a sequence of booleans. + nulls_last + Place null values last. + multithreaded + Sort using multiple threads. + maintain_order + Whether the order should be maintained if elements are equal. See Also -------- @@ -1619,7 +1628,9 @@ def arg_sort_by( elif len(exprs) != len(descending): msg = f"the length of `descending` ({len(descending)}) does not match the length of `exprs` ({len(exprs)})" raise ValueError(msg) - return wrap_expr(plr.arg_sort_by(exprs, descending)) + return wrap_expr( + plr.arg_sort_by(exprs, descending, nulls_last, multithreaded, maintain_order) + ) def collect_all( diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index bbdf3b605405..688ef8d40c54 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -1101,6 +1101,7 @@ def sort( descending: bool | Sequence[bool] = False, nulls_last: bool = False, maintain_order: bool = False, + multithreaded: bool = True, ) -> Self: """ Sort the LazyFrame by the given columns. @@ -1121,6 +1122,8 @@ def sort( Whether the order should be maintained if elements are equal. Note that if `true` streaming is not possible and performance might be worse since this requires a stable search. + multithreaded + Sort using multiple threads. Examples -------- @@ -1190,7 +1193,9 @@ def sort( # Fast path for sorting by a single existing column if isinstance(by, str) and not more_by: return self._from_pyldf( - self._ldf.sort(by, descending, nulls_last, maintain_order) + self._ldf.sort( + by, descending, nulls_last, maintain_order, multithreaded + ) ) by = parse_as_list_of_expressions(by, *more_by) @@ -1201,7 +1206,9 @@ def sort( msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})" raise ValueError(msg) return self._from_pyldf( - self._ldf.sort_by_exprs(by, descending, nulls_last, maintain_order) + self._ldf.sort_by_exprs( + by, descending, nulls_last, maintain_order, multithreaded + ) ) def top_k( @@ -1212,6 +1219,7 @@ def top_k( descending: bool | Sequence[bool] = False, nulls_last: bool = False, maintain_order: bool = False, + multithreaded: bool = True, ) -> Self: """ Return the `k` largest elements. @@ -1234,6 +1242,8 @@ def top_k( Whether the order should be maintained if elements are equal. Note that if `true` streaming is not possible and performance might be worse since this requires a stable search. + multithreaded + Sort using multiple threads. See Also -------- @@ -1285,7 +1295,9 @@ def top_k( msg = f"the length of `descending` ({len(descending)}) does not match the length of `by` ({len(by)})" raise ValueError(msg) return self._from_pyldf( - self._ldf.top_k(k, by, descending, nulls_last, maintain_order) + self._ldf.top_k( + k, by, descending, nulls_last, maintain_order, multithreaded + ) ) def bottom_k( @@ -1296,6 +1308,7 @@ def bottom_k( descending: bool | Sequence[bool] = False, nulls_last: bool = False, maintain_order: bool = False, + multithreaded: bool = True, ) -> Self: """ Return the `k` smallest elements. @@ -1318,6 +1331,8 @@ def bottom_k( Whether the order should be maintained if elements are equal. Note that if `true` streaming is not possible and performance might be worse since this requires a stable search. + multithreaded + Sort using multiple threads. See Also -------- @@ -1366,7 +1381,9 @@ def bottom_k( if isinstance(descending, bool): descending = [descending] return self._from_pyldf( - self._ldf.bottom_k(k, by, descending, nulls_last, maintain_order) + self._ldf.bottom_k( + k, by, descending, nulls_last, maintain_order, multithreaded + ) ) def profile( diff --git a/py-polars/polars/series/array.py b/py-polars/polars/series/array.py index 7451f686ee92..90e48dd94afd 100644 --- a/py-polars/polars/series/array.py +++ b/py-polars/polars/series/array.py @@ -236,7 +236,13 @@ def all(self) -> Series: ] """ - def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Series: + def sort( + self, + *, + descending: bool = False, + nulls_last: bool = False, + multithreaded: bool = True, + ) -> Series: """ Sort the arrays in this column. @@ -246,6 +252,8 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Series: Sort in descending order. nulls_last Place null values last. + multithreaded + Sort using multiple threads. Examples -------- diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 371d606e2935..762d6d55eccb 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -281,7 +281,13 @@ def var(self, ddof: int = 1) -> Series: ] """ - def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Series: + def sort( + self, + *, + descending: bool = False, + nulls_last: bool = False, + multithreaded: bool = True, + ) -> Series: """ Sort the arrays in this column. @@ -291,6 +297,8 @@ def sort(self, *, descending: bool = False, nulls_last: bool = False) -> Series: Sort in descending order. nulls_last Place null values last. + multithreaded + Sort using multiple threads. Examples -------- diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 31a5577b1d4a..91b6c8159cdf 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -3344,6 +3344,7 @@ def sort( *, descending: bool = False, nulls_last: bool = False, + multithreaded: bool = True, in_place: bool = False, ) -> Self: """ @@ -3355,6 +3356,8 @@ def sort( Sort in descending order. nulls_last Place null values last instead of first. + multithreaded + Sort using multiple threads. in_place Sort in-place. @@ -3381,10 +3384,12 @@ def sort( ] """ if in_place: - self._s = self._s.sort(descending, nulls_last) + self._s = self._s.sort(descending, nulls_last, multithreaded) return self else: - return self._from_pyseries(self._s.sort(descending, nulls_last)) + return self._from_pyseries( + self._s.sort(descending, nulls_last, multithreaded) + ) def top_k(self, k: int | IntoExprColumn = 5) -> Series: r""" diff --git a/py-polars/src/expr/general.rs b/py-polars/src/expr/general.rs index 8e21ab7eb16a..45eb9ef616e5 100644 --- a/py-polars/src/expr/general.rs +++ b/py-polars/src/expr/general.rs @@ -268,7 +268,7 @@ impl PyExpr { fn sort_with(&self, descending: bool, nulls_last: bool) -> Self { self.inner .clone() - .sort_with(SortOptions { + .sort(SortOptions { descending, nulls_last, multithreaded: true, @@ -332,9 +332,27 @@ impl PyExpr { self.inner.clone().get(idx.inner).into() } - fn sort_by(&self, by: Vec, descending: Vec) -> Self { + fn sort_by( + &self, + by: Vec, + descending: Vec, + nulls_last: bool, + multithreaded: bool, + maintain_order: bool, + ) -> Self { let by = by.into_iter().map(|e| e.inner).collect::>(); - self.inner.clone().sort_by(by, descending).into() + self.inner + .clone() + .sort_by( + by, + SortMultipleOptions { + descending, + nulls_last, + multithreaded, + maintain_order, + }, + ) + .into() } fn backward_fill(&self, limit: FillNullLimit) -> Self { diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index b00476c7bb3a..00b442cae1d3 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -132,11 +132,11 @@ impl PyExpr { self.inner .clone() .list() - .sort(SortOptions { - descending, - nulls_last, - ..Default::default() - }) + .sort( + SortOptions::default() + .with_order_descending(descending) + .with_nulls_last(nulls_last), + ) .into() } diff --git a/py-polars/src/functions/lazy.rs b/py-polars/src/functions/lazy.rs index 704658f9d8cc..1d0c4eaf4a73 100644 --- a/py-polars/src/functions/lazy.rs +++ b/py-polars/src/functions/lazy.rs @@ -57,11 +57,25 @@ pub fn rolling_cov( } #[pyfunction] -pub fn arg_sort_by(by: Vec, descending: Vec) -> PyExpr { +pub fn arg_sort_by( + by: Vec, + descending: Vec, + nulls_last: bool, + multithreaded: bool, + maintain_order: bool, +) -> PyExpr { let by = by.into_iter().map(|e| e.inner).collect::>(); - dsl::arg_sort_by(by, &descending).into() + dsl::arg_sort_by( + by, + SortMultipleOptions { + descending, + nulls_last, + multithreaded, + maintain_order, + }, + ) + .into() } - #[pyfunction] pub fn arg_where(condition: PyExpr) -> PyExpr { dsl::arg_where(condition.inner).into() diff --git a/py-polars/src/lazyframe/mod.rs b/py-polars/src/lazyframe/mod.rs index 3b0e9188a137..918f9c99c5f7 100644 --- a/py-polars/src/lazyframe/mod.rs +++ b/py-polars/src/lazyframe/mod.rs @@ -458,14 +458,15 @@ impl PyLazyFrame { descending: bool, nulls_last: bool, maintain_order: bool, + multithreaded: bool, ) -> Self { let ldf = self.ldf.clone(); ldf.sort( - by_column, - SortOptions { - descending, + [by_column], + SortMultipleOptions { + descending: vec![descending], nulls_last, - multithreaded: true, + multithreaded, maintain_order, }, ) @@ -478,11 +479,20 @@ impl PyLazyFrame { descending: Vec, nulls_last: bool, maintain_order: bool, + multithreaded: bool, ) -> Self { let ldf = self.ldf.clone(); let exprs = by.to_exprs(); - ldf.sort_by_exprs(exprs, descending, nulls_last, maintain_order) - .into() + ldf.sort_by_exprs( + exprs, + SortMultipleOptions { + descending, + nulls_last, + maintain_order, + multithreaded, + }, + ) + .into() } fn top_k( @@ -492,11 +502,21 @@ impl PyLazyFrame { descending: Vec, nulls_last: bool, maintain_order: bool, + multithreaded: bool, ) -> Self { let ldf = self.ldf.clone(); let exprs = by.to_exprs(); - ldf.top_k(k, exprs, descending, nulls_last, maintain_order) - .into() + ldf.top_k( + k, + exprs, + SortMultipleOptions { + descending, + nulls_last, + maintain_order, + multithreaded, + }, + ) + .into() } fn bottom_k( @@ -506,11 +526,21 @@ impl PyLazyFrame { descending: Vec, nulls_last: bool, maintain_order: bool, + multithreaded: bool, ) -> Self { let ldf = self.ldf.clone(); let exprs = by.to_exprs(); - ldf.bottom_k(k, exprs, descending, nulls_last, maintain_order) - .into() + ldf.bottom_k( + k, + exprs, + SortMultipleOptions { + descending, + nulls_last, + maintain_order, + multithreaded, + }, + ) + .into() } fn cache(&self) -> Self { diff --git a/py-polars/src/series/mod.rs b/py-polars/src/series/mod.rs index 7d33f8858152..773c39e3adfd 100644 --- a/py-polars/src/series/mod.rs +++ b/py-polars/src/series/mod.rs @@ -283,10 +283,15 @@ impl PySeries { } } - fn sort(&mut self, descending: bool, nulls_last: bool) -> PyResult { + fn sort(&mut self, descending: bool, nulls_last: bool, multithreaded: bool) -> PyResult { Ok(self .series - .sort(descending, nulls_last) + .sort( + SortOptions::default() + .with_order_descending(descending) + .with_nulls_last(nulls_last) + .with_multithreaded(multithreaded), + ) .map_err(PyPolarsErr::from)? .into()) } diff --git a/py-polars/tests/unit/operations/test_join.py b/py-polars/tests/unit/operations/test_join.py index ea8ec9e8e2fe..89c7ef6c917b 100644 --- a/py-polars/tests/unit/operations/test_join.py +++ b/py-polars/tests/unit/operations/test_join.py @@ -384,18 +384,30 @@ def test_jit_sort_joins() -> None: pd_result.columns = pd.Index(["a", "b", "b_right"]) # left key sorted right is not - pl_result = dfa_pl.join(dfb_pl, on="a", how=how).sort(["a", "b"]) + pl_result = dfa_pl.join(dfb_pl, on="a", how=how).sort( + ["a", "b"], maintain_order=True + ) - a = pl.from_pandas(pd_result).with_columns(pl.all().cast(int)).sort(["a", "b"]) + a = ( + pl.from_pandas(pd_result) + .with_columns(pl.all().cast(int)) + .sort(["a", "b"], maintain_order=True) + ) assert_frame_equal(a, pl_result) assert pl_result["a"].flags["SORTED_ASC"] # left key sorted right is not pd_result = dfb.merge(dfa, on="a", how=how) pd_result.columns = pd.Index(["a", "b", "b_right"]) - pl_result = dfb_pl.join(dfa_pl, on="a", how=how).sort(["a", "b"]) + pl_result = dfb_pl.join(dfa_pl, on="a", how=how).sort( + ["a", "b"], maintain_order=True + ) - a = pl.from_pandas(pd_result).with_columns(pl.all().cast(int)).sort(["a", "b"]) + a = ( + pl.from_pandas(pd_result) + .with_columns(pl.all().cast(int)) + .sort(["a", "b"], maintain_order=True) + ) assert_frame_equal(a, pl_result) assert pl_result["a"].flags["SORTED_ASC"] diff --git a/py-polars/tests/unit/operations/test_sort.py b/py-polars/tests/unit/operations/test_sort.py index 3fb6b0d7607f..a0cbd2b2a820 100644 --- a/py-polars/tests/unit/operations/test_sort.py +++ b/py-polars/tests/unit/operations/test_sort.py @@ -74,6 +74,37 @@ def test_sort_by() -> None: assert out["a"].to_list() == [1, 2, 3, 4, 5] +def test_expr_sort_by_nulls_last() -> None: + # nulls last + df = pl.DataFrame( + {"a": [1, 2, None, None, 5], "b": [None, 1, 1, 2, None], "c": [2, 3, 1, 2, 1]} + ) + + out = df.select(pl.all().sort_by("a", nulls_last=True, maintain_order=True)) + + excepted = pl.DataFrame( + { + "a": [1, 2, 5, None, None], + "b": [None, 1, None, 1, 2], + "c": [2, 3, 1, 1, 2], + } + ) + + assert_frame_equal(out, excepted) + + # nulls first + + out = df.select(pl.all().sort_by("a", nulls_last=False, maintain_order=True)) + + excepted = pl.DataFrame( + { + "a": [None, None, 1, 2, 5], + "b": [1, 2, None, 1, None], + "c": [1, 2, 2, 3, 1], + } + ) + + def test_sort_by_exprs() -> None: # make sure that the expression does not overwrite columns in the dataframe df = pl.DataFrame({"a": [1, 2, -1, -2]}) @@ -103,6 +134,22 @@ def test_arg_sort_nulls() -> None: ] +def test_expr_arg_sort_nulls_last() -> None: + df = pl.DataFrame( + {"a": [1, 2, None, None, 5], "b": [None, 1, 2, 1, None], "c": [2, 3, 1, 2, 1]} + ) + + out = ( + df.select(pl.arg_sort_by("a", "b", nulls_last=True, maintain_order=True)) + .to_series() + .to_list() + ) + + expected = [0, 1, 4, 3, 2] + + assert out == expected + + def test_arg_sort_window_functions() -> None: df = pl.DataFrame({"Id": [1, 1, 2, 2, 3, 3], "Age": [1, 2, 3, 4, 5, 6]}) out = df.select( diff --git a/py-polars/tests/unit/streaming/test_streaming_join.py b/py-polars/tests/unit/streaming/test_streaming_join.py index bc09443a045c..0dc613086ef4 100644 --- a/py-polars/tests/unit/streaming/test_streaming_join.py +++ b/py-polars/tests/unit/streaming/test_streaming_join.py @@ -67,11 +67,15 @@ def test_streaming_joins() -> None: pl_result = ( dfa_pl.lazy() .join(dfb_pl.lazy(), on="a", how=how) - .sort(["a", "b"]) + .sort(["a", "b"], maintain_order=True) .collect(streaming=True) ) - a = pl.from_pandas(pd_result).with_columns(pl.all().cast(int)).sort(["a", "b"]) + a = ( + pl.from_pandas(pd_result) + .with_columns(pl.all().cast(int)) + .sort(["a", "b"], maintain_order=True) + ) assert_frame_equal(a, pl_result, check_dtype=False) pd_result = dfa.merge(dfb, on=["a", "b"], how=how)