From 2ee5c99f7120a97d9423ead7aa6855ca92782a92 Mon Sep 17 00:00:00 2001 From: Qianqian <130200611+Sevenannn@users.noreply.github.com> Date: Tue, 12 Nov 2024 17:49:57 -0800 Subject: [PATCH 1/2] Fix duckdb & sqlite character_length scalar unparsing (#59) * Fix duckdb & sqlite character_length scalar unparsing * Add comments * Update CharacterLengthStyle::SQLStandard to CharacterLengthExtractStyle::CharacterLength --- datafusion/sql/src/unparser/dialect.rs | 101 ++++++++++++++++++++++--- datafusion/sql/src/unparser/expr.rs | 31 +++++++- datafusion/sql/src/unparser/utils.rs | 20 ++++- 3 files changed, 140 insertions(+), 12 deletions(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 87ed1b8f4140..9a04f8b4c10b 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -27,7 +27,7 @@ use sqlparser::{ use datafusion_common::Result; -use super::{utils::date_part_to_sql, Unparser}; +use super::{utils::character_length_to_sql, utils::date_part_to_sql, Unparser}; /// `Dialect` to use for Unparsing /// @@ -80,6 +80,11 @@ pub trait Dialect: Send + Sync { DateFieldExtractStyle::DatePart } + /// The character length extraction style to use: `CharacterLengthStyle` + fn character_length_style(&self) -> CharacterLengthStyle { + CharacterLengthStyle::CharacterLength + } + /// The SQL type to use for Arrow Int64 unparsing /// Most dialects use BigInt, but some, like MySQL, require SIGNED fn int64_cast_dtype(&self) -> ast::DataType { @@ -176,6 +181,17 @@ pub enum DateFieldExtractStyle { Strftime, } +/// `CharacterLengthStyle` to use for unparsing +/// +/// Different DBMSs uses different names for function calculating the number of characters in the string +/// `Length` style uses length(x) +/// `SQLStandard` style uses character_length(x) +#[derive(Clone, Copy, PartialEq)] +pub enum CharacterLengthStyle { + Length, + CharacterLength, +} + pub struct DefaultDialect {} impl Dialect for DefaultDialect { @@ -271,6 +287,35 @@ impl PostgreSqlDialect { } } +pub struct DuckDBDialect {} + +impl Dialect for DuckDBDialect { + fn identifier_quote_style(&self, _: &str) -> Option { + Some('"') + } + + fn character_length_style(&self) -> CharacterLengthStyle { + CharacterLengthStyle::Length + } + + fn scalar_function_to_sql_overrides( + &self, + unparser: &Unparser, + func_name: &str, + args: &[Expr], + ) -> Result> { + if func_name == "character_length" { + return character_length_to_sql( + unparser, + self.character_length_style(), + args, + ); + } + + Ok(None) + } +} + pub struct MySqlDialect {} impl Dialect for MySqlDialect { @@ -347,6 +392,10 @@ impl Dialect for SqliteDialect { ast::DataType::Text } + fn character_length_style(&self) -> CharacterLengthStyle { + CharacterLengthStyle::Length + } + fn supports_column_alias_in_table_alias(&self) -> bool { false } @@ -357,11 +406,19 @@ impl Dialect for SqliteDialect { func_name: &str, args: &[Expr], ) -> Result> { - if func_name == "date_part" { - return date_part_to_sql(unparser, self.date_field_extract_style(), args); + match func_name { + "date_part" => { + return date_part_to_sql(unparser, self.date_field_extract_style(), args); + } + "character_length" => { + return character_length_to_sql( + unparser, + self.character_length_style(), + args, + ); + } + _ => return Ok(None), } - - Ok(None) } } @@ -374,6 +431,7 @@ pub struct CustomDialect { utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, date_field_extract_style: DateFieldExtractStyle, + character_length_style: CharacterLengthStyle, int64_cast_dtype: ast::DataType, int32_cast_dtype: ast::DataType, timestamp_cast_dtype: ast::DataType, @@ -395,6 +453,7 @@ impl Default for CustomDialect { utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, + character_length_style: CharacterLengthStyle::CharacterLength, int64_cast_dtype: ast::DataType::BigInt(None), int32_cast_dtype: ast::DataType::Integer(None), timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), @@ -454,6 +513,10 @@ impl Dialect for CustomDialect { self.date_field_extract_style } + fn character_length_style(&self) -> CharacterLengthStyle { + self.character_length_style + } + fn int64_cast_dtype(&self) -> ast::DataType { self.int64_cast_dtype.clone() } @@ -488,11 +551,19 @@ impl Dialect for CustomDialect { func_name: &str, args: &[Expr], ) -> Result> { - if func_name == "date_part" { - return date_part_to_sql(unparser, self.date_field_extract_style(), args); + match func_name { + "date_part" => { + return date_part_to_sql(unparser, self.date_field_extract_style(), args); + } + "character_length" => { + return character_length_to_sql( + unparser, + self.character_length_style(), + args, + ) + } + _ => return Ok(None), } - - Ok(None) } fn requires_derived_table_alias(&self) -> bool { @@ -527,6 +598,7 @@ pub struct CustomDialectBuilder { utf8_cast_dtype: ast::DataType, large_utf8_cast_dtype: ast::DataType, date_field_extract_style: DateFieldExtractStyle, + character_length_style: CharacterLengthStyle, int64_cast_dtype: ast::DataType, int32_cast_dtype: ast::DataType, timestamp_cast_dtype: ast::DataType, @@ -554,6 +626,7 @@ impl CustomDialectBuilder { utf8_cast_dtype: ast::DataType::Varchar(None), large_utf8_cast_dtype: ast::DataType::Text, date_field_extract_style: DateFieldExtractStyle::DatePart, + character_length_style: CharacterLengthStyle::CharacterLength, int64_cast_dtype: ast::DataType::BigInt(None), int32_cast_dtype: ast::DataType::Integer(None), timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None), @@ -578,6 +651,7 @@ impl CustomDialectBuilder { utf8_cast_dtype: self.utf8_cast_dtype, large_utf8_cast_dtype: self.large_utf8_cast_dtype, date_field_extract_style: self.date_field_extract_style, + character_length_style: self.character_length_style, int64_cast_dtype: self.int64_cast_dtype, int32_cast_dtype: self.int32_cast_dtype, timestamp_cast_dtype: self.timestamp_cast_dtype, @@ -620,6 +694,15 @@ impl CustomDialectBuilder { self } + /// Customize the dialect with a specific character_length_style listed in `CharacterLengthStyle` + pub fn with_character_length_style( + mut self, + character_length_style: CharacterLengthStyle, + ) -> Self { + self.character_length_style = character_length_style; + self + } + /// Customize the dialect with a specific SQL type for Float64 casting: DOUBLE, DOUBLE PRECISION, etc. pub fn with_float64_ast_dtype(mut self, float64_ast_dtype: ast::DataType) -> Self { self.float64_ast_dtype = float64_ast_dtype; diff --git a/datafusion/sql/src/unparser/expr.rs b/datafusion/sql/src/unparser/expr.rs index 8f6ffa51f76a..d09bd6e8b90c 100644 --- a/datafusion/sql/src/unparser/expr.rs +++ b/datafusion/sql/src/unparser/expr.rs @@ -1488,8 +1488,8 @@ mod tests { use datafusion_functions_window::row_number::row_number_udwf; use crate::unparser::dialect::{ - CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, Dialect, - PostgreSqlDialect, + CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, + Dialect, PostgreSqlDialect, }; use super::*; @@ -2007,6 +2007,33 @@ mod tests { Ok(()) } + #[test] + fn test_character_length_scalar_to_expr() { + let tests = [ + (CharacterLengthStyle::Length, "length(x)"), + (CharacterLengthStyle::CharacterLength, "character_length(x)"), + ]; + + for (style, expected) in tests { + let dialect = CustomDialectBuilder::new() + .with_character_length_style(style) + .build(); + let unparser = Unparser::new(&dialect); + + let expr = ScalarUDF::new_from_impl( + datafusion_functions::unicode::character_length::CharacterLengthFunc::new( + ), + ) + .call(vec![col("x")]); + + let ast = unparser.expr_to_sql(&expr).expect("to be unparsed"); + + let actual = format!("{ast}"); + + assert_eq!(actual, expected); + } + } + #[test] fn test_interval_scalar_to_expr() { let tests = [ diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index 284956cef195..ec515120ac54 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -28,7 +28,10 @@ use datafusion_expr::{ }; use sqlparser::ast; -use super::{dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, Unparser}; +use super::{ + dialect::CharacterLengthStyle, dialect::DateFieldExtractStyle, + rewrite::TableAliasRewriter, Unparser, +}; /// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists /// prior to encountering a Join, TableScan, or a nested subquery (derived table factor). @@ -445,3 +448,18 @@ pub(crate) fn date_part_to_sql( Ok(None) } + +pub(crate) fn character_length_to_sql( + unparser: &Unparser, + style: CharacterLengthStyle, + character_length_args: &[Expr], +) -> Result> { + let func_name = match style { + CharacterLengthStyle::CharacterLength => "character_length", + CharacterLengthStyle::Length => "length", + }; + + return Ok(Some( + unparser.scalar_function_to_sql(func_name, character_length_args)?, + )); +} From a78fd0dd5ee3ed7f0a5584552ef90a36f797b46b Mon Sep 17 00:00:00 2001 From: Sevenannn Date: Thu, 14 Nov 2024 15:33:50 -0800 Subject: [PATCH 2/2] Fix clippy error --- datafusion/sql/src/unparser/dialect.rs | 20 ++++++-------------- datafusion/sql/src/unparser/utils.rs | 7 ++++--- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/datafusion/sql/src/unparser/dialect.rs b/datafusion/sql/src/unparser/dialect.rs index 9a04f8b4c10b..fbaa402e703c 100644 --- a/datafusion/sql/src/unparser/dialect.rs +++ b/datafusion/sql/src/unparser/dialect.rs @@ -408,16 +408,12 @@ impl Dialect for SqliteDialect { ) -> Result> { match func_name { "date_part" => { - return date_part_to_sql(unparser, self.date_field_extract_style(), args); + date_part_to_sql(unparser, self.date_field_extract_style(), args) } "character_length" => { - return character_length_to_sql( - unparser, - self.character_length_style(), - args, - ); + character_length_to_sql(unparser, self.character_length_style(), args) } - _ => return Ok(None), + _ => Ok(None), } } } @@ -553,16 +549,12 @@ impl Dialect for CustomDialect { ) -> Result> { match func_name { "date_part" => { - return date_part_to_sql(unparser, self.date_field_extract_style(), args); + date_part_to_sql(unparser, self.date_field_extract_style(), args) } "character_length" => { - return character_length_to_sql( - unparser, - self.character_length_style(), - args, - ) + character_length_to_sql(unparser, self.character_length_style(), args) } - _ => return Ok(None), + _ => Ok(None), } } diff --git a/datafusion/sql/src/unparser/utils.rs b/datafusion/sql/src/unparser/utils.rs index ec515120ac54..d0f80da83d63 100644 --- a/datafusion/sql/src/unparser/utils.rs +++ b/datafusion/sql/src/unparser/utils.rs @@ -459,7 +459,8 @@ pub(crate) fn character_length_to_sql( CharacterLengthStyle::Length => "length", }; - return Ok(Some( - unparser.scalar_function_to_sql(func_name, character_length_args)?, - )); + Ok(Some(unparser.scalar_function_to_sql( + func_name, + character_length_args, + )?)) }