@@ -40,6 +40,7 @@ use datafusion_common::tree_node::{
40
40
use datafusion_common:: {
41
41
plan_err, Column , DFSchema , Result , ScalarValue , TableReference ,
42
42
} ;
43
+ use datafusion_functions_window_common:: field:: WindowUDFFieldArgs ;
43
44
use sqlparser:: ast:: {
44
45
display_comma_separated, ExceptSelectItem , ExcludeSelectItem , IlikeSelectItem ,
45
46
NullTreatment , RenameSelectItem , ReplaceSelectElement ,
@@ -706,6 +707,7 @@ impl WindowFunctionDefinition {
706
707
& self ,
707
708
input_expr_types : & [ DataType ] ,
708
709
_input_expr_nullable : & [ bool ] ,
710
+ display_name : & str ,
709
711
) -> Result < DataType > {
710
712
match self {
711
713
WindowFunctionDefinition :: BuiltInWindowFunction ( fun) => {
@@ -714,12 +716,9 @@ impl WindowFunctionDefinition {
714
716
WindowFunctionDefinition :: AggregateUDF ( fun) => {
715
717
fun. return_type ( input_expr_types)
716
718
}
717
- WindowFunctionDefinition :: WindowUDF ( _) => {
718
- // To get the return data type of the result from
719
- // evaluating the user-defined window function instead
720
- // use the `WindowUDF::field` trait method.
721
- unreachable ! ( )
722
- }
719
+ WindowFunctionDefinition :: WindowUDF ( fun) => fun
720
+ . field ( WindowUDFFieldArgs :: new ( input_expr_types, display_name) )
721
+ . map ( |field| field. data_type ( ) . clone ( ) ) ,
723
722
}
724
723
}
725
724
@@ -2558,10 +2557,10 @@ mod test {
2558
2557
#[ test]
2559
2558
fn test_first_value_return_type ( ) -> Result < ( ) > {
2560
2559
let fun = find_df_window_func ( "first_value" ) . unwrap ( ) ;
2561
- let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] ) ?;
2560
+ let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] , "" ) ?;
2562
2561
assert_eq ! ( DataType :: Utf8 , observed) ;
2563
2562
2564
- let observed = fun. return_type ( & [ DataType :: UInt64 ] , & [ true ] ) ?;
2563
+ let observed = fun. return_type ( & [ DataType :: UInt64 ] , & [ true ] , "" ) ?;
2565
2564
assert_eq ! ( DataType :: UInt64 , observed) ;
2566
2565
2567
2566
Ok ( ( ) )
@@ -2570,10 +2569,10 @@ mod test {
2570
2569
#[ test]
2571
2570
fn test_last_value_return_type ( ) -> Result < ( ) > {
2572
2571
let fun = find_df_window_func ( "last_value" ) . unwrap ( ) ;
2573
- let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] ) ?;
2572
+ let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] , "" ) ?;
2574
2573
assert_eq ! ( DataType :: Utf8 , observed) ;
2575
2574
2576
- let observed = fun. return_type ( & [ DataType :: Float64 ] , & [ true ] ) ?;
2575
+ let observed = fun. return_type ( & [ DataType :: Float64 ] , & [ true ] , "" ) ?;
2577
2576
assert_eq ! ( DataType :: Float64 , observed) ;
2578
2577
2579
2578
Ok ( ( ) )
@@ -2582,10 +2581,10 @@ mod test {
2582
2581
#[ test]
2583
2582
fn test_lead_return_type ( ) -> Result < ( ) > {
2584
2583
let fun = find_df_window_func ( "lead" ) . unwrap ( ) ;
2585
- let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] ) ?;
2584
+ let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] , "" ) ?;
2586
2585
assert_eq ! ( DataType :: Utf8 , observed) ;
2587
2586
2588
- let observed = fun. return_type ( & [ DataType :: Float64 ] , & [ true ] ) ?;
2587
+ let observed = fun. return_type ( & [ DataType :: Float64 ] , & [ true ] , "" ) ?;
2589
2588
assert_eq ! ( DataType :: Float64 , observed) ;
2590
2589
2591
2590
Ok ( ( ) )
@@ -2594,10 +2593,10 @@ mod test {
2594
2593
#[ test]
2595
2594
fn test_lag_return_type ( ) -> Result < ( ) > {
2596
2595
let fun = find_df_window_func ( "lag" ) . unwrap ( ) ;
2597
- let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] ) ?;
2596
+ let observed = fun. return_type ( & [ DataType :: Utf8 ] , & [ true ] , "" ) ?;
2598
2597
assert_eq ! ( DataType :: Utf8 , observed) ;
2599
2598
2600
- let observed = fun. return_type ( & [ DataType :: Float64 ] , & [ true ] ) ?;
2599
+ let observed = fun. return_type ( & [ DataType :: Float64 ] , & [ true ] , "" ) ?;
2601
2600
assert_eq ! ( DataType :: Float64 , observed) ;
2602
2601
2603
2602
Ok ( ( ) )
@@ -2607,11 +2606,11 @@ mod test {
2607
2606
fn test_nth_value_return_type ( ) -> Result < ( ) > {
2608
2607
let fun = find_df_window_func ( "nth_value" ) . unwrap ( ) ;
2609
2608
let observed =
2610
- fun. return_type ( & [ DataType :: Utf8 , DataType :: UInt64 ] , & [ true , true ] ) ?;
2609
+ fun. return_type ( & [ DataType :: Utf8 , DataType :: UInt64 ] , & [ true , true ] , "" ) ?;
2611
2610
assert_eq ! ( DataType :: Utf8 , observed) ;
2612
2611
2613
2612
let observed =
2614
- fun. return_type ( & [ DataType :: Float64 , DataType :: UInt64 ] , & [ true , true ] ) ?;
2613
+ fun. return_type ( & [ DataType :: Float64 , DataType :: UInt64 ] , & [ true , true ] , "" ) ?;
2615
2614
assert_eq ! ( DataType :: Float64 , observed) ;
2616
2615
2617
2616
Ok ( ( ) )
@@ -2620,7 +2619,7 @@ mod test {
2620
2619
#[ test]
2621
2620
fn test_percent_rank_return_type ( ) -> Result < ( ) > {
2622
2621
let fun = find_df_window_func ( "percent_rank" ) . unwrap ( ) ;
2623
- let observed = fun. return_type ( & [ ] , & [ ] ) ?;
2622
+ let observed = fun. return_type ( & [ ] , & [ ] , "" ) ?;
2624
2623
assert_eq ! ( DataType :: Float64 , observed) ;
2625
2624
2626
2625
Ok ( ( ) )
@@ -2629,7 +2628,7 @@ mod test {
2629
2628
#[ test]
2630
2629
fn test_cume_dist_return_type ( ) -> Result < ( ) > {
2631
2630
let fun = find_df_window_func ( "cume_dist" ) . unwrap ( ) ;
2632
- let observed = fun. return_type ( & [ ] , & [ ] ) ?;
2631
+ let observed = fun. return_type ( & [ ] , & [ ] , "" ) ?;
2633
2632
assert_eq ! ( DataType :: Float64 , observed) ;
2634
2633
2635
2634
Ok ( ( ) )
@@ -2638,7 +2637,7 @@ mod test {
2638
2637
#[ test]
2639
2638
fn test_ntile_return_type ( ) -> Result < ( ) > {
2640
2639
let fun = find_df_window_func ( "ntile" ) . unwrap ( ) ;
2641
- let observed = fun. return_type ( & [ DataType :: Int16 ] , & [ true ] ) ?;
2640
+ let observed = fun. return_type ( & [ DataType :: Int16 ] , & [ true ] , "" ) ?;
2642
2641
assert_eq ! ( DataType :: UInt64 , observed) ;
2643
2642
2644
2643
Ok ( ( ) )
0 commit comments