From 9056e6078a8b972803284d3038b7ea6efe64b892 Mon Sep 17 00:00:00 2001 From: ritchie Date: Sat, 7 Sep 2024 16:26:34 +0200 Subject: [PATCH] fix: Fix group first value after group-by slice --- crates/polars-expr/src/expressions/slice.rs | 11 +++++++++-- py-polars/tests/unit/operations/test_slice.py | 15 +++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/crates/polars-expr/src/expressions/slice.rs b/crates/polars-expr/src/expressions/slice.rs index d2bc9137a7d3..579c8d66635e 100644 --- a/crates/polars-expr/src/expressions/slice.rs +++ b/crates/polars-expr/src/expressions/slice.rs @@ -60,9 +60,16 @@ fn check_argument(arg: &Series, groups: &GroupsProxy, name: &str, expr: &Expr) - Ok(()) } -fn slice_groups_idx(offset: i64, length: usize, first: IdxSize, idx: &[IdxSize]) -> IdxItem { +fn slice_groups_idx(offset: i64, length: usize, mut first: IdxSize, idx: &[IdxSize]) -> IdxItem { let (offset, len) = slice_offsets(offset, length, idx.len()); - (first + offset as IdxSize, idx[offset..offset + len].into()) + + // If slice isn't out of bounds, we replace first. + // If slice is oob, the `idx` vec will be empty and `first` will be ignored + if let Some(f) = idx.get(offset) { + first = *f; + } + // This is a clone of the vec, which is unfortunate. Maybe we have a `sliceable` unitvec one day. + (first, idx[offset..offset + len].into()) } fn slice_groups_slice(offset: i64, length: usize, first: IdxSize, len: IdxSize) -> [IdxSize; 2] { diff --git a/py-polars/tests/unit/operations/test_slice.py b/py-polars/tests/unit/operations/test_slice.py index 692fcb5634dc..94dc1e3283ff 100644 --- a/py-polars/tests/unit/operations/test_slice.py +++ b/py-polars/tests/unit/operations/test_slice.py @@ -273,3 +273,18 @@ def test_group_by_slice_all_keys() -> None: gb = df.group_by(["a", "b", "c"], maintain_order=True) assert_frame_equal(gb.tail(1), gb.head(1)) + + +def test_slice_first_in_agg_18551() -> None: + df = pl.DataFrame({"id": [1, 1, 2], "name": ["A", "B", "C"], "value": [31, 21, 32]}) + + assert df.group_by("id", maintain_order=True).agg( + sort_by=pl.col("name").sort_by("value"), + x=pl.col("name").sort_by("value").slice(0, 1).first(), + y=pl.col("name").sort_by("value").slice(1, 1).first(), + ).to_dict(as_series=False) == { + "id": [1, 2], + "sort_by": [["B", "A"], ["C"]], + "x": ["B", "C"], + "y": ["A", None], + }