Skip to content

Commit

Permalink
fix: list.mean fast path shouldn't produce NaN
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Apr 15, 2024
1 parent 44f1097 commit b705e17
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
26 changes: 18 additions & 8 deletions crates/polars-ops/src/chunked_array/list/sum_mean.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,10 @@ pub(super) fn sum_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> Polars
Ok(out)
}

fn mean_between_offsets<T, S>(values: &[T], offset: &[i64]) -> Vec<S>
fn mean_between_offsets<T, S>(values: &[T], offset: &[i64]) -> PrimitiveArray<S>
where
T: NativeType + ToPrimitive,
S: NumCast + std::iter::Sum + Div<Output = S>,
S: NativeType + NumCast + std::iter::Sum + Div<Output = S>,
{
let mut running_offset = offset[0];

Expand All @@ -131,10 +131,15 @@ where
.map(|end| {
let current_offset = running_offset;
running_offset = *end;

if current_offset == *end {
return None;
}
let slice = unsafe { values.get_unchecked(current_offset as usize..*end as usize) };
unsafe {
sum_slice::<_, S>(slice) / NumCast::from(slice.len()).unwrap_unchecked_release()
Some(
sum_slice::<_, S>(slice)
/ NumCast::from(slice.len()).unwrap_unchecked_release(),
)
}
})
.collect_trusted()
Expand All @@ -147,10 +152,15 @@ where
{
let values = arr.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap();
let values = values.values().as_slice();
Box::new(PrimitiveArray::from_data_default(
mean_between_offsets::<_, S>(values, offsets).into(),
validity.cloned(),
)) as ArrayRef
let mut out = mean_between_offsets::<_, S>(values, offsets);
if let Some(validity) = validity {
if out.has_validity() {
out.apply_validity(|other_validity| validity & &other_validity)
} else {
out = out.with_validity(Some(validity.clone()));
}
}
Box::new(out)
}

pub(super) fn mean_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series {
Expand Down
10 changes: 10 additions & 0 deletions py-polars/tests/unit/datatypes/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,16 @@ def test_list_min_max() -> None:
}


def test_list_mean_fast_path_empty() -> None:
df = pl.DataFrame(
{
"a": [[], [1, 2, 3]],
}
)
output = df.select(pl.col("a").list.mean())
assert output.to_dict(as_series=False) == {"a": [None, 2.0]}


def test_list_min_max_13978() -> None:
df = pl.DataFrame(
{
Expand Down

0 comments on commit b705e17

Please sign in to comment.