Skip to content

Commit

Permalink
fix: Ensure mean_horizontal raises on non-numeric input (#19648)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie authored Nov 6, 2024
1 parent 047e578 commit 10abada
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 30 deletions.
43 changes: 22 additions & 21 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2840,33 +2840,34 @@ impl DataFrame {
}
}

/// Compute the mean of all values horizontally across columns.
/// Compute the mean of all numeric values horizontally across columns.
pub fn mean_horizontal(&self, null_strategy: NullStrategy) -> PolarsResult<Option<Series>> {
match self.columns.len() {
let (numeric_columns, non_numeric_columns): (Vec<_>, Vec<_>) =
self.columns.iter().partition(|s| {
let dtype = s.dtype();
dtype.is_numeric() || dtype.is_decimal() || dtype.is_bool() || dtype.is_null()
});

if !non_numeric_columns.is_empty() {
let col = non_numeric_columns.first().cloned();
polars_bail!(
InvalidOperation: "'horizontal_mean' expects numeric expressions, found {:?} (dtype={})",
col.unwrap().name(),
col.unwrap().dtype(),
);
}
let columns = numeric_columns.into_iter().cloned().collect::<Vec<_>>();
match columns.len() {
0 => Ok(None),
1 => Ok(Some(match self.columns[0].dtype() {
dt if dt != &DataType::Float32 && (dt.is_numeric() || dt == &DataType::Boolean) => {
self.columns[0]
.as_materialized_series()
.cast(&DataType::Float64)?
},
_ => self.columns[0].as_materialized_series().clone(),
1 => Ok(Some(match columns[0].dtype() {
dt if dt != &DataType::Float32 && !dt.is_decimal() => columns[0]
.as_materialized_series()
.cast(&DataType::Float64)?,
_ => columns[0].as_materialized_series().clone(),
})),
_ => {
let columns = self
.columns
.iter()
.filter(|s| {
let dtype = s.dtype();
dtype.is_numeric() || matches!(dtype, DataType::Boolean)
})
.cloned()
.collect::<Vec<_>>();
polars_ensure!(!columns.is_empty(), InvalidOperation: "'horizontal_mean' expected at least 1 numerical column");
let numeric_df = unsafe { DataFrame::_new_no_checks_impl(self.height(), columns) };

let sum = || numeric_df.sum_horizontal(null_strategy);

let null_count = || {
numeric_df
.par_materialized_column_iter()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,12 +78,6 @@ def test_duration_aggs() -> None:
}


def test_mean_horizontal_with_str_column() -> None:
assert pl.DataFrame(
{"int": [1, 2, 3], "bool": [True, True, None], "str": ["a", "b", "c"]}
).mean_horizontal().to_list() == [1.0, 1.5, 3.0]


def test_list_aggregation_that_filters_all_data_6017() -> None:
out = (
pl.DataFrame({"col_to_group_by": [2], "flt": [1672740910.967138], "col3": [1]})
Expand Down
55 changes: 52 additions & 3 deletions py-polars/tests/unit/operations/aggregation/test_horizontal.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import polars as pl
import polars.selectors as cs
from polars.exceptions import ComputeError
from polars.exceptions import ComputeError, PolarsError
from polars.testing import assert_frame_equal, assert_series_equal

if TYPE_CHECKING:
Expand Down Expand Up @@ -398,10 +398,22 @@ def test_horizontal_broadcasting() -> None:

def test_mean_horizontal() -> None:
lf = pl.LazyFrame({"a": [1, 2, 3], "b": [2.0, 4.0, 6.0], "c": [3, None, 9]})
result = lf.select(pl.mean_horizontal(pl.all()).alias("mean"))

expected = pl.LazyFrame({"mean": [2.0, 3.0, 6.0]}, schema={"mean": pl.Float64})
assert_frame_equal(result, expected)

result = lf.select(pl.mean_horizontal(pl.all()))

expected = pl.LazyFrame({"a": [2.0, 3.0, 6.0]}, schema={"a": pl.Float64})
def test_mean_horizontal_bool() -> None:
df = pl.DataFrame(
{
"a": [True, False, False],
"b": [None, True, False],
"c": [True, False, False],
}
)
expected = pl.DataFrame({"mean": [1.0, 1 / 3, 0.0]}, schema={"mean": pl.Float64})
result = df.select(mean=pl.mean_horizontal(pl.all()))
assert_frame_equal(result, expected)


Expand Down Expand Up @@ -475,3 +487,40 @@ def test_fold_all_schema() -> None:
# divide because of overflow
result = df.select(pl.sum_horizontal(pl.all().hash(seed=1) // int(1e8)))
assert result.dtypes == [pl.UInt64]


@pytest.mark.parametrize(
"horizontal_func",
[
pl.all_horizontal,
pl.any_horizontal,
pl.max_horizontal,
pl.min_horizontal,
pl.mean_horizontal,
pl.sum_horizontal,
],
)
def test_expected_horizontal_dtype_errors(horizontal_func: type[pl.Expr]) -> None:
from decimal import Decimal as D

import polars as pl

df = pl.DataFrame(
{
"cola": [D("1.5"), D("0.5"), D("5"), D("0"), D("-0.25")],
"colb": [[0, 1], [2], [3, 4], [5], [6]],
"colc": ["aa", "bb", "cc", "dd", "ee"],
"cold": ["bb", "cc", "dd", "ee", "ff"],
"cole": [1000, 2000, 3000, 4000, 5000],
}
)
with pytest.raises(PolarsError):
df.select(
horizontal_func( # type: ignore[call-arg]
pl.col("cola"),
pl.col("colb"),
pl.col("colc"),
pl.col("cold"),
pl.col("cole"),
)
)

0 comments on commit 10abada

Please sign in to comment.