Skip to content

Commit

Permalink
fix: Fix literal slice in group by (#17242)
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 authored Jun 27, 2024
1 parent c45e5ec commit 01a65f0
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 6 deletions.
8 changes: 5 additions & 3 deletions crates/polars-expr/src/expressions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,11 @@ impl<'a> AggregationContext<'a> {

/// Update the group tuples
pub(crate) fn with_groups(&mut self, groups: GroupsProxy) -> &mut Self {
// In case of new groups, a series always needs to be flattened
self.with_series(self.flat_naive().into_owned(), false, None)
.unwrap();
if let AggState::AggregatedList(_) = self.agg_state() {
// In case of new groups, a series always needs to be flattened
self.with_series(self.flat_naive().into_owned(), false, None)
.unwrap();
}
self.groups = Cow::Owned(groups);
// make sure that previous setting is not used
self.update_groups = UpdateGroups::No;
Expand Down
14 changes: 11 additions & 3 deletions crates/polars-expr/src/expressions/slice.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use polars_core::prelude::*;
use polars_core::utils::{slice_offsets, CustomIterTools};
use polars_core::utils::{slice_offsets, Container, CustomIterTools};
use polars_core::POOL;
use rayon::prelude::*;
use AnyValue::Null;
Expand Down Expand Up @@ -106,13 +106,18 @@ impl PhysicalExpr for SliceExpr {
let mut ac_length = results.pop().unwrap();
let mut ac_offset = results.pop().unwrap();

let groups = ac.groups();

use AggState::*;
let groups = match (&ac_offset.state, &ac_length.state) {
(Literal(offset), Literal(length)) => {
let (offset, length) = extract_args(offset, length, &self.expr)?;

if let Literal(s) = ac.agg_state() {
let s1 = s.slice(offset, length);
ac.with_literal(s1);
return Ok(ac);
}
let groups = ac.groups();

match groups.as_ref() {
GroupsProxy::Idx(groups) => {
let groups = groups
Expand All @@ -134,6 +139,7 @@ impl PhysicalExpr for SliceExpr {
}
},
(Literal(offset), _) => {
let groups = ac.groups();
let offset = extract_offset(offset, &self.expr)?;
let length = ac_length.aggregated();
check_argument(&length, groups, "length", &self.expr)?;
Expand Down Expand Up @@ -168,6 +174,7 @@ impl PhysicalExpr for SliceExpr {
}
},
(_, Literal(length)) => {
let groups = ac.groups();
let length = extract_length(length, &self.expr)?;
let offset = ac_offset.aggregated();
check_argument(&offset, groups, "offset", &self.expr)?;
Expand Down Expand Up @@ -202,6 +209,7 @@ impl PhysicalExpr for SliceExpr {
}
},
_ => {
let groups = ac.groups();
let length = ac_length.aggregated();
let offset = ac_offset.aggregated();
check_argument(&length, groups, "length", &self.expr)?;
Expand Down
9 changes: 9 additions & 0 deletions py-polars/tests/unit/operations/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -1130,3 +1130,12 @@ def sub_col_min(column: str, min_column: str) -> pl.Expr:
pl.List(pl.Float64),
pl.List(pl.Float64),
]


def test_grouped_slice_literals() -> None:
assert pl.DataFrame({"idx": [1, 2, 3]}).group_by(True).agg(
x=pl.lit([1, 2]).slice(
-1, 1
), # slices a list of 1 element, so remains the same element
x2=pl.lit(pl.Series([1, 2])).slice(-1, 1),
).to_dict(as_series=False) == {"literal": [True], "x": [[1, 2]], "x2": [2]}

0 comments on commit 01a65f0

Please sign in to comment.