Skip to content

Commit cd36ee3

Browse files
jonahgaoalamb
andauthored
fix: make columnize_expr resistant to display_name collisions (#10459)
* fix: make `columnize_expr` resistant to display_name collisions * fix simple_window_function test * remove Projection * add tests * retry ci * fix DataFrame tests * Update datafusion/expr/src/logical_plan/plan.rs Co-authored-by: Andrew Lamb <[email protected]> * Remove copies --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent b8fab5c commit cd36ee3

File tree

10 files changed

+168
-86
lines changed

10 files changed

+168
-86
lines changed

datafusion/core/src/dataframe/mod.rs

+61
Original file line numberDiff line numberDiff line change
@@ -2045,6 +2045,67 @@ mod tests {
20452045
Ok(())
20462046
}
20472047

2048+
#[tokio::test]
2049+
async fn test_aggregate_subexpr() -> Result<()> {
2050+
let df = test_table().await?;
2051+
2052+
let group_expr = col("c2") + lit(1);
2053+
let aggr_expr = sum(col("c3") + lit(2));
2054+
2055+
let df = df
2056+
// GROUP BY `c2 + 1`
2057+
.aggregate(vec![group_expr.clone()], vec![aggr_expr.clone()])?
2058+
// SELECT `c2 + 1` as c2 + 10, sum(c3 + 2) + 20
2059+
// SELECT expressions contain aggr_expr and group_expr as subexpressions
2060+
.select(vec![
2061+
group_expr.alias("c2") + lit(10),
2062+
(aggr_expr + lit(20)).alias("sum"),
2063+
])?;
2064+
2065+
let df_results = df.collect().await?;
2066+
2067+
#[rustfmt::skip]
2068+
assert_batches_sorted_eq!([
2069+
"+----------------+------+",
2070+
"| c2 + Int32(10) | sum |",
2071+
"+----------------+------+",
2072+
"| 12 | 431 |",
2073+
"| 13 | 248 |",
2074+
"| 14 | 453 |",
2075+
"| 15 | 95 |",
2076+
"| 16 | -146 |",
2077+
"+----------------+------+",
2078+
],
2079+
&df_results
2080+
);
2081+
2082+
Ok(())
2083+
}
2084+
2085+
#[tokio::test]
2086+
async fn test_aggregate_name_collision() -> Result<()> {
2087+
let df = test_table().await?;
2088+
2089+
let collided_alias = "aggregate_test_100.c2 + aggregate_test_100.c3";
2090+
let group_expr = lit(1).alias(collided_alias);
2091+
2092+
let df = df
2093+
// GROUP BY 1
2094+
.aggregate(vec![group_expr], vec![])?
2095+
// SELECT `aggregate_test_100.c2 + aggregate_test_100.c3`
2096+
.select(vec![
2097+
(col("aggregate_test_100.c2") + col("aggregate_test_100.c3")),
2098+
])
2099+
// The select expr has the same display_name as the group_expr,
2100+
// but since they are different expressions, it should fail.
2101+
.expect_err("Expected error");
2102+
let expected = "Schema error: No field named aggregate_test_100.c2. \
2103+
Valid fields are \"aggregate_test_100.c2 + aggregate_test_100.c3\".";
2104+
assert_eq!(df.strip_backtrace(), expected);
2105+
2106+
Ok(())
2107+
}
2108+
20482109
// Test issue: https://github.com/apache/datafusion/issues/10346
20492110
#[tokio::test]
20502111
async fn test_select_over_aggregate_schema() -> Result<()> {

datafusion/core/tests/dataframe/mod.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> {
210210
let sql_results = ctx
211211
.sql("select count(*) from t1")
212212
.await?
213-
.select(vec![count(wildcard())])?
213+
.select(vec![col("COUNT(*)")])?
214214
.explain(false, false)?
215215
.collect()
216216
.await?;

datafusion/core/tests/user_defined/user_defined_table_functions.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ impl SimpleCsvTable {
156156
let logical_plan = Projection::try_new(
157157
vec![columnize_expr(
158158
normalize_col(self.exprs[0].clone(), &plan)?,
159-
plan.schema(),
160-
)],
159+
&plan,
160+
)?],
161161
Arc::new(plan),
162162
)
163163
.map(LogicalPlan::Projection)?;

datafusion/expr/src/expr.rs

+7-6
Original file line numberDiff line numberDiff line change
@@ -833,15 +833,16 @@ impl GroupingSet {
833833
/// Return all distinct exprs in the grouping set. For `CUBE` and `ROLLUP` this
834834
/// is just the underlying list of exprs. For `GROUPING SET` we need to deduplicate
835835
/// the exprs in the underlying sets.
836-
pub fn distinct_expr(&self) -> Vec<Expr> {
836+
pub fn distinct_expr(&self) -> Vec<&Expr> {
837837
match self {
838-
GroupingSet::Rollup(exprs) => exprs.clone(),
839-
GroupingSet::Cube(exprs) => exprs.clone(),
838+
GroupingSet::Rollup(exprs) | GroupingSet::Cube(exprs) => {
839+
exprs.iter().collect()
840+
}
840841
GroupingSet::GroupingSets(groups) => {
841-
let mut exprs: Vec<Expr> = vec![];
842+
let mut exprs: Vec<&Expr> = vec![];
842843
for exp in groups.iter().flatten() {
843-
if !exprs.contains(exp) {
844-
exprs.push(exp.clone());
844+
if !exprs.contains(&exp) {
845+
exprs.push(exp);
845846
}
846847
}
847848
exprs

datafusion/expr/src/logical_plan/builder.rs

+1-2
Original file line numberDiff line numberDiff line change
@@ -1435,8 +1435,7 @@ pub fn project(
14351435
input_schema,
14361436
None,
14371437
)?),
1438-
_ => projected_expr
1439-
.push(columnize_expr(normalize_col(e, &plan)?, input_schema)),
1438+
_ => projected_expr.push(columnize_expr(normalize_col(e, &plan)?, &plan)?),
14401439
}
14411440
}
14421441

datafusion/expr/src/logical_plan/plan.rs

+49-2
Original file line numberDiff line numberDiff line change
@@ -1214,6 +1214,45 @@ impl LogicalPlan {
12141214
.unwrap();
12151215
contains
12161216
}
1217+
1218+
/// Get the output expressions and their corresponding columns.
1219+
///
1220+
/// The parent node may reference the output columns of the plan by expressions, such as
1221+
/// projection over aggregate or window functions. This method helps to convert the
1222+
/// referenced expressions into columns.
1223+
///
1224+
/// See also: [`crate::utils::columnize_expr`]
1225+
pub(crate) fn columnized_output_exprs(&self) -> Result<Vec<(&Expr, Column)>> {
1226+
match self {
1227+
LogicalPlan::Aggregate(aggregate) => Ok(aggregate
1228+
.output_expressions()?
1229+
.into_iter()
1230+
.zip(self.schema().columns())
1231+
.collect()),
1232+
LogicalPlan::Window(Window {
1233+
window_expr,
1234+
input,
1235+
schema,
1236+
}) => {
1237+
// The input could be another Window, so the result should also include the input's. For Example:
1238+
// `EXPLAIN SELECT RANK() OVER (PARTITION BY a ORDER BY b), SUM(b) OVER (PARTITION BY a) FROM t`
1239+
// Its plan is:
1240+
// Projection: RANK() PARTITION BY [t.a] ORDER BY [t.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW, SUM(t.b) PARTITION BY [t.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING
1241+
// WindowAggr: windowExpr=[[SUM(CAST(t.b AS Int64)) PARTITION BY [t.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]]
1242+
// WindowAggr: windowExpr=[[RANK() PARTITION BY [t.a] ORDER BY [t.b ASC NULLS LAST] RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW]]/
1243+
// TableScan: t projection=[a, b]
1244+
let mut output_exprs = input.columnized_output_exprs()?;
1245+
let input_len = input.schema().fields().len();
1246+
output_exprs.extend(
1247+
window_expr
1248+
.iter()
1249+
.zip(schema.columns().into_iter().skip(input_len)),
1250+
);
1251+
Ok(output_exprs)
1252+
}
1253+
_ => Ok(vec![]),
1254+
}
1255+
}
12171256
}
12181257

12191258
impl LogicalPlan {
@@ -2480,9 +2519,9 @@ impl Aggregate {
24802519

24812520
let is_grouping_set = matches!(group_expr.as_slice(), [Expr::GroupingSet(_)]);
24822521

2483-
let grouping_expr: Vec<Expr> = grouping_set_to_exprlist(group_expr.as_slice())?;
2522+
let grouping_expr: Vec<&Expr> = grouping_set_to_exprlist(group_expr.as_slice())?;
24842523

2485-
let mut qualified_fields = exprlist_to_fields(grouping_expr.as_slice(), &input)?;
2524+
let mut qualified_fields = exprlist_to_fields(grouping_expr, &input)?;
24862525

24872526
// Even columns that cannot be null will become nullable when used in a grouping set.
24882527
if is_grouping_set {
@@ -2538,6 +2577,14 @@ impl Aggregate {
25382577
})
25392578
}
25402579

2580+
/// Get the output expressions.
2581+
fn output_expressions(&self) -> Result<Vec<&Expr>> {
2582+
let mut exprs = grouping_set_to_exprlist(self.group_expr.as_slice())?;
2583+
exprs.extend(self.aggr_expr.iter());
2584+
debug_assert!(exprs.len() == self.schema.fields().len());
2585+
Ok(exprs)
2586+
}
2587+
25412588
/// Get the length of the group by expression in the output schema
25422589
/// This is not simply group by expression length. Expression may be
25432590
/// GroupingSet, etc. In these case we need to get inner expression lengths.

datafusion/expr/src/utils.rs

+28-40
Original file line numberDiff line numberDiff line change
@@ -18,19 +18,20 @@
1818
//! Expression utilities
1919
2020
use std::cmp::Ordering;
21-
use std::collections::HashSet;
21+
use std::collections::{HashMap, HashSet};
2222
use std::sync::Arc;
2323

2424
use crate::expr::{Alias, Sort, WindowFunction};
2525
use crate::expr_rewriter::strip_outer_reference;
2626
use crate::signature::{Signature, TypeSignature};
2727
use crate::{
28-
and, BinaryExpr, Cast, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan,
29-
Operator, TryCast,
28+
and, BinaryExpr, Expr, ExprSchemable, Filter, GroupingSet, LogicalPlan, Operator,
3029
};
3130

3231
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
33-
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
32+
use datafusion_common::tree_node::{
33+
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
34+
};
3435
use datafusion_common::utils::get_at_indices;
3536
use datafusion_common::{
3637
internal_err, plan_datafusion_err, plan_err, Column, DFSchema, DFSchemaRef, Result,
@@ -247,7 +248,7 @@ pub fn enumerate_grouping_sets(group_expr: Vec<Expr>) -> Result<Vec<Expr>> {
247248

248249
/// Find all distinct exprs in a list of group by expressions. If the
249250
/// first element is a `GroupingSet` expression then it must be the only expr.
250-
pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<Expr>> {
251+
pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<&Expr>> {
251252
if let Some(Expr::GroupingSet(grouping_set)) = group_expr.first() {
252253
if group_expr.len() > 1 {
253254
return plan_err!(
@@ -256,7 +257,7 @@ pub fn grouping_set_to_exprlist(group_expr: &[Expr]) -> Result<Vec<Expr>> {
256257
}
257258
Ok(grouping_set.distinct_expr())
258259
} else {
259-
Ok(group_expr.to_vec())
260+
Ok(group_expr.iter().collect())
260261
}
261262
}
262263

@@ -725,13 +726,16 @@ pub fn from_plan(
725726
}
726727

727728
/// Create field meta-data from an expression, for use in a result set schema
728-
pub fn exprlist_to_fields(
729-
exprs: &[Expr],
729+
pub fn exprlist_to_fields<'a>(
730+
exprs: impl IntoIterator<Item = &'a Expr>,
730731
plan: &LogicalPlan,
731732
) -> Result<Vec<(Option<TableReference>, Arc<Field>)>> {
732733
// look for exact match in plan's output schema
733734
let input_schema = &plan.schema();
734-
exprs.iter().map(|e| e.to_field(input_schema)).collect()
735+
exprs
736+
.into_iter()
737+
.map(|e| e.to_field(input_schema))
738+
.collect()
735739
}
736740

737741
/// Convert an expression into Column expression if it's already provided as input plan.
@@ -749,37 +753,21 @@ pub fn exprlist_to_fields(
749753
/// .aggregate(vec![col("c1")], vec![sum(col("c2"))])?
750754
/// .project(vec![col("c1"), col("SUM(c2)")?
751755
/// ```
752-
pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr {
753-
match e {
754-
Expr::Column(_) => e,
755-
Expr::OuterReferenceColumn(_, _) => e,
756-
Expr::Alias(Alias {
757-
expr,
758-
relation,
759-
name,
760-
}) => columnize_expr(*expr, input_schema).alias_qualified(relation, name),
761-
Expr::Cast(Cast { expr, data_type }) => Expr::Cast(Cast {
762-
expr: Box::new(columnize_expr(*expr, input_schema)),
763-
data_type,
764-
}),
765-
Expr::TryCast(TryCast { expr, data_type }) => Expr::TryCast(TryCast::new(
766-
Box::new(columnize_expr(*expr, input_schema)),
767-
data_type,
756+
pub fn columnize_expr(e: Expr, input: &LogicalPlan) -> Result<Expr> {
757+
let output_exprs = match input.columnized_output_exprs() {
758+
Ok(exprs) if !exprs.is_empty() => exprs,
759+
_ => return Ok(e),
760+
};
761+
let exprs_map: HashMap<&Expr, Column> = output_exprs.into_iter().collect();
762+
e.transform_down(|node: Expr| match exprs_map.get(&node) {
763+
Some(column) => Ok(Transformed::new(
764+
Expr::Column(column.clone()),
765+
true,
766+
TreeNodeRecursion::Jump,
768767
)),
769-
Expr::ScalarSubquery(_) => e.clone(),
770-
_ => match e.display_name() {
771-
Ok(name) => {
772-
match input_schema.qualified_field_with_unqualified_name(&name) {
773-
Ok((qualifier, field)) => {
774-
Expr::Column(Column::from((qualifier, field)))
775-
}
776-
// expression not provided as input, do not convert to a column reference
777-
Err(_) => e,
778-
}
779-
}
780-
Err(_) => e,
781-
},
782-
}
768+
None => Ok(Transformed::no(node)),
769+
})
770+
.data()
783771
}
784772

785773
/// Collect all deeply nested `Expr::Column`'s. They are returned in order of
@@ -1235,7 +1223,7 @@ mod tests {
12351223
use super::*;
12361224
use crate::{
12371225
col, cube, expr, expr_vec_fmt, grouping_set, lit, rollup, AggregateFunction,
1238-
WindowFrame, WindowFunctionDefinition,
1226+
Cast, WindowFrame, WindowFunctionDefinition,
12391227
};
12401228

12411229
#[test]

datafusion/optimizer/src/common_subexpr_eliminate.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1154,12 +1154,12 @@ mod test {
11541154
let table_scan = test_table_scan()?;
11551155

11561156
let plan = LogicalPlanBuilder::from(table_scan)
1157-
.project(vec![lit(1) + col("a")])?
1157+
.project(vec![lit(1) + col("a"), col("a")])?
11581158
.project(vec![lit(1) + col("a")])?
11591159
.build()?;
11601160

11611161
let expected = "Projection: Int32(1) + test.a\
1162-
\n Projection: Int32(1) + test.a\
1162+
\n Projection: Int32(1) + test.a, test.a\
11631163
\n TableScan: test";
11641164

11651165
assert_optimized_plan_eq(expected, &plan);

0 commit comments

Comments
 (0)