Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement string, boolean and binary dtype in top_k #15488

Merged
merged 13 commits into from
Apr 8, 2024
25 changes: 25 additions & 0 deletions crates/polars-arrow/src/array/boolean/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,31 @@ impl MutableBooleanArray {
}
}

/// Extends `MutableBooleanArray` by additional values of constant value.
#[inline]
pub fn extend_constant(&mut self, additional: usize, value: Option<bool>) {
match value {
Some(value) => {
self.values.extend_constant(additional, value);
if let Some(validity) = self.validity.as_mut() {
validity.extend_constant(additional, true);
}
},
None => {
self.values.extend_constant(additional, false);
if let Some(validity) = self.validity.as_mut() {
validity.extend_constant(additional, false)
} else {
self.init_validity();
self.validity
.as_mut()
.unwrap()
.extend_constant(additional, false)
};
},
};
}

fn init_validity(&mut self) {
let mut validity = MutableBitmap::with_capacity(self.values.capacity());
validity.extend_constant(self.len(), true);
Expand Down
176 changes: 151 additions & 25 deletions crates/polars-ops/src/chunked_array/top_k.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
use std::cmp::Ordering;

use arrow::array::{BooleanArray, MutableBooleanArray};
use arrow::bitmap::MutableBitmap;
use either::Either;
use polars_core::downcast_as_macro_arg_physical;
use polars_core::prelude::*;
use polars_utils::total_ord::TotalOrd;

fn arg_partition<T: TotalOrd>(v: &mut [T], k: usize, descending: bool) -> &[T] {
let (lower, _el, upper) = v.select_nth_unstable_by(k, TotalOrd::tot_cmp);
fn arg_partition<T, C: Fn(&T, &T) -> Ordering>(
v: &mut [T],
k: usize,
descending: bool,
cmp: C,
) -> &[T] {
let (lower, _el, upper) = v.select_nth_unstable_by(k, &cmp);
if descending {
lower.sort_unstable_by(|a, b| a.tot_cmp(b));
lower.sort_unstable_by(cmp);
lower
} else {
upper.sort_unstable_by(|a, b| b.tot_cmp(a));
upper.sort_unstable_by(|a, b| cmp(b, a));
upper
}
}

fn top_k_impl<T>(ca: &ChunkedArray<T>, k: usize, descending: bool) -> ChunkedArray<T>
fn top_k_num_impl<T>(ca: &ChunkedArray<T>, k: usize, descending: bool) -> ChunkedArray<T>
where
T: PolarsNumericType,
ChunkedArray<T>: ChunkSort<T>,
Expand All @@ -32,46 +41,163 @@ where

match ca.to_vec_null_aware() {
Either::Left(mut v) => {
let values = arg_partition(&mut v, k, descending);
let values = arg_partition(&mut v, k, descending, TotalOrd::tot_cmp);
ChunkedArray::from_slice(ca.name(), values)
},
Either::Right(mut v) => {
let values = arg_partition(&mut v, k, descending);
let values = arg_partition(&mut v, k, descending, TotalOrd::tot_cmp);
let mut out = ChunkedArray::from_iter(values.iter().copied());
out.rename(ca.name());
out
},
}
}

fn top_k_bool_impl(
ca: &ChunkedArray<BooleanType>,
k: usize,
descending: bool,
) -> ChunkedArray<BooleanType> {
if ca.null_count() == 0 {
let true_count = ca.sum().unwrap() as usize;
let mut bitmap = MutableBitmap::with_capacity(k);
if !descending {
// true first
bitmap.extend_constant(std::cmp::min(k, true_count), true);
bitmap.extend_constant(k.saturating_sub(true_count), false);
} else {
let false_count = ca.len().saturating_sub(true_count);
bitmap.extend_constant(std::cmp::min(k, false_count), false);
bitmap.extend_constant(k.saturating_sub(false_count), true);
}
let arr = BooleanArray::from_data_default(bitmap.into(), None);
unsafe {
ChunkedArray::from_chunks_and_dtype(ca.name(), vec![Box::new(arr)], DataType::Boolean)
}
} else {
let null_count = ca.null_count();
let true_count = ca.sum().unwrap() as usize;
let false_count = ca.len() - true_count - null_count;
let mut remaining = k;

fn extend_constant_check_remaining(
array: &mut MutableBooleanArray,
remaining: &mut usize,
additional: usize,
value: Option<bool>,
) {
array.extend_constant(std::cmp::min(additional, *remaining), value);
*remaining = remaining.saturating_sub(additional);
}

let mut array = MutableBooleanArray::with_capacity(k);
if !descending {
// Null -> True -> False
extend_constant_check_remaining(&mut array, &mut remaining, null_count, None);
extend_constant_check_remaining(&mut array, &mut remaining, true_count, Some(true));
extend_constant_check_remaining(&mut array, &mut remaining, false_count, Some(false));
} else {
// False -> True -> Null
extend_constant_check_remaining(&mut array, &mut remaining, false_count, Some(false));
extend_constant_check_remaining(&mut array, &mut remaining, true_count, Some(true));
extend_constant_check_remaining(&mut array, &mut remaining, null_count, None);
}
let mut new_ca: ChunkedArray<BooleanType> = BooleanArray::from(array).into();
new_ca.rename(ca.name());
new_ca
}
}

fn top_k_binary_impl(
ca: &ChunkedArray<BinaryType>,
k: usize,
descending: bool,
) -> ChunkedArray<BinaryType> {
if k >= ca.len() {
return ca.sort(!descending);
}

// descending is opposite from sort as top-k returns largest
let k = if descending {
std::cmp::min(k, ca.len())
} else {
ca.len().saturating_sub(k + 1)
};

if ca.null_count() == 0 {
let mut v: Vec<&[u8]> = Vec::with_capacity(ca.len());
for arr in ca.downcast_iter() {
v.extend(arr.non_null_values_iter());
}
let values = arg_partition(&mut v, k, descending, TotalOrd::tot_cmp);
ChunkedArray::from_slice(ca.name(), values)
} else {
let mut v = Vec::with_capacity(ca.len());
for arr in ca.downcast_iter() {
v.extend(arr.iter());
}
let values = arg_partition(&mut v, k, descending, TotalOrd::tot_cmp);
let mut out = ChunkedArray::from_iter(values.iter().copied());
out.rename(ca.name());
out
}
}

pub fn top_k(s: &[Series], descending: bool) -> PolarsResult<Series> {
let src = &s[0];
let k_s = &s[1];

polars_ensure!(
k_s.len() == 1,
ComputeError: "`k` must be a single value for `top_k`."
);

let Some(k) = k_s.cast(&IDX_DTYPE)?.idx()?.get(0) else {
polars_bail!(ComputeError: "`k` must be set for `top_k`")
};

let src = &s[0];

if src.is_empty() {
return Ok(src.clone());
}

polars_ensure!(
k_s.len() == 1,
ComputeError: "k must be a single value."
);
match src.is_sorted_flag() {
CanglongCl marked this conversation as resolved.
Show resolved Hide resolved
polars_core::series::IsSorted::Ascending => {
// TopK is the k element in the bottom of ascending sorted array
return Ok(src
.slice((src.len() - k as usize) as i64, k as usize)
.reverse());
},
polars_core::series::IsSorted::Descending => {
return Ok(src.slice(0, k as usize));
},
_ => {},
}

let k_s = k_s.cast(&IDX_DTYPE)?;
let k = k_s.idx()?;
let origin_dtype = src.dtype();

let dtype = src.dtype();
let s = src.to_physical_repr();

if let Some(k) = k.get(0) {
let s = src.to_physical_repr();
macro_rules! dispatch {
($ca:expr) => {{
top_k_impl($ca, k as usize, descending).into_series()
}};
}
match s.dtype() {
DataType::Boolean => {
Ok(top_k_bool_impl(s.bool().unwrap(), k as usize, descending).into_series())
},
DataType::String => {
top_k_binary_impl(&s.str().unwrap().as_binary(), k as usize, descending)
CanglongCl marked this conversation as resolved.
Show resolved Hide resolved
.into_series()
.cast(origin_dtype)
},
DataType::Binary => {
Ok(top_k_binary_impl(s.binary().unwrap(), k as usize, descending).into_series())
},
_dt => {
macro_rules! dispatch {
($ca:expr) => {{
top_k_num_impl($ca, k as usize, descending).into_series()
}};
}

downcast_as_macro_arg_physical!(&s, dispatch).cast(dtype)
} else {
Ok(Series::full_null(src.name(), src.len(), dtype))
downcast_as_macro_arg_physical!(&s, dispatch).cast(origin_dtype)
CanglongCl marked this conversation as resolved.
Show resolved Hide resolved
},
}
}
27 changes: 27 additions & 0 deletions py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal, assert_series_equal


Expand Down Expand Up @@ -288,6 +289,8 @@ def test_top_k() -> None:
{
"test": [2, 4, 1, 3],
"val": [2, 4, 9, 3],
"bool_val": [False, True, True, False],
"str_value": ["d", "b", "a", "c"],
}
)
assert_frame_equal(
Expand All @@ -303,6 +306,30 @@ def test_top_k() -> None:
pl.DataFrame({"top_k": [4, 3], "bottom_k": [1, 2]}),
)

assert_frame_equal(
df.select(
pl.col("bool_val").top_k(2).alias("top_k"),
pl.col("bool_val").bottom_k(2).alias("bottom_k"),
),
pl.DataFrame({"top_k": [True, True], "bottom_k": [False, False]}),
)

assert_frame_equal(
df.select(
pl.col("str_value").top_k(2).alias("top_k"),
pl.col("str_value").bottom_k(2).alias("bottom_k"),
),
pl.DataFrame({"top_k": ["d", "c"], "bottom_k": ["a", "b"]}),
)

with pytest.raises(ComputeError, match="`k` must be set for `top_k`"):
df.select(
pl.col("bool_val").top_k(pl.lit(None)),
)

with pytest.raises(ComputeError, match="`k` must be a single value for `top_k`."):
df.select(pl.col("test").top_k(pl.lit(pl.Series("s", [1, 2]))))

# dataframe
df = pl.DataFrame(
{
Expand Down
Loading