Skip to content

Commit 3d1d28d

Browse files
authored
fix: Add Int32 type override for Dialects (apache#12916)
* fix: Add Int32 type override for Dialects * fix: Dialect builder with_int32_cast_dtype: * test: Fix with_int32 test
1 parent d8e4e92 commit 3d1d28d

File tree

2 files changed

+54
-1
lines changed

2 files changed

+54
-1
lines changed

datafusion/sql/src/unparser/dialect.rs

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ pub trait Dialect: Send + Sync {
8686
ast::DataType::BigInt(None)
8787
}
8888

89+
/// The SQL type to use for Arrow Int32 unparsing
90+
/// Most dialects use Integer, but some, like MySQL, require SIGNED
91+
fn int32_cast_dtype(&self) -> ast::DataType {
92+
ast::DataType::Integer(None)
93+
}
94+
8995
/// The SQL type to use for Timestamp unparsing
9096
/// Most dialects use Timestamp, but some, like MySQL, require Datetime
9197
/// Some dialects like Dremio does not support WithTimeZone and requires always Timestamp
@@ -282,6 +288,10 @@ impl Dialect for MySqlDialect {
282288
ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![])
283289
}
284290

291+
fn int32_cast_dtype(&self) -> ast::DataType {
292+
ast::DataType::Custom(ObjectName(vec![Ident::new("SIGNED")]), vec![])
293+
}
294+
285295
fn timestamp_cast_dtype(
286296
&self,
287297
_time_unit: &TimeUnit,
@@ -347,6 +357,7 @@ pub struct CustomDialect {
347357
large_utf8_cast_dtype: ast::DataType,
348358
date_field_extract_style: DateFieldExtractStyle,
349359
int64_cast_dtype: ast::DataType,
360+
int32_cast_dtype: ast::DataType,
350361
timestamp_cast_dtype: ast::DataType,
351362
timestamp_tz_cast_dtype: ast::DataType,
352363
date32_cast_dtype: sqlparser::ast::DataType,
@@ -365,6 +376,7 @@ impl Default for CustomDialect {
365376
large_utf8_cast_dtype: ast::DataType::Text,
366377
date_field_extract_style: DateFieldExtractStyle::DatePart,
367378
int64_cast_dtype: ast::DataType::BigInt(None),
379+
int32_cast_dtype: ast::DataType::Integer(None),
368380
timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None),
369381
timestamp_tz_cast_dtype: ast::DataType::Timestamp(
370382
None,
@@ -424,6 +436,10 @@ impl Dialect for CustomDialect {
424436
self.int64_cast_dtype.clone()
425437
}
426438

439+
fn int32_cast_dtype(&self) -> ast::DataType {
440+
self.int32_cast_dtype.clone()
441+
}
442+
427443
fn timestamp_cast_dtype(
428444
&self,
429445
_time_unit: &TimeUnit,
@@ -482,6 +498,7 @@ pub struct CustomDialectBuilder {
482498
large_utf8_cast_dtype: ast::DataType,
483499
date_field_extract_style: DateFieldExtractStyle,
484500
int64_cast_dtype: ast::DataType,
501+
int32_cast_dtype: ast::DataType,
485502
timestamp_cast_dtype: ast::DataType,
486503
timestamp_tz_cast_dtype: ast::DataType,
487504
date32_cast_dtype: ast::DataType,
@@ -506,6 +523,7 @@ impl CustomDialectBuilder {
506523
large_utf8_cast_dtype: ast::DataType::Text,
507524
date_field_extract_style: DateFieldExtractStyle::DatePart,
508525
int64_cast_dtype: ast::DataType::BigInt(None),
526+
int32_cast_dtype: ast::DataType::Integer(None),
509527
timestamp_cast_dtype: ast::DataType::Timestamp(None, TimezoneInfo::None),
510528
timestamp_tz_cast_dtype: ast::DataType::Timestamp(
511529
None,
@@ -527,6 +545,7 @@ impl CustomDialectBuilder {
527545
large_utf8_cast_dtype: self.large_utf8_cast_dtype,
528546
date_field_extract_style: self.date_field_extract_style,
529547
int64_cast_dtype: self.int64_cast_dtype,
548+
int32_cast_dtype: self.int32_cast_dtype,
530549
timestamp_cast_dtype: self.timestamp_cast_dtype,
531550
timestamp_tz_cast_dtype: self.timestamp_tz_cast_dtype,
532551
date32_cast_dtype: self.date32_cast_dtype,
@@ -604,6 +623,12 @@ impl CustomDialectBuilder {
604623
self
605624
}
606625

626+
/// Customize the dialect with a specific SQL type for Int32 casting: Integer, SIGNED, etc.
627+
pub fn with_int32_cast_dtype(mut self, int32_cast_dtype: ast::DataType) -> Self {
628+
self.int32_cast_dtype = int32_cast_dtype;
629+
self
630+
}
631+
607632
/// Customize the dialect with a specific SQL type for Timestamp casting: Timestamp, Datetime, etc.
608633
pub fn with_timestamp_cast_dtype(
609634
mut self,

datafusion/sql/src/unparser/expr.rs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1352,7 +1352,7 @@ impl Unparser<'_> {
13521352
DataType::Boolean => Ok(ast::DataType::Bool),
13531353
DataType::Int8 => Ok(ast::DataType::TinyInt(None)),
13541354
DataType::Int16 => Ok(ast::DataType::SmallInt(None)),
1355-
DataType::Int32 => Ok(ast::DataType::Integer(None)),
1355+
DataType::Int32 => Ok(self.dialect.int32_cast_dtype()),
13561356
DataType::Int64 => Ok(self.dialect.int64_cast_dtype()),
13571357
DataType::UInt8 => Ok(ast::DataType::UnsignedTinyInt(None)),
13581358
DataType::UInt16 => Ok(ast::DataType::UnsignedSmallInt(None)),
@@ -2253,6 +2253,34 @@ mod tests {
22532253
Ok(())
22542254
}
22552255

2256+
#[test]
2257+
fn custom_dialect_with_int32_cast_dtype() -> Result<()> {
2258+
let default_dialect = CustomDialectBuilder::new().build();
2259+
let mysql_dialect = CustomDialectBuilder::new()
2260+
.with_int32_cast_dtype(ast::DataType::Custom(
2261+
ObjectName(vec![Ident::new("SIGNED")]),
2262+
vec![],
2263+
))
2264+
.build();
2265+
2266+
for (dialect, identifier) in
2267+
[(default_dialect, "INTEGER"), (mysql_dialect, "SIGNED")]
2268+
{
2269+
let unparser = Unparser::new(&dialect);
2270+
let expr = Expr::Cast(Cast {
2271+
expr: Box::new(col("a")),
2272+
data_type: DataType::Int32,
2273+
});
2274+
let ast = unparser.expr_to_sql(&expr)?;
2275+
2276+
let actual = format!("{}", ast);
2277+
let expected = format!(r#"CAST(a AS {identifier})"#);
2278+
2279+
assert_eq!(actual, expected);
2280+
}
2281+
Ok(())
2282+
}
2283+
22562284
#[test]
22572285
fn custom_dialect_with_timestamp_cast_dtype() -> Result<()> {
22582286
let default_dialect = CustomDialectBuilder::new().build();

0 commit comments

Comments
 (0)