Skip to content

Commit 014306d

Browse files
committed
refactored nth_value
1 parent 223bb02 commit 014306d

File tree

24 files changed

+643
-1001
lines changed

24 files changed

+643
-1001
lines changed

datafusion/core/src/dataframe/mod.rs

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1946,9 +1946,8 @@ mod tests {
19461946
use datafusion_common_runtime::SpawnedTask;
19471947
use datafusion_expr::expr::WindowFunction;
19481948
use datafusion_expr::{
1949-
cast, create_udf, lit, BuiltInWindowFunction, ExprFunctionExt,
1950-
ScalarFunctionImplementation, Volatility, WindowFrame, WindowFrameBound,
1951-
WindowFrameUnits, WindowFunctionDefinition,
1949+
cast, create_udf, lit, ExprFunctionExt, ScalarFunctionImplementation, Volatility,
1950+
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
19521951
};
19531952
use datafusion_functions_aggregate::expr_fn::{array_agg, count_distinct};
19541953
use datafusion_functions_window::expr_fn::row_number;
@@ -2172,31 +2171,6 @@ mod tests {
21722171
Ok(())
21732172
}
21742173

2175-
#[tokio::test]
2176-
async fn select_with_window_exprs() -> Result<()> {
2177-
// build plan using Table API
2178-
let t = test_table().await?;
2179-
let first_row = Expr::WindowFunction(WindowFunction::new(
2180-
WindowFunctionDefinition::BuiltInWindowFunction(
2181-
BuiltInWindowFunction::FirstValue,
2182-
),
2183-
vec![col("aggregate_test_100.c1")],
2184-
))
2185-
.partition_by(vec![col("aggregate_test_100.c2")])
2186-
.build()
2187-
.unwrap();
2188-
let t2 = t.select(vec![col("c1"), first_row])?;
2189-
let plan = t2.plan.clone();
2190-
2191-
let sql_plan = create_plan(
2192-
"select c1, first_value(c1) over (partition by c2) from aggregate_test_100",
2193-
)
2194-
.await?;
2195-
2196-
assert_same_plan(&plan, &sql_plan);
2197-
Ok(())
2198-
}
2199-
22002174
#[tokio::test]
22012175
async fn select_with_periods() -> Result<()> {
22022176
// define data with a column name that has a "." in it:

datafusion/core/tests/fuzz_cases/window_fuzz.rs

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ use datafusion_common::{Result, ScalarValue};
3434
use datafusion_common_runtime::SpawnedTask;
3535
use datafusion_expr::type_coercion::functions::data_types_with_aggregate_udf;
3636
use datafusion_expr::{
37-
BuiltInWindowFunction, WindowFrame, WindowFrameBound, WindowFrameUnits,
38-
WindowFunctionDefinition,
37+
WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunctionDefinition,
3938
};
4039
use datafusion_functions_aggregate::count::count_udaf;
4140
use datafusion_functions_aggregate::min_max::{max_udaf, min_udaf};
@@ -414,36 +413,6 @@ fn get_random_function(
414413
),
415414
);
416415
}
417-
window_fn_map.insert(
418-
"first_value",
419-
(
420-
WindowFunctionDefinition::BuiltInWindowFunction(
421-
BuiltInWindowFunction::FirstValue,
422-
),
423-
vec![arg.clone()],
424-
),
425-
);
426-
window_fn_map.insert(
427-
"last_value",
428-
(
429-
WindowFunctionDefinition::BuiltInWindowFunction(
430-
BuiltInWindowFunction::LastValue,
431-
),
432-
vec![arg.clone()],
433-
),
434-
);
435-
window_fn_map.insert(
436-
"nth_value",
437-
(
438-
WindowFunctionDefinition::BuiltInWindowFunction(
439-
BuiltInWindowFunction::NthValue,
440-
),
441-
vec![
442-
arg.clone(),
443-
lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
444-
],
445-
),
446-
);
447416

448417
let rand_fn_idx = rng.gen_range(0..window_fn_map.len());
449418
let fn_name = window_fn_map.keys().collect::<Vec<_>>()[rand_fn_idx];

datafusion/expr/src/expr.rs

Lines changed: 3 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,13 @@ use std::collections::{HashMap, HashSet};
2121
use std::fmt::{self, Display, Formatter, Write};
2222
use std::hash::{Hash, Hasher};
2323
use std::mem;
24-
use std::str::FromStr;
2524
use std::sync::Arc;
2625

2726
use crate::expr_fn::binary_expr;
2827
use crate::logical_plan::Subquery;
2928
use crate::utils::expr_to_columns;
3029
use crate::Volatility;
31-
use crate::{
32-
udaf, BuiltInWindowFunction, ExprSchemable, Operator, Signature, WindowFrame,
33-
WindowUDF,
34-
};
30+
use crate::{udaf, ExprSchemable, Operator, Signature, WindowFrame, WindowUDF};
3531

3632
use arrow::datatypes::{DataType, FieldRef};
3733
use datafusion_common::cse::HashNode;
@@ -693,9 +689,6 @@ impl AggregateFunction {
693689
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
694690
/// Defines which implementation of an aggregate function DataFusion should call.
695691
pub enum WindowFunctionDefinition {
696-
/// A built in aggregate function that leverages an aggregate function
697-
/// A a built-in window function
698-
BuiltInWindowFunction(BuiltInWindowFunction),
699692
/// A user defined aggregate function
700693
AggregateUDF(Arc<crate::AggregateUDF>),
701694
/// A user defined aggregate function
@@ -711,9 +704,6 @@ impl WindowFunctionDefinition {
711704
display_name: &str,
712705
) -> Result<DataType> {
713706
match self {
714-
WindowFunctionDefinition::BuiltInWindowFunction(fun) => {
715-
fun.return_type(input_expr_types)
716-
}
717707
WindowFunctionDefinition::AggregateUDF(fun) => {
718708
fun.return_type(input_expr_types)
719709
}
@@ -726,7 +716,6 @@ impl WindowFunctionDefinition {
726716
/// The signatures supported by the function `fun`.
727717
pub fn signature(&self) -> Signature {
728718
match self {
729-
WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.signature(),
730719
WindowFunctionDefinition::AggregateUDF(fun) => fun.signature().clone(),
731720
WindowFunctionDefinition::WindowUDF(fun) => fun.signature().clone(),
732721
}
@@ -735,7 +724,6 @@ impl WindowFunctionDefinition {
735724
/// Function's name for display
736725
pub fn name(&self) -> &str {
737726
match self {
738-
WindowFunctionDefinition::BuiltInWindowFunction(fun) => fun.name(),
739727
WindowFunctionDefinition::WindowUDF(fun) => fun.name(),
740728
WindowFunctionDefinition::AggregateUDF(fun) => fun.name(),
741729
}
@@ -745,19 +733,12 @@ impl WindowFunctionDefinition {
745733
impl Display for WindowFunctionDefinition {
746734
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
747735
match self {
748-
WindowFunctionDefinition::BuiltInWindowFunction(fun) => Display::fmt(fun, f),
749736
WindowFunctionDefinition::AggregateUDF(fun) => Display::fmt(fun, f),
750737
WindowFunctionDefinition::WindowUDF(fun) => Display::fmt(fun, f),
751738
}
752739
}
753740
}
754741

755-
impl From<BuiltInWindowFunction> for WindowFunctionDefinition {
756-
fn from(value: BuiltInWindowFunction) -> Self {
757-
Self::BuiltInWindowFunction(value)
758-
}
759-
}
760-
761742
impl From<Arc<crate::AggregateUDF>> for WindowFunctionDefinition {
762743
fn from(value: Arc<crate::AggregateUDF>) -> Self {
763744
Self::AggregateUDF(value)
@@ -783,9 +764,10 @@ impl From<Arc<WindowUDF>> for WindowFunctionDefinition {
783764
/// ```
784765
/// # use datafusion_expr::{Expr, BuiltInWindowFunction, col, ExprFunctionExt};
785766
/// # use datafusion_expr::expr::WindowFunction;
767+
/// use datafusion_expr::WindowFunctionDefinition::WindowUDF;
786768
/// // Create FIRST_VALUE(a) OVER (PARTITION BY b ORDER BY c)
787769
/// let expr = Expr::WindowFunction(
788-
/// WindowFunction::new(BuiltInWindowFunction::FirstValue, vec![col("a")])
770+
/// WindowFunction::new(WindowUDF::, vec![col("a")])
789771
/// )
790772
/// .partition_by(vec![col("b")])
791773
/// .order_by(vec![col("b").sort(true, true)])
@@ -823,23 +805,6 @@ impl WindowFunction {
823805
}
824806
}
825807

826-
/// Find DataFusion's built-in window function by name.
827-
pub fn find_df_window_func(name: &str) -> Option<WindowFunctionDefinition> {
828-
let name = name.to_lowercase();
829-
// Code paths for window functions leveraging ordinary aggregators and
830-
// built-in window functions are quite different, and the same function
831-
// may have different implementations for these cases. If the sought
832-
// function is not found among built-in window functions, we search for
833-
// it among aggregate functions.
834-
if let Ok(built_in_function) = BuiltInWindowFunction::from_str(name.as_str()) {
835-
Some(WindowFunctionDefinition::BuiltInWindowFunction(
836-
built_in_function,
837-
))
838-
} else {
839-
None
840-
}
841-
}
842-
843808
/// EXISTS expression
844809
#[derive(Clone, PartialEq, Eq, PartialOrd, Hash, Debug)]
845810
pub struct Exists {
@@ -2525,77 +2490,6 @@ mod test {
25252490

25262491
use super::*;
25272492

2528-
#[test]
2529-
fn test_first_value_return_type() -> Result<()> {
2530-
let fun = find_df_window_func("first_value").unwrap();
2531-
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
2532-
assert_eq!(DataType::Utf8, observed);
2533-
2534-
let observed = fun.return_type(&[DataType::UInt64], &[true], "")?;
2535-
assert_eq!(DataType::UInt64, observed);
2536-
2537-
Ok(())
2538-
}
2539-
2540-
#[test]
2541-
fn test_last_value_return_type() -> Result<()> {
2542-
let fun = find_df_window_func("last_value").unwrap();
2543-
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
2544-
assert_eq!(DataType::Utf8, observed);
2545-
2546-
let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
2547-
assert_eq!(DataType::Float64, observed);
2548-
2549-
Ok(())
2550-
}
2551-
2552-
#[test]
2553-
fn test_nth_value_return_type() -> Result<()> {
2554-
let fun = find_df_window_func("nth_value").unwrap();
2555-
let observed =
2556-
fun.return_type(&[DataType::Utf8, DataType::UInt64], &[true, true], "")?;
2557-
assert_eq!(DataType::Utf8, observed);
2558-
2559-
let observed =
2560-
fun.return_type(&[DataType::Float64, DataType::UInt64], &[true, true], "")?;
2561-
assert_eq!(DataType::Float64, observed);
2562-
2563-
Ok(())
2564-
}
2565-
2566-
#[test]
2567-
fn test_window_function_case_insensitive() -> Result<()> {
2568-
let names = vec!["first_value", "last_value", "nth_value"];
2569-
for name in names {
2570-
let fun = find_df_window_func(name).unwrap();
2571-
let fun2 = find_df_window_func(name.to_uppercase().as_str()).unwrap();
2572-
assert_eq!(fun, fun2);
2573-
if fun.to_string() == "first_value" || fun.to_string() == "last_value" {
2574-
assert_eq!(fun.to_string(), name);
2575-
} else {
2576-
assert_eq!(fun.to_string(), name.to_uppercase());
2577-
}
2578-
}
2579-
Ok(())
2580-
}
2581-
2582-
#[test]
2583-
fn test_find_df_window_function() {
2584-
assert_eq!(
2585-
find_df_window_func("first_value"),
2586-
Some(WindowFunctionDefinition::BuiltInWindowFunction(
2587-
BuiltInWindowFunction::FirstValue
2588-
))
2589-
);
2590-
assert_eq!(
2591-
find_df_window_func("LAST_value"),
2592-
Some(WindowFunctionDefinition::BuiltInWindowFunction(
2593-
BuiltInWindowFunction::LastValue
2594-
))
2595-
);
2596-
assert_eq!(find_df_window_func("not_exist"), None)
2597-
}
2598-
25992493
#[test]
26002494
fn test_display_wildcard() {
26012495
assert_eq!(format!("{}", wildcard()), "*");

datafusion/expr/src/expr_schema.rs

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -478,12 +478,6 @@ impl Expr {
478478
.map(|e| e.get_type(schema))
479479
.collect::<Result<Vec<_>>>()?;
480480
match fun {
481-
WindowFunctionDefinition::BuiltInWindowFunction(window_fun) => {
482-
let return_type = window_fun.return_type(&data_types)?;
483-
let nullable =
484-
!["RANK", "NTILE", "CUME_DIST"].contains(&window_fun.name());
485-
Ok((return_type, nullable))
486-
}
487481
WindowFunctionDefinition::AggregateUDF(udaf) => {
488482
let new_types = data_types_with_aggregate_udf(&data_types, udaf)
489483
.map_err(|err| {

datafusion/expr/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ pub mod type_coercion;
6464
pub mod utils;
6565
pub mod var_provider;
6666
pub mod window_frame;
67-
pub mod window_function;
6867
pub mod window_state;
6968

7069
pub use built_in_window_function::BuiltInWindowFunction;

datafusion/expr/src/window_function.rs

Lines changed: 0 additions & 26 deletions
This file was deleted.

datafusion/functions-window/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ pub mod macros;
3434

3535
pub mod cume_dist;
3636
pub mod lead_lag;
37+
pub mod nth_value;
3738
pub mod ntile;
3839
pub mod rank;
3940
pub mod row_number;
@@ -44,6 +45,7 @@ pub mod expr_fn {
4445
pub use super::cume_dist::cume_dist;
4546
pub use super::lead_lag::lag;
4647
pub use super::lead_lag::lead;
48+
pub use super::nth_value::{first_value, last_value, nth_value};
4749
pub use super::ntile::ntile;
4850
pub use super::rank::{dense_rank, percent_rank, rank};
4951
pub use super::row_number::row_number;
@@ -60,6 +62,9 @@ pub fn all_default_window_functions() -> Vec<Arc<WindowUDF>> {
6062
rank::dense_rank_udwf(),
6163
rank::percent_rank_udwf(),
6264
ntile::ntile_udwf(),
65+
nth_value::first_value_udwf(),
66+
nth_value::last_value_udwf(),
67+
nth_value::nth_value_udwf(),
6368
]
6469
}
6570
/// Registers all enabled packages with a [`FunctionRegistry`]

0 commit comments

Comments
 (0)