Skip to content

Commit

Permalink
refactor: improve rank implementation, especially around nulls (#11651)
Browse files Browse the repository at this point in the history
  • Loading branch information
orlp authored Oct 11, 2023
1 parent 32e3652 commit 104ee93
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 238 deletions.
352 changes: 114 additions & 238 deletions crates/polars-ops/src/series/ops/rank.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use polars_arrow::prelude::FromData;
use arrow::array::BooleanArray;
use arrow::compute::concatenate::concatenate_validities;
use polars_core::prelude::*;
#[cfg(feature = "random")]
use rand::prelude::SliceRandom;
Expand Down Expand Up @@ -42,8 +43,29 @@ fn get_random_seed() -> u64 {
rng.next_u64()
}

unsafe fn rank_impl<F: FnMut(&mut [IdxSize])>(idxs: &IdxCa, neq: &BooleanArray, mut flush_ties: F) {
let mut ties_indices = Vec::with_capacity(128);
let mut idx_it = idxs.downcast_iter().flat_map(|arr| arr.values_iter());
let Some(first_idx) = idx_it.next() else {
return;
};
ties_indices.push(*first_idx);

for (eq_idx, idx) in idx_it.enumerate() {
if neq.value_unchecked(eq_idx) {
flush_ties(&mut ties_indices);
ties_indices.clear()
}

ties_indices.push(*idx);
}
flush_ties(&mut ties_indices);
}

fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option<u64>) -> Series {
match s.len() {
let len = s.len();
let null_count = s.null_count();
match len {
1 => {
return match method {
Average => Series::new(s.name(), &[1.0f64]),
Expand All @@ -59,251 +81,105 @@ fn rank(s: &Series, method: RankMethod, descending: bool, seed: Option<u64>) ->
_ => {},
}

if s.null_count() > 0 {
let nulls = s.is_not_null().rechunk();
let arr = nulls.downcast_iter().next().unwrap();
let validity = arr.values();
// Currently, nulls tie with the minimum or maximum bound for a type, depending on descending.
// TODO: Need to expose nulls_last in arg_sort to prevent this.
// Fill using MaxBound/MinBound to give nulls last rank.
// we will replace them later.
let null_strategy = if descending {
FillNullStrategy::MinBound
} else {
FillNullStrategy::MaxBound
if null_count == len {
return match method {
Average => Float64Chunked::full_null(s.name(), len).into_series(),
_ => IdxCa::full_null(s.name(), len).into_series(),
};
let s = s.fill_null(null_strategy).unwrap();

let mut out = rank(&s, method, descending, seed);
unsafe {
let arr = &mut out.chunks_mut()[0];
*arr = arr.with_validity(Some(validity.clone()))
}
return out;
}

// See: https://github.com/scipy/scipy/blob/v1.7.1/scipy/stats/stats.py#L8631-L8737
let sort_idx_ca = s
.arg_sort(SortOptions {
descending,
nulls_last: true,
..Default::default()
})
.slice(0, len - null_count);

let len = s.len();
let null_count = s.null_count();
let sort_idx_ca = s.arg_sort(SortOptions {
descending,
..Default::default()
});
let sort_idx = sort_idx_ca.downcast_iter().next().unwrap().values();

let mut inv: Vec<IdxSize> = Vec::with_capacity(len);
// Safety:
// Values will be filled next and there is only primitive data
#[allow(clippy::uninit_vec)]
unsafe {
inv.set_len(len)
}
let inv_values = inv.as_mut_slice();

#[cfg(feature = "random")]
let mut count = if let RankMethod::Ordinal | RankMethod::Random = method {
1 as IdxSize
} else {
0
};

#[cfg(not(feature = "random"))]
let mut count = if let RankMethod::Ordinal = method {
1 as IdxSize
} else {
0
};

// Safety:
// we are in bounds
unsafe {
sort_idx.iter().for_each(|&i| {
*inv_values.get_unchecked_mut(i as usize) = count;
count += 1;
});
}
let chunk_refs: Vec<_> = s.chunks().iter().map(|c| &**c).collect();
let validity = concatenate_validities(&chunk_refs);

use RankMethod::*;
match method {
Ordinal => {
let inv_ca = IdxCa::from_vec(s.name(), inv);
inv_ca.into_series()
},
#[cfg(feature = "random")]
Random => {
// Safety:
// in bounds
let arr = unsafe { s.take_unchecked(&sort_idx_ca) };
let not_consecutive_same = arr
.slice(1, len - 1)
.not_equal(&arr.slice(0, len - 1))
.unwrap()
.rechunk();
let obs = not_consecutive_same.downcast_iter().next().unwrap();

// Collect slice indices for sort_idx which point to ties in the original series.
let mut ties_indices = Vec::with_capacity(len + 1);
let mut ties_index: usize = 0;

ties_indices.push(ties_index);
obs.iter().for_each(|b| {
if let Some(b) = b {
ties_index += 1;
if b {
ties_indices.push(ties_index)
}
}
});
// Close last slice (if there where nulls in the original series, they will always be in the last slice).
ties_indices.push(len);

let mut sort_idx = sort_idx.to_vec();

let rng = &mut SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed));

// Shuffle sort_idx positions which point to ties in the original series.
for i in 0..(ties_indices.len() - 1) {
let ties_index_start = ties_indices[i];
let ties_index_end = ties_indices[i + 1];
if ties_index_end - ties_index_start > 1 {
sort_idx[ties_index_start..ties_index_end].shuffle(rng);
}
if let Ordinal = method {
let mut out = vec![0 as IdxSize; s.len()];
let mut rank = 0;
for arr in sort_idx_ca.downcast_iter() {
for i in arr.values_iter() {
out[*i as usize] = rank + 1;
rank += 1;
}

// Recreate inv_ca (where ties are randomly shuffled compared with Ordinal).
let mut count = 1 as IdxSize;
unsafe {
sort_idx.iter().for_each(|&i| {
*inv_values.get_unchecked_mut(i as usize) = count;
count += 1;
}
IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series()
} else {
let sorted_values = unsafe { s.take_unchecked(&sort_idx_ca) };
let not_consecutive_same = sorted_values
.slice(1, sorted_values.len() - 1)
.not_equal(&sorted_values.slice(0, sorted_values.len() - 1))
.unwrap()
.rechunk();
let neq = not_consecutive_same.downcast_iter().next().unwrap();

let mut rank = 1;
match method {
#[cfg(feature = "random")]
Random => unsafe {
let mut rng = SmallRng::seed_from_u64(seed.unwrap_or_else(get_random_seed));
let mut out = vec![0 as IdxSize; s.len()];
rank_impl(&sort_idx_ca, neq, |ties| {
ties.shuffle(&mut rng);
for i in ties {
*out.get_unchecked_mut(*i as usize) = rank;
rank += 1;
}
});
}

let inv_ca = IdxCa::from_vec(s.name(), inv);
inv_ca.into_series()
},
_ => {
let inv_ca = IdxCa::from_vec(s.name(), inv);
// SAFETY: in bounds.
let arr = unsafe { s.take_unchecked(&sort_idx_ca) };
let validity = arr.chunks()[0].validity().cloned();
let not_consecutive_same = arr
.slice(1, len - 1)
.not_equal(&arr.slice(0, len - 1))
.unwrap()
.rechunk();
// This obs is shorter than that of scipy stats, because we can just start the cumsum by 1
// instead of 0.
let obs = not_consecutive_same.downcast_iter().next().unwrap();
let mut dense = Vec::with_capacity(len);

// This offset save an offset on the whole column, what scipy does in:
//
// ```python
// if method == 'min':
// return count[dense - 1] + 1
// ```
// INVALID LINT REMOVE LATER
#[allow(clippy::bool_to_int_with_if)]
let mut cumsum: IdxSize = if let RankMethod::Min = method {
0
} else {
// Nulls will be first, rank, but we will replace them (with null),
// this ensures the second rank will be 1.
if matches!(method, RankMethod::Dense) && s.null_count() > 0 {
0
} else {
1
}
};

dense.push(cumsum);
obs.values_iter().for_each(|b| {
if b {
cumsum += 1;
}
dense.push(cumsum)
});
let arr = IdxArr::from_data_default(dense.into(), validity);
let dense = IdxCa::with_chunk(s.name(), arr);

// SAFETY: in bounds.
let dense = unsafe { dense.take_unchecked(&inv_ca) };

if let RankMethod::Dense = method {
return if s.null_count() == 0 {
dense.into_series()
} else {
// Null will be the first rank. We restore original nulls and shift all ranks by one.
let validity = s.is_null().rechunk();
let validity = validity.downcast_iter().next().unwrap();
let validity = validity.values().clone();

let arr = &dense.chunks()[0];
let arr = arr.with_validity(Some(validity));
let dtype = arr.data_type().clone();

// SAFETY: given dtype is correct.
unsafe {
Series::_try_from_arrow_unchecked(s.name(), vec![arr], &dtype).unwrap()
IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series()
},
Average => unsafe {
let mut out = vec![0.0; s.len()];
rank_impl(&sort_idx_ca, neq, |ties| {
let first = rank;
rank += ties.len() as IdxSize;
let last = rank - 1;
let avg = 0.5 * (first as f64 + last as f64);
for i in ties {
*out.get_unchecked_mut(*i as usize) = avg;
}
};
}

let bitmap = obs.values();
let cap = bitmap.len() - bitmap.unset_bits();
let mut count = Vec::with_capacity(cap + 1);
let mut cnt: IdxSize = 0;
count.push(cnt);

if null_count > 0 {
obs.iter().for_each(|b| {
if let Some(b) = b {
cnt += 1;
if b {
count.push(cnt)
}
});
Float64Chunked::new_from_owned_with_null_bitmap(s.name(), out, validity)
.into_series()
},
Min => unsafe {
let mut out = vec![0 as IdxSize; s.len()];
rank_impl(&sort_idx_ca, neq, |ties| {
for i in ties.iter() {
*out.get_unchecked_mut(*i as usize) = rank;
}
rank += ties.len() as IdxSize;
});
} else {
obs.values_iter().for_each(|b| {
cnt += 1;
if b {
count.push(cnt)
IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series()
},
Max => unsafe {
let mut out = vec![0 as IdxSize; s.len()];
rank_impl(&sort_idx_ca, neq, |ties| {
rank += ties.len() as IdxSize;
for i in ties {
*out.get_unchecked_mut(*i as usize) = rank - 1;
}
});
}

count.push((len - null_count) as IdxSize);
let count = IdxCa::from_vec(s.name(), count);

match method {
Max => {
// SAFETY: in bounds.
unsafe { count.take_unchecked(&dense).into_series() }
},
Min => {
// SAFETY: in bounds.
unsafe { (count.take_unchecked(&dense) + 1).into_series() }
},
Average => {
// SAFETY: in bounds.
let a = unsafe { count.take_unchecked(&dense) }
.cast(&DataType::Float64)
.unwrap();
let b = unsafe { count.take_unchecked(&(dense - 1)) }
.cast(&DataType::Float64)
.unwrap()
+ 1.0;
(&a + &b) * 0.5
},
#[cfg(feature = "random")]
Dense | Ordinal | Random => unimplemented!(),
#[cfg(not(feature = "random"))]
Dense | Ordinal => unimplemented!(),
}
},
IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series()
},
Dense => unsafe {
let mut out = vec![0 as IdxSize; s.len()];
rank_impl(&sort_idx_ca, neq, |ties| {
for i in ties {
*out.get_unchecked_mut(*i as usize) = rank;
}
rank += 1;
});
IdxCa::new_from_owned_with_null_bitmap(s.name(), out, validity).into_series()
},
Ordinal => unreachable!(),
}
}
}

Expand Down Expand Up @@ -429,14 +305,14 @@ mod test {
let s = UInt32Chunked::new("", &[None, None, None]).into_series();
let out = rank(&s, RankMethod::Average, false, None)
.f64()?
.into_no_null_iter()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[2.0f64, 2.0, 2.0]);
assert_eq!(out, &[None, None, None]);
let out = rank(&s, RankMethod::Dense, false, None)
.idx()?
.into_no_null_iter()
.into_iter()
.collect::<Vec<_>>();
assert_eq!(out, &[1, 1, 1]);
assert_eq!(out, &[None, None, None]);
Ok(())
}

Expand Down
Loading

0 comments on commit 104ee93

Please sign in to comment.