Skip to content

Commit ef227f4

Browse files
eejbyfeldtalamb
andauthored
fix: Correct results for grouping sets when columns contain nulls (#12571)
* Fix grouping sets behavior when data contains nulls * PR suggestion comment * Update new test case * Add grouping_id to the logical plan * Add doc comment next to INTERNAL_GROUPING_ID * Fix unparsing of Aggregate with grouping sets --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 134939a commit ef227f4

File tree

11 files changed

+359
-187
lines changed

11 files changed

+359
-187
lines changed

datafusion/core/src/dataframe/mod.rs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -535,9 +535,26 @@ impl DataFrame {
535535
group_expr: Vec<Expr>,
536536
aggr_expr: Vec<Expr>,
537537
) -> Result<DataFrame> {
538+
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
539+
let aggr_expr_len = aggr_expr.len();
538540
let plan = LogicalPlanBuilder::from(self.plan)
539541
.aggregate(group_expr, aggr_expr)?
540542
.build()?;
543+
let plan = if is_grouping_set {
544+
let grouping_id_pos = plan.schema().fields().len() - 1 - aggr_expr_len;
545+
// For grouping sets we do a project to not expose the internal grouping id
546+
let exprs = plan
547+
.schema()
548+
.columns()
549+
.into_iter()
550+
.enumerate()
551+
.filter(|(idx, _)| *idx != grouping_id_pos)
552+
.map(|(_, column)| Expr::Column(column))
553+
.collect::<Vec<_>>();
554+
LogicalPlanBuilder::from(plan).project(exprs)?.build()?
555+
} else {
556+
plan
557+
};
541558
Ok(DataFrame {
542559
session_state: self.session_state,
543560
plan,

datafusion/core/src/physical_planner.rs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -692,10 +692,6 @@ impl DefaultPhysicalPlanner {
692692
physical_input_schema.clone(),
693693
)?);
694694

695-
// update group column indices based on partial aggregate plan evaluation
696-
let final_group: Vec<Arc<dyn PhysicalExpr>> =
697-
initial_aggr.output_group_expr();
698-
699695
let can_repartition = !groups.is_empty()
700696
&& session_state.config().target_partitions() > 1
701697
&& session_state.config().repartition_aggregations();
@@ -716,13 +712,7 @@ impl DefaultPhysicalPlanner {
716712
AggregateMode::Final
717713
};
718714

719-
let final_grouping_set = PhysicalGroupBy::new_single(
720-
final_group
721-
.iter()
722-
.enumerate()
723-
.map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone()))
724-
.collect(),
725-
);
715+
let final_grouping_set = initial_aggr.group_expr().as_final();
726716

727717
Arc::new(AggregateExec::try_new(
728718
next_partition_mode,
@@ -2345,7 +2335,7 @@ mod tests {
23452335
.expect("hash aggregate");
23462336
assert_eq!(
23472337
"sum(aggregate_test_100.c3)",
2348-
final_hash_agg.schema().field(2).name()
2338+
final_hash_agg.schema().field(3).name()
23492339
);
23502340
// we need access to the input to the partial aggregate so that other projects can
23512341
// implement serde

datafusion/expr/src/logical_plan/plan.rs

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use std::cmp::Ordering;
2121
use std::collections::{HashMap, HashSet};
2222
use std::fmt::{self, Debug, Display, Formatter};
2323
use std::hash::{Hash, Hasher};
24-
use std::sync::Arc;
24+
use std::sync::{Arc, OnceLock};
2525

2626
use super::dml::CopyTo;
2727
use super::DdlStatement;
@@ -2965,6 +2965,15 @@ impl Aggregate {
29652965
.into_iter()
29662966
.map(|(q, f)| (q, f.as_ref().clone().with_nullable(true).into()))
29672967
.collect::<Vec<_>>();
2968+
qualified_fields.push((
2969+
None,
2970+
Field::new(
2971+
Self::INTERNAL_GROUPING_ID,
2972+
Self::grouping_id_type(qualified_fields.len()),
2973+
false,
2974+
)
2975+
.into(),
2976+
));
29682977
}
29692978

29702979
qualified_fields.extend(exprlist_to_fields(aggr_expr.as_slice(), &input)?);
@@ -3016,9 +3025,19 @@ impl Aggregate {
30163025
})
30173026
}
30183027

3028+
fn is_grouping_set(&self) -> bool {
3029+
matches!(self.group_expr.as_slice(), [Expr::GroupingSet(_)])
3030+
}
3031+
30193032
/// Get the output expressions.
30203033
fn output_expressions(&self) -> Result<Vec<&Expr>> {
3034+
static INTERNAL_ID_EXPR: OnceLock<Expr> = OnceLock::new();
30213035
let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?;
3036+
if self.is_grouping_set() {
3037+
exprs.push(INTERNAL_ID_EXPR.get_or_init(|| {
3038+
Expr::Column(Column::from_name(Self::INTERNAL_GROUPING_ID))
3039+
}));
3040+
}
30223041
exprs.extend(self.aggr_expr.iter());
30233042
debug_assert!(exprs.len() == self.schema.fields().len());
30243043
Ok(exprs)
@@ -3030,6 +3049,41 @@ impl Aggregate {
30303049
pub fn group_expr_len(&self) -> Result<usize> {
30313050
grouping_set_expr_count(&self.group_expr)
30323051
}
3052+
3053+
/// Returns the data type of the grouping id.
3054+
/// The grouping ID value is a bitmask where each set bit
3055+
/// indicates that the corresponding grouping expression is
3056+
/// null
3057+
pub fn grouping_id_type(group_exprs: usize) -> DataType {
3058+
if group_exprs <= 8 {
3059+
DataType::UInt8
3060+
} else if group_exprs <= 16 {
3061+
DataType::UInt16
3062+
} else if group_exprs <= 32 {
3063+
DataType::UInt32
3064+
} else {
3065+
DataType::UInt64
3066+
}
3067+
}
3068+
3069+
/// Internal column used when the aggregation is a grouping set.
3070+
///
3071+
/// This column contains a bitmask where each bit represents a grouping
3072+
/// expression. The least significant bit corresponds to the rightmost
3073+
/// grouping expression. A bit value of 0 indicates that the corresponding
3074+
/// column is included in the grouping set, while a value of 1 means it is excluded.
3075+
///
3076+
/// For example, for the grouping expressions CUBE(a, b), the grouping ID
3077+
/// column will have the following values:
3078+
/// 0b00: Both `a` and `b` are included
3079+
/// 0b01: `b` is excluded
3080+
/// 0b10: `a` is excluded
3081+
/// 0b11: Both `a` and `b` are excluded
3082+
///
3083+
/// This internal column is necessary because excluded columns are replaced
3084+
/// with `NULL` values. To handle these cases correctly, we must distinguish
3085+
/// between an actual `NULL` value in a column and a column being excluded from the set.
3086+
pub const INTERNAL_GROUPING_ID: &'static str = "__grouping_id";
30333087
}
30343088

30353089
// Manual implementation needed because of `schema` field. Comparison excludes this field.

datafusion/expr/src/utils.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,17 @@ pub fn exprlist_to_columns(expr: &[Expr], accum: &mut HashSet<Column>) -> Result
6161
/// Count the number of distinct exprs in a list of group by expressions. If the
6262
/// first element is a `GroupingSet` expression then it must be the only expr.
6363
pub fn grouping_set_expr_count(group_expr: &[Expr]) -> Result<usize> {
64-
grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
64+
if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
65+
if group_expr.len() > 1 {
66+
return plan_err!(
67+
"Invalid group by expressions, GroupingSet must be the only expression"
68+
);
69+
}
70+
// Groupings sets have an additional interal column for the grouping id
71+
Ok(grouping_set.distinct_expr().len() + 1)
72+
} else {
73+
grouping_set_to_exprlist(group_expr).map(|exprs| exprs.len())
74+
}
6575
}
6676

6777
/// The [power set] (or powerset) of a set S is the set of all subsets of S, \

datafusion/optimizer/src/single_distinct_to_groupby.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ mod tests {
355355
.build()?;
356356

357357
// Should not be optimized
358-
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
358+
let expected = "Aggregate: groupBy=[[GROUPING SETS ((test.a), (test.b))]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
359359
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
360360

361361
assert_optimized_plan_equal(plan, expected)
@@ -373,7 +373,7 @@ mod tests {
373373
.build()?;
374374

375375
// Should not be optimized
376-
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
376+
let expected = "Aggregate: groupBy=[[CUBE (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
377377
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
378378

379379
assert_optimized_plan_equal(plan, expected)
@@ -392,7 +392,7 @@ mod tests {
392392
.build()?;
393393

394394
// Should not be optimized
395-
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, count(DISTINCT test.c):Int64]\
395+
let expected = "Aggregate: groupBy=[[ROLLUP (test.a, test.b)]], aggr=[[count(DISTINCT test.c)]] [a:UInt32;N, b:UInt32;N, __grouping_id:UInt8, count(DISTINCT test.c):Int64]\
396396
\n TableScan: test [a:UInt32, b:UInt32, c:UInt32]";
397397

398398
assert_optimized_plan_equal(plan, expected)

0 commit comments

Comments
 (0)