Skip to content

Commit

Permalink
fix: Fix Expr.over with order_by did not take effect if group key…
Browse files Browse the repository at this point in the history
…s were sorted (#18947)
  • Loading branch information
nameexhaustion authored Sep 27, 2024
1 parent d097d3c commit e2c7150
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 11 deletions.
12 changes: 3 additions & 9 deletions crates/polars-expr/src/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,7 +315,6 @@ impl WindowExpr {
fn determine_map_strategy(
&self,
agg_state: &AggState,
sorted_keys: bool,
gb: &GroupBy,
) -> PolarsResult<MapStrategy> {
match (self.mapping, agg_state) {
Expand All @@ -334,13 +333,8 @@ impl WindowExpr {
// no explicit aggregations, map over the groups
//`(col("x").sum() * col("y")).over("groups")`
(WindowMapping::GroupsToRows, AggState::AggregatedList(_)) => {
if sorted_keys {
if let GroupsProxy::Idx(g) = gb.get_groups() {
debug_assert!(g.is_sorted_flag())
}
// GroupsProxy::Slice is always sorted

// Note that group columns must be sorted for this to make sense!!!
if let GroupsProxy::Slice { .. } = gb.get_groups() {
// Result can be directly exploded if the input was sorted.
Ok(MapStrategy::Explode)
} else {
Ok(MapStrategy::Map)
Expand Down Expand Up @@ -516,7 +510,7 @@ impl PhysicalExpr for WindowExpr {
let mut ac = self.run_aggregation(df, state, &gb)?;

use MapStrategy::*;
match self.determine_map_strategy(ac.agg_state(), sorted_keys, &gb)? {
match self.determine_map_strategy(ac.agg_state(), &gb)? {
Nothing => {
let mut out = ac.flat_naive().into_owned();

Expand Down
2 changes: 0 additions & 2 deletions crates/polars/tests/it/lazy/expressions/window.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,7 @@ fn test_sort_by_in_groups() -> PolarsResult<()> {
col("cars"),
col("A")
.sort_by([col("B")], SortMultipleOptions::default())
.implode()
.over([col("cars")])
.explode()
.alias("sorted_A_by_B"),
])
.collect()?;
Expand Down
18 changes: 18 additions & 0 deletions py-polars/tests/unit/operations/test_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,3 +518,21 @@ def test_lit_window_broadcast() -> None:
assert pl.DataFrame({"a": [1, 1, 2]}).select(pl.lit(0).over("a").alias("a"))[
"a"
].to_list() == [0, 0, 0]


def test_order_by_sorted_keys_18943() -> None:
df = pl.DataFrame(
{
"g": [1, 1, 1, 1],
"t": [4, 3, 2, 1],
"x": [10, 20, 30, 40],
}
)

expect = pl.DataFrame({"x": [100, 90, 70, 40]})

out = df.select(pl.col("x").cum_sum().over("g", order_by="t"))
assert_frame_equal(out, expect)

out = df.set_sorted("g").select(pl.col("x").cum_sum().over("g", order_by="t"))
assert_frame_equal(out, expect)

0 comments on commit e2c7150

Please sign in to comment.