Skip to content

Commit 5cc7d06

Browse files
committed
Adds name parameter to WindowFunctionDefinition::return_type
1 parent 045d352 commit 5cc7d06

File tree

2 files changed

+19
-27
lines changed

2 files changed

+19
-27
lines changed

datafusion/expr/src/expr.rs

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ use datafusion_common::tree_node::{
4040
use datafusion_common::{
4141
plan_err, Column, DFSchema, Result, ScalarValue, TableReference,
4242
};
43+
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
4344
use sqlparser::ast::{
4445
display_comma_separated, ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem,
4546
NullTreatment, RenameSelectItem, ReplaceSelectElement,
@@ -706,6 +707,7 @@ impl WindowFunctionDefinition {
706707
&self,
707708
input_expr_types: &[DataType],
708709
_input_expr_nullable: &[bool],
710+
display_name: &str,
709711
) -> Result<DataType> {
710712
match self {
711713
WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
@@ -714,12 +716,9 @@ impl WindowFunctionDefinition {
714716
WindowFunctionDefinition::AggregateUDF(fun) => {
715717
fun.return_type(input_expr_types)
716718
}
717-
WindowFunctionDefinition::WindowUDF(_) => {
718-
// To get the return data type of the result from
719-
// evaluating the user-defined window function instead
720-
// use the `WindowUDF::field` trait method.
721-
unreachable!()
722-
}
719+
WindowFunctionDefinition::WindowUDF(fun) => fun
720+
.field(WindowUDFFieldArgs::new(input_expr_types, display_name))
721+
.map(|field| field.data_type().clone()),
723722
}
724723
}
725724

@@ -2558,10 +2557,10 @@ mod test {
25582557
#[test]
25592558
fn test_first_value_return_type() -> Result<()> {
25602559
let fun = find_df_window_func("first_value").unwrap();
2561-
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
2560+
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
25622561
assert_eq!(DataType::Utf8, observed);
25632562

2564-
let observed = fun.return_type(&[DataType::UInt64], &[true])?;
2563+
let observed = fun.return_type(&[DataType::UInt64], &[true], "")?;
25652564
assert_eq!(DataType::UInt64, observed);
25662565

25672566
Ok(())
@@ -2570,10 +2569,10 @@ mod test {
25702569
#[test]
25712570
fn test_last_value_return_type() -> Result<()> {
25722571
let fun = find_df_window_func("last_value").unwrap();
2573-
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
2572+
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
25742573
assert_eq!(DataType::Utf8, observed);
25752574

2576-
let observed = fun.return_type(&[DataType::Float64], &[true])?;
2575+
let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
25772576
assert_eq!(DataType::Float64, observed);
25782577

25792578
Ok(())
@@ -2582,10 +2581,10 @@ mod test {
25822581
#[test]
25832582
fn test_lead_return_type() -> Result<()> {
25842583
let fun = find_df_window_func("lead").unwrap();
2585-
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
2584+
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
25862585
assert_eq!(DataType::Utf8, observed);
25872586

2588-
let observed = fun.return_type(&[DataType::Float64], &[true])?;
2587+
let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
25892588
assert_eq!(DataType::Float64, observed);
25902589

25912590
Ok(())
@@ -2594,10 +2593,10 @@ mod test {
25942593
#[test]
25952594
fn test_lag_return_type() -> Result<()> {
25962595
let fun = find_df_window_func("lag").unwrap();
2597-
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
2596+
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
25982597
assert_eq!(DataType::Utf8, observed);
25992598

2600-
let observed = fun.return_type(&[DataType::Float64], &[true])?;
2599+
let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
26012600
assert_eq!(DataType::Float64, observed);
26022601

26032602
Ok(())
@@ -2607,11 +2606,11 @@ mod test {
26072606
fn test_nth_value_return_type() -> Result<()> {
26082607
let fun = find_df_window_func("nth_value").unwrap();
26092608
let observed =
2610-
fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true])?;
2609+
fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true], "")?;
26112610
assert_eq!(DataType::Utf8, observed);
26122611

26132612
let observed =
2614-
fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true])?;
2613+
fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true], "")?;
26152614
assert_eq!(DataType::Float64, observed);
26162615

26172616
Ok(())
@@ -2620,7 +2619,7 @@ mod test {
26202619
#[test]
26212620
fn test_percent_rank_return_type() -> Result<()> {
26222621
let fun = find_df_window_func("percent_rank").unwrap();
2623-
let observed = fun.return_type(&[], &[])?;
2622+
let observed = fun.return_type(&[], &[], "")?;
26242623
assert_eq!(DataType::Float64, observed);
26252624

26262625
Ok(())
@@ -2629,7 +2628,7 @@ mod test {
26292628
#[test]
26302629
fn test_cume_dist_return_type() -> Result<()> {
26312630
let fun = find_df_window_func("cume_dist").unwrap();
2632-
let observed = fun.return_type(&[], &[])?;
2631+
let observed = fun.return_type(&[], &[], "")?;
26332632
assert_eq!(DataType::Float64, observed);
26342633

26352634
Ok(())
@@ -2638,7 +2637,7 @@ mod test {
26382637
#[test]
26392638
fn test_ntile_return_type() -> Result<()> {
26402639
let fun = find_df_window_func("ntile").unwrap();
2641-
let observed = fun.return_type(&[DataType::Int16], &[true])?;
2640+
let observed = fun.return_type(&[DataType::Int16], &[true], "")?;
26422641
assert_eq!(DataType::UInt64, observed);
26432642

26442643
Ok(())

datafusion/physical-plan/src/windows/mod.rs

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,7 @@ pub fn schema_add_window_field(
7575
.map(|e| Arc::clone(e).as_ref().nullable(schema))
7676
.collect::<Result<Vec<_>>>()?;
7777
let window_expr_return_type =
78-
if let WindowFunctionDefinition::WindowUDF(udwf) = window_fn {
79-
let field_args = WindowUDFFieldArgs::new(&data_types, fn_name);
80-
81-
udwf.field(field_args)
82-
.map(|field| field.data_type().clone())?
83-
} else {
84-
window_fn.return_type(&data_types, &nullability)?
85-
};
78+
window_fn.return_type(&data_types, &nullability, fn_name)?;
8679
let mut window_fields = schema
8780
.fields()
8881
.iter()

0 commit comments

Comments
 (0)