diff --git a/crates/polars-sql/src/context.rs b/crates/polars-sql/src/context.rs index c9dcd63ff56b..c690042f18f5 100644 --- a/crates/polars-sql/src/context.rs +++ b/crates/polars-sql/src/context.rs @@ -411,7 +411,7 @@ impl SQLContext { }) .collect::>()?; - // Check for group by (after projections since there might be numbers). + // Check for group by (after projections as there may be ordinal/position ints). let group_by_keys: Vec; if let GroupByExpr::Expressions(group_by_exprs) = &select_stmt.group_by { group_by_keys = group_by_exprs.iter() @@ -425,7 +425,8 @@ impl SQLContext { )), Ok(idx) => Ok(idx), }?; - Ok(projections[idx].clone()) + // note: sql queries are 1-indexed + Ok(projections[idx - 1].clone()) }, SQLExpr::Value(_) => Err(polars_err!( ComputeError: @@ -694,7 +695,6 @@ impl SQLContext { ComputeError: "group_by error: can't process wildcard in group_by" ); let schema_before = lf.schema()?; - let group_by_keys_schema = expressions_to_schema(group_by_keys, &schema_before, Context::Default)?; @@ -703,24 +703,21 @@ impl SQLContext { let mut aliases: BTreeSet<&str> = BTreeSet::new(); for mut e in projections { - // If it is a simple expression & has alias, - // we must defer the aliasing until after the group_by. + // If simple aliased expression we defer aliasing until after the group_by. if e.clone().meta().is_simple_projection() { if let Expr::Alias(expr, name) = e { aliases.insert(name); e = expr } } - let field = e.to_field(&schema_before, Context::Default)?; if group_by_keys_schema.get(&field.name).is_none() { aggregation_projection.push(e.clone()) } } - - let aggregated = lf.group_by(group_by_keys).agg(&aggregation_projection); let projection_schema = expressions_to_schema(projections, &schema_before, Context::Default)?; + // A final projection to get the proper order. let final_projection = projection_schema .iter_names() @@ -734,6 +731,7 @@ impl SQLContext { }) .collect::>(); + let aggregated = lf.group_by(group_by_keys).agg(&aggregation_projection); Ok(aggregated.select(&final_projection)) } diff --git a/crates/polars-sql/tests/functions_string.rs b/crates/polars-sql/tests/functions_string.rs index f6ea4314a5de..7e08e48c1e86 100644 --- a/crates/polars-sql/tests/functions_string.rs +++ b/crates/polars-sql/tests/functions_string.rs @@ -83,7 +83,7 @@ fn test_string_functions() { } #[test] -fn array_to_string() { +fn test_array_to_string() { let df = df! { "a" => &["first", "first", "third"], "b" => &[1, 1, 42], diff --git a/crates/polars-sql/tests/simple_exprs.rs b/crates/polars-sql/tests/simple_exprs.rs index d5974d12312a..0dabdf1cbbff 100644 --- a/crates/polars-sql/tests/simple_exprs.rs +++ b/crates/polars-sql/tests/simple_exprs.rs @@ -53,6 +53,7 @@ fn test_nested_expr() -> PolarsResult<()> { assert_eq!(df_sql, df_pl); Ok(()) } + #[test] fn test_group_by_simple() -> PolarsResult<()> { let df = create_sample_df()?; @@ -76,6 +77,7 @@ fn test_group_by_simple() -> PolarsResult<()> { }, ) .collect()?; + let df_pl = df .lazy() .group_by(&[col("a")]) @@ -249,7 +251,7 @@ fn test_null_exprs_in_where() { } #[test] -fn binary_functions() { +fn test_binary_functions() { let df = create_sample_df().unwrap(); let mut context = SQLContext::new(); context.register("df", df.clone().lazy()); diff --git a/py-polars/tests/unit/sql/test_group_by.py b/py-polars/tests/unit/sql/test_group_by.py index 280c310bc4cb..03f15778e605 100644 --- a/py-polars/tests/unit/sql/test_group_by.py +++ b/py-polars/tests/unit/sql/test_group_by.py @@ -5,6 +5,7 @@ import pytest import polars as pl +from polars.testing import assert_frame_equal @pytest.fixture() @@ -62,3 +63,36 @@ def test_group_by(foods_ipc_path: Path) -> None: """ ) assert out.to_dict(as_series=False) == {"grp": ["c"], "n_dist_attr": [2]} + + +def test_group_by_ordinal_position() -> None: + df = pl.DataFrame( + { + "a": ["xx", "yy", "xx", "yy", "xx", "zz"], + "b": [1, 2, 3, 4, 5, 6], + "c": [99, 99, 66, 66, 66, 66], + } + ) + expected = pl.LazyFrame({"c": [66, 99], "total_b": [18, 3]}) + + with pl.SQLContext(frame=df) as ctx: + res1 = ctx.execute( + """ + SELECT c, SUM(b) AS total_b + FROM frame + GROUP BY 1 + ORDER BY c + """ + ) + assert_frame_equal(res1, expected) + + res2 = ctx.execute( + """ + WITH "grp" AS ( + SELECT NULL::date as dt, c, SUM(b) AS total_b + FROM frame + GROUP BY 2, 1 + ) + SELECT c, total_b FROM grp ORDER BY c""" + ) + assert_frame_equal(res2, expected)