15
15
// specific language governing permissions and limitations
16
16
// under the License.
17
17
18
- use std:: sync:: Arc ;
18
+ use std:: { collections :: HashMap , sync:: Arc } ;
19
19
20
20
use arrow_schema:: TimeUnit ;
21
21
use datafusion_common:: Result ;
@@ -29,6 +29,9 @@ use sqlparser::{
29
29
30
30
use super :: { utils:: character_length_to_sql, utils:: date_part_to_sql, Unparser } ;
31
31
32
+ pub type ScalarFnToSqlHandler =
33
+ Box < dyn Fn ( & Unparser , & [ Expr ] ) -> Result < Option < ast:: Expr > > + Send + Sync > ;
34
+
32
35
/// `Dialect` to use for Unparsing
33
36
///
34
37
/// The default dialect tries to avoid quoting identifiers unless necessary (e.g. `a` instead of `"a"`)
@@ -150,6 +153,18 @@ pub trait Dialect: Send + Sync {
150
153
Ok ( None )
151
154
}
152
155
156
+ /// Extends the dialect's default rules for unparsing scalar functions.
157
+ /// This is useful for supporting application-specific UDFs or custom engine extensions.
158
+ fn with_custom_scalar_overrides (
159
+ self ,
160
+ _handlers : Vec < ( & str , ScalarFnToSqlHandler ) > ,
161
+ ) -> Self
162
+ where
163
+ Self : Sized ,
164
+ {
165
+ unimplemented ! ( "Custom scalar overrides are not supported by this dialect yet" ) ;
166
+ }
167
+
153
168
/// Allow to unparse a qualified column with a full qualified name
154
169
/// (e.g. catalog_name.schema_name.table_name.column_name)
155
170
/// Otherwise, the column will be unparsed with only the table name and column name
@@ -305,7 +320,19 @@ impl PostgreSqlDialect {
305
320
}
306
321
}
307
322
308
- pub struct DuckDBDialect { }
323
+ #[ derive( Default ) ]
324
+ pub struct DuckDBDialect {
325
+ custom_scalar_fn_overrides : HashMap < String , ScalarFnToSqlHandler > ,
326
+ }
327
+
328
+ impl DuckDBDialect {
329
+ #[ must_use]
330
+ pub fn new ( ) -> Self {
331
+ Self {
332
+ custom_scalar_fn_overrides : HashMap :: new ( ) ,
333
+ }
334
+ }
335
+ }
309
336
310
337
impl Dialect for DuckDBDialect {
311
338
fn identifier_quote_style ( & self , _: & str ) -> Option < char > {
@@ -320,12 +347,27 @@ impl Dialect for DuckDBDialect {
320
347
BinaryOperator :: DuckIntegerDivide
321
348
}
322
349
350
+ fn with_custom_scalar_overrides (
351
+ mut self ,
352
+ handlers : Vec < ( & str , ScalarFnToSqlHandler ) > ,
353
+ ) -> Self {
354
+ for ( func_name, handler) in handlers {
355
+ self . custom_scalar_fn_overrides
356
+ . insert ( func_name. to_string ( ) , handler) ;
357
+ }
358
+ self
359
+ }
360
+
323
361
fn scalar_function_to_sql_overrides (
324
362
& self ,
325
363
unparser : & Unparser ,
326
364
func_name : & str ,
327
365
args : & [ Expr ] ,
328
366
) -> Result < Option < ast:: Expr > > {
367
+ if let Some ( handler) = self . custom_scalar_fn_overrides . get ( func_name) {
368
+ return handler ( unparser, args) ;
369
+ }
370
+
329
371
if func_name == "character_length" {
330
372
return character_length_to_sql (
331
373
unparser,
0 commit comments