Skip to content

Commit

Permalink
fix: Fix incorrect lazy schema for explode() in agg() (pola-rs#19629
Browse files Browse the repository at this point in the history
)
  • Loading branch information
nameexhaustion authored and tylerriccio33 committed Nov 8, 2024
1 parent 3ffad45 commit 8d8ae34
Show file tree
Hide file tree
Showing 4 changed files with 212 additions and 58 deletions.
6 changes: 5 additions & 1 deletion crates/polars-expr/src/expressions/apply.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,11 @@ impl ApplyExpr {
// })?
let out: ListChunked = POOL.install(|| iter.collect::<PolarsResult<_>>())?;

debug_assert_eq!(out.dtype(), &DataType::List(Box::new(dtype)));
if self.function_returns_scalar {
debug_assert_eq!(&DataType::List(Box::new(dtype)), out.dtype());
} else {
debug_assert_eq!(&dtype, out.dtype());
}

out
} else {
Expand Down
20 changes: 9 additions & 11 deletions crates/polars-expr/src/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ fn create_physical_expr_inner(
// TODO! Order by
let group_by = create_physical_expressions_from_nodes(
partition_by,
Context::Default,
Context::Aggregation,
expr_arena,
schema,
state,
Expand Down Expand Up @@ -473,10 +473,9 @@ fn create_physical_expr_inner(
options,
} => {
let is_scalar = is_scalar_ae(expression, expr_arena);
let output_dtype =
expr_arena
.get(expression)
.to_field(schema, Context::Default, expr_arena)?;
let output_field = expr_arena
.get(expression)
.to_field(schema, ctxt, expr_arena)?;

let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR)
&& matches!(options.collect_groups, ApplyOptions::GroupWise);
Expand All @@ -501,21 +500,19 @@ fn create_physical_expr_inner(
*options,
state.allow_threading,
schema.clone(),
output_dtype,
output_field,
is_scalar,
)))
},
Function {
input,
function,
options,
..
} => {
let is_scalar = is_scalar_ae(expression, expr_arena);
let output_field =
expr_arena
.get(expression)
.to_field(schema, Context::Default, expr_arena)?;
let output_field = expr_arena
.get(expression)
.to_field(schema, ctxt, expr_arena)?;
let is_reducing_aggregation = options.flags.contains(FunctionFlags::RETURNS_SCALAR)
&& matches!(options.collect_groups, ApplyOptions::GroupWise);
// Will be reset in the function so get that here.
Expand Down Expand Up @@ -568,6 +565,7 @@ fn create_physical_expr_inner(
let field = expr_arena
.get(expression)
.to_field(schema, ctxt, expr_arena)?;

Ok(Arc::new(ApplyExpr::new(
vec![input],
function,
Expand Down
125 changes: 80 additions & 45 deletions crates/polars-plan/src/plans/aexpr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl AExpr {
// col(foo: i64).sum() -> i64
// The `nested` keeps track of the nesting we need to add.
let mut nested = matches!(ctx, Context::Aggregation) as u8;
let mut field = self.to_field_impl(schema, arena, &mut nested)?;
let mut field = self.to_field_impl(schema, ctx, arena, &mut nested)?;

if nested >= 1 {
field.coerce(field.dtype().clone().implode());
Expand All @@ -51,6 +51,7 @@ impl AExpr {
pub fn to_field_impl(
&self,
schema: &Schema,
ctx: Context,
arena: &Arena<AExpr>,
nested: &mut u8,
) -> PolarsResult<Field> {
Expand All @@ -68,11 +69,13 @@ impl AExpr {
*nested += matches!(mapping, WindowMapping::Join) as u8;
}
let e = arena.get(*function);
e.to_field_impl(schema, arena, nested)
e.to_field_impl(schema, ctx, arena, nested)
},
Explode(expr) => {
let field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
*nested = nested.saturating_sub(1);
// `Explode` is a "flatten" operation, which is not the same as returning a scalar.
// Namely, it should be auto-imploded in the aggregation context, so we don't update
// the `nested` state here.
let field = arena.get(*expr).to_field_impl(schema, ctx, arena, &mut 0)?;

if let List(inner) = field.dtype() {
Ok(Field::new(field.name().clone(), *inner.clone()))
Expand All @@ -82,7 +85,10 @@ impl AExpr {
},
Alias(expr, name) => Ok(Field::new(
name.clone(),
arena.get(*expr).to_field_impl(schema, arena, nested)?.dtype,
arena
.get(*expr)
.to_field_impl(schema, ctx, arena, nested)?
.dtype,
)),
Column(name) => schema
.get_field(name)
Expand Down Expand Up @@ -110,20 +116,23 @@ impl AExpr {
| Operator::LogicalOr => {
let out_field;
let out_name = {
out_field = arena.get(*left).to_field_impl(schema, arena, nested)?;
out_field =
arena.get(*left).to_field_impl(schema, ctx, arena, nested)?;
out_field.name()
};
Field::new(out_name.clone(), Boolean)
},
Operator::TrueDivide => {
return get_truediv_field(*left, *right, arena, schema, nested)
return get_truediv_field(*left, *right, arena, ctx, schema, nested)
},
_ => {
return get_arithmetic_field(*left, *right, arena, *op, ctx, schema, nested)
},
_ => return get_arithmetic_field(*left, *right, arena, *op, schema, nested),
};

Ok(field)
},
Sort { expr, .. } => arena.get(*expr).to_field_impl(schema, arena, nested),
Sort { expr, .. } => arena.get(*expr).to_field_impl(schema, ctx, arena, nested),
Gather {
expr,
returns_scalar,
Expand All @@ -132,10 +141,10 @@ impl AExpr {
if *returns_scalar {
*nested = nested.saturating_sub(1);
}
arena.get(*expr).to_field_impl(schema, arena, nested)
arena.get(*expr).to_field_impl(schema, ctx, arena, nested)
},
SortBy { expr, .. } => arena.get(*expr).to_field_impl(schema, arena, nested),
Filter { input, .. } => arena.get(*input).to_field_impl(schema, arena, nested),
SortBy { expr, .. } => arena.get(*expr).to_field_impl(schema, ctx, arena, nested),
Filter { input, .. } => arena.get(*input).to_field_impl(schema, ctx, arena, nested),
Agg(agg) => {
use IRAggExpr::*;
match agg {
Expand All @@ -144,11 +153,12 @@ impl AExpr {
| First(expr)
| Last(expr) => {
*nested = nested.saturating_sub(1);
arena.get(*expr).to_field_impl(schema, arena, nested)
arena.get(*expr).to_field_impl(schema, ctx, arena, nested)
},
Sum(expr) => {
*nested = nested.saturating_sub(1);
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
let mut field =
arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?;
let dt = match field.dtype() {
Boolean => Some(IDX_DTYPE),
UInt8 | Int8 | Int16 | UInt16 => Some(Int64),
Expand All @@ -161,7 +171,8 @@ impl AExpr {
},
Median(expr) => {
*nested = nested.saturating_sub(1);
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
let mut field =
arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?;
match field.dtype {
Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)),
_ => float_type(&mut field),
Expand All @@ -170,65 +181,73 @@ impl AExpr {
},
Mean(expr) => {
*nested = nested.saturating_sub(1);
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
let mut field =
arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?;
match field.dtype {
Date => field.coerce(Datetime(TimeUnit::Milliseconds, None)),
_ => float_type(&mut field),
}
Ok(field)
},
Implode(expr) => {
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
let mut field =
arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?;
field.coerce(DataType::List(field.dtype().clone().into()));
Ok(field)
},
Std(expr, _) => {
*nested = nested.saturating_sub(1);
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
let mut field =
arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?;
float_type(&mut field);
Ok(field)
},
Var(expr, _) => {
*nested = nested.saturating_sub(1);
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
let mut field =
arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?;
float_type(&mut field);
Ok(field)
},
NUnique(expr) => {
*nested = 0;
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
let mut field =
arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?;
field.coerce(IDX_DTYPE);
Ok(field)
},
Count(expr, _) => {
*nested = 0;
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
let mut field =
arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?;
field.coerce(IDX_DTYPE);
Ok(field)
},
AggGroups(expr) => {
*nested = 1;
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
let mut field =
arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?;
field.coerce(List(IDX_DTYPE.into()));
Ok(field)
},
Quantile { expr, .. } => {
*nested = nested.saturating_sub(1);
let mut field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
let mut field =
arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?;
float_type(&mut field);
Ok(field)
},
#[cfg(feature = "bitwise")]
Bitwise(expr, _) => {
*nested = nested.saturating_sub(1);
let field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
let field = arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?;
// @Q? Do we need to coerce here?
Ok(field)
},
}
},
Cast { expr, dtype, .. } => {
let field = arena.get(*expr).to_field_impl(schema, arena, nested)?;
let field = arena.get(*expr).to_field_impl(schema, ctx, arena, nested)?;
Ok(Field::new(field.name().clone(), dtype.clone()))
},
Ternary { truthy, falsy, .. } => {
Expand All @@ -242,10 +261,11 @@ impl AExpr {
let mut truthy =
arena
.get(*truthy)
.to_field_impl(schema, arena, &mut nested_truthy)?;
let falsy = arena
.get(*falsy)
.to_field_impl(schema, arena, &mut nested_falsy)?;
.to_field_impl(schema, ctx, arena, &mut nested_truthy)?;
let falsy =
arena
.get(*falsy)
.to_field_impl(schema, ctx, arena, &mut nested_falsy)?;

let st = if let DataType::Null = *truthy.dtype() {
falsy.dtype().clone()
Expand All @@ -264,30 +284,43 @@ impl AExpr {
options,
..
} => {
*nested = nested
.saturating_sub(options.flags.contains(FunctionFlags::RETURNS_SCALAR) as _);
let fields = func_args_to_fields(input, schema, arena, nested)?;
let fields = func_args_to_fields(input, ctx, schema, arena, nested)?;
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", options.fmt_str);
output_type.get_field(schema, Context::Default, &fields)
let out = output_type.get_field(schema, ctx, &fields)?;

if options.flags.contains(FunctionFlags::RETURNS_SCALAR) {
*nested = 0;
} else if matches!(ctx, Context::Aggregation) {
*nested += 1;
}

Ok(out)
},
Function {
function,
input,
options,
} => {
*nested = nested
.saturating_sub(options.flags.contains(FunctionFlags::RETURNS_SCALAR) as _);
let fields = func_args_to_fields(input, schema, arena, nested)?;
let fields = func_args_to_fields(input, ctx, schema, arena, nested)?;
polars_ensure!(!fields.is_empty(), ComputeError: "expression: '{}' didn't get any inputs", function);
function.get_field(schema, Context::Default, &fields)
let out = function.get_field(schema, ctx, &fields)?;

if options.flags.contains(FunctionFlags::RETURNS_SCALAR) {
*nested = 0;
} else if matches!(ctx, Context::Aggregation) {
*nested += 1;
}

Ok(out)
},
Slice { input, .. } => arena.get(*input).to_field_impl(schema, arena, nested),
Slice { input, .. } => arena.get(*input).to_field_impl(schema, ctx, arena, nested),
}
}
}

fn func_args_to_fields(
input: &[ExprIR],
ctx: Context,
schema: &Schema,
arena: &Arena<AExpr>,
nested: &mut u8,
Expand All @@ -308,7 +341,7 @@ fn func_args_to_fields(

arena
.get(e.node())
.to_field_impl(schema, arena, nested)
.to_field_impl(schema, ctx, arena, nested)
.map(|mut field| {
field.name = e.output_name().clone();
field
Expand All @@ -322,6 +355,7 @@ fn get_arithmetic_field(
right: Node,
arena: &Arena<AExpr>,
op: Operator,
ctx: Context,
schema: &Schema,
nested: &mut u8,
) -> PolarsResult<Field> {
Expand All @@ -337,11 +371,11 @@ fn get_arithmetic_field(
// leading to quadratic behavior. # 4736
//
// further right_type is only determined when needed.
let mut left_field = left_ae.to_field_impl(schema, arena, nested)?;
let mut left_field = left_ae.to_field_impl(schema, ctx, arena, nested)?;

let super_type = match op {
Operator::Minus => {
let right_type = right_ae.to_field_impl(schema, arena, nested)?.dtype;
let right_type = right_ae.to_field_impl(schema, ctx, arena, nested)?.dtype;
match (&left_field.dtype, &right_type) {
#[cfg(feature = "dtype-struct")]
(Struct(_), Struct(_)) => {
Expand Down Expand Up @@ -396,7 +430,7 @@ fn get_arithmetic_field(
}
},
Operator::Plus => {
let right_type = right_ae.to_field_impl(schema, arena, nested)?.dtype;
let right_type = right_ae.to_field_impl(schema, ctx, arena, nested)?.dtype;
match (&left_field.dtype, &right_type) {
(Duration(_), Datetime(_, _))
| (Datetime(_, _), Duration(_))
Expand Down Expand Up @@ -438,7 +472,7 @@ fn get_arithmetic_field(
}
},
_ => {
let right_type = right_ae.to_field_impl(schema, arena, nested)?.dtype;
let right_type = right_ae.to_field_impl(schema, ctx, arena, nested)?.dtype;

match (&left_field.dtype, &right_type) {
#[cfg(feature = "dtype-struct")]
Expand Down Expand Up @@ -522,11 +556,12 @@ fn get_truediv_field(
left: Node,
right: Node,
arena: &Arena<AExpr>,
ctx: Context,
schema: &Schema,
nested: &mut u8,
) -> PolarsResult<Field> {
let mut left_field = arena.get(left).to_field_impl(schema, arena, nested)?;
let right_field = arena.get(right).to_field_impl(schema, arena, nested)?;
let mut left_field = arena.get(left).to_field_impl(schema, ctx, arena, nested)?;
let right_field = arena.get(right).to_field_impl(schema, ctx, arena, nested)?;
use DataType::*;

// TODO: Re-investigate this. A lot of "_" is being used on the RHS match because this code
Expand Down
Loading

0 comments on commit 8d8ae34

Please sign in to comment.