@@ -28,7 +28,7 @@ use datafusion_common::tree_node::{
28
28
TreeNodeVisitor ,
29
29
} ;
30
30
use datafusion_common:: {
31
- internal_err , qualified_name, Column , DFSchema , DFSchemaRef , DataFusionError , Result ,
31
+ qualified_name, Column , DFSchema , DFSchemaRef , DataFusionError , Result ,
32
32
} ;
33
33
use datafusion_expr:: expr:: Alias ;
34
34
use datafusion_expr:: logical_plan:: { Aggregate , LogicalPlan , Projection , Window } ;
@@ -166,6 +166,15 @@ impl CommonSubexprEliminate {
166
166
) -> Result < ( Vec < Vec < Expr > > , LogicalPlan ) > {
167
167
let mut common_exprs = IndexMap :: new ( ) ;
168
168
169
+ input. schema ( ) . iter ( ) . for_each ( |( qualifier, field) | {
170
+ let name = field. name ( ) ;
171
+ if name. starts_with ( '#' ) {
172
+ common_exprs. insert ( name. clone ( ) , Expr :: from ( ( qualifier, field) ) ) ;
173
+ }
174
+ } ) ;
175
+
176
+ let input_cse_len = common_exprs. len ( ) ;
177
+
169
178
let rewrite_exprs = self . rewrite_exprs_list (
170
179
exprs_list,
171
180
arrays_list,
@@ -176,9 +185,9 @@ impl CommonSubexprEliminate {
176
185
let mut new_input = self
177
186
. try_optimize ( input, config) ?
178
187
. unwrap_or_else ( || input. clone ( ) ) ;
179
- if !common_exprs . is_empty ( ) {
180
- new_input =
181
- build_common_expr_project_plan ( new_input, common_exprs, expr_stats ) ?;
188
+
189
+ if common_exprs . len ( ) > input_cse_len {
190
+ new_input = build_common_expr_project_plan ( new_input, common_exprs) ?;
182
191
}
183
192
184
193
Ok ( ( rewrite_exprs, new_input) )
@@ -517,18 +526,15 @@ fn to_arrays(
517
526
fn build_common_expr_project_plan (
518
527
input : LogicalPlan ,
519
528
common_exprs : CommonExprs ,
520
- expr_stats : & ExprStats ,
521
529
) -> Result < LogicalPlan > {
522
530
let mut fields_set = BTreeSet :: new ( ) ;
523
531
let mut project_exprs = common_exprs
524
532
. into_iter ( )
525
533
. enumerate ( )
526
- . map ( |( index, ( expr_id, expr) ) | {
527
- let Some ( ( _, data_type) ) = expr_stats. get ( & expr_id) else {
528
- return internal_err ! ( "expr_stats invalid state" ) ;
529
- } ;
534
+ . map ( |( index, ( _, expr) ) | {
530
535
let alias = format ! ( "#{}" , index + 1 ) ;
531
- let field = Field :: new ( & alias, data_type. clone ( ) , true ) ;
536
+ let ( dt, nullable) = expr. data_type_and_nullable ( input. schema ( ) ) ?;
537
+ let field = Field :: new ( & alias, dt, nullable) ;
532
538
fields_set. insert ( field. name ( ) . to_owned ( ) ) ;
533
539
Ok ( expr. alias ( alias) )
534
540
} )
@@ -1225,28 +1231,16 @@ mod test {
1225
1231
#[ test]
1226
1232
fn redundant_project_fields ( ) {
1227
1233
let table_scan = test_table_scan ( ) . unwrap ( ) ;
1228
- let expr_stats_1 = ExprStats :: from ( [
1229
- ( "c+a" . to_string ( ) , ( 1 , DataType :: UInt32 ) ) ,
1230
- ( "b+a" . to_string ( ) , ( 1 , DataType :: UInt32 ) ) ,
1231
- ] ) ;
1232
1234
let common_exprs_1 = CommonExprs :: from ( [
1233
1235
( "c+a" . to_string ( ) , col ( "c" ) + col ( "a" ) ) ,
1234
1236
( "b+a" . to_string ( ) , col ( "b" ) + col ( "a" ) ) ,
1235
1237
] ) ;
1236
- let exprs_stats_2 = ExprStats :: from ( [
1237
- ( "c+a" . to_string ( ) , ( 1 , DataType :: UInt32 ) ) ,
1238
- ( "b+a" . to_string ( ) , ( 1 , DataType :: UInt32 ) ) ,
1239
- ] ) ;
1240
1238
let common_exprs_2 = CommonExprs :: from ( [
1241
1239
( "c+a" . to_string ( ) , col ( "#1" ) ) ,
1242
1240
( "b+a" . to_string ( ) , col ( "#2" ) ) ,
1243
1241
] ) ;
1244
- let project =
1245
- build_common_expr_project_plan ( table_scan, common_exprs_1, & expr_stats_1)
1246
- . unwrap ( ) ;
1247
- let project_2 =
1248
- build_common_expr_project_plan ( project, common_exprs_2, & exprs_stats_2)
1249
- . unwrap ( ) ;
1242
+ let project = build_common_expr_project_plan ( table_scan, common_exprs_1) . unwrap ( ) ;
1243
+ let project_2 = build_common_expr_project_plan ( project, common_exprs_2) . unwrap ( ) ;
1250
1244
1251
1245
let mut field_set = BTreeSet :: new ( ) ;
1252
1246
for name in project_2. schema ( ) . field_names ( ) {
@@ -1263,10 +1257,6 @@ mod test {
1263
1257
. unwrap ( )
1264
1258
. build ( )
1265
1259
. unwrap ( ) ;
1266
- let expr_stats_1 = ExprStats :: from ( [
1267
- ( "test1.c+test1.a" . to_string ( ) , ( 1 , DataType :: UInt32 ) ) ,
1268
- ( "test1.b+test1.a" . to_string ( ) , ( 1 , DataType :: UInt32 ) ) ,
1269
- ] ) ;
1270
1260
let common_exprs_1 = CommonExprs :: from ( [
1271
1261
(
1272
1262
"test1.c+test1.a" . to_string ( ) ,
@@ -1277,19 +1267,12 @@ mod test {
1277
1267
col ( "test1.b" ) + col ( "test1.a" ) ,
1278
1268
) ,
1279
1269
] ) ;
1280
- let expr_stats_2 = ExprStats :: from ( [
1281
- ( "test1.c+test1.a" . to_string ( ) , ( 1 , DataType :: UInt32 ) ) ,
1282
- ( "test1.b+test1.a" . to_string ( ) , ( 1 , DataType :: UInt32 ) ) ,
1283
- ] ) ;
1284
1270
let common_exprs_2 = CommonExprs :: from ( [
1285
1271
( "test1.c+test1.a" . to_string ( ) , col ( "#1" ) ) ,
1286
1272
( "test1.b+test1.a" . to_string ( ) , col ( "#2" ) ) ,
1287
1273
] ) ;
1288
- let project =
1289
- build_common_expr_project_plan ( join, common_exprs_1, & expr_stats_1) . unwrap ( ) ;
1290
- let project_2 =
1291
- build_common_expr_project_plan ( project, common_exprs_2, & expr_stats_2)
1292
- . unwrap ( ) ;
1274
+ let project = build_common_expr_project_plan ( join, common_exprs_1) . unwrap ( ) ;
1275
+ let project_2 = build_common_expr_project_plan ( project, common_exprs_2) . unwrap ( ) ;
1293
1276
1294
1277
let mut field_set = BTreeSet :: new ( ) ;
1295
1278
for name in project_2. schema ( ) . field_names ( ) {
@@ -1402,6 +1385,30 @@ mod test {
1402
1385
Ok ( ( ) )
1403
1386
}
1404
1387
1388
+ #[ test]
1389
+ fn test_alias_collision ( ) -> Result < ( ) > {
1390
+ let table_scan = test_table_scan ( ) ?;
1391
+
1392
+ let plan = LogicalPlanBuilder :: from ( table_scan. clone ( ) )
1393
+ . project ( vec ! [ ( col( "a" ) + col( "b" ) ) . alias( "#1" ) , col( "c" ) ] ) ?
1394
+ . project ( vec ! [
1395
+ col( "#1" ) . alias( "c1" ) ,
1396
+ col( "#1" ) . alias( "c2" ) ,
1397
+ ( col( "c" ) + lit( 2 ) ) . alias( "c3" ) ,
1398
+ ( col( "c" ) + lit( 2 ) ) . alias( "c4" ) ,
1399
+ ] ) ?
1400
+ . build ( ) ?;
1401
+
1402
+ let expected = "Projection: #1 AS c1, #1 AS c2, #2 AS c3, #2 AS c4\
1403
+ \n Projection: #1 AS #1, test.c + Int32(2) AS #2, test.c\
1404
+ \n Projection: test.a + test.b AS #1, test.c\
1405
+ \n TableScan: test";
1406
+
1407
+ assert_optimized_plan_eq ( expected, & plan) ;
1408
+
1409
+ Ok ( ( ) )
1410
+ }
1411
+
1405
1412
#[ test]
1406
1413
fn test_extract_expressions_from_col ( ) -> Result < ( ) > {
1407
1414
let mut result = Vec :: with_capacity ( 1 ) ;
0 commit comments