Skip to content

Commit 80c828b

Browse files
authored
Custom scalar to sql overrides support for DuckDB Unparser dialect (#13915)
* Allow adding custom scalar to sql overrides for DuckDB (#68) * Add unit test: custom_scalar_overrides_duckdb * Move `with_custom_scalar_overrides` definition on `Dialect` trait level
1 parent f379719 commit 80c828b

File tree

2 files changed

+69
-3
lines changed

2 files changed

+69
-3
lines changed

datafusion/sql/src/unparser/dialect.rs

+44-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use std::sync::Arc;
18+
use std::{collections::HashMap, sync::Arc};
1919

2020
use arrow_schema::TimeUnit;
2121
use datafusion_common::Result;
@@ -29,6 +29,9 @@ use sqlparser::{
2929

3030
use super::{utils::character_length_to_sql, utils::date_part_to_sql, Unparser};
3131

32+
pub type ScalarFnToSqlHandler =
33+
Box<dyn Fn(&Unparser, &[Expr]) -> Result<Option<ast::Expr>> + Send + Sync>;
34+
3235
/// `Dialect` to use for Unparsing
3336
///
3437
/// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`)
@@ -150,6 +153,18 @@ pub trait Dialect: Send + Sync {
150153
Ok(None)
151154
}
152155

156+
/// Extends the dialect's default rules for unparsing scalar functions.
157+
/// This is useful for supporting application-specific UDFs or custom engine extensions.
158+
fn with_custom_scalar_overrides(
159+
self,
160+
_handlers: Vec<(&str, ScalarFnToSqlHandler)>,
161+
) -> Self
162+
where
163+
Self: Sized,
164+
{
165+
unimplemented!("Custom scalar overrides are not supported by this dialect yet");
166+
}
167+
153168
/// Allow to unparse a qualified column with a full qualified name
154169
/// (e.g. catalog_name.schema_name.table_name.column_name)
155170
/// Otherwise, the column will be unparsed with only the table name and column name
@@ -305,7 +320,19 @@ impl PostgreSqlDialect {
305320
}
306321
}
307322

308-
pub struct DuckDBDialect {}
323+
#[derive(Default)]
324+
pub struct DuckDBDialect {
325+
custom_scalar_fn_overrides: HashMap<String, ScalarFnToSqlHandler>,
326+
}
327+
328+
impl DuckDBDialect {
329+
#[must_use]
330+
pub fn new() -> Self {
331+
Self {
332+
custom_scalar_fn_overrides: HashMap::new(),
333+
}
334+
}
335+
}
309336

310337
impl Dialect for DuckDBDialect {
311338
fn identifier_quote_style(&self, _: &str) -> Option<char> {
@@ -320,12 +347,27 @@ impl Dialect for DuckDBDialect {
320347
BinaryOperator::DuckIntegerDivide
321348
}
322349

350+
fn with_custom_scalar_overrides(
351+
mut self,
352+
handlers: Vec<(&str, ScalarFnToSqlHandler)>,
353+
) -> Self {
354+
for (func_name, handler) in handlers {
355+
self.custom_scalar_fn_overrides
356+
.insert(func_name.to_string(), handler);
357+
}
358+
self
359+
}
360+
323361
fn scalar_function_to_sql_overrides(
324362
&self,
325363
unparser: &Unparser,
326364
func_name: &str,
327365
args: &[Expr],
328366
) -> Result<Option<ast::Expr>> {
367+
if let Some(handler) = self.custom_scalar_fn_overrides.get(func_name) {
368+
return handler(unparser, args);
369+
}
370+
329371
if func_name == "character_length" {
330372
return character_length_to_sql(
331373
unparser,

datafusion/sql/src/unparser/expr.rs

+25-1
Original file line numberDiff line numberDiff line change
@@ -1636,7 +1636,7 @@ mod tests {
16361636

16371637
use crate::unparser::dialect::{
16381638
CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle,
1639-
Dialect, PostgreSqlDialect,
1639+
Dialect, DuckDBDialect, PostgreSqlDialect, ScalarFnToSqlHandler,
16401640
};
16411641

16421642
use super::*;
@@ -2722,4 +2722,28 @@ mod tests {
27222722

27232723
Ok(())
27242724
}
2725+
2726+
#[test]
2727+
fn test_custom_scalar_overrides_duckdb() -> Result<()> {
2728+
let duckdb_default = DuckDBDialect::new();
2729+
let duckdb_extended = DuckDBDialect::new().with_custom_scalar_overrides(vec![(
2730+
"dummy_udf",
2731+
Box::new(|unparser: &Unparser, args: &[Expr]| {
2732+
unparser.scalar_function_to_sql("smart_udf", args).map(Some)
2733+
}) as ScalarFnToSqlHandler,
2734+
)]);
2735+
2736+
for (dialect, expected) in [
2737+
(duckdb_default, r#"dummy_udf("a", "b")"#),
2738+
(duckdb_extended, r#"smart_udf("a", "b")"#),
2739+
] {
2740+
let unparser = Unparser::new(&dialect);
2741+
let expr =
2742+
ScalarUDF::new_from_impl(DummyUDF::new()).call(vec![col("a"), col("b")]);
2743+
let actual = format!("{}", unparser.expr_to_sql(&expr)?);
2744+
assert_eq!(actual, expected);
2745+
}
2746+
2747+
Ok(())
2748+
}
27252749
}

0 commit comments

Comments
 (0)