From 3a3c173c4bc2e389a4213dc316d7358bc629fb3e Mon Sep 17 00:00:00 2001 From: Steve Lau Date: Fri, 1 Mar 2024 09:04:29 +0800 Subject: [PATCH 1/5] feat: function name hints for UDFs --- datafusion-cli/Cargo.lock | 1 + datafusion-examples/examples/rewrite_expr.rs | 13 +++++ datafusion/core/src/execution/context/mod.rs | 12 +++++ datafusion/expr/src/function.rs | 37 ++----------- .../optimizer/tests/optimizer_integration.rs | 12 +++++ datafusion/sql/Cargo.toml | 1 + datafusion/sql/examples/sql.rs | 12 +++++ datafusion/sql/src/expr/function.rs | 54 +++++++++++++++++-- datafusion/sql/src/expr/mod.rs | 12 +++++ datafusion/sql/src/planner.rs | 4 ++ datafusion/sql/tests/sql_integration.rs | 12 +++++ 11 files changed, 131 insertions(+), 39 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 5e3c8648fc25..b4af7896821b 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1363,6 +1363,7 @@ dependencies = [ "datafusion-expr", "log", "sqlparser", + "strum 0.26.1", ] [[package]] diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index cc1396f770e4..f99e83dd18de 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -30,6 +30,7 @@ use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; use datafusion_sql::sqlparser::parser::Parser; use datafusion_sql::TableReference; use std::any::Any; +use std::collections::HashMap; use std::sync::Arc; pub fn main() -> Result<()> { @@ -226,6 +227,18 @@ impl ContextProvider for MyContextProvider { fn options(&self) -> &ConfigOptions { &self.options } + + fn udfs(&self) -> HashMap> { + HashMap::new() + } + + fn udafs(&self) -> HashMap> { + HashMap::new() + } + + fn udwfs(&self) -> HashMap> { + HashMap::new() + } } struct MyTableSource { diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index e071c5c80e11..ef415df46ab9 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2084,6 +2084,18 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { fn options(&self) -> &ConfigOptions { self.state.config_options() } + + fn udfs(&self) -> HashMap> { + self.state.scalar_functions.clone() + } + + fn udafs(&self) -> HashMap> { + self.state.aggregate_functions().clone() + } + + fn udwfs(&self) -> HashMap> { + self.state.window_functions().clone() + } } impl FunctionRegistry for SessionState { diff --git a/datafusion/expr/src/function.rs b/datafusion/expr/src/function.rs index 3e30a5574be0..a3760eeb357d 100644 --- a/datafusion/expr/src/function.rs +++ b/datafusion/expr/src/function.rs @@ -17,13 +17,12 @@ //! Function module contains typing and signature for built-in and user defined functions. -use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature}; -use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue}; +use crate::{ + Accumulator, BuiltinScalarFunction, ColumnarValue, PartitionEvaluator, Signature, +}; use arrow::datatypes::DataType; -use datafusion_common::utils::datafusion_strsim; use datafusion_common::Result; use std::sync::Arc; -use strum::IntoEnumIterator; /// Scalar function /// @@ -75,33 +74,3 @@ pub fn return_type( pub fn signature(fun: &BuiltinScalarFunction) -> Signature { fun.signature() } - -/// Suggest a valid function based on an invalid input function name -pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) -> String { - let valid_funcs = if is_window_func { - // All aggregate functions and builtin window functions - AggregateFunction::iter() - .map(|func| func.to_string()) - .chain(BuiltInWindowFunction::iter().map(|func| func.to_string())) - .collect() - } else { - // All scalar functions and aggregate functions - BuiltinScalarFunction::iter() - .map(|func| func.to_string()) - .chain(AggregateFunction::iter().map(|func| func.to_string())) - .collect() - }; - find_closest_match(valid_funcs, input_function_name) -} - -/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve) -/// Input `candidates` must not be empty otherwise it will panic -fn find_closest_match(candidates: Vec, target: &str) -> String { - let target = target.to_lowercase(); - candidates - .into_iter() - .min_by_key(|candidate| { - datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target) - }) - .expect("No candidates provided.") // Panic if `candidates` argument is empty -} diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index db7bfa8b3bc8..ea8976a29323 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -417,6 +417,18 @@ impl ContextProvider for MyContextProvider { fn options(&self) -> &ConfigOptions { &self.options } + + fn udfs(&self) -> HashMap> { + HashMap::new() + } + + fn udafs(&self) -> HashMap> { + HashMap::new() + } + + fn udwfs(&self) -> HashMap> { + HashMap::new() + } } struct MyTableSource { diff --git a/datafusion/sql/Cargo.toml b/datafusion/sql/Cargo.toml index fb300e2c8791..7739058a5c9d 100644 --- a/datafusion/sql/Cargo.toml +++ b/datafusion/sql/Cargo.toml @@ -43,6 +43,7 @@ datafusion-common = { workspace = true, default-features = true } datafusion-expr = { workspace = true } log = { workspace = true } sqlparser = { workspace = true } +strum = { version = "0.26.1", features = ["derive"] } [dev-dependencies] ctor = { workspace = true } diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 8744a905481f..52f4e015aa54 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -131,4 +131,16 @@ impl ContextProvider for MyContextProvider { fn options(&self) -> &ConfigOptions { &self.options } + + fn udfs(&self) -> HashMap> { + HashMap::new() + } + + fn udafs(&self) -> HashMap> { + HashMap::new() + } + + fn udwfs(&self) -> HashMap> { + HashMap::new() + } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index bcf641e4b5a0..3d60135ecae9 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -20,20 +20,63 @@ use arrow_schema::DataType; use datafusion_common::{ not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result, }; -use datafusion_expr::expr::{ScalarFunction, Unnest}; -use datafusion_expr::function::suggest_valid_function; use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by}; use datafusion_expr::{ - expr, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, WindowFrame, - WindowFunctionDefinition, + expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition, +}; +use datafusion_expr::{ + expr::{ScalarFunction, Unnest}, + BuiltInWindowFunction, BuiltinScalarFunction, }; use sqlparser::ast::{ Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType, }; use std::str::FromStr; +use strum::IntoEnumIterator; use super::arrow_cast::ARROW_CAST_NAME; +/// Suggest a valid function based on an invalid input function name +pub fn suggest_valid_function( + input_function_name: &str, + is_window_func: bool, + ctx: &dyn ContextProvider, +) -> String { + let valid_funcs = if is_window_func { + // All aggregate functions and builtin window functions + AggregateFunction::iter() + .map(|func| func.to_string()) + .chain(BuiltInWindowFunction::iter().map(|func| func.to_string())) + .collect() + } else { + // All scalar functions and aggregate functions + let mut funcs = Vec::new(); + + funcs.extend(BuiltinScalarFunction::iter().map(|func| func.to_string())); + funcs.extend(ctx.udfs().into_keys()); + funcs.extend(AggregateFunction::iter().map(|func| func.to_string())); + funcs.extend(ctx.udafs().into_keys()); + + funcs + }; + find_closest_match(valid_funcs, input_function_name) +} + +/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve) +/// Input `candidates` must not be empty otherwise it will panic +fn find_closest_match(candidates: Vec, target: &str) -> String { + let target = target.to_lowercase(); + candidates + .into_iter() + .min_by_key(|candidate| { + datafusion_common::utils::datafusion_strsim::levenshtein( + &candidate.to_lowercase(), + &target, + ) + }) + .expect("No candidates provided.") // Panic if `candidates` argument is empty +} + impl<'a, S: ContextProvider> SqlToRel<'a, S> { pub(super) fn sql_function_to_expr( &self, @@ -211,7 +254,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { } // Could not find the relevant function, so return an error - let suggested_func_name = suggest_valid_function(&name, is_function_window); + let suggested_func_name = + suggest_valid_function(&name, is_function_window, self.context_provider); plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?") } diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index aa0b619167dc..39b932ea0673 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -955,6 +955,18 @@ mod tests { fn get_window_meta(&self, _name: &str) -> Option> { None } + + fn udfs(&self) -> HashMap> { + HashMap::new() + } + + fn udafs(&self) -> HashMap> { + HashMap::new() + } + + fn udwfs(&self) -> HashMap> { + HashMap::new() + } } fn create_table_source(fields: Vec) -> Arc { diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 2db2c01c5ee1..0b5893e058c5 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -85,6 +85,10 @@ pub trait ContextProvider { /// Get configuration options fn options(&self) -> &ConfigOptions; + + fn udfs(&self) -> HashMap>; + fn udafs(&self) -> HashMap>; + fn udwfs(&self) -> HashMap>; } /// SQL parser options diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 57e2e1ef06a7..b7526a4b605d 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2896,6 +2896,18 @@ impl ContextProvider for MockContextProvider { ) -> Result> { Ok(Arc::new(EmptyTable::new(schema))) } + + fn udfs(&self) -> HashMap> { + self.udfs.clone() + } + + fn udafs(&self) -> HashMap> { + self.udafs.clone() + } + + fn udwfs(&self) -> HashMap> { + HashMap::new() + } } #[test] From 5740ebb72680ba82d83e262d947eb0dc75211e86 Mon Sep 17 00:00:00 2001 From: Steve Lau Date: Sun, 10 Mar 2024 18:01:24 +0800 Subject: [PATCH 2/5] refactor: rebase fn to xxx_names() --- datafusion-examples/examples/rewrite_expr.rs | 12 +++++----- datafusion/core/src/execution/context/mod.rs | 24 ++++++++++++++----- .../optimizer/tests/optimizer_integration.rs | 12 +++++----- datafusion/sql/examples/sql.rs | 12 +++++----- datafusion/sql/src/expr/function.rs | 16 ++++++++----- datafusion/sql/src/expr/mod.rs | 12 +++++----- datafusion/sql/src/planner.rs | 6 ++--- datafusion/sql/tests/sql_integration.rs | 12 +++++----- 8 files changed, 61 insertions(+), 45 deletions(-) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index f99e83dd18de..6aba1e91c7f9 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -228,16 +228,16 @@ impl ContextProvider for MyContextProvider { &self.options } - fn udfs(&self) -> HashMap> { - HashMap::new() + fn udfs_names(&self) -> Vec { + Vec::new() } - fn udafs(&self) -> HashMap> { - HashMap::new() + fn udafs_names(&self) -> Vec { + Vec::new() } - fn udwfs(&self) -> HashMap> { - HashMap::new() + fn udwfs_names(&self) -> Vec { + Vec::new() } } diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index ef415df46ab9..45c9fd3d5888 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2085,16 +2085,28 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { self.state.config_options() } - fn udfs(&self) -> HashMap> { - self.state.scalar_functions.clone() + fn udfs_names(&self) -> Vec { + self.state + .scalar_functions() + .keys() + .map(|str| str.clone()) + .collect() } - fn udafs(&self) -> HashMap> { - self.state.aggregate_functions().clone() + fn udafs_names(&self) -> Vec { + self.state + .aggregate_functions() + .keys() + .map(|str| str.clone()) + .collect() } - fn udwfs(&self) -> HashMap> { - self.state.window_functions().clone() + fn udwfs_names(&self) -> Vec { + self.state + .window_functions() + .keys() + .map(|str| str.clone()) + .collect() } } diff --git a/datafusion/optimizer/tests/optimizer_integration.rs b/datafusion/optimizer/tests/optimizer_integration.rs index ea8976a29323..b02623854b8a 100644 --- a/datafusion/optimizer/tests/optimizer_integration.rs +++ b/datafusion/optimizer/tests/optimizer_integration.rs @@ -418,16 +418,16 @@ impl ContextProvider for MyContextProvider { &self.options } - fn udfs(&self) -> HashMap> { - HashMap::new() + fn udfs_names(&self) -> Vec { + Vec::new() } - fn udafs(&self) -> HashMap> { - HashMap::new() + fn udafs_names(&self) -> Vec { + Vec::new() } - fn udwfs(&self) -> HashMap> { - HashMap::new() + fn udwfs_names(&self) -> Vec { + Vec::new() } } diff --git a/datafusion/sql/examples/sql.rs b/datafusion/sql/examples/sql.rs index 52f4e015aa54..5bab2f19cfc0 100644 --- a/datafusion/sql/examples/sql.rs +++ b/datafusion/sql/examples/sql.rs @@ -132,15 +132,15 @@ impl ContextProvider for MyContextProvider { &self.options } - fn udfs(&self) -> HashMap> { - HashMap::new() + fn udfs_names(&self) -> Vec { + Vec::new() } - fn udafs(&self) -> HashMap> { - HashMap::new() + fn udafs_names(&self) -> Vec { + Vec::new() } - fn udwfs(&self) -> HashMap> { - HashMap::new() + fn udwfs_names(&self) -> Vec { + Vec::new() } } diff --git a/datafusion/sql/src/expr/function.rs b/datafusion/sql/src/expr/function.rs index 3d60135ecae9..ffc951a6fa66 100644 --- a/datafusion/sql/src/expr/function.rs +++ b/datafusion/sql/src/expr/function.rs @@ -44,18 +44,22 @@ pub fn suggest_valid_function( ) -> String { let valid_funcs = if is_window_func { // All aggregate functions and builtin window functions - AggregateFunction::iter() - .map(|func| func.to_string()) - .chain(BuiltInWindowFunction::iter().map(|func| func.to_string())) - .collect() + let mut funcs = Vec::new(); + + funcs.extend(AggregateFunction::iter().map(|func| func.to_string())); + funcs.extend(ctx.udafs_names()); + funcs.extend(BuiltInWindowFunction::iter().map(|func| func.to_string())); + funcs.extend(ctx.udwfs_names()); + + funcs } else { // All scalar functions and aggregate functions let mut funcs = Vec::new(); funcs.extend(BuiltinScalarFunction::iter().map(|func| func.to_string())); - funcs.extend(ctx.udfs().into_keys()); + funcs.extend(ctx.udfs_names()); funcs.extend(AggregateFunction::iter().map(|func| func.to_string())); - funcs.extend(ctx.udafs().into_keys()); + funcs.extend(ctx.udafs_names()); funcs }; diff --git a/datafusion/sql/src/expr/mod.rs b/datafusion/sql/src/expr/mod.rs index 39b932ea0673..c0520002ed21 100644 --- a/datafusion/sql/src/expr/mod.rs +++ b/datafusion/sql/src/expr/mod.rs @@ -956,16 +956,16 @@ mod tests { None } - fn udfs(&self) -> HashMap> { - HashMap::new() + fn udfs_names(&self) -> Vec { + Vec::new() } - fn udafs(&self) -> HashMap> { - HashMap::new() + fn udafs_names(&self) -> Vec { + Vec::new() } - fn udwfs(&self) -> HashMap> { - HashMap::new() + fn udwfs_names(&self) -> Vec { + Vec::new() } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index 0b5893e058c5..f94c6ec4e8c9 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -86,9 +86,9 @@ pub trait ContextProvider { /// Get configuration options fn options(&self) -> &ConfigOptions; - fn udfs(&self) -> HashMap>; - fn udafs(&self) -> HashMap>; - fn udwfs(&self) -> HashMap>; + fn udfs_names(&self) -> Vec; + fn udafs_names(&self) -> Vec; + fn udwfs_names(&self) -> Vec; } /// SQL parser options diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index b7526a4b605d..fd4c5c923230 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2897,16 +2897,16 @@ impl ContextProvider for MockContextProvider { Ok(Arc::new(EmptyTable::new(schema))) } - fn udfs(&self) -> HashMap> { - self.udfs.clone() + fn udfs_names(&self) -> Vec { + self.udfs.keys().map(|str| str.clone()).collect() } - fn udafs(&self) -> HashMap> { - self.udafs.clone() + fn udafs_names(&self) -> Vec { + self.udafs.keys().map(|str| str.clone()).collect() } - fn udwfs(&self) -> HashMap> { - HashMap::new() + fn udwfs_names(&self) -> Vec { + Vec::new() } } From 081028478ef615cb8d407ac3a282fff2e0bdad09 Mon Sep 17 00:00:00 2001 From: Steve Lau Date: Sun, 10 Mar 2024 18:05:37 +0800 Subject: [PATCH 3/5] style: fix clippy --- datafusion/core/src/execution/context/mod.rs | 18 +++--------------- datafusion/sql/tests/sql_integration.rs | 4 ++-- 2 files changed, 5 insertions(+), 17 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 45c9fd3d5888..9f06fe9c8d6c 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -2086,27 +2086,15 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { } fn udfs_names(&self) -> Vec { - self.state - .scalar_functions() - .keys() - .map(|str| str.clone()) - .collect() + self.state.scalar_functions().keys().cloned().collect() } fn udafs_names(&self) -> Vec { - self.state - .aggregate_functions() - .keys() - .map(|str| str.clone()) - .collect() + self.state.aggregate_functions().keys().cloned().collect() } fn udwfs_names(&self) -> Vec { - self.state - .window_functions() - .keys() - .map(|str| str.clone()) - .collect() + self.state.window_functions().keys().cloned().collect() } } diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index fd4c5c923230..b128383a3a59 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -2898,11 +2898,11 @@ impl ContextProvider for MockContextProvider { } fn udfs_names(&self) -> Vec { - self.udfs.keys().map(|str| str.clone()).collect() + self.udfs.keys().cloned().collect() } fn udafs_names(&self) -> Vec { - self.udafs.keys().map(|str| str.clone()).collect() + self.udafs.keys().cloned().collect() } fn udwfs_names(&self) -> Vec { From eca853eefa940bbc02c90edd3dd41c6bbae077ab Mon Sep 17 00:00:00 2001 From: Steve Lau Date: Sun, 10 Mar 2024 18:06:04 +0800 Subject: [PATCH 4/5] style: fix clippy --- datafusion-examples/examples/rewrite_expr.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/datafusion-examples/examples/rewrite_expr.rs b/datafusion-examples/examples/rewrite_expr.rs index 6aba1e91c7f9..541448ebf149 100644 --- a/datafusion-examples/examples/rewrite_expr.rs +++ b/datafusion-examples/examples/rewrite_expr.rs @@ -30,7 +30,6 @@ use datafusion_sql::sqlparser::dialect::PostgreSqlDialect; use datafusion_sql::sqlparser::parser::Parser; use datafusion_sql::TableReference; use std::any::Any; -use std::collections::HashMap; use std::sync::Arc; pub fn main() -> Result<()> { From d2a291c97ebda73ed934e85ec08ef7a678a13d52 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 10 Mar 2024 06:44:10 -0400 Subject: [PATCH 5/5] Add test --- datafusion/sqllogictest/test_files/functions.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/functions.slt b/datafusion/sqllogictest/test_files/functions.slt index 96aa3e275209..21433ba16810 100644 --- a/datafusion/sqllogictest/test_files/functions.slt +++ b/datafusion/sqllogictest/test_files/functions.slt @@ -483,7 +483,7 @@ statement error Did you mean 'arrow_typeof'? SELECT arrowtypeof(v1) from test; # Scalar function -statement error Invalid function 'to_timestamps_second' +statement error Did you mean 'to_timestamp_seconds'? SELECT to_TIMESTAMPS_second(v2) from test; # Aggregate function