Skip to content

Commit 54ab128

Browse files
buraksennjcsherinberkaysynnadaalamb
authored
Convert nth_value builtIn function to User Defined Window Function (#13201)
* refactored nth_value * continue * test * proto and rustlint * fix datatype * cont * cont * apply jcsherins early validation * docs * doc * Apply suggestions from code review Co-authored-by: Sherin Jacob <[email protected]> * passes lint but does not have tests * continue * Update roundtrip_physical_plan.rs * udwf, not udaf * fix bounded but not fixed roundtrip * added * Update datafusion/sqllogictest/test_files/errors.slt Co-authored-by: Sherin Jacob <[email protected]> --------- Co-authored-by: Sherin Jacob <[email protected]> Co-authored-by: berkaysynnada <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent 4e1f839 commit 54ab128

File tree

27 files changed

+728
-828
lines changed

27 files changed

+728
-828
lines changed

datafusion/core/src/dataframe/mod.rs

+4-6
Original file line numberDiff line numberDiff line change
@@ -1946,12 +1946,12 @@ mod tests {
19461946
use datafusion_common_runtime::SpawnedTask;
19471947
use datafusion_expr::expr::WindowFunction;
19481948
use datafusion_expr::{
1949-
cast, create_udf, lit, BuiltInWindowFunction, ExprFunctionExt,
1950-
ScalarFunctionImplementation, Volatility, WindowFrame, WindowFrameBound,
1951-
WindowFrameUnits, WindowFunctionDefinition,
1949+
cast, create_udf, lit, ExprFunctionExt, ScalarFunctionImplementation, Volatility,
1950+
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
19521951
};
19531952
use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct};
19541953
use datafusion_functions_window::expr_fn::row_number;
1954+
use datafusion_functions_window::nth_value::first_value_udwf;
19551955
use datafusion_physical_expr::expressions::Column;
19561956
use datafusion_physical_plan::{get_plan_string, ExecutionPlanProperties};
19571957
use sqlparser::ast::NullTreatment;
@@ -2177,9 +2177,7 @@ mod tests {
21772177
// build plan using Table API
21782178
let t = test_table().await?;
21792179
let first_row = Expr::WindowFunction(WindowFunction::new(
2180-
WindowFunctionDefinition::BuiltInWindowFunction(
2181-
BuiltInWindowFunction::FirstValue,
2182-
),
2180+
WindowFunctionDefinition::WindowUDF(first_value_udwf()),
21832181
vec![col("aggregate_test_100.c1")],
21842182
))
21852183
.partition_by(vec![col("aggregate_test_100.c2")])

datafusion/core/tests/fuzz_cases/window_fuzz.rs

+7-11
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ use datafusion_common::{Result, ScalarValue};
3434
use datafusion_common_runtime::SpawnedTask;
3535
use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf;
3636
use datafusion_expr::{
37-
BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits,
38-
WindowFunctionDefinition,
37+
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
3938
};
4039
use datafusion_functions_aggregate::count::count_udaf;
4140
use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf};
@@ -47,6 +46,9 @@ use test_utils::add_empty_batches;
4746
use datafusion::functions_window::row_number::row_number_udwf;
4847
use datafusion_common::HashMap;
4948
use datafusion_functions_window::lead_lag::{lag_udwf, lead_udwf};
49+
use datafusion_functions_window::nth_value::{
50+
first_value_udwf, last_value_udwf, nth_value_udwf,
51+
};
5052
use datafusion_functions_window::rank::{dense_rank_udwf, rank_udwf};
5153
use datafusion_physical_expr_common::sort_expr::LexOrdering;
5254
use rand::distributions::Alphanumeric;
@@ -418,27 +420,21 @@ fn get_random_function(
418420
window_fn_map.insert(
419421
"first_value",
420422
(
421-
WindowFunctionDefinition::BuiltInWindowFunction(
422-
BuiltInWindowFunction::FirstValue,
423-
),
423+
WindowFunctionDefinition::WindowUDF(first_value_udwf()),
424424
vec![arg.clone()],
425425
),
426426
);
427427
window_fn_map.insert(
428428
"last_value",
429429
(
430-
WindowFunctionDefinition::BuiltInWindowFunction(
431-
BuiltInWindowFunction::LastValue,
432-
),
430+
WindowFunctionDefinition::WindowUDF(last_value_udwf()),
433431
vec![arg.clone()],
434432
),
435433
);
436434
window_fn_map.insert(
437435
"nth_value",
438436
(
439-
WindowFunctionDefinition::BuiltInWindowFunction(
440-
BuiltInWindowFunction::NthValue,
441-
),
437+
WindowFunctionDefinition::WindowUDF(nth_value_udwf()),
442438
vec![
443439
arg.clone(),
444440
lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),

datafusion/expr/src/expr.rs

-89
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ use std::collections::HashSet;
2121
use std::fmt::{self, Display, Formatter, Write};
2222
use std::hash::{Hash, Hasher};
2323
use std::mem;
24-
use std::str::FromStr;
2524
use std::sync::Arc;
2625

2726
use crate::expr_fn::binary_expr;
@@ -832,23 +831,6 @@ impl WindowFunction {
832831
}
833832
}
834833

835-
/// Find DataFusion's built-in window function by name.
836-
pub fn find_df_window_func(name: &str) -> Option<WindowFunctionDefinition> {
837-
let name = name.to_lowercase();
838-
// Code paths for window functions leveraging ordinary aggregators and
839-
// built-in window functions are quite different, and the same function
840-
// may have different implementations for these cases. If the sought
841-
// function is not found among built-in window functions, we search for
842-
// it among aggregate functions.
843-
if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) {
844-
Some(WindowFunctionDefinition::BuiltInWindowFunction(
845-
built_in_function,
846-
))
847-
} else {
848-
None
849-
}
850-
}
851-
852834
/// EXISTS expression
853835
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
854836
pub struct Exists {
@@ -2548,77 +2530,6 @@ mod test {
25482530

25492531
use super::*;
25502532

2551-
#[test]
2552-
fn test_first_value_return_type() -> Result<()> {
2553-
let fun = find_df_window_func("first_value").unwrap();
2554-
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
2555-
assert_eq!(DataType::Utf8, observed);
2556-
2557-
let observed = fun.return_type(&[DataType::UInt64], &[true], "")?;
2558-
assert_eq!(DataType::UInt64, observed);
2559-
2560-
Ok(())
2561-
}
2562-
2563-
#[test]
2564-
fn test_last_value_return_type() -> Result<()> {
2565-
let fun = find_df_window_func("last_value").unwrap();
2566-
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
2567-
assert_eq!(DataType::Utf8, observed);
2568-
2569-
let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
2570-
assert_eq!(DataType::Float64, observed);
2571-
2572-
Ok(())
2573-
}
2574-
2575-
#[test]
2576-
fn test_nth_value_return_type() -> Result<()> {
2577-
let fun = find_df_window_func("nth_value").unwrap();
2578-
let observed =
2579-
fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true], "")?;
2580-
assert_eq!(DataType::Utf8, observed);
2581-
2582-
let observed =
2583-
fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true], "")?;
2584-
assert_eq!(DataType::Float64, observed);
2585-
2586-
Ok(())
2587-
}
2588-
2589-
#[test]
2590-
fn test_window_function_case_insensitive() -> Result<()> {
2591-
let names = vec!["first_value", "last_value", "nth_value"];
2592-
for name in names {
2593-
let fun = find_df_window_func(name).unwrap();
2594-
let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap();
2595-
assert_eq!(fun, fun2);
2596-
if fun.to_string() == "first_value" || fun.to_string() == "last_value" {
2597-
assert_eq!(fun.to_string(), name);
2598-
} else {
2599-
assert_eq!(fun.to_string(), name.to_uppercase());
2600-
}
2601-
}
2602-
Ok(())
2603-
}
2604-
2605-
#[test]
2606-
fn test_find_df_window_function() {
2607-
assert_eq!(
2608-
find_df_window_func("first_value"),
2609-
Some(WindowFunctionDefinition::BuiltInWindowFunction(
2610-
BuiltInWindowFunction::FirstValue
2611-
))
2612-
);
2613-
assert_eq!(
2614-
find_df_window_func("LAST_value"),
2615-
Some(WindowFunctionDefinition::BuiltInWindowFunction(
2616-
BuiltInWindowFunction::LastValue
2617-
))
2618-
);
2619-
assert_eq!(find_df_window_func("not_exist"), None)
2620-
}
2621-
26222533
#[test]
26232534
fn test_display_wildcard() {
26242535
assert_eq!(format!("{}", wildcard()), "*");

datafusion/expr/src/lib.rs

-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ pub mod type_coercion;
6565
pub mod utils;
6666
pub mod var_provider;
6767
pub mod window_frame;
68-
pub mod window_function;
6968
pub mod window_state;
7069

7170
pub use built_in_window_function::BuiltInWindowFunction;

datafusion/expr/src/window_function.rs

-26
This file was deleted.

datafusion/functions-window/src/lib.rs

+6
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
//!
2323
//! [DataFusion]: https://crates.io/crates/datafusion
2424
//!
25+
2526
use std::sync::Arc;
2627

2728
use log::debug;
@@ -34,6 +35,7 @@ pub mod macros;
3435

3536
pub mod cume_dist;
3637
pub mod lead_lag;
38+
pub mod nth_value;
3739
pub mod ntile;
3840
pub mod rank;
3941
pub mod row_number;
@@ -44,6 +46,7 @@ pub mod expr_fn {
4446
pub use super::cume_dist::cume_dist;
4547
pub use super::lead_lag::lag;
4648
pub use super::lead_lag::lead;
49+
pub use super::nth_value::{first_value, last_value, nth_value};
4750
pub use super::ntile::ntile;
4851
pub use super::rank::{dense_rank, percent_rank, rank};
4952
pub use super::row_number::row_number;
@@ -60,6 +63,9 @@ pub fn all_default_window_functions() -> Vec<Arc<WindowUDF>> {
6063
rank::dense_rank_udwf(),
6164
rank::percent_rank_udwf(),
6265
ntile::ntile_udwf(),
66+
nth_value::first_value_udwf(),
67+
nth_value::last_value_udwf(),
68+
nth_value::nth_value_udwf(),
6369
]
6470
}
6571
/// Registers all enabled packages with a [`FunctionRegistry`]

0 commit comments

Comments
 (0)