Skip to content

Commit

Permalink
fix: Mean of boolean in group_by incorrectly gave NULL
Browse files Browse the repository at this point in the history
  • Loading branch information
nameexhaustion committed Apr 12, 2024
1 parent c068e76 commit c43ea14
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,9 @@ where
);
}
let agg_fn = match logical_dtype.to_physical() {
dt if dt.is_integer() => AggregateFunction::MeanF64(MeanAgg::<f64>::new()),
dt if dt.is_integer() | dt.is_bool() => {
AggregateFunction::MeanF64(MeanAgg::<f64>::new())
},
DataType::Float32 => AggregateFunction::MeanF32(MeanAgg::<f32>::new()),
DataType::Float64 => AggregateFunction::MeanF64(MeanAgg::<f64>::new()),
dt => AggregateFunction::Null(NullAgg::new(dt)),
Expand Down
23 changes: 23 additions & 0 deletions py-polars/tests/unit/streaming/test_streaming_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,3 +488,26 @@ def test_streaming_group_by_convert_15380() -> None:
pl.DataFrame({"a": [1] * PARTITION_LIMIT}).group_by(b="a").len()["len"].item()
== PARTITION_LIMIT
)


@pytest.mark.parametrize("streaming", [True, False])
@pytest.mark.parametrize("n_rows", [PARTITION_LIMIT - 1, PARTITION_LIMIT])
def test_streaming_group_by_boolean_mean_15610(n_rows: int, streaming: bool) -> None:
expect = pl.DataFrame({"a": [False, True], "c": [0.0, 0.5]})

n_repeats = n_rows // 3
assert n_repeats > 0

out = (
pl.select(
a=pl.repeat([True, False, True], n_repeats).explode(),
b=pl.repeat([True, False, False], n_repeats).explode(),
)
.lazy()
.group_by("a")
.agg(c=pl.mean("b"))
.sort("a")
.collect(streaming=streaming)
)

assert_frame_equal(out, expect)

0 comments on commit c43ea14

Please sign in to comment.