Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix duckdb & sqlite character_length scalar unparsing #13428

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 84 additions & 9 deletions datafusion/sql/src/unparser/dialect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use sqlparser::{

use datafusion_common::Result;

use super::{utils::date_part_to_sql, Unparser};
use super::{utils::character_length_to_sql, utils::date_part_to_sql, Unparser};

/// `Dialect` to use for Unparsing
///
Expand Down Expand Up @@ -80,6 +80,11 @@ pub trait Dialect: Send + Sync {
DateFieldExtractStyle::DatePart
}

/// The character length extraction style to use: `CharacterLengthStyle`
fn character_length_style(&self) -> CharacterLengthStyle {
CharacterLengthStyle::CharacterLength
}

/// The SQL type to use for Arrow Int64 unparsing
/// Most dialects use BigInt, but some, like MySQL, require SIGNED
fn int64_cast_dtype(&self) -> ast::DataType {
Expand Down Expand Up @@ -176,6 +181,17 @@ pub enum DateFieldExtractStyle {
Strftime,
}

/// `CharacterLengthStyle` to use for unparsing
///
/// Different DBMSs uses different names for function calculating the number of characters in the string
/// `Length` style uses length(x)
/// `SQLStandard` style uses character_length(x)
#[derive(Clone, Copy, PartialEq)]
pub enum CharacterLengthStyle {
Length,
CharacterLength,
}

pub struct DefaultDialect {}

impl Dialect for DefaultDialect {
Expand Down Expand Up @@ -271,6 +287,35 @@ impl PostgreSqlDialect {
}
}

pub struct DuckDBDialect {}

impl Dialect for DuckDBDialect {
fn identifier_quote_style(&self, _: &str) -> Option<char> {
Some('"')
}

fn character_length_style(&self) -> CharacterLengthStyle {
CharacterLengthStyle::Length
}

fn scalar_function_to_sql_overrides(
&self,
unparser: &Unparser,
func_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
if func_name == "character_length" {
return character_length_to_sql(
unparser,
self.character_length_style(),
args,
);
}

Ok(None)
}
}

pub struct MySqlDialect {}

impl Dialect for MySqlDialect {
Expand Down Expand Up @@ -347,6 +392,10 @@ impl Dialect for SqliteDialect {
ast::DataType::Text
}

fn character_length_style(&self) -> CharacterLengthStyle {
CharacterLengthStyle::Length
}

fn supports_column_alias_in_table_alias(&self) -> bool {
false
}
Expand All @@ -357,11 +406,15 @@ impl Dialect for SqliteDialect {
func_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
if func_name == "date_part" {
return date_part_to_sql(unparser, self.date_field_extract_style(), args);
match func_name {
"date_part" => {
date_part_to_sql(unparser, self.date_field_extract_style(), args)
}
"character_length" => {
character_length_to_sql(unparser, self.character_length_style(), args)
}
_ => Ok(None),
}

Ok(None)
}
}

Expand All @@ -374,6 +427,7 @@ pub struct CustomDialect {
utf8_cast_dtype: ast::DataType,
large_utf8_cast_dtype: ast::DataType,
date_field_extract_style: DateFieldExtractStyle,
character_length_style: CharacterLengthStyle,
int64_cast_dtype: ast::DataType,
int32_cast_dtype: ast::DataType,
timestamp_cast_dtype: ast::DataType,
Expand All @@ -395,6 +449,7 @@ impl Default for CustomDialect {
utf8_cast_dtype: ast::DataType::Varchar(None),
large_utf8_cast_dtype: ast::DataType::Text,
date_field_extract_style: DateFieldExtractStyle::DatePart,
character_length_style: CharacterLengthStyle::CharacterLength,
int64_cast_dtype: ast::DataType::BigInt(None),
int32_cast_dtype: ast::DataType::Integer(None),
timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None),
Expand Down Expand Up @@ -454,6 +509,10 @@ impl Dialect for CustomDialect {
self.date_field_extract_style
}

fn character_length_style(&self) -> CharacterLengthStyle {
self.character_length_style
}

fn int64_cast_dtype(&self) -> ast::DataType {
self.int64_cast_dtype.clone()
}
Expand Down Expand Up @@ -488,11 +547,15 @@ impl Dialect for CustomDialect {
func_name: &str,
args: &[Expr],
) -> Result<Option<ast::Expr>> {
if func_name == "date_part" {
return date_part_to_sql(unparser, self.date_field_extract_style(), args);
match func_name {
"date_part" => {
date_part_to_sql(unparser, self.date_field_extract_style(), args)
}
"character_length" => {
character_length_to_sql(unparser, self.character_length_style(), args)
}
_ => Ok(None),
}

Ok(None)
}

fn requires_derived_table_alias(&self) -> bool {
Expand Down Expand Up @@ -527,6 +590,7 @@ pub struct CustomDialectBuilder {
utf8_cast_dtype: ast::DataType,
large_utf8_cast_dtype: ast::DataType,
date_field_extract_style: DateFieldExtractStyle,
character_length_style: CharacterLengthStyle,
int64_cast_dtype: ast::DataType,
int32_cast_dtype: ast::DataType,
timestamp_cast_dtype: ast::DataType,
Expand Down Expand Up @@ -554,6 +618,7 @@ impl CustomDialectBuilder {
utf8_cast_dtype: ast::DataType::Varchar(None),
large_utf8_cast_dtype: ast::DataType::Text,
date_field_extract_style: DateFieldExtractStyle::DatePart,
character_length_style: CharacterLengthStyle::CharacterLength,
int64_cast_dtype: ast::DataType::BigInt(None),
int32_cast_dtype: ast::DataType::Integer(None),
timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None),
Expand All @@ -578,6 +643,7 @@ impl CustomDialectBuilder {
utf8_cast_dtype: self.utf8_cast_dtype,
large_utf8_cast_dtype: self.large_utf8_cast_dtype,
date_field_extract_style: self.date_field_extract_style,
character_length_style: self.character_length_style,
int64_cast_dtype: self.int64_cast_dtype,
int32_cast_dtype: self.int32_cast_dtype,
timestamp_cast_dtype: self.timestamp_cast_dtype,
Expand Down Expand Up @@ -620,6 +686,15 @@ impl CustomDialectBuilder {
self
}

/// Customize the dialect with a specific character_length_style listed in `CharacterLengthStyle`
pub fn with_character_length_style(
mut self,
character_length_style: CharacterLengthStyle,
) -> Self {
self.character_length_style = character_length_style;
self
}

/// Customize the dialect with a specific SQL type for Float64 casting: DOUBLE, DOUBLE PRECISION, etc.
pub fn with_float64_ast_dtype(mut self, float64_ast_dtype: ast::DataType) -> Self {
self.float64_ast_dtype = float64_ast_dtype;
Expand Down
31 changes: 29 additions & 2 deletions datafusion/sql/src/unparser/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1488,8 +1488,8 @@ mod tests {
use datafusion_functions_window::row_number::row_number_udwf;

use crate::unparser::dialect::{
CustomDialect, CustomDialectBuilder, DateFieldExtractStyle, Dialect,
PostgreSqlDialect,
CharacterLengthStyle, CustomDialect, CustomDialectBuilder, DateFieldExtractStyle,
Dialect, PostgreSqlDialect,
};

use super::*;
Expand Down Expand Up @@ -2007,6 +2007,33 @@ mod tests {
Ok(())
}

#[test]
fn test_character_length_scalar_to_expr() {
let tests = [
(CharacterLengthStyle::Length, "length(x)"),
(CharacterLengthStyle::CharacterLength, "character_length(x)"),
];

for (style, expected) in tests {
let dialect = CustomDialectBuilder::new()
.with_character_length_style(style)
.build();
let unparser = Unparser::new(&dialect);

let expr = ScalarUDF::new_from_impl(
datafusion_functions::unicode::character_length::CharacterLengthFunc::new(
),
)
.call(vec![col("x")]);

let ast = unparser.expr_to_sql(&expr).expect("to be unparsed");

let actual = format!("{ast}");

assert_eq!(actual, expected);
}
}

#[test]
fn test_interval_scalar_to_expr() {
let tests = [
Expand Down
21 changes: 20 additions & 1 deletion datafusion/sql/src/unparser/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ use datafusion_expr::{
};
use sqlparser::ast;

use super::{dialect::DateFieldExtractStyle, rewrite::TableAliasRewriter, Unparser};
use super::{
dialect::CharacterLengthStyle, dialect::DateFieldExtractStyle,
rewrite::TableAliasRewriter, Unparser,
};

/// Recursively searches children of [LogicalPlan] to find an Aggregate node if exists
/// prior to encountering a Join, TableScan, or a nested subquery (derived table factor).
Expand Down Expand Up @@ -445,3 +448,19 @@ pub(crate) fn date_part_to_sql(

Ok(None)
}

pub(crate) fn character_length_to_sql(
unparser: &Unparser,
style: CharacterLengthStyle,
character_length_args: &[Expr],
) -> Result<Option<ast::Expr>> {
let func_name = match style {
CharacterLengthStyle::CharacterLength => "character_length",
CharacterLengthStyle::Length => "length",
};

Ok(Some(unparser.scalar_function_to_sql(
func_name,
character_length_args,
)?))
}