Skip to content

Commit

Permalink
feat: Implement string, boolean and binary dtype in top_k (#1…
Browse files Browse the repository at this point in the history
…5488)

Co-authored-by: Ritchie Vink <[email protected]>
  • Loading branch information
CanglongCl and ritchie46 authored Apr 8, 2024
1 parent 968c0c6 commit 9c73063
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 26 deletions.
25 changes: 25 additions & 0 deletions crates/polars-arrow/src/array/boolean/mutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,31 @@ impl MutableBooleanArray {
}
}

/// Extends `MutableBooleanArray` by additional values of constant value.
#[inline]
pub fn extend_constant(&mut self, additional: usize, value: Option<bool>) {
match value {
Some(value) => {
self.values.extend_constant(additional, value);
if let Some(validity) = self.validity.as_mut() {
validity.extend_constant(additional, true);
}
},
None => {
self.values.extend_constant(additional, false);
if let Some(validity) = self.validity.as_mut() {
validity.extend_constant(additional, false)
} else {
self.init_validity();
self.validity
.as_mut()
.unwrap()
.extend_constant(additional, false)
};
},
};
}

fn init_validity(&mut self) {
let mut validity = MutableBitmap::with_capacity(self.values.capacity());
validity.extend_constant(self.len(), true);
Expand Down
177 changes: 151 additions & 26 deletions crates/polars-ops/src/chunked_array/top_k.rs
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
use std::cmp::Ordering;

use arrow::array::{BooleanArray, MutableBooleanArray};
use arrow::bitmap::MutableBitmap;
use either::Either;
use polars_core::downcast_as_macro_arg_physical;
use polars_core::prelude::*;
use polars_utils::total_ord::TotalOrd;

fn arg_partition<T: TotalOrd>(v: &mut [T], k: usize, descending: bool) -> &[T] {
let (lower, _el, upper) = v.select_nth_unstable_by(k, TotalOrd::tot_cmp);
fn arg_partition<T, C: Fn(&T, &T) -> Ordering>(
v: &mut [T],
k: usize,
descending: bool,
cmp: C,
) -> &[T] {
let (lower, _el, upper) = v.select_nth_unstable_by(k, &cmp);
if descending {
lower.sort_unstable_by(|a, b| a.tot_cmp(b));
lower.sort_unstable_by(cmp);
lower
} else {
upper.sort_unstable_by(|a, b| b.tot_cmp(a));
upper.sort_unstable_by(|a, b| cmp(b, a));
upper
}
}

fn top_k_impl<T>(ca: &ChunkedArray<T>, k: usize, descending: bool) -> ChunkedArray<T>
fn top_k_num_impl<T>(ca: &ChunkedArray<T>, k: usize, descending: bool) -> ChunkedArray<T>
where
T: PolarsNumericType,
ChunkedArray<T>: ChunkSort<T>,
Expand All @@ -32,46 +41,162 @@ where

match ca.to_vec_null_aware() {
Either::Left(mut v) => {
let values = arg_partition(&mut v, k, descending);
let values = arg_partition(&mut v, k, descending, TotalOrd::tot_cmp);
ChunkedArray::from_slice(ca.name(), values)
},
Either::Right(mut v) => {
let values = arg_partition(&mut v, k, descending);
let values = arg_partition(&mut v, k, descending, TotalOrd::tot_cmp);
let mut out = ChunkedArray::from_iter(values.iter().copied());
out.rename(ca.name());
out
},
}
}

pub fn top_k(s: &[Series], descending: bool) -> PolarsResult<Series> {
let src = &s[0];
let k_s = &s[1];
fn top_k_bool_impl(
ca: &ChunkedArray<BooleanType>,
k: usize,
descending: bool,
) -> ChunkedArray<BooleanType> {
if ca.null_count() == 0 {
let true_count = ca.sum().unwrap() as usize;
let mut bitmap = MutableBitmap::with_capacity(k);
if !descending {
// true first
bitmap.extend_constant(std::cmp::min(k, true_count), true);
bitmap.extend_constant(k.saturating_sub(true_count), false);
} else {
let false_count = ca.len().saturating_sub(true_count);
bitmap.extend_constant(std::cmp::min(k, false_count), false);
bitmap.extend_constant(k.saturating_sub(false_count), true);
}
let arr = BooleanArray::from_data_default(bitmap.into(), None);
unsafe {
ChunkedArray::from_chunks_and_dtype(ca.name(), vec![Box::new(arr)], DataType::Boolean)
}
} else {
let null_count = ca.null_count();
let true_count = ca.sum().unwrap() as usize;
let false_count = ca.len() - true_count - null_count;
let mut remaining = k;

if src.is_empty() {
return Ok(src.clone());
fn extend_constant_check_remaining(
array: &mut MutableBooleanArray,
remaining: &mut usize,
additional: usize,
value: Option<bool>,
) {
array.extend_constant(std::cmp::min(additional, *remaining), value);
*remaining = remaining.saturating_sub(additional);
}

let mut array = MutableBooleanArray::with_capacity(k);
if !descending {
// Null -> True -> False
extend_constant_check_remaining(&mut array, &mut remaining, null_count, None);
extend_constant_check_remaining(&mut array, &mut remaining, true_count, Some(true));
extend_constant_check_remaining(&mut array, &mut remaining, false_count, Some(false));
} else {
// False -> True -> Null
extend_constant_check_remaining(&mut array, &mut remaining, false_count, Some(false));
extend_constant_check_remaining(&mut array, &mut remaining, true_count, Some(true));
extend_constant_check_remaining(&mut array, &mut remaining, null_count, None);
}
let mut new_ca: ChunkedArray<BooleanType> = BooleanArray::from(array).into();
new_ca.rename(ca.name());
new_ca
}
}

fn top_k_binary_impl(
ca: &ChunkedArray<BinaryType>,
k: usize,
descending: bool,
) -> ChunkedArray<BinaryType> {
if k >= ca.len() {
return ca.sort(!descending);
}

// descending is opposite from sort as top-k returns largest
let k = if descending {
std::cmp::min(k, ca.len())
} else {
ca.len().saturating_sub(k + 1)
};

if ca.null_count() == 0 {
let mut v: Vec<&[u8]> = Vec::with_capacity(ca.len());
for arr in ca.downcast_iter() {
v.extend(arr.non_null_values_iter());
}
let values = arg_partition(&mut v, k, descending, TotalOrd::tot_cmp);
ChunkedArray::from_slice(ca.name(), values)
} else {
let mut v = Vec::with_capacity(ca.len());
for arr in ca.downcast_iter() {
v.extend(arr.iter());
}
let values = arg_partition(&mut v, k, descending, TotalOrd::tot_cmp);
let mut out = ChunkedArray::from_iter(values.iter().copied());
out.rename(ca.name());
out
}
}

pub fn top_k(s: &[Series], descending: bool) -> PolarsResult<Series> {
let k_s = &s[1];

polars_ensure!(
k_s.len() == 1,
ComputeError: "k must be a single value."
ComputeError: "`k` must be a single value for `top_k`."
);

let k_s = k_s.cast(&IDX_DTYPE)?;
let k = k_s.idx()?;
let Some(k) = k_s.cast(&IDX_DTYPE)?.idx()?.get(0) else {
polars_bail!(ComputeError: "`k` must be set for `top_k`")
};

let dtype = src.dtype();
let src = &s[0];

if let Some(k) = k.get(0) {
let s = src.to_physical_repr();
macro_rules! dispatch {
($ca:expr) => {{
top_k_impl($ca, k as usize, descending).into_series()
}};
}
if src.is_empty() {
return Ok(src.clone());
}

downcast_as_macro_arg_physical!(&s, dispatch).cast(dtype)
} else {
Ok(Series::full_null(src.name(), src.len(), dtype))
match src.is_sorted_flag() {
polars_core::series::IsSorted::Ascending => {
// TopK is the k element in the bottom of ascending sorted array
return Ok(src
.slice((src.len() - k as usize) as i64, k as usize)
.reverse());
},
polars_core::series::IsSorted::Descending => {
return Ok(src.slice(0, k as usize));
},
_ => {},
}

let origin_dtype = src.dtype();

let s = src.to_physical_repr();

match s.dtype() {
DataType::Boolean => {
Ok(top_k_bool_impl(s.bool().unwrap(), k as usize, descending).into_series())
},
DataType::String => {
let ca = top_k_binary_impl(&s.str().unwrap().as_binary(), k as usize, descending);
let ca = unsafe { ca.to_string() };
Ok(ca.into_series())
},
DataType::Binary => {
Ok(top_k_binary_impl(s.binary().unwrap(), k as usize, descending).into_series())
},
_dt => {
macro_rules! dispatch {
($ca:expr) => {{
top_k_num_impl($ca, k as usize, descending).into_series()
}};
}
unsafe { downcast_as_macro_arg_physical!(&s, dispatch).cast_unchecked(origin_dtype) }
},
}
}
27 changes: 27 additions & 0 deletions py-polars/tests/unit/operations/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import pytest

import polars as pl
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal, assert_series_equal


Expand Down Expand Up @@ -288,6 +289,8 @@ def test_top_k() -> None:
{
"test": [2, 4, 1, 3],
"val": [2, 4, 9, 3],
"bool_val": [False, True, True, False],
"str_value": ["d", "b", "a", "c"],
}
)
assert_frame_equal(
Expand All @@ -303,6 +306,30 @@ def test_top_k() -> None:
pl.DataFrame({"top_k": [4, 3], "bottom_k": [1, 2]}),
)

assert_frame_equal(
df.select(
pl.col("bool_val").top_k(2).alias("top_k"),
pl.col("bool_val").bottom_k(2).alias("bottom_k"),
),
pl.DataFrame({"top_k": [True, True], "bottom_k": [False, False]}),
)

assert_frame_equal(
df.select(
pl.col("str_value").top_k(2).alias("top_k"),
pl.col("str_value").bottom_k(2).alias("bottom_k"),
),
pl.DataFrame({"top_k": ["d", "c"], "bottom_k": ["a", "b"]}),
)

with pytest.raises(ComputeError, match="`k` must be set for `top_k`"):
df.select(
pl.col("bool_val").top_k(pl.lit(None)),
)

with pytest.raises(ComputeError, match="`k` must be a single value for `top_k`."):
df.select(pl.col("test").top_k(pl.lit(pl.Series("s", [1, 2]))))

# dataframe
df = pl.DataFrame(
{
Expand Down

0 comments on commit 9c73063

Please sign in to comment.