@@ -21,11 +21,11 @@ use crate::{OptimizerConfig, OptimizerRule};
21
21
22
22
use datafusion_common:: tree_node:: Transformed ;
23
23
use datafusion_common:: { internal_err, Column , Result } ;
24
+ use datafusion_expr:: expr:: AggregateFunction ;
24
25
use datafusion_expr:: expr_rewriter:: normalize_cols;
25
26
use datafusion_expr:: utils:: expand_wildcard;
26
27
use datafusion_expr:: { col, LogicalPlanBuilder } ;
27
28
use datafusion_expr:: { Aggregate , Distinct , DistinctOn , Expr , LogicalPlan } ;
28
- use datafusion_functions_aggregate:: first_last:: first_value;
29
29
30
30
/// Optimizer that replaces logical [[Distinct]] with a logical [[Aggregate]]
31
31
///
@@ -73,7 +73,7 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
73
73
fn rewrite (
74
74
& self ,
75
75
plan : LogicalPlan ,
76
- _config : & dyn OptimizerConfig ,
76
+ config : & dyn OptimizerConfig ,
77
77
) -> Result < Transformed < LogicalPlan > > {
78
78
match plan {
79
79
LogicalPlan :: Distinct ( Distinct :: All ( input) ) => {
@@ -95,9 +95,18 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
95
95
let expr_cnt = on_expr. len ( ) ;
96
96
97
97
// Construct the aggregation expression to be used to fetch the selected expressions.
98
- let aggr_expr = select_expr
99
- . into_iter ( )
100
- . map ( |e| first_value ( vec ! [ e] , false , None , sort_expr. clone ( ) , None ) ) ;
98
+ let first_value_udaf =
99
+ config. function_registry ( ) . unwrap ( ) . udaf ( "first_value" ) ?;
100
+ let aggr_expr = select_expr. into_iter ( ) . map ( |e| {
101
+ Expr :: AggregateFunction ( AggregateFunction :: new_udf (
102
+ first_value_udaf. clone ( ) ,
103
+ vec ! [ e] ,
104
+ false ,
105
+ None ,
106
+ sort_expr. clone ( ) ,
107
+ None ,
108
+ ) )
109
+ } ) ;
101
110
102
111
let aggr_expr = normalize_cols ( aggr_expr, input. as_ref ( ) ) ?;
103
112
let group_expr = normalize_cols ( on_expr, input. as_ref ( ) ) ?;
@@ -163,53 +172,3 @@ impl OptimizerRule for ReplaceDistinctWithAggregate {
163
172
Some ( BottomUp )
164
173
}
165
174
}
166
-
167
- #[ cfg( test) ]
168
- mod tests {
169
- use crate :: replace_distinct_aggregate:: ReplaceDistinctWithAggregate ;
170
- use crate :: test:: { assert_optimized_plan_eq, test_table_scan} ;
171
- use datafusion_expr:: { col, LogicalPlanBuilder } ;
172
- use std:: sync:: Arc ;
173
-
174
- #[ test]
175
- fn replace_distinct ( ) -> datafusion_common:: Result < ( ) > {
176
- let table_scan = test_table_scan ( ) . unwrap ( ) ;
177
- let plan = LogicalPlanBuilder :: from ( table_scan)
178
- . project ( vec ! [ col( "a" ) , col( "b" ) ] ) ?
179
- . distinct ( ) ?
180
- . build ( ) ?;
181
-
182
- let expected = "Aggregate: groupBy=[[test.a, test.b]], aggr=[[]]\
183
- \n Projection: test.a, test.b\
184
- \n TableScan: test";
185
-
186
- assert_optimized_plan_eq (
187
- Arc :: new ( ReplaceDistinctWithAggregate :: new ( ) ) ,
188
- plan,
189
- expected,
190
- )
191
- }
192
-
193
- #[ test]
194
- fn replace_distinct_on ( ) -> datafusion_common:: Result < ( ) > {
195
- let table_scan = test_table_scan ( ) . unwrap ( ) ;
196
- let plan = LogicalPlanBuilder :: from ( table_scan)
197
- . distinct_on (
198
- vec ! [ col( "a" ) ] ,
199
- vec ! [ col( "b" ) ] ,
200
- Some ( vec ! [ col( "a" ) . sort( false , true ) , col( "c" ) . sort( true , false ) ] ) ,
201
- ) ?
202
- . build ( ) ?;
203
-
204
- let expected = "Projection: first_value(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST] AS b\
205
- \n Sort: test.a DESC NULLS FIRST\
206
- \n Aggregate: groupBy=[[test.a]], aggr=[[first_value(test.b) ORDER BY [test.a DESC NULLS FIRST, test.c ASC NULLS LAST]]]\
207
- \n TableScan: test";
208
-
209
- assert_optimized_plan_eq (
210
- Arc :: new ( ReplaceDistinctWithAggregate :: new ( ) ) ,
211
- plan,
212
- expected,
213
- )
214
- }
215
- }
0 commit comments