Skip to content

Commit 40cd987

Browse files
committed
feat: function name hints for UDFs
1 parent f229dcc commit 40cd987

File tree

11 files changed

+131
-39
lines changed

11 files changed

+131
-39
lines changed

datafusion-cli/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion-examples/examples/rewrite_expr.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use datafusion_sql::sqlparser::dialect::PostgreSqlDialect;
3030
use datafusion_sql::sqlparser::parser::Parser;
3131
use datafusion_sql::TableReference;
3232
use std::any::Any;
33+
use std::collections::HashMap;
3334
use std::sync::Arc;
3435

3536
pub fn main() -> Result<()> {
@@ -223,6 +224,18 @@ impl ContextProvider for MyContextProvider {
223224
fn options(&self) -> &ConfigOptions {
224225
&self.options
225226
}
227+
228+
fn udfs(&self) -> HashMap<String, Arc<ScalarUDF>> {
229+
HashMap::new()
230+
}
231+
232+
fn udafs(&self) -> HashMap<String, Arc<AggregateUDF>> {
233+
HashMap::new()
234+
}
235+
236+
fn udwfs(&self) -> HashMap<String, Arc<WindowUDF>> {
237+
HashMap::new()
238+
}
226239
}
227240

228241
struct MyTableSource {

datafusion/core/src/execution/context/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1993,6 +1993,18 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
19931993
fn options(&self) -> &ConfigOptions {
19941994
self.state.config_options()
19951995
}
1996+
1997+
fn udfs(&self) -> HashMap<String, Arc<ScalarUDF>> {
1998+
self.state.scalar_functions.clone()
1999+
}
2000+
2001+
fn udafs(&self) -> HashMap<String, Arc<AggregateUDF>> {
2002+
self.state.aggregate_functions().clone()
2003+
}
2004+
2005+
fn udwfs(&self) -> HashMap<String, Arc<WindowUDF>> {
2006+
self.state.window_functions().clone()
2007+
}
19962008
}
19972009

19982010
impl FunctionRegistry for SessionState {

datafusion/expr/src/function.rs

Lines changed: 3 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@
1717

1818
//! Function module contains typing and signature for built-in and user defined functions.
1919
20-
use crate::{Accumulator, BuiltinScalarFunction, PartitionEvaluator, Signature};
21-
use crate::{AggregateFunction, BuiltInWindowFunction, ColumnarValue};
20+
use crate::{
21+
Accumulator, BuiltinScalarFunction, ColumnarValue, PartitionEvaluator, Signature,
22+
};
2223
use arrow::datatypes::DataType;
23-
use datafusion_common::utils::datafusion_strsim;
2424
use datafusion_common::Result;
2525
use std::sync::Arc;
26-
use strum::IntoEnumIterator;
2726

2827
/// Scalar function
2928
///
@@ -75,33 +74,3 @@ pub fn return_type(
7574
pub fn signature(fun: &BuiltinScalarFunction) -> Signature {
7675
fun.signature()
7776
}
78-
79-
/// Suggest a valid function based on an invalid input function name
80-
pub fn suggest_valid_function(input_function_name: &str, is_window_func: bool) -> String {
81-
let valid_funcs = if is_window_func {
82-
// All aggregate functions and builtin window functions
83-
AggregateFunction::iter()
84-
.map(|func| func.to_string())
85-
.chain(BuiltInWindowFunction::iter().map(|func| func.to_string()))
86-
.collect()
87-
} else {
88-
// All scalar functions and aggregate functions
89-
BuiltinScalarFunction::iter()
90-
.map(|func| func.to_string())
91-
.chain(AggregateFunction::iter().map(|func| func.to_string()))
92-
.collect()
93-
};
94-
find_closest_match(valid_funcs, input_function_name)
95-
}
96-
97-
/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve)
98-
/// Input `candidates` must not be empty otherwise it will panic
99-
fn find_closest_match(candidates: Vec<String>, target: &str) -> String {
100-
let target = target.to_lowercase();
101-
candidates
102-
.into_iter()
103-
.min_by_key(|candidate| {
104-
datafusion_strsim::levenshtein(&candidate.to_lowercase(), &target)
105-
})
106-
.expect("No candidates provided.") // Panic if `candidates` argument is empty
107-
}

datafusion/optimizer/tests/optimizer_integration.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,18 @@ impl ContextProvider for MyContextProvider {
418418
fn options(&self) -> &ConfigOptions {
419419
&self.options
420420
}
421+
422+
fn udfs(&self) -> HashMap<String, Arc<ScalarUDF>> {
423+
HashMap::new()
424+
}
425+
426+
fn udafs(&self) -> HashMap<String, Arc<AggregateUDF>> {
427+
HashMap::new()
428+
}
429+
430+
fn udwfs(&self) -> HashMap<String, Arc<WindowUDF>> {
431+
HashMap::new()
432+
}
421433
}
422434

423435
struct MyTableSource {

datafusion/sql/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ datafusion-common = { workspace = true, default-features = true }
4343
datafusion-expr = { workspace = true }
4444
log = { workspace = true }
4545
sqlparser = { workspace = true }
46+
strum = { version = "0.26.1", features = ["derive"] }
4647

4748
[dev-dependencies]
4849
ctor = { workspace = true }

datafusion/sql/examples/sql.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,4 +131,16 @@ impl ContextProvider for MyContextProvider {
131131
fn options(&self) -> &ConfigOptions {
132132
&self.options
133133
}
134+
135+
fn udfs(&self) -> HashMap<String, Arc<ScalarUDF>> {
136+
HashMap::new()
137+
}
138+
139+
fn udafs(&self) -> HashMap<String, Arc<AggregateUDF>> {
140+
HashMap::new()
141+
}
142+
143+
fn udwfs(&self) -> HashMap<String, Arc<WindowUDF>> {
144+
HashMap::new()
145+
}
134146
}

datafusion/sql/src/expr/function.rs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,63 @@ use arrow_schema::DataType;
2020
use datafusion_common::{
2121
not_impl_err, plan_datafusion_err, plan_err, DFSchema, Dependency, Result,
2222
};
23-
use datafusion_expr::expr::{ScalarFunction, Unnest};
24-
use datafusion_expr::function::suggest_valid_function;
2523
use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by};
2624
use datafusion_expr::{
27-
expr, AggregateFunction, BuiltinScalarFunction, Expr, ExprSchemable, WindowFrame,
28-
WindowFunctionDefinition,
25+
expr, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFunctionDefinition,
26+
};
27+
use datafusion_expr::{
28+
expr::{ScalarFunction, Unnest},
29+
BuiltInWindowFunction, BuiltinScalarFunction,
2930
};
3031
use sqlparser::ast::{
3132
Expr as SQLExpr, Function as SQLFunction, FunctionArg, FunctionArgExpr, WindowType,
3233
};
3334
use std::str::FromStr;
35+
use strum::IntoEnumIterator;
3436

3537
use super::arrow_cast::ARROW_CAST_NAME;
3638

39+
/// Suggest a valid function based on an invalid input function name
40+
pub fn suggest_valid_function(
41+
input_function_name: &str,
42+
is_window_func: bool,
43+
ctx: &dyn ContextProvider,
44+
) -> String {
45+
let valid_funcs = if is_window_func {
46+
// All aggregate functions and builtin window functions
47+
AggregateFunction::iter()
48+
.map(|func| func.to_string())
49+
.chain(BuiltInWindowFunction::iter().map(|func| func.to_string()))
50+
.collect()
51+
} else {
52+
// All scalar functions and aggregate functions
53+
let mut funcs = Vec::new();
54+
55+
funcs.extend(BuiltinScalarFunction::iter().map(|func| func.to_string()));
56+
funcs.extend(ctx.udfs().into_keys());
57+
funcs.extend(AggregateFunction::iter().map(|func| func.to_string()));
58+
funcs.extend(ctx.udafs().into_keys());
59+
60+
funcs
61+
};
62+
find_closest_match(valid_funcs, input_function_name)
63+
}
64+
65+
/// Find the closest matching string to the target string in the candidates list, using edit distance(case insensitve)
66+
/// Input `candidates` must not be empty otherwise it will panic
67+
fn find_closest_match(candidates: Vec<String>, target: &str) -> String {
68+
let target = target.to_lowercase();
69+
candidates
70+
.into_iter()
71+
.min_by_key(|candidate| {
72+
datafusion_common::utils::datafusion_strsim::levenshtein(
73+
&candidate.to_lowercase(),
74+
&target,
75+
)
76+
})
77+
.expect("No candidates provided.") // Panic if `candidates` argument is empty
78+
}
79+
3780
impl<'a, S: ContextProvider> SqlToRel<'a, S> {
3881
pub(super) fn sql_function_to_expr(
3982
&self,
@@ -211,7 +254,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
211254
}
212255

213256
// Could not find the relevant function, so return an error
214-
let suggested_func_name = suggest_valid_function(&name, is_function_window);
257+
let suggested_func_name =
258+
suggest_valid_function(&name, is_function_window, self.context_provider);
215259
plan_err!("Invalid function '{name}'.\nDid you mean '{suggested_func_name}'?")
216260
}
217261

datafusion/sql/src/expr/mod.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,18 @@ mod tests {
951951
fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
952952
None
953953
}
954+
955+
fn udfs(&self) -> HashMap<String, Arc<ScalarUDF>> {
956+
HashMap::new()
957+
}
958+
959+
fn udafs(&self) -> HashMap<String, Arc<AggregateUDF>> {
960+
HashMap::new()
961+
}
962+
963+
fn udwfs(&self) -> HashMap<String, Arc<WindowUDF>> {
964+
HashMap::new()
965+
}
954966
}
955967

956968
fn create_table_source(fields: Vec<Field>) -> Arc<dyn TableSource> {

datafusion/sql/src/planner.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ pub trait ContextProvider {
8585

8686
/// Get configuration options
8787
fn options(&self) -> &ConfigOptions;
88+
89+
fn udfs(&self) -> HashMap<String, Arc<ScalarUDF>>;
90+
fn udafs(&self) -> HashMap<String, Arc<AggregateUDF>>;
91+
fn udwfs(&self) -> HashMap<String, Arc<WindowUDF>>;
8892
}
8993

9094
/// SQL parser options

datafusion/sql/tests/sql_integration.rs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2907,6 +2907,18 @@ impl ContextProvider for MockContextProvider {
29072907
) -> Result<Arc<dyn TableSource>> {
29082908
Ok(Arc::new(EmptyTable::new(schema)))
29092909
}
2910+
2911+
fn udfs(&self) -> HashMap<String, Arc<ScalarUDF>> {
2912+
self.udfs.clone()
2913+
}
2914+
2915+
fn udafs(&self) -> HashMap<String, Arc<AggregateUDF>> {
2916+
self.udafs.clone()
2917+
}
2918+
2919+
fn udwfs(&self) -> HashMap<String, Arc<WindowUDF>> {
2920+
HashMap::new()
2921+
}
29102922
}
29112923

29122924
#[test]

0 commit comments

Comments
 (0)