Skip to content

Commit

Permalink
fix: SQL interface "off-by-one' indexing error with GROUP BY clause…
Browse files Browse the repository at this point in the history
…s that use position ordinals (#15584)
  • Loading branch information
alexander-beedie authored Apr 11, 2024
1 parent f329b5c commit c8e26ca
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 10 deletions.
14 changes: 6 additions & 8 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,7 @@ impl SQLContext {
})
.collect::<PolarsResult<_>>()?;

// Check for group by (after projections since there might be numbers).
// Check for group by (after projections as there may be ordinal/position ints).
let group_by_keys: Vec<Expr>;
if let GroupByExpr::Expressions(group_by_exprs) = &select_stmt.group_by {
group_by_keys = group_by_exprs.iter()
Expand All @@ -425,7 +425,8 @@ impl SQLContext {
)),
Ok(idx) => Ok(idx),
}?;
Ok(projections[idx].clone())
// note: sql queries are 1-indexed
Ok(projections[idx - 1].clone())
},
SQLExpr::Value(_) => Err(polars_err!(
ComputeError:
Expand Down Expand Up @@ -694,7 +695,6 @@ impl SQLContext {
ComputeError: "group_by error: can't process wildcard in group_by"
);
let schema_before = lf.schema()?;

let group_by_keys_schema =
expressions_to_schema(group_by_keys, &schema_before, Context::Default)?;

Expand All @@ -703,24 +703,21 @@ impl SQLContext {
let mut aliases: BTreeSet<&str> = BTreeSet::new();

for mut e in projections {
// If it is a simple expression & has alias,
// we must defer the aliasing until after the group_by.
// If simple aliased expression we defer aliasing until after the group_by.
if e.clone().meta().is_simple_projection() {
if let Expr::Alias(expr, name) = e {
aliases.insert(name);
e = expr
}
}

let field = e.to_field(&schema_before, Context::Default)?;
if group_by_keys_schema.get(&field.name).is_none() {
aggregation_projection.push(e.clone())
}
}

let aggregated = lf.group_by(group_by_keys).agg(&aggregation_projection);
let projection_schema =
expressions_to_schema(projections, &schema_before, Context::Default)?;

// A final projection to get the proper order.
let final_projection = projection_schema
.iter_names()
Expand All @@ -734,6 +731,7 @@ impl SQLContext {
})
.collect::<Vec<_>>();

let aggregated = lf.group_by(group_by_keys).agg(&aggregation_projection);
Ok(aggregated.select(&final_projection))
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-sql/tests/functions_string.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ fn test_string_functions() {
}

#[test]
fn array_to_string() {
fn test_array_to_string() {
let df = df! {
"a" => &["first", "first", "third"],
"b" => &[1, 1, 42],
Expand Down
4 changes: 3 additions & 1 deletion crates/polars-sql/tests/simple_exprs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ fn test_nested_expr() -> PolarsResult<()> {
assert_eq!(df_sql, df_pl);
Ok(())
}

#[test]
fn test_group_by_simple() -> PolarsResult<()> {
let df = create_sample_df()?;
Expand All @@ -76,6 +77,7 @@ fn test_group_by_simple() -> PolarsResult<()> {
},
)
.collect()?;

let df_pl = df
.lazy()
.group_by(&[col("a")])
Expand Down Expand Up @@ -249,7 +251,7 @@ fn test_null_exprs_in_where() {
}

#[test]
fn binary_functions() {
fn test_binary_functions() {
let df = create_sample_df().unwrap();
let mut context = SQLContext::new();
context.register("df", df.clone().lazy());
Expand Down
34 changes: 34 additions & 0 deletions py-polars/tests/unit/sql/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import polars as pl
from polars.testing import assert_frame_equal


@pytest.fixture()
Expand Down Expand Up @@ -62,3 +63,36 @@ def test_group_by(foods_ipc_path: Path) -> None:
"""
)
assert out.to_dict(as_series=False) == {"grp": ["c"], "n_dist_attr": [2]}


def test_group_by_ordinal_position() -> None:
df = pl.DataFrame(
{
"a": ["xx", "yy", "xx", "yy", "xx", "zz"],
"b": [1, 2, 3, 4, 5, 6],
"c": [99, 99, 66, 66, 66, 66],
}
)
expected = pl.LazyFrame({"c": [66, 99], "total_b": [18, 3]})

with pl.SQLContext(frame=df) as ctx:
res1 = ctx.execute(
"""
SELECT c, SUM(b) AS total_b
FROM frame
GROUP BY 1
ORDER BY c
"""
)
assert_frame_equal(res1, expected)

res2 = ctx.execute(
"""
WITH "grp" AS (
SELECT NULL::date as dt, c, SUM(b) AS total_b
FROM frame
GROUP BY 2, 1
)
SELECT c, total_b FROM grp ORDER BY c"""
)
assert_frame_equal(res2, expected)

0 comments on commit c8e26ca

Please sign in to comment.