Skip to content

Commit e1b992a

Browse files
authored
Add field trait method to WindowUDFImpl, remove return_type/nullable (#12374)
* Adds new library `functions-window-common` * Adds `FieldArgs` struct for field of final result * Adds `field` method to `WindowUDFImpl` trait * Minor: fixes formatting * Fixes: udwf doc test * Fixes: implements missing trait items * Updates `datafusion-cli` dependencies * Fixes: formatting of `Cargo.toml` files * Fixes: implementation of `field` in udwf example * Pass `FieldArgs` argument to `field` * Use `field` in place of `return_type` for udwf * Update `field` in udwf implementations * Fixes: implementation of `field` in udwf example * Revert unrelated change * Mark `return_type` for udwf as unreachable * Delete code * Uses schema name of udwf to construct `FieldArgs` * Adds deprecated notice to `return_type` trait method * Add doc comments to `field` trait method * Reify `input_types` when creating the udwf window expression * Rename name field to `schema_name` in `FieldArgs` * Make `FieldArgs` opaque * Minor refactor * Removes `nullable` trait method from `WindowUDFImpl` * Add doc comments * Rename to `WindowUDFResultArgs` * Minor: fixes formatting * Copy edits for doc comments * Renames field to `function_name` * Rename struct to `WindowUDFFieldArgs` * Add comments for unreachable code * Copy edit for `WindowUDFImpl::field` trait method * Renames module * Fix warning: unused doc comment * Minor: rename bindings * Minor refactor * Minor: copy edit * Fixes: use `Expr::qualified_name` for window function name * Fixes: apply previous fix to `Expr::nullable` * Refactor: reuse type coercion for window functions * Fixes: clippy errors * Adds name parameter to `WindowFunctionDefinition::return_type` * Removes `return_type` field from `SimpleWindowUDF` * Add doc comment for helper method * Rewrite doc comments * Minor: remove empty comment * Remove `WindowUDFImpl::return_type` * Fixes doc test
1 parent d9cb6e6 commit e1b992a

File tree

24 files changed

+357
-168
lines changed

24 files changed

+357
-168
lines changed

Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ members = [
3131
"datafusion/functions-aggregate-common",
3232
"datafusion/functions-nested",
3333
"datafusion/functions-window",
34+
"datafusion/functions-window-common",
3435
"datafusion/optimizer",
3536
"datafusion/physical-expr",
3637
"datafusion/physical-expr-common",
@@ -103,6 +104,7 @@ datafusion-functions-aggregate = { path = "datafusion/functions-aggregate", vers
103104
datafusion-functions-aggregate-common = { path = "datafusion/functions-aggregate-common", version = "42.0.0" }
104105
datafusion-functions-nested = { path = "datafusion/functions-nested", version = "42.0.0" }
105106
datafusion-functions-window = { path = "datafusion/functions-window", version = "42.0.0" }
107+
datafusion-functions-window-common = { path = "datafusion/functions-window-common", version = "42.0.0" }
106108
datafusion-optimizer = { path = "datafusion/optimizer", version = "42.0.0", default-features = false }
107109
datafusion-physical-expr = { path = "datafusion/physical-expr", version = "42.0.0", default-features = false }
108110
datafusion-physical-expr-common = { path = "datafusion/physical-expr-common", version = "42.0.0", default-features = false }

datafusion-cli/Cargo.lock

Lines changed: 10 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-examples/examples/advanced_udwf.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ use arrow::{
2222
array::{ArrayRef, AsArray, Float64Array},
2323
datatypes::Float64Type,
2424
};
25+
use arrow_schema::Field;
2526
use datafusion::error::Result;
2627
use datafusion::prelude::*;
2728
use datafusion_common::ScalarValue;
29+
use datafusion_expr::function::WindowUDFFieldArgs;
2830
use datafusion_expr::{
2931
PartitionEvaluator, Signature, WindowFrame, WindowUDF, WindowUDFImpl,
3032
};
@@ -70,16 +72,15 @@ impl WindowUDFImpl for SmoothItUdf {
7072
&self.signature
7173
}
7274

73-
/// What is the type of value that will be returned by this function.
74-
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
75-
Ok(DataType::Float64)
76-
}
77-
7875
/// Create a `PartitionEvaluator` to evaluate this function on a new
7976
/// partition.
8077
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
8178
Ok(Box::new(MyPartitionEvaluator::new()))
8279
}
80+
81+
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
82+
Ok(Field::new(field_args.name(), DataType::Float64, true))
83+
}
8384
}
8485

8586
/// This implements the lowest level evaluation for a window function

datafusion-examples/examples/simplify_udwf_expression.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@
1717

1818
use std::any::Any;
1919

20-
use arrow_schema::DataType;
20+
use arrow_schema::{DataType, Field};
2121

2222
use datafusion::execution::context::SessionContext;
2323
use datafusion::functions_aggregate::average::avg_udaf;
2424
use datafusion::{error::Result, execution::options::CsvReadOptions};
25-
use datafusion_expr::function::WindowFunctionSimplification;
25+
use datafusion_expr::function::{WindowFunctionSimplification, WindowUDFFieldArgs};
2626
use datafusion_expr::{
2727
expr::WindowFunction, simplify::SimplifyInfo, Expr, PartitionEvaluator, Signature,
2828
Volatility, WindowUDF, WindowUDFImpl,
@@ -60,10 +60,6 @@ impl WindowUDFImpl for SimplifySmoothItUdf {
6060
&self.signature
6161
}
6262

63-
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
64-
Ok(DataType::Float64)
65-
}
66-
6763
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
6864
todo!()
6965
}
@@ -84,6 +80,10 @@ impl WindowUDFImpl for SimplifySmoothItUdf {
8480

8581
Some(Box::new(simplify))
8682
}
83+
84+
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
85+
Ok(Field::new(field_args.name(), DataType::Float64, true))
86+
}
8787
}
8888

8989
// create local execution context with `cars.csv` registered as a table named `cars`

datafusion/core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ bigdecimal = { workspace = true }
145145
criterion = { version = "0.5", features = ["async_tokio"] }
146146
csv = "1.1.6"
147147
ctor = { workspace = true }
148+
datafusion-functions-window-common = { workspace = true }
148149
doc-comment = { workspace = true }
149150
env_logger = { workspace = true }
150151
half = { workspace = true, default-features = true }

datafusion/core/tests/user_defined/user_defined_window_functions.rs

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,13 @@ use std::{
2929

3030
use arrow::array::AsArray;
3131
use arrow_array::{ArrayRef, Int64Array, RecordBatch, StringArray};
32-
use arrow_schema::DataType;
32+
use arrow_schema::{DataType, Field};
3333
use datafusion::{assert_batches_eq, prelude::SessionContext};
3434
use datafusion_common::{Result, ScalarValue};
3535
use datafusion_expr::{
3636
PartitionEvaluator, Signature, Volatility, WindowUDF, WindowUDFImpl,
3737
};
38+
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
3839

3940
/// A query with a window function evaluated over the entire partition
4041
const UNBOUNDED_WINDOW_QUERY: &str = "SELECT x, y, val, \
@@ -522,7 +523,6 @@ impl OddCounter {
522523
#[derive(Debug, Clone)]
523524
struct SimpleWindowUDF {
524525
signature: Signature,
525-
return_type: DataType,
526526
test_state: Arc<TestState>,
527527
aliases: Vec<String>,
528528
}
@@ -531,10 +531,8 @@ impl OddCounter {
531531
fn new(test_state: Arc<TestState>) -> Self {
532532
let signature =
533533
Signature::exact(vec![DataType::Float64], Volatility::Immutable);
534-
let return_type = DataType::Int64;
535534
Self {
536535
signature,
537-
return_type,
538536
test_state,
539537
aliases: vec!["odd_counter_alias".to_string()],
540538
}
@@ -554,17 +552,17 @@ impl OddCounter {
554552
&self.signature
555553
}
556554

557-
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
558-
Ok(self.return_type.clone())
559-
}
560-
561555
fn partition_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
562556
Ok(Box::new(OddCounter::new(Arc::clone(&self.test_state))))
563557
}
564558

565559
fn aliases(&self) -> &[String] {
566560
&self.aliases
567561
}
562+
563+
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
564+
Ok(Field::new(field_args.name(), DataType::Int64, true))
565+
}
568566
}
569567

570568
ctx.register_udwf(WindowUDF::from(SimpleWindowUDF::new(test_state)))

datafusion/expr/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ chrono = { workspace = true }
4646
datafusion-common = { workspace = true }
4747
datafusion-expr-common = { workspace = true }
4848
datafusion-functions-aggregate-common = { workspace = true }
49+
datafusion-functions-window-common = { workspace = true }
4950
datafusion-physical-expr-common = { workspace = true }
5051
paste = "^1.0"
5152
serde_json = { workspace = true }

datafusion/expr/src/expr.rs

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ use datafusion_common::tree_node::{
4040
use datafusion_common::{
4141
plan_err, Column, DFSchema, Result, ScalarValue, TableReference,
4242
};
43+
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
4344
use sqlparser::ast::{
4445
display_comma_separated, ExceptSelectItem, ExcludeSelectItem, IlikeSelectItem,
4546
NullTreatment, RenameSelectItem, ReplaceSelectElement,
@@ -706,6 +707,7 @@ impl WindowFunctionDefinition {
706707
&self,
707708
input_expr_types: &[DataType],
708709
_input_expr_nullable: &[bool],
710+
display_name: &str,
709711
) -> Result<DataType> {
710712
match self {
711713
WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
@@ -714,7 +716,9 @@ impl WindowFunctionDefinition {
714716
WindowFunctionDefinition::AggregateUDF(fun) => {
715717
fun.return_type(input_expr_types)
716718
}
717-
WindowFunctionDefinition::WindowUDF(fun) => fun.return_type(input_expr_types),
719+
WindowFunctionDefinition::WindowUDF(fun) => fun
720+
.field(WindowUDFFieldArgs::new(input_expr_types, display_name))
721+
.map(|field| field.data_type().clone()),
718722
}
719723
}
720724

@@ -2536,10 +2540,10 @@ mod test {
25362540
#[test]
25372541
fn test_first_value_return_type() -> Result<()> {
25382542
let fun = find_df_window_func("first_value").unwrap();
2539-
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
2543+
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
25402544
assert_eq!(DataType::Utf8, observed);
25412545

2542-
let observed = fun.return_type(&[DataType::UInt64], &[true])?;
2546+
let observed = fun.return_type(&[DataType::UInt64], &[true], "")?;
25432547
assert_eq!(DataType::UInt64, observed);
25442548

25452549
Ok(())
@@ -2548,10 +2552,10 @@ mod test {
25482552
#[test]
25492553
fn test_last_value_return_type() -> Result<()> {
25502554
let fun = find_df_window_func("last_value").unwrap();
2551-
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
2555+
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
25522556
assert_eq!(DataType::Utf8, observed);
25532557

2554-
let observed = fun.return_type(&[DataType::Float64], &[true])?;
2558+
let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
25552559
assert_eq!(DataType::Float64, observed);
25562560

25572561
Ok(())
@@ -2560,10 +2564,10 @@ mod test {
25602564
#[test]
25612565
fn test_lead_return_type() -> Result<()> {
25622566
let fun = find_df_window_func("lead").unwrap();
2563-
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
2567+
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
25642568
assert_eq!(DataType::Utf8, observed);
25652569

2566-
let observed = fun.return_type(&[DataType::Float64], &[true])?;
2570+
let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
25672571
assert_eq!(DataType::Float64, observed);
25682572

25692573
Ok(())
@@ -2572,10 +2576,10 @@ mod test {
25722576
#[test]
25732577
fn test_lag_return_type() -> Result<()> {
25742578
let fun = find_df_window_func("lag").unwrap();
2575-
let observed = fun.return_type(&[DataType::Utf8], &[true])?;
2579+
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
25762580
assert_eq!(DataType::Utf8, observed);
25772581

2578-
let observed = fun.return_type(&[DataType::Float64], &[true])?;
2582+
let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
25792583
assert_eq!(DataType::Float64, observed);
25802584

25812585
Ok(())
@@ -2585,11 +2589,11 @@ mod test {
25852589
fn test_nth_value_return_type() -> Result<()> {
25862590
let fun = find_df_window_func("nth_value").unwrap();
25872591
let observed =
2588-
fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true])?;
2592+
fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true], "")?;
25892593
assert_eq!(DataType::Utf8, observed);
25902594

25912595
let observed =
2592-
fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true])?;
2596+
fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true], "")?;
25932597
assert_eq!(DataType::Float64, observed);
25942598

25952599
Ok(())
@@ -2598,7 +2602,7 @@ mod test {
25982602
#[test]
25992603
fn test_percent_rank_return_type() -> Result<()> {
26002604
let fun = find_df_window_func("percent_rank").unwrap();
2601-
let observed = fun.return_type(&[], &[])?;
2605+
let observed = fun.return_type(&[], &[], "")?;
26022606
assert_eq!(DataType::Float64, observed);
26032607

26042608
Ok(())
@@ -2607,7 +2611,7 @@ mod test {
26072611
#[test]
26082612
fn test_cume_dist_return_type() -> Result<()> {
26092613
let fun = find_df_window_func("cume_dist").unwrap();
2610-
let observed = fun.return_type(&[], &[])?;
2614+
let observed = fun.return_type(&[], &[], "")?;
26112615
assert_eq!(DataType::Float64, observed);
26122616

26132617
Ok(())
@@ -2616,7 +2620,7 @@ mod test {
26162620
#[test]
26172621
fn test_ntile_return_type() -> Result<()> {
26182622
let fun = find_df_window_func("ntile").unwrap();
2619-
let observed = fun.return_type(&[DataType::Int16], &[true])?;
2623+
let observed = fun.return_type(&[DataType::Int16], &[true], "")?;
26202624
assert_eq!(DataType::UInt64, observed);
26212625

26222626
Ok(())

datafusion/expr/src/expr_fn.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ use arrow::compute::kernels::cast_utils::{
3838
};
3939
use arrow::datatypes::{DataType, Field};
4040
use datafusion_common::{plan_err, Column, Result, ScalarValue, TableReference};
41+
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
4142
use sqlparser::ast::NullTreatment;
4243
use std::any::Any;
4344
use std::fmt::Debug;
@@ -657,13 +658,17 @@ impl WindowUDFImpl for SimpleWindowUDF {
657658
&self.signature
658659
}
659660

660-
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
661-
Ok(self.return_type.clone())
662-
}
663-
664661
fn partition_evaluator(&self) -> Result<Box<dyn crate::PartitionEvaluator>> {
665662
(self.partition_evaluator_factory)()
666663
}
664+
665+
fn field(&self, field_args: WindowUDFFieldArgs) -> Result<Field> {
666+
Ok(Field::new(
667+
field_args.name(),
668+
self.return_type.clone(),
669+
true,
670+
))
671+
}
667672
}
668673

669674
pub fn interval_year_month_lit(value: &str) -> Expr {

0 commit comments

Comments
 (0)