Skip to content

Commit f1f0965

Browse files
SteveLauCalamb
andauthored
feat: function name hints for UDFs (#9407)
* feat: function name hints for UDFs * refactor: rebase fn to xxx_names() * style: fix clippy * style: fix clippy * Add test --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 96664ce commit f1f0965

File tree

12 files changed

+135
-40
lines changed

12 files changed

+135
-40
lines changed

datafusion-cli/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-examples/examples/rewrite_expr.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,18 @@ impl ContextProvider for MyContextProvider {
226226
fn options(&self) -> &ConfigOptions {
227227
&self.options
228228
}
229+
230+
fn udfs_names(&self) -> Vec<String> {
231+
Vec::new()
232+
}
233+
234+
fn udafs_names(&self) -> Vec<String> {
235+
Vec::new()
236+
}
237+
238+
fn udwfs_names(&self) -> Vec<String> {
239+
Vec::new()
240+
}
229241
}
230242

231243
struct MyTableSource {

datafusion/core/src/execution/context/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2098,6 +2098,18 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
20982098
fn options(&self) -> &ConfigOptions {
20992099
self.state.config_options()
21002100
}
2101+
2102+
fn udfs_names(&self) -> Vec<String> {
2103+
self.state.scalar_functions().keys().cloned().collect()
2104+
}
2105+
2106+
fn udafs_names(&self) -> Vec<String> {
2107+
self.state.aggregate_functions().keys().cloned().collect()
2108+
}
2109+
2110+
fn udwfs_names(&self) -> Vec<String> {
2111+
self.state.window_functions().keys().cloned().collect()
2112+
}
21012113
}
21022114

21032115
impl FunctionRegistry for SessionState {

datafusion/expr/src/function.rs

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717

1818
//! Function module contains typing and signature for built-in and user defined functions.
1919
20-
use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature};
21-
use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue};
20+
use crate::{
21+
Accumulator, BuiltinScalarFunction, ColumnarValue, PartitionEvaluator, Signature,
22+
};
2223
use arrow::datatypes::DataType;
23-
use datafusion_common::utils::datafusion_strsim;
2424
use datafusion_common::Result;
2525
use std::sync::Arc;
26-
use strum::IntoEnumIterator;
2726

2827
/// Scalar function
2928
///
@@ -75,33 +74,3 @@ pub fn return_type(
7574
pub fn signature(fun: &BuiltinScalarFunction) -> Signature {
7675
fun.signature()
7776
}
78-
79-
/// Suggest a valid function based on an invalid input function name
80-
pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) -> String {
81-
let valid_funcs = if is_window_func {
82-
// All aggregate functions and builtin window functions
83-
AggregateFunction::iter()
84-
.map(|func| func.to_string())
85-
.chain(BuiltInWindowFunction::iter().map(|func| func.to_string()))
86-
.collect()
87-
} else {
88-
// All scalar functions and aggregate functions
89-
BuiltinScalarFunction::iter()
90-
.map(|func| func.to_string())
91-
.chain(AggregateFunction::iter().map(|func| func.to_string()))
92-
.collect()
93-
};
94-
find_closest_match(valid_funcs, input_function_name)
95-
}
96-
97-
/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve)
98-
/// Input `candidates` must not be empty otherwise it will panic
99-
fn find_closest_match(candidates: Vec<String>, target: &str) -> String {
100-
let target = target.to_lowercase();
101-
candidates
102-
.into_iter()
103-
.min_by_key(|candidate| {
104-
datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target)
105-
})
106-
.expect("No candidates provided.") // Panic if `candidates` argument is empty
107-
}

datafusion/optimizer/tests/optimizer_integration.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,6 +417,18 @@ impl ContextProvider for MyContextProvider {
417417
fn options(&self) -> &ConfigOptions {
418418
&self.options
419419
}
420+
421+
fn udfs_names(&self) -> Vec<String> {
422+
Vec::new()
423+
}
424+
425+
fn udafs_names(&self) -> Vec<String> {
426+
Vec::new()
427+
}
428+
429+
fn udwfs_names(&self) -> Vec<String> {
430+
Vec::new()
431+
}
420432
}
421433

422434
struct MyTableSource {

datafusion/sql/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ datafusion-common = { workspace = true, default-features = true }
4343
datafusion-expr = { workspace = true }
4444
log = { workspace = true }
4545
sqlparser = { workspace = true }
46+
strum = { version = "0.26.1", features = ["derive"] }
4647

4748
[dev-dependencies]
4849
ctor = { workspace = true }

datafusion/sql/examples/sql.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,16 @@ impl ContextProvider for MyContextProvider {
131131
fn options(&self) -> &ConfigOptions {
132132
&self.options
133133
}
134+
135+
fn udfs_names(&self) -> Vec<String> {
136+
Vec::new()
137+
}
138+
139+
fn udafs_names(&self) -> Vec<String> {
140+
Vec::new()
141+
}
142+
143+
fn udwfs_names(&self) -> Vec<String> {
144+
Vec::new()
145+
}
134146
}

datafusion/sql/src/expr/function.rs

Lines changed: 53 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,67 @@ use arrow_schema::DataType;
2020
use datafusion_common::{
2121
not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result,
2222
};
23-
use datafusion_expr::expr::{ScalarFunction, Unnest};
24-
use datafusion_expr::function::suggest_valid_function;
2523
use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by};
2624
use datafusion_expr::{
27-
expr, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, WindowFrame,
28-
WindowFunctionDefinition,
25+
expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition,
26+
};
27+
use datafusion_expr::{
28+
expr::{ScalarFunction, Unnest},
29+
BuiltInWindowFunction, BuiltinScalarFunction,
2930
};
3031
use sqlparser::ast::{
3132
Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType,
3233
};
3334
use std::str::FromStr;
35+
use strum::IntoEnumIterator;
3436

3537
use super::arrow_cast::ARROW_CAST_NAME;
3638

39+
/// Suggest a valid function based on an invalid input function name
40+
pub fn suggest_valid_function(
41+
input_function_name: &str,
42+
is_window_func: bool,
43+
ctx: &dyn ContextProvider,
44+
) -> String {
45+
let valid_funcs = if is_window_func {
46+
// All aggregate functions and builtin window functions
47+
let mut funcs = Vec::new();
48+
49+
funcs.extend(AggregateFunction::iter().map(|func| func.to_string()));
50+
funcs.extend(ctx.udafs_names());
51+
funcs.extend(BuiltInWindowFunction::iter().map(|func| func.to_string()));
52+
funcs.extend(ctx.udwfs_names());
53+
54+
funcs
55+
} else {
56+
// All scalar functions and aggregate functions
57+
let mut funcs = Vec::new();
58+
59+
funcs.extend(BuiltinScalarFunction::iter().map(|func| func.to_string()));
60+
funcs.extend(ctx.udfs_names());
61+
funcs.extend(AggregateFunction::iter().map(|func| func.to_string()));
62+
funcs.extend(ctx.udafs_names());
63+
64+
funcs
65+
};
66+
find_closest_match(valid_funcs, input_function_name)
67+
}
68+
69+
/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve)
70+
/// Input `candidates` must not be empty otherwise it will panic
71+
fn find_closest_match(candidates: Vec<String>, target: &str) -> String {
72+
let target = target.to_lowercase();
73+
candidates
74+
.into_iter()
75+
.min_by_key(|candidate| {
76+
datafusion_common::utils::datafusion_strsim::levenshtein(
77+
&candidate.to_lowercase(),
78+
&target,
79+
)
80+
})
81+
.expect("No candidates provided.") // Panic if `candidates` argument is empty
82+
}
83+
3784
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
3885
pub(super) fn sql_function_to_expr(
3986
&self,
@@ -211,7 +258,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
211258
}
212259

213260
// Could not find the relevant function, so return an error
214-
let suggested_func_name = suggest_valid_function(&name, is_function_window);
261+
let suggested_func_name =
262+
suggest_valid_function(&name, is_function_window, self.context_provider);
215263
plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?")
216264
}
217265

datafusion/sql/src/expr/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -983,6 +983,18 @@ mod tests {
983983
fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
984984
None
985985
}
986+
987+
fn udfs_names(&self) -> Vec<String> {
988+
Vec::new()
989+
}
990+
991+
fn udafs_names(&self) -> Vec<String> {
992+
Vec::new()
993+
}
994+
995+
fn udwfs_names(&self) -> Vec<String> {
996+
Vec::new()
997+
}
986998
}
987999

9881000
fn create_table_source(fields: Vec<Field>) -> Arc<dyn TableSource> {

datafusion/sql/src/planner.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ pub trait ContextProvider {
8585

8686
/// Get configuration options
8787
fn options(&self) -> &ConfigOptions;
88+
89+
fn udfs_names(&self) -> Vec<String>;
90+
fn udafs_names(&self) -> Vec<String>;
91+
fn udwfs_names(&self) -> Vec<String>;
8892
}
8993

9094
/// SQL parser options

datafusion/sql/tests/sql_integration.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2901,6 +2901,18 @@ impl ContextProvider for MockContextProvider {
29012901
) -> Result<Arc<dyn TableSource>> {
29022902
Ok(Arc::new(EmptyTable::new(schema)))
29032903
}
2904+
2905+
fn udfs_names(&self) -> Vec<String> {
2906+
self.udfs.keys().cloned().collect()
2907+
}
2908+
2909+
fn udafs_names(&self) -> Vec<String> {
2910+
self.udafs.keys().cloned().collect()
2911+
}
2912+
2913+
fn udwfs_names(&self) -> Vec<String> {
2914+
Vec::new()
2915+
}
29042916
}
29052917

29062918
#[test]

datafusion/sqllogictest/test_files/functions.slt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ statement error Did you mean 'arrow_typeof'?
483483
SELECT arrowtypeof(v1) from test;
484484

485485
# Scalar function
486-
statement error Invalid function 'to_timestamps_second'
486+
statement error Did you mean 'to_timestamp_seconds'?
487487
SELECT to_TIMESTAMPS_second(v2) from test;
488488

489489
# Aggregate function

0 commit comments

Comments
 (0)