Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pref(rust!, python): Unify sort with SortOptions and SortMultipleOptions #15590

Merged
merged 40 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
9795849
move `by` out of sort multiple options
CanglongCl Apr 10, 2024
7d7f6a4
bring sort multi options to lazy frame
CanglongCl Apr 10, 2024
d674580
fix test build
CanglongCl Apr 10, 2024
6b1ce2b
fix test
CanglongCl Apr 10, 2024
9d2d4e5
delete unreachable branch
CanglongCl Apr 10, 2024
e4a5a73
impl maintain_order and multithreaded in sortby
CanglongCl Apr 10, 2024
7e285ff
impl topk mutlithreaded
CanglongCl Apr 10, 2024
615b2df
Revert "delete unreachable branch"
CanglongCl Apr 10, 2024
f9c7f0d
more fn use sort options
CanglongCl Apr 10, 2024
8307ee7
upcast sort options use reference
CanglongCl Apr 10, 2024
e7bc67a
impl top k by
CanglongCl Apr 10, 2024
7d1bfb7
fix docs
CanglongCl Apr 10, 2024
0b1ed53
add arg for python side
CanglongCl Apr 10, 2024
d54a147
impl for python
CanglongCl Apr 10, 2024
c05a473
add multi threaded param to sort
CanglongCl Apr 10, 2024
95ef562
maintain order for test
CanglongCl Apr 10, 2024
8b84400
deprecate sort arguement
CanglongCl Apr 10, 2024
8bf316a
improve options api
CanglongCl Apr 10, 2024
3459bca
impl test new api
CanglongCl Apr 10, 2024
22646f6
Fix maintain_order parameter in sort functions
CanglongCl Apr 10, 2024
d458664
refactor top k and bottom k
CanglongCl Apr 10, 2024
c13ba5e
fix: incorrect refactor
CanglongCl Apr 10, 2024
82ca572
unify fn sort in rust
CanglongCl Apr 11, 2024
6abe22f
impl in python side
CanglongCl Apr 11, 2024
0e33ca7
add test for expr sort by
CanglongCl Apr 11, 2024
b2d4b94
fix: expr sortby nulls last
CanglongCl Apr 11, 2024
7fda3ce
add comments for new parameters
CanglongCl Apr 11, 2024
d16f9c4
lint & format
CanglongCl Apr 11, 2024
e28a456
add docs and refactor func name api
CanglongCl Apr 11, 2024
e434515
add docs for lazy sort
CanglongCl Apr 11, 2024
dfe4c9e
format
CanglongCl Apr 11, 2024
86576b4
add docs for Expr::sort and LazyFrame::sort
CanglongCl Apr 11, 2024
ad88b57
add docs for Expr::sort_by
CanglongCl Apr 11, 2024
2f1746b
fix: merge upstream build
CanglongCl Apr 11, 2024
31a21a5
Fix typo
CanglongCl Apr 11, 2024
5ba316d
fix: rust doc build
CanglongCl Apr 11, 2024
4fd1c02
fix: docs nulls lasst
CanglongCl Apr 11, 2024
4e6474c
fix: list sort options
CanglongCl Apr 11, 2024
0447d8f
delete `from_n_rows` in `_arg_bottom_k`
CanglongCl Apr 12, 2024
31b6bbf
fix: merge expr equal
CanglongCl Apr 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ impl CategoricalChunked {
counts.rename("counts");
let cols = vec![values.into_series(), counts.into_series()];
let df = unsafe { DataFrame::new_no_checks(cols) };
df.sort(["counts"], true, false)
df.sort(
["counts"],
SortMultipleOptions::default().with_order_descending(true),
)
}
}
37 changes: 7 additions & 30 deletions crates/polars-core/src/chunked_array/ops/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub mod zip;

#[cfg(feature = "serde-lazy")]
use serde::{Deserialize, Serialize};
pub use sort::options::*;

use crate::series::IsSorted;

Expand Down Expand Up @@ -357,35 +358,6 @@ pub trait ChunkUnique<T: PolarsDataType> {
}
}

#[derive(Copy, Clone, Eq, PartialEq, Debug, Hash)]
#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))]
pub struct SortOptions {
pub descending: bool,
pub nulls_last: bool,
pub multithreaded: bool,
pub maintain_order: bool,
}

#[derive(Clone)]
#[cfg_attr(feature = "serde-lazy", derive(Serialize, Deserialize))]
pub struct SortMultipleOptions {
pub other: Vec<Series>,
pub descending: Vec<bool>,
pub nulls_last: bool,
pub multithreaded: bool,
}

impl Default for SortOptions {
fn default() -> Self {
Self {
descending: false,
nulls_last: false,
multithreaded: true,
maintain_order: false,
}
}
}

/// Sort operations on `ChunkedArray`.
pub trait ChunkSort<T: PolarsDataType> {
#[allow(unused_variables)]
Expand All @@ -398,7 +370,12 @@ pub trait ChunkSort<T: PolarsDataType> {
fn arg_sort(&self, options: SortOptions) -> IdxCa;

/// Retrieve the indexes need to sort this and the other arrays.
fn arg_sort_multiple(&self, _options: &SortMultipleOptions) -> PolarsResult<IdxCa> {
#[allow(unused_variables)]
fn arg_sort_multiple(
&self,
by: &[Series],
_options: &SortMultipleOptions,
) -> PolarsResult<IdxCa> {
polars_bail!(opq = arg_sort_multiple, T::get_dtype());
}
}
Expand Down
81 changes: 81 additions & 0 deletions crates/polars-core/src/chunked_array/ops/sort/arg_bottom_k.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
use polars_utils::iter::EnumerateIdxTrait;

use super::*;

#[derive(Eq)]
struct CompareRow<'a> {
idx: IdxSize,
bytes: &'a [u8],
}

impl PartialEq for CompareRow<'_> {
fn eq(&self, other: &Self) -> bool {
self.bytes == other.bytes
}
}

impl Ord for CompareRow<'_> {
fn cmp(&self, other: &Self) -> Ordering {
self.bytes.cmp(other.bytes)
}
}

impl PartialOrd for CompareRow<'_> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

pub fn _arg_bottom_k(
k: usize,
by_column: &[Series],
sort_options: &mut SortMultipleOptions,
) -> PolarsResult<NoNull<IdxCa>> {
let from_n_rows = by_column[0].len();
_broadcast_descending(by_column.len(), &mut sort_options.descending);
let encoded = _get_rows_encoded(by_column, &sort_options.descending, sort_options.nulls_last)?;
let arr = encoded.into_array();
let mut rows = arr
.values_iter()
.enumerate_idx()
.map(|(idx, bytes)| CompareRow { idx, bytes })
.collect::<Vec<_>>();

let sorted = if k >= from_n_rows {
match (sort_options.multithreaded, sort_options.maintain_order) {
(true, true) => POOL.install(|| {
rows.par_sort();
}),
(true, false) => POOL.install(|| {
rows.par_sort_unstable();
}),
(false, true) => rows.sort(),
(false, false) => rows.sort_unstable(),
}
&rows
} else if sort_options.maintain_order {
// todo: maybe there is some more efficient method, comparable to select_nth_unstable
if sort_options.multithreaded {
POOL.install(|| {
rows.par_sort();
})
} else {
rows.sort();
}
&rows[..k]
} else {
// todo: possible multi threaded `select_nth_unstable`?
let (lower, _el, _upper) = rows.select_nth_unstable(k);
if sort_options.multithreaded {
POOL.install(|| {
lower.par_sort_unstable();
})
} else {
lower.sort_unstable();
}
&*lower
};

let idx: NoNull<IdxCa> = sorted.iter().map(|cmp_row| cmp_row.idx).collect();
Ok(idx)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,46 +24,57 @@ pub(crate) fn args_validate<T: PolarsDataType>(

pub(crate) fn arg_sort_multiple_impl<T: NullOrderCmp + Send + Copy>(
mut vals: Vec<(IdxSize, T)>,
by: &[Series],
options: &SortMultipleOptions,
) -> PolarsResult<IdxCa> {
let descending = &options.descending;
debug_assert_eq!(descending.len() - 1, options.other.len());
let compare_inner: Vec<_> = options
.other
debug_assert_eq!(descending.len() - 1, by.len());
let compare_inner: Vec<_> = by
.iter()
.map(|s| s.into_total_ord_inner())
.collect_trusted();

let first_descending = descending[0];
POOL.install(|| {
vals.par_sort_by(|tpl_a, tpl_b| {
match (
first_descending,
tpl_a
.1
.null_order_cmp(&tpl_b.1, options.nulls_last ^ first_descending),
) {
// if ordering is equal, we check the other arrays until we find a non-equal ordering
// if we have exhausted all arrays, we keep the equal ordering.
(_, Ordering::Equal) => {
let idx_a = tpl_a.0 as usize;
let idx_b = tpl_b.0 as usize;
unsafe {
ordering_other_columns(
&compare_inner,
descending.get_unchecked(1..),
options.nulls_last,
idx_a,
idx_b,
)
}
},
(true, Ordering::Less) => Ordering::Greater,
(true, Ordering::Greater) => Ordering::Less,
(_, ord) => ord,
}
});
});

let compare = |tpl_a: &(_, T), tpl_b: &(_, T)| -> Ordering {
match (
first_descending,
tpl_a
.1
.null_order_cmp(&tpl_b.1, options.nulls_last ^ first_descending),
) {
// if ordering is equal, we check the other arrays until we find a non-equal ordering
// if we have exhausted all arrays, we keep the equal ordering.
(_, Ordering::Equal) => {
let idx_a = tpl_a.0 as usize;
let idx_b = tpl_b.0 as usize;
unsafe {
ordering_other_columns(
&compare_inner,
descending.get_unchecked(1..),
options.nulls_last,
idx_a,
idx_b,
)
}
},
(true, Ordering::Less) => Ordering::Greater,
(true, Ordering::Greater) => Ordering::Less,
(_, ord) => ord,
}
};

match (options.multithreaded, options.maintain_order) {
(true, true) => POOL.install(|| {
vals.par_sort_by(compare);
}),
(true, false) => POOL.install(|| {
vals.par_sort_unstable_by(compare);
}),
(false, true) => vals.sort_by(compare),
(false, false) => vals.sort_unstable_by(compare),
}

let ca: NoNull<IdxCa> = vals.into_iter().map(|(idx, _v)| idx).collect_trusted();
// Don't set to sorted. Argsort indices are not sorted.
Ok(ca.into_inner())
Expand Down
22 changes: 16 additions & 6 deletions crates/polars-core/src/chunked_array/ops/sort/categorical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,13 @@ impl CategoricalChunked {

/// Retrieve the indexes need to sort this and the other arrays.

pub(crate) fn arg_sort_multiple(&self, options: &SortMultipleOptions) -> PolarsResult<IdxCa> {
pub(crate) fn arg_sort_multiple(
&self,
by: &[Series],
options: &SortMultipleOptions,
) -> PolarsResult<IdxCa> {
if self.uses_lexical_ordering() {
args_validate(self.physical(), &options.other, &options.descending)?;
args_validate(self.physical(), by, &options.descending)?;
let mut count: IdxSize = 0;

// we use bytes to save a monomorphisized str impl
Expand All @@ -95,9 +99,9 @@ impl CategoricalChunked {
})
.collect_trusted();

arg_sort_multiple_impl(vals, options)
arg_sort_multiple_impl(vals, by, options)
} else {
self.physical().arg_sort_multiple(options)
self.physical().arg_sort_multiple(by, options)
}
}
}
Expand Down Expand Up @@ -171,12 +175,18 @@ mod test {
"vals" => [1, 1, 2, 2]
]?;

let out = df.sort(["cat", "vals"], vec![false, false], false)?;
let out = df.sort(
["cat", "vals"],
SortMultipleOptions::default().with_order_descendings([false, false]),
)?;
let out = out.column("cat")?;
let cat = out.categorical()?;
assert_order(cat, &["a", "a", "b", "c"]);

let out = df.sort(["vals", "cat"], vec![false, false], false)?;
let out = df.sort(
["vals", "cat"],
SortMultipleOptions::default().with_order_descendings([false, false]),
)?;
let out = out.column("cat")?;
let cat = out.categorical()?;
assert_order(cat, &["b", "c", "a", "a"]);
Expand Down
Loading
Loading