From 104ee93b64e61876aa24676ed20f2bec070338c4 Mon Sep 17 00:00:00 2001 From: Orson Peters Date: Wed, 11 Oct 2023 12:12:49 +0200 Subject: [PATCH] refactor: improve rank implementation, especially around nulls (#11651) --- crates/polars-ops/src/series/ops/rank.rs | 352 ++++++++--------------- py-polars/tests/unit/test_exprs.py | 5 + 2 files changed, 119 insertions(+), 238 deletions(-) diff --git a/crates/polars-ops/src/series/ops/rank.rs b/crates/polars-ops/src/series/ops/rank.rs index 888c00a61922..41f9b4ca8eb9 100644 --- a/crates/polars-ops/src/series/ops/rank.rs +++ b/crates/polars-ops/src/series/ops/rank.rs @@ -1,4 +1,5 @@ -use polars_arrow::prelude::FromData; +use arrow::array::BooleanArray; +use arrow::compute::concatenate::concatenate_validities; use polars_core::prelude::*; #[cfg(feature = "random")] use rand::prelude::SliceRandom; @@ -42,8 +43,29 @@ fn get_random_seed() -> u64 { rng.next_u64() } +unsafe fn rank_impl(idxs: &IdxCa, neq: &BooleanArray, mut flush_ties: F) { + let mut ties_indices = Vec::with_capacity(128); + let mut idx_it = idxs.downcast_iter().flat_map(|arr| arr.values_iter()); + let Some(first_idx) = idx_it.next() else { + return; + }; + ties_indices.push(*first_idx); + + for (eq_idx, idx) in idx_it.enumerate() { + if neq.value_unchecked(eq_idx) { + flush_ties(&mut ties_indices); + ties_indices.clear() + } + + ties_indices.push(*idx); + } + flush_ties(&mut ties_indices); +} + fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> Series { - match s.len() { + let len = s.len(); + let null_count = s.null_count(); + match len { 1 => { return match method { Average => Series::new(s.name(), &[1.0f64]), @@ -59,251 +81,105 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option) -> _ => {}, } - if s.null_count() > 0 { - let nulls = s.is_not_null().rechunk(); - let arr = nulls.downcast_iter().next().unwrap(); - let validity = arr.values(); - // Currently, nulls tie with the minimum or maximum bound for a type, depending on descending. - // TODO: Need to expose nulls_last in arg_sort to prevent this. - // Fill using MaxBound/MinBound to give nulls last rank. - // we will replace them later. - let null_strategy = if descending { - FillNullStrategy::MinBound - } else { - FillNullStrategy::MaxBound + if null_count == len { + return match method { + Average => Float64Chunked::full_null(s.name(), len).into_series(), + _ => IdxCa::full_null(s.name(), len).into_series(), }; - let s = s.fill_null(null_strategy).unwrap(); - - let mut out = rank(&s, method, descending, seed); - unsafe { - let arr = &mut out.chunks_mut()[0]; - *arr = arr.with_validity(Some(validity.clone())) - } - return out; } - // See: https://github.com/scipy/scipy/blob/v1.7.1/scipy/stats/stats.py#L8631-L8737 + let sort_idx_ca = s + .arg_sort(SortOptions { + descending, + nulls_last: true, + ..Default::default() + }) + .slice(0, len - null_count); - let len = s.len(); - let null_count = s.null_count(); - let sort_idx_ca = s.arg_sort(SortOptions { - descending, - ..Default::default() - }); - let sort_idx = sort_idx_ca.downcast_iter().next().unwrap().values(); - - let mut inv: Vec = Vec::with_capacity(len); - // Safety: - // Values will be filled next and there is only primitive data - #[allow(clippy::uninit_vec)] - unsafe { - inv.set_len(len) - } - let inv_values = inv.as_mut_slice(); - - #[cfg(feature = "random")] - let mut count = if let RankMethod::Ordinal | RankMethod::Random = method { - 1 as IdxSize - } else { - 0 - }; - - #[cfg(not(feature = "random"))] - let mut count = if let RankMethod::Ordinal = method { - 1 as IdxSize - } else { - 0 - }; - - // Safety: - // we are in bounds - unsafe { - sort_idx.iter().for_each(|&i| { - *inv_values.get_unchecked_mut(i as usize) = count; - count += 1; - }); - } + let chunk_refs: Vec<_> = s.chunks().iter().map(|c| &**c).collect(); + let validity = concatenate_validities(&chunk_refs); use RankMethod::*; - match method { - Ordinal => { - let inv_ca = IdxCa::from_vec(s.name(), inv); - inv_ca.into_series() - }, - #[cfg(feature = "random")] - Random => { - // Safety: - // in bounds - let arr = unsafe { s.take_unchecked(&sort_idx_ca) }; - let not_consecutive_same = arr - .slice(1, len - 1) - .not_equal(&arr.slice(0, len - 1)) - .unwrap() - .rechunk(); - let obs = not_consecutive_same.downcast_iter().next().unwrap(); - - // Collect slice indices for sort_idx which point to ties in the original series. - let mut ties_indices = Vec::with_capacity(len + 1); - let mut ties_index: usize = 0; - - ties_indices.push(ties_index); - obs.iter().for_each(|b| { - if let Some(b) = b { - ties_index += 1; - if b { - ties_indices.push(ties_index) - } - } - }); - // Close last slice (if there where nulls in the original series, they will always be in the last slice). - ties_indices.push(len); - - let mut sort_idx = sort_idx.to_vec(); - - let rng = &mut SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed)); - - // Shuffle sort_idx positions which point to ties in the original series. - for i in 0..(ties_indices.len() - 1) { - let ties_index_start = ties_indices[i]; - let ties_index_end = ties_indices[i + 1]; - if ties_index_end - ties_index_start > 1 { - sort_idx[ties_index_start..ties_index_end].shuffle(rng); - } + if let Ordinal = method { + let mut out = vec![0 as IdxSize; s.len()]; + let mut rank = 0; + for arr in sort_idx_ca.downcast_iter() { + for i in arr.values_iter() { + out[*i as usize] = rank + 1; + rank += 1; } - - // Recreate inv_ca (where ties are randomly shuffled compared with Ordinal). - let mut count = 1 as IdxSize; - unsafe { - sort_idx.iter().for_each(|&i| { - *inv_values.get_unchecked_mut(i as usize) = count; - count += 1; + } + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + } else { + let sorted_values = unsafe { s.take_unchecked(&sort_idx_ca) }; + let not_consecutive_same = sorted_values + .slice(1, sorted_values.len() - 1) + .not_equal(&sorted_values.slice(0, sorted_values.len() - 1)) + .unwrap() + .rechunk(); + let neq = not_consecutive_same.downcast_iter().next().unwrap(); + + let mut rank = 1; + match method { + #[cfg(feature = "random")] + Random => unsafe { + let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed)); + let mut out = vec![0 as IdxSize; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + ties.shuffle(&mut rng); + for i in ties { + *out.get_unchecked_mut(*i as usize) = rank; + rank += 1; + } }); - } - - let inv_ca = IdxCa::from_vec(s.name(), inv); - inv_ca.into_series() - }, - _ => { - let inv_ca = IdxCa::from_vec(s.name(), inv); - // SAFETY: in bounds. - let arr = unsafe { s.take_unchecked(&sort_idx_ca) }; - let validity = arr.chunks()[0].validity().cloned(); - let not_consecutive_same = arr - .slice(1, len - 1) - .not_equal(&arr.slice(0, len - 1)) - .unwrap() - .rechunk(); - // This obs is shorter than that of scipy stats, because we can just start the cumsum by 1 - // instead of 0. - let obs = not_consecutive_same.downcast_iter().next().unwrap(); - let mut dense = Vec::with_capacity(len); - - // This offset save an offset on the whole column, what scipy does in: - // - // ```python - // if method == 'min': - // return count[dense - 1] + 1 - // ``` - // INVALID LINT REMOVE LATER - #[allow(clippy::bool_to_int_with_if)] - let mut cumsum: IdxSize = if let RankMethod::Min = method { - 0 - } else { - // Nulls will be first, rank, but we will replace them (with null), - // this ensures the second rank will be 1. - if matches!(method, RankMethod::Dense) && s.null_count() > 0 { - 0 - } else { - 1 - } - }; - - dense.push(cumsum); - obs.values_iter().for_each(|b| { - if b { - cumsum += 1; - } - dense.push(cumsum) - }); - let arr = IdxArr::from_data_default(dense.into(), validity); - let dense = IdxCa::with_chunk(s.name(), arr); - - // SAFETY: in bounds. - let dense = unsafe { dense.take_unchecked(&inv_ca) }; - - if let RankMethod::Dense = method { - return if s.null_count() == 0 { - dense.into_series() - } else { - // Null will be the first rank. We restore original nulls and shift all ranks by one. - let validity = s.is_null().rechunk(); - let validity = validity.downcast_iter().next().unwrap(); - let validity = validity.values().clone(); - - let arr = &dense.chunks()[0]; - let arr = arr.with_validity(Some(validity)); - let dtype = arr.data_type().clone(); - - // SAFETY: given dtype is correct. - unsafe { - Series::_try_from_arrow_unchecked(s.name(), vec![arr], &dtype).unwrap() + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + }, + Average => unsafe { + let mut out = vec![0.0; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + let first = rank; + rank += ties.len() as IdxSize; + let last = rank - 1; + let avg = 0.5 * (first as f64 + last as f64); + for i in ties { + *out.get_unchecked_mut(*i as usize) = avg; } - }; - } - - let bitmap = obs.values(); - let cap = bitmap.len() - bitmap.unset_bits(); - let mut count = Vec::with_capacity(cap + 1); - let mut cnt: IdxSize = 0; - count.push(cnt); - - if null_count > 0 { - obs.iter().for_each(|b| { - if let Some(b) = b { - cnt += 1; - if b { - count.push(cnt) - } + }); + Float64Chunked::new_from_owned_with_null_bitmap(s.name(), out, validity) + .into_series() + }, + Min => unsafe { + let mut out = vec![0 as IdxSize; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + for i in ties.iter() { + *out.get_unchecked_mut(*i as usize) = rank; } + rank += ties.len() as IdxSize; }); - } else { - obs.values_iter().for_each(|b| { - cnt += 1; - if b { - count.push(cnt) + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + }, + Max => unsafe { + let mut out = vec![0 as IdxSize; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + rank += ties.len() as IdxSize; + for i in ties { + *out.get_unchecked_mut(*i as usize) = rank - 1; } }); - } - - count.push((len - null_count) as IdxSize); - let count = IdxCa::from_vec(s.name(), count); - - match method { - Max => { - // SAFETY: in bounds. - unsafe { count.take_unchecked(&dense).into_series() } - }, - Min => { - // SAFETY: in bounds. - unsafe { (count.take_unchecked(&dense) + 1).into_series() } - }, - Average => { - // SAFETY: in bounds. - let a = unsafe { count.take_unchecked(&dense) } - .cast(&DataType::Float64) - .unwrap(); - let b = unsafe { count.take_unchecked(&(dense - 1)) } - .cast(&DataType::Float64) - .unwrap() - + 1.0; - (&a + &b) * 0.5 - }, - #[cfg(feature = "random")] - Dense | Ordinal | Random => unimplemented!(), - #[cfg(not(feature = "random"))] - Dense | Ordinal => unimplemented!(), - } - }, + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + }, + Dense => unsafe { + let mut out = vec![0 as IdxSize; s.len()]; + rank_impl(&sort_idx_ca, neq, |ties| { + for i in ties { + *out.get_unchecked_mut(*i as usize) = rank; + } + rank += 1; + }); + IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series() + }, + Ordinal => unreachable!(), + } } } @@ -429,14 +305,14 @@ mod test { let s = UInt32Chunked::new("", &[None, None, None]).into_series(); let out = rank(&s, RankMethod::Average, false, None) .f64()? - .into_no_null_iter() + .into_iter() .collect::>(); - assert_eq!(out, &[2.0f64, 2.0, 2.0]); + assert_eq!(out, &[None, None, None]); let out = rank(&s, RankMethod::Dense, false, None) .idx()? - .into_no_null_iter() + .into_iter() .collect::>(); - assert_eq!(out, &[1, 1, 1]); + assert_eq!(out, &[None, None, None]); Ok(()) } diff --git a/py-polars/tests/unit/test_exprs.py b/py-polars/tests/unit/test_exprs.py index 9bc7ceeaf0e6..8518755eb4f1 100644 --- a/py-polars/tests/unit/test_exprs.py +++ b/py-polars/tests/unit/test_exprs.py @@ -380,6 +380,11 @@ def test_rank_so_4109() -> None: } +def test_rank_string_null_11252() -> None: + rank = pl.Series([None, "", "z", None, "a"]).rank() + assert rank.to_list() == [None, 1.0, 3.0, None, 2.0] + + def test_unique_empty() -> None: for dt in [pl.Utf8, pl.Boolean, pl.Int32, pl.UInt32]: s = pl.Series([], dtype=dt)