diff --git a/crates/polars-expr/src/expressions/apply.rs b/crates/polars-expr/src/expressions/apply.rs index 53579b763033..d6c37a5a004f 100644 --- a/crates/polars-expr/src/expressions/apply.rs +++ b/crates/polars-expr/src/expressions/apply.rs @@ -170,7 +170,11 @@ impl ApplyExpr { // })? let out: ListChunked = POOL.install(|| iter.collect::>())?; - debug_assert_eq!(out.dtype(), &DataType::List(Box::new(dtype))); + if self.function_returns_scalar { + debug_assert_eq!(&DataType::List(Box::new(dtype)), out.dtype()); + } else { + debug_assert_eq!(&dtype, out.dtype()); + } out } else { diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index c4006de0c8ec..beaa6b6cdae0 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -239,7 +239,7 @@ fn create_physical_expr_inner( // TODO! Order by let group_by = create_physical_expressions_from_nodes( partition_by, - Context::Default, + Context::Aggregation, expr_arena, schema, state, @@ -473,10 +473,9 @@ fn create_physical_expr_inner( options, } => { let is_scalar = is_scalar_ae(expression, expr_arena); - let output_dtype = - expr_arena - .get(expression) - .to_field(schema, Context::Default, expr_arena)?; + let output_field = expr_arena + .get(expression) + .to_field(schema, ctxt, expr_arena)?; let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR) && matches!(options.collect_groups, ApplyOptions::GroupWise); @@ -501,7 +500,7 @@ fn create_physical_expr_inner( *options, state.allow_threading, schema.clone(), - output_dtype, + output_field, is_scalar, ))) }, @@ -509,13 +508,11 @@ fn create_physical_expr_inner( input, function, options, - .. } => { let is_scalar = is_scalar_ae(expression, expr_arena); - let output_field = - expr_arena - .get(expression) - .to_field(schema, Context::Default, expr_arena)?; + let output_field = expr_arena + .get(expression) + .to_field(schema, ctxt, expr_arena)?; let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR) && matches!(options.collect_groups, ApplyOptions::GroupWise); // Will be reset in the function so get that here. @@ -568,6 +565,7 @@ fn create_physical_expr_inner( let field = expr_arena .get(expression) .to_field(schema, ctxt, expr_arena)?; + Ok(Arc::new(ApplyExpr::new( vec![input], function, diff --git a/crates/polars-plan/src/plans/aexpr/schema.rs b/crates/polars-plan/src/plans/aexpr/schema.rs index 8cb4b8cc2387..7105855636c5 100644 --- a/crates/polars-plan/src/plans/aexpr/schema.rs +++ b/crates/polars-plan/src/plans/aexpr/schema.rs @@ -38,7 +38,7 @@ impl AExpr { // col(foo: i64).sum() -> i64 // The `nested` keeps track of the nesting we need to add. let mut nested = matches!(ctx, Context::Aggregation) as u8; - let mut field = self.to_field_impl(schema, arena, &mut nested)?; + let mut field = self.to_field_impl(schema, ctx, arena, &mut nested)?; if nested >= 1 { field.coerce(field.dtype().clone().implode()); @@ -51,6 +51,7 @@ impl AExpr { pub fn to_field_impl( &self, schema: &Schema, + ctx: Context, arena: &Arena, nested: &mut u8, ) -> PolarsResult { @@ -68,11 +69,13 @@ impl AExpr { *nested += matches!(mapping, WindowMapping::Join) as u8; } let e = arena.get(*function); - e.to_field_impl(schema, arena, nested) + e.to_field_impl(schema, ctx, arena, nested) }, Explode(expr) => { - let field = arena.get(*expr).to_field_impl(schema, arena, nested)?; - *nested = nested.saturating_sub(1); + // `Explode` is a "flatten" operation, which is not the same as returning a scalar. + // Namely, it should be auto-imploded in the aggregation context, so we don't update + // the `nested` state here. + let field = arena.get(*expr).to_field_impl(schema, ctx, arena, &mut 0)?; if let List(inner) = field.dtype() { Ok(Field::new(field.name().clone(), *inner.clone())) @@ -82,7 +85,10 @@ impl AExpr { }, Alias(expr, name) => Ok(Field::new( name.clone(), - arena.get(*expr).to_field_impl(schema, arena, nested)?.dtype, + arena + .get(*expr) + .to_field_impl(schema, ctx, arena, nested)? + .dtype, )), Column(name) => schema .get_field(name) @@ -110,20 +116,23 @@ impl AExpr { | Operator::LogicalOr => { let out_field; let out_name = { - out_field = arena.get(*left).to_field_impl(schema, arena, nested)?; + out_field = + arena.get(*left).to_field_impl(schema, ctx, arena, nested)?; out_field.name() }; Field::new(out_name.clone(), Boolean) }, Operator::TrueDivide => { - return get_truediv_field(*left, *right, arena, schema, nested) + return get_truediv_field(*left, *right, arena, ctx, schema, nested) + }, + _ => { + return get_arithmetic_field(*left, *right, arena, *op, ctx, schema, nested) }, - _ => return get_arithmetic_field(*left, *right, arena, *op, schema, nested), }; Ok(field) }, - Sort { expr, .. } => arena.get(*expr).to_field_impl(schema, arena, nested), + Sort { expr, .. } => arena.get(*expr).to_field_impl(schema, ctx, arena, nested), Gather { expr, returns_scalar, @@ -132,10 +141,10 @@ impl AExpr { if *returns_scalar { *nested = nested.saturating_sub(1); } - arena.get(*expr).to_field_impl(schema, arena, nested) + arena.get(*expr).to_field_impl(schema, ctx, arena, nested) }, - SortBy { expr, .. } => arena.get(*expr).to_field_impl(schema, arena, nested), - Filter { input, .. } => arena.get(*input).to_field_impl(schema, arena, nested), + SortBy { expr, .. } => arena.get(*expr).to_field_impl(schema, ctx, arena, nested), + Filter { input, .. } => arena.get(*input).to_field_impl(schema, ctx, arena, nested), Agg(agg) => { use IRAggExpr::*; match agg { @@ -144,11 +153,12 @@ impl AExpr { | First(expr) | Last(expr) => { *nested = nested.saturating_sub(1); - arena.get(*expr).to_field_impl(schema, arena, nested) + arena.get(*expr).to_field_impl(schema, ctx, arena, nested) }, Sum(expr) => { *nested = nested.saturating_sub(1); - let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?; + let mut field = + arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; let dt = match field.dtype() { Boolean => Some(IDX_DTYPE), UInt8 | Int8 | Int16 | UInt16 => Some(Int64), @@ -161,7 +171,8 @@ impl AExpr { }, Median(expr) => { *nested = nested.saturating_sub(1); - let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?; + let mut field = + arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; match field.dtype { Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)), _ => float_type(&mut field), @@ -170,7 +181,8 @@ impl AExpr { }, Mean(expr) => { *nested = nested.saturating_sub(1); - let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?; + let mut field = + arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; match field.dtype { Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)), _ => float_type(&mut field), @@ -178,57 +190,64 @@ impl AExpr { Ok(field) }, Implode(expr) => { - let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?; + let mut field = + arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; field.coerce(DataType::List(field.dtype().clone().into())); Ok(field) }, Std(expr, _) => { *nested = nested.saturating_sub(1); - let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?; + let mut field = + arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; float_type(&mut field); Ok(field) }, Var(expr, _) => { *nested = nested.saturating_sub(1); - let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?; + let mut field = + arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; float_type(&mut field); Ok(field) }, NUnique(expr) => { *nested = 0; - let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?; + let mut field = + arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; field.coerce(IDX_DTYPE); Ok(field) }, Count(expr, _) => { *nested = 0; - let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?; + let mut field = + arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; field.coerce(IDX_DTYPE); Ok(field) }, AggGroups(expr) => { *nested = 1; - let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?; + let mut field = + arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; field.coerce(List(IDX_DTYPE.into())); Ok(field) }, Quantile { expr, .. } => { *nested = nested.saturating_sub(1); - let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?; + let mut field = + arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; float_type(&mut field); Ok(field) }, #[cfg(feature = "bitwise")] Bitwise(expr, _) => { *nested = nested.saturating_sub(1); - let field = arena.get(*expr).to_field_impl(schema, arena, nested)?; + let field = arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; // @Q? Do we need to coerce here? Ok(field) }, } }, Cast { expr, dtype, .. } => { - let field = arena.get(*expr).to_field_impl(schema, arena, nested)?; + let field = arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?; Ok(Field::new(field.name().clone(), dtype.clone())) }, Ternary { truthy, falsy, .. } => { @@ -242,10 +261,11 @@ impl AExpr { let mut truthy = arena .get(*truthy) - .to_field_impl(schema, arena, &mut nested_truthy)?; - let falsy = arena - .get(*falsy) - .to_field_impl(schema, arena, &mut nested_falsy)?; + .to_field_impl(schema, ctx, arena, &mut nested_truthy)?; + let falsy = + arena + .get(*falsy) + .to_field_impl(schema, ctx, arena, &mut nested_falsy)?; let st = if let DataType::Null = *truthy.dtype() { falsy.dtype().clone() @@ -264,30 +284,43 @@ impl AExpr { options, .. } => { - *nested = nested - .saturating_sub(options.flags.contains(FunctionFlags::RETURNS_SCALAR) as _); - let fields = func_args_to_fields(input, schema, arena, nested)?; + let fields = func_args_to_fields(input, ctx, schema, arena, nested)?; polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", options.fmt_str); - output_type.get_field(schema, Context::Default, &fields) + let out = output_type.get_field(schema, ctx, &fields)?; + + if options.flags.contains(FunctionFlags::RETURNS_SCALAR) { + *nested = 0; + } else if matches!(ctx, Context::Aggregation) { + *nested += 1; + } + + Ok(out) }, Function { function, input, options, } => { - *nested = nested - .saturating_sub(options.flags.contains(FunctionFlags::RETURNS_SCALAR) as _); - let fields = func_args_to_fields(input, schema, arena, nested)?; + let fields = func_args_to_fields(input, ctx, schema, arena, nested)?; polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function); - function.get_field(schema, Context::Default, &fields) + let out = function.get_field(schema, ctx, &fields)?; + + if options.flags.contains(FunctionFlags::RETURNS_SCALAR) { + *nested = 0; + } else if matches!(ctx, Context::Aggregation) { + *nested += 1; + } + + Ok(out) }, - Slice { input, .. } => arena.get(*input).to_field_impl(schema, arena, nested), + Slice { input, .. } => arena.get(*input).to_field_impl(schema, ctx, arena, nested), } } } fn func_args_to_fields( input: &[ExprIR], + ctx: Context, schema: &Schema, arena: &Arena, nested: &mut u8, @@ -308,7 +341,7 @@ fn func_args_to_fields( arena .get(e.node()) - .to_field_impl(schema, arena, nested) + .to_field_impl(schema, ctx, arena, nested) .map(|mut field| { field.name = e.output_name().clone(); field @@ -322,6 +355,7 @@ fn get_arithmetic_field( right: Node, arena: &Arena, op: Operator, + ctx: Context, schema: &Schema, nested: &mut u8, ) -> PolarsResult { @@ -337,11 +371,11 @@ fn get_arithmetic_field( // leading to quadratic behavior. # 4736 // // further right_type is only determined when needed. - let mut left_field = left_ae.to_field_impl(schema, arena, nested)?; + let mut left_field = left_ae.to_field_impl(schema, ctx, arena, nested)?; let super_type = match op { Operator::Minus => { - let right_type = right_ae.to_field_impl(schema, arena, nested)?.dtype; + let right_type = right_ae.to_field_impl(schema, ctx, arena, nested)?.dtype; match (&left_field.dtype, &right_type) { #[cfg(feature = "dtype-struct")] (Struct(_), Struct(_)) => { @@ -396,7 +430,7 @@ fn get_arithmetic_field( } }, Operator::Plus => { - let right_type = right_ae.to_field_impl(schema, arena, nested)?.dtype; + let right_type = right_ae.to_field_impl(schema, ctx, arena, nested)?.dtype; match (&left_field.dtype, &right_type) { (Duration(_), Datetime(_, _)) | (Datetime(_, _), Duration(_)) @@ -438,7 +472,7 @@ fn get_arithmetic_field( } }, _ => { - let right_type = right_ae.to_field_impl(schema, arena, nested)?.dtype; + let right_type = right_ae.to_field_impl(schema, ctx, arena, nested)?.dtype; match (&left_field.dtype, &right_type) { #[cfg(feature = "dtype-struct")] @@ -522,11 +556,12 @@ fn get_truediv_field( left: Node, right: Node, arena: &Arena, + ctx: Context, schema: &Schema, nested: &mut u8, ) -> PolarsResult { - let mut left_field = arena.get(left).to_field_impl(schema, arena, nested)?; - let right_field = arena.get(right).to_field_impl(schema, arena, nested)?; + let mut left_field = arena.get(left).to_field_impl(schema, ctx, arena, nested)?; + let right_field = arena.get(right).to_field_impl(schema, ctx, arena, nested)?; use DataType::*; // TODO: Re-investigate this. A lot of "_" is being used on the RHS match because this code diff --git a/py-polars/tests/unit/test_schema.py b/py-polars/tests/unit/test_schema.py index 9c5848382f42..78a277a3662f 100644 --- a/py-polars/tests/unit/test_schema.py +++ b/py-polars/tests/unit/test_schema.py @@ -1,9 +1,11 @@ import pickle from datetime import datetime +from typing import Any import pytest import polars as pl +from polars.testing.asserts.frame import assert_frame_equal def test_schema() -> None: @@ -110,7 +112,7 @@ def test_schema_in_map_elements_returns_scalar() -> None: ) .alias("irr") ) - assert (q.collect_schema()) == schema + assert q.collect_schema() == schema assert q.collect().schema == schema @@ -129,3 +131,118 @@ def test_schema_functions_in_agg_with_literal_arg_19011() -> None: assert q.collect_schema() == pl.Schema( [("idx", pl.Int64), ("a_1", pl.List(pl.Int64)), ("a_2", pl.List(pl.Float64))] ) + + +def test_lf_explode_in_agg_schema_19562() -> None: + def new_df_check_schema( + value: dict[str, Any], schema: dict[str, Any] + ) -> pl.DataFrame: + df = pl.DataFrame(value) + assert df.schema == schema + return df + + lf = pl.LazyFrame({"a": [1], "b": [[1]]}) + + q = lf.group_by("a").agg(pl.col("b")) + schema = {"a": pl.Int64, "b": pl.List(pl.List(pl.Int64))} + + assert q.collect_schema() == schema + assert_frame_equal( + q.collect(), new_df_check_schema({"a": [1], "b": [[[1]]]}, schema) + ) + + q = lf.group_by("a").agg(pl.col("b").explode()) + schema = {"a": pl.Int64, "b": pl.List(pl.Int64)} + + assert q.collect_schema() == schema + assert_frame_equal(q.collect(), new_df_check_schema({"a": [1], "b": [[1]]}, schema)) + + q = lf.group_by("a").agg(pl.col("b").explode().explode()) + schema = {"a": pl.Int64, "b": pl.List(pl.Int64)} + + assert q.collect_schema() == schema + assert_frame_equal(q.collect(), new_df_check_schema({"a": [1], "b": [[1]]}, schema)) + + # 2x nested + lf = pl.LazyFrame({"a": [1], "b": [[[1]]]}) + + q = lf.group_by("a").agg(pl.col("b")) + schema = { + "a": pl.Int64, + "b": pl.List(pl.List(pl.List(pl.Int64))), + } + + assert q.collect_schema() == schema + assert_frame_equal( + q.collect(), new_df_check_schema({"a": [1], "b": [[[[1]]]]}, schema) + ) + + q = lf.group_by("a").agg(pl.col("b").explode()) + schema = {"a": pl.Int64, "b": pl.List(pl.List(pl.Int64))} + + assert q.collect_schema() == schema + assert_frame_equal( + q.collect(), new_df_check_schema({"a": [1], "b": [[[1]]]}, schema) + ) + + q = lf.group_by("a").agg(pl.col("b").explode().explode()) + schema = {"a": pl.Int64, "b": pl.List(pl.Int64)} + + assert q.collect_schema() == schema + assert_frame_equal(q.collect(), new_df_check_schema({"a": [1], "b": [[1]]}, schema)) + + +def test_lf_nested_function_expr_agg_schema() -> None: + q = ( + pl.LazyFrame({"k": [1, 1, 2]}) + .group_by(pl.first(), maintain_order=True) + .agg(o=pl.int_range(pl.len()).reverse() < 1) + ) + + assert q.collect_schema() == {"k": pl.Int64, "o": pl.List(pl.Boolean)} + assert_frame_equal( + q.collect(), pl.DataFrame({"k": [1, 2], "o": [[False, True], [True]]}) + ) + + +def test_lf_agg_scalar_return_schema() -> None: + q = pl.LazyFrame({"k": [1]}).group_by("k").agg(pl.col("k").null_count().alias("o")) + + schema = {"k": pl.Int64, "o": pl.UInt32} + assert q.collect_schema() == schema + assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": 0}, schema=schema)) + + +def test_lf_agg_nested_expr_schema() -> None: + q = ( + pl.LazyFrame({"k": [1]}) + .group_by("k") + .agg( + ( + ( + (pl.col("k").reverse().shuffle() + 1) + + pl.col("k").shuffle().reverse() + ) + .shuffle() + .reverse() + .sum() + * 0 + ).alias("o") + ) + ) + + schema = {"k": pl.Int64, "o": pl.Int64} + assert q.collect_schema() == schema + assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": 0}, schema=schema)) + + +def test_lf_agg_lit_explode() -> None: + q = ( + pl.LazyFrame({"k": [1]}) + .group_by("k") + .agg(pl.lit(1, dtype=pl.Int64).explode().alias("o")) + ) + + schema = {"k": pl.Int64, "o": pl.List(pl.Int64)} + assert q.collect_schema() == schema + assert_frame_equal(q.collect(), pl.DataFrame({"k": 1, "o": [[1]]}, schema=schema)) # type: ignore[arg-type]