From 7af0ec3c7724bbb7a2be325f539054fe18836e74 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Mon, 15 Apr 2024 17:45:15 +0800 Subject: [PATCH] fix: `list.mean` fast path shouldn't produce NaN (#15652) --- .../src/chunked_array/list/sum_mean.rs | 26 +++++++++++++------ py-polars/tests/unit/datatypes/test_list.py | 10 +++++++ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/sum_mean.rs b/crates/polars-ops/src/chunked_array/list/sum_mean.rs index e3a14e2340f7..e6628a3c0e73 100644 --- a/crates/polars-ops/src/chunked_array/list/sum_mean.rs +++ b/crates/polars-ops/src/chunked_array/list/sum_mean.rs @@ -119,10 +119,10 @@ pub(super) fn sum_with_nulls(ca: &ListChunked, inner_dtype: &DataType) -> Polars Ok(out) } -fn mean_between_offsets(values: &[T], offset: &[i64]) -> Vec +fn mean_between_offsets(values: &[T], offset: &[i64]) -> PrimitiveArray where T: NativeType + ToPrimitive, - S: NumCast + std::iter::Sum + Div, + S: NativeType + NumCast + std::iter::Sum + Div, { let mut running_offset = offset[0]; @@ -131,10 +131,15 @@ where .map(|end| { let current_offset = running_offset; running_offset = *end; - + if current_offset == *end { + return None; + } let slice = unsafe { values.get_unchecked(current_offset as usize..*end as usize) }; unsafe { - sum_slice::<_, S>(slice) / NumCast::from(slice.len()).unwrap_unchecked_release() + Some( + sum_slice::<_, S>(slice) + / NumCast::from(slice.len()).unwrap_unchecked_release(), + ) } }) .collect_trusted() @@ -147,10 +152,15 @@ where { let values = arr.as_any().downcast_ref::>().unwrap(); let values = values.values().as_slice(); - Box::new(PrimitiveArray::from_data_default( - mean_between_offsets::<_, S>(values, offsets).into(), - validity.cloned(), - )) as ArrayRef + let mut out = mean_between_offsets::<_, S>(values, offsets); + if let Some(validity) = validity { + if out.has_validity() { + out.apply_validity(|other_validity| validity & &other_validity) + } else { + out = out.with_validity(Some(validity.clone())); + } + } + Box::new(out) } pub(super) fn mean_list_numerical(ca: &ListChunked, inner_type: &DataType) -> Series { diff --git a/py-polars/tests/unit/datatypes/test_list.py b/py-polars/tests/unit/datatypes/test_list.py index fe0028d5067c..6dfd207df48f 100644 --- a/py-polars/tests/unit/datatypes/test_list.py +++ b/py-polars/tests/unit/datatypes/test_list.py @@ -443,6 +443,16 @@ def test_list_min_max() -> None: } +def test_list_mean_fast_path_empty() -> None: + df = pl.DataFrame( + { + "a": [[], [1, 2, 3]], + } + ) + output = df.select(pl.col("a").list.mean()) + assert output.to_dict(as_series=False) == {"a": [None, 2.0]} + + def test_list_min_max_13978() -> None: df = pl.DataFrame( {