@@ -24,13 +24,15 @@ use crate::{OptimizerConfig, OptimizerRule};
24
24
use arrow:: datatypes:: {
25
25
DataType , TimeUnit , MAX_DECIMAL_FOR_EACH_PRECISION , MIN_DECIMAL_FOR_EACH_PRECISION ,
26
26
} ;
27
+ use arrow:: temporal_conversions:: { MICROSECONDS , MILLISECONDS , NANOSECONDS } ;
27
28
use datafusion_common:: { DFSchemaRef , DataFusionError , Result , ScalarValue } ;
28
29
use datafusion_expr:: expr:: { BinaryExpr , Cast , TryCast } ;
29
30
use datafusion_expr:: expr_rewriter:: { ExprRewriter , RewriteRecursion } ;
30
31
use datafusion_expr:: utils:: from_plan;
31
32
use datafusion_expr:: {
32
33
binary_expr, in_list, lit, Expr , ExprSchemable , LogicalPlan , Operator ,
33
34
} ;
35
+ use std:: cmp:: Ordering ;
34
36
use std:: sync:: Arc ;
35
37
36
38
/// [`UnwrapCastInComparison`] attempts to remove casts from
@@ -400,16 +402,36 @@ fn try_cast_literal_to_type(
400
402
DataType :: UInt32 => ScalarValue :: UInt32 ( Some ( value as u32 ) ) ,
401
403
DataType :: UInt64 => ScalarValue :: UInt64 ( Some ( value as u64 ) ) ,
402
404
DataType :: Timestamp ( TimeUnit :: Second , tz) => {
403
- ScalarValue :: TimestampSecond ( Some ( value as i64 ) , tz. clone ( ) )
405
+ let value = cast_between_timestamp (
406
+ lit_data_type,
407
+ DataType :: Timestamp ( TimeUnit :: Second , tz. clone ( ) ) ,
408
+ value,
409
+ ) ;
410
+ ScalarValue :: TimestampSecond ( value, tz. clone ( ) )
404
411
}
405
412
DataType :: Timestamp ( TimeUnit :: Millisecond , tz) => {
406
- ScalarValue :: TimestampMillisecond ( Some ( value as i64 ) , tz. clone ( ) )
413
+ let value = cast_between_timestamp (
414
+ lit_data_type,
415
+ DataType :: Timestamp ( TimeUnit :: Millisecond , tz. clone ( ) ) ,
416
+ value,
417
+ ) ;
418
+ ScalarValue :: TimestampMillisecond ( value, tz. clone ( ) )
407
419
}
408
420
DataType :: Timestamp ( TimeUnit :: Microsecond , tz) => {
409
- ScalarValue :: TimestampMicrosecond ( Some ( value as i64 ) , tz. clone ( ) )
421
+ let value = cast_between_timestamp (
422
+ lit_data_type,
423
+ DataType :: Timestamp ( TimeUnit :: Microsecond , tz. clone ( ) ) ,
424
+ value,
425
+ ) ;
426
+ ScalarValue :: TimestampMicrosecond ( value, tz. clone ( ) )
410
427
}
411
428
DataType :: Timestamp ( TimeUnit :: Nanosecond , tz) => {
412
- ScalarValue :: TimestampNanosecond ( Some ( value as i64 ) , tz. clone ( ) )
429
+ let value = cast_between_timestamp (
430
+ lit_data_type,
431
+ DataType :: Timestamp ( TimeUnit :: Nanosecond , tz. clone ( ) ) ,
432
+ value,
433
+ ) ;
434
+ ScalarValue :: TimestampNanosecond ( value, tz. clone ( ) )
413
435
}
414
436
DataType :: Decimal128 ( p, s) => {
415
437
ScalarValue :: Decimal128 ( Some ( value) , * p, * s)
@@ -428,6 +450,32 @@ fn try_cast_literal_to_type(
428
450
}
429
451
}
430
452
453
+ /// Cast a timestamp value from one unit to another
454
+ fn cast_between_timestamp ( from : DataType , to : DataType , value : i128 ) -> Option < i64 > {
455
+ let value = value as i64 ;
456
+ let from_scale = match from {
457
+ DataType :: Timestamp ( TimeUnit :: Second , _) => 1 ,
458
+ DataType :: Timestamp ( TimeUnit :: Millisecond , _) => MILLISECONDS ,
459
+ DataType :: Timestamp ( TimeUnit :: Microsecond , _) => MICROSECONDS ,
460
+ DataType :: Timestamp ( TimeUnit :: Nanosecond , _) => NANOSECONDS ,
461
+ _ => return Some ( value) ,
462
+ } ;
463
+
464
+ let to_scale = match to {
465
+ DataType :: Timestamp ( TimeUnit :: Second , _) => 1 ,
466
+ DataType :: Timestamp ( TimeUnit :: Millisecond , _) => MILLISECONDS ,
467
+ DataType :: Timestamp ( TimeUnit :: Microsecond , _) => MICROSECONDS ,
468
+ DataType :: Timestamp ( TimeUnit :: Nanosecond , _) => NANOSECONDS ,
469
+ _ => return Some ( value) ,
470
+ } ;
471
+
472
+ match from_scale. cmp ( & to_scale) {
473
+ Ordering :: Less => value. checked_mul ( to_scale / from_scale) ,
474
+ Ordering :: Greater => Some ( value / ( from_scale / to_scale) ) ,
475
+ Ordering :: Equal => Some ( value) ,
476
+ }
477
+ }
478
+
431
479
#[ cfg( test) ]
432
480
mod tests {
433
481
use super :: * ;
@@ -1070,4 +1118,162 @@ mod tests {
1070
1118
}
1071
1119
}
1072
1120
}
1121
+
1122
+ #[ test]
1123
+ fn test_try_cast_literal_to_timestamp ( ) {
1124
+ // same timestamp
1125
+ let new_scalar = try_cast_literal_to_type (
1126
+ & ScalarValue :: TimestampNanosecond ( Some ( 123456 ) , None ) ,
1127
+ & DataType :: Timestamp ( TimeUnit :: Nanosecond , None ) ,
1128
+ )
1129
+ . unwrap ( )
1130
+ . unwrap ( ) ;
1131
+
1132
+ assert_eq ! (
1133
+ new_scalar,
1134
+ ScalarValue :: TimestampNanosecond ( Some ( 123456 ) , None )
1135
+ ) ;
1136
+
1137
+ // TimestampNanosecond to TimestampMicrosecond
1138
+ let new_scalar = try_cast_literal_to_type (
1139
+ & ScalarValue :: TimestampNanosecond ( Some ( 123456 ) , None ) ,
1140
+ & DataType :: Timestamp ( TimeUnit :: Microsecond , None ) ,
1141
+ )
1142
+ . unwrap ( )
1143
+ . unwrap ( ) ;
1144
+
1145
+ assert_eq ! (
1146
+ new_scalar,
1147
+ ScalarValue :: TimestampMicrosecond ( Some ( 123 ) , None )
1148
+ ) ;
1149
+
1150
+ // TimestampNanosecond to TimestampMillisecond
1151
+ let new_scalar = try_cast_literal_to_type (
1152
+ & ScalarValue :: TimestampNanosecond ( Some ( 123456 ) , None ) ,
1153
+ & DataType :: Timestamp ( TimeUnit :: Millisecond , None ) ,
1154
+ )
1155
+ . unwrap ( )
1156
+ . unwrap ( ) ;
1157
+
1158
+ assert_eq ! ( new_scalar, ScalarValue :: TimestampMillisecond ( Some ( 0 ) , None ) ) ;
1159
+
1160
+ // TimestampNanosecond to TimestampSecond
1161
+ let new_scalar = try_cast_literal_to_type (
1162
+ & ScalarValue :: TimestampNanosecond ( Some ( 123456 ) , None ) ,
1163
+ & DataType :: Timestamp ( TimeUnit :: Second , None ) ,
1164
+ )
1165
+ . unwrap ( )
1166
+ . unwrap ( ) ;
1167
+
1168
+ assert_eq ! ( new_scalar, ScalarValue :: TimestampSecond ( Some ( 0 ) , None ) ) ;
1169
+
1170
+ // TimestampMicrosecond to TimestampNanosecond
1171
+ let new_scalar = try_cast_literal_to_type (
1172
+ & ScalarValue :: TimestampMicrosecond ( Some ( 123 ) , None ) ,
1173
+ & DataType :: Timestamp ( TimeUnit :: Nanosecond , None ) ,
1174
+ )
1175
+ . unwrap ( )
1176
+ . unwrap ( ) ;
1177
+
1178
+ assert_eq ! (
1179
+ new_scalar,
1180
+ ScalarValue :: TimestampNanosecond ( Some ( 123000 ) , None )
1181
+ ) ;
1182
+
1183
+ // TimestampMicrosecond to TimestampMillisecond
1184
+ let new_scalar = try_cast_literal_to_type (
1185
+ & ScalarValue :: TimestampMicrosecond ( Some ( 123 ) , None ) ,
1186
+ & DataType :: Timestamp ( TimeUnit :: Millisecond , None ) ,
1187
+ )
1188
+ . unwrap ( )
1189
+ . unwrap ( ) ;
1190
+
1191
+ assert_eq ! ( new_scalar, ScalarValue :: TimestampMillisecond ( Some ( 0 ) , None ) ) ;
1192
+
1193
+ // TimestampMicrosecond to TimestampSecond
1194
+ let new_scalar = try_cast_literal_to_type (
1195
+ & ScalarValue :: TimestampMicrosecond ( Some ( 123456789 ) , None ) ,
1196
+ & DataType :: Timestamp ( TimeUnit :: Second , None ) ,
1197
+ )
1198
+ . unwrap ( )
1199
+ . unwrap ( ) ;
1200
+ assert_eq ! ( new_scalar, ScalarValue :: TimestampSecond ( Some ( 123 ) , None ) ) ;
1201
+
1202
+ // TimestampMillisecond to TimestampNanosecond
1203
+ let new_scalar = try_cast_literal_to_type (
1204
+ & ScalarValue :: TimestampMillisecond ( Some ( 123 ) , None ) ,
1205
+ & DataType :: Timestamp ( TimeUnit :: Nanosecond , None ) ,
1206
+ )
1207
+ . unwrap ( )
1208
+ . unwrap ( ) ;
1209
+ assert_eq ! (
1210
+ new_scalar,
1211
+ ScalarValue :: TimestampNanosecond ( Some ( 123000000 ) , None )
1212
+ ) ;
1213
+
1214
+ // TimestampMillisecond to TimestampMicrosecond
1215
+ let new_scalar = try_cast_literal_to_type (
1216
+ & ScalarValue :: TimestampMillisecond ( Some ( 123 ) , None ) ,
1217
+ & DataType :: Timestamp ( TimeUnit :: Microsecond , None ) ,
1218
+ )
1219
+ . unwrap ( )
1220
+ . unwrap ( ) ;
1221
+ assert_eq ! (
1222
+ new_scalar,
1223
+ ScalarValue :: TimestampMicrosecond ( Some ( 123000 ) , None )
1224
+ ) ;
1225
+ // TimestampMillisecond to TimestampSecond
1226
+ let new_scalar = try_cast_literal_to_type (
1227
+ & ScalarValue :: TimestampMillisecond ( Some ( 123456789 ) , None ) ,
1228
+ & DataType :: Timestamp ( TimeUnit :: Second , None ) ,
1229
+ )
1230
+ . unwrap ( )
1231
+ . unwrap ( ) ;
1232
+ assert_eq ! ( new_scalar, ScalarValue :: TimestampSecond ( Some ( 123456 ) , None ) ) ;
1233
+
1234
+ // TimestampSecond to TimestampNanosecond
1235
+ let new_scalar = try_cast_literal_to_type (
1236
+ & ScalarValue :: TimestampSecond ( Some ( 123 ) , None ) ,
1237
+ & DataType :: Timestamp ( TimeUnit :: Nanosecond , None ) ,
1238
+ )
1239
+ . unwrap ( )
1240
+ . unwrap ( ) ;
1241
+ assert_eq ! (
1242
+ new_scalar,
1243
+ ScalarValue :: TimestampNanosecond ( Some ( 123000000000 ) , None )
1244
+ ) ;
1245
+
1246
+ // TimestampSecond to TimestampMicrosecond
1247
+ let new_scalar = try_cast_literal_to_type (
1248
+ & ScalarValue :: TimestampSecond ( Some ( 123 ) , None ) ,
1249
+ & DataType :: Timestamp ( TimeUnit :: Microsecond , None ) ,
1250
+ )
1251
+ . unwrap ( )
1252
+ . unwrap ( ) ;
1253
+ assert_eq ! (
1254
+ new_scalar,
1255
+ ScalarValue :: TimestampMicrosecond ( Some ( 123000000 ) , None )
1256
+ ) ;
1257
+
1258
+ // TimestampSecond to TimestampMillisecond
1259
+ let new_scalar = try_cast_literal_to_type (
1260
+ & ScalarValue :: TimestampSecond ( Some ( 123 ) , None ) ,
1261
+ & DataType :: Timestamp ( TimeUnit :: Millisecond , None ) ,
1262
+ )
1263
+ . unwrap ( )
1264
+ . unwrap ( ) ;
1265
+ assert_eq ! (
1266
+ new_scalar,
1267
+ ScalarValue :: TimestampMillisecond ( Some ( 123000 ) , None )
1268
+ ) ;
1269
+
1270
+ // overflow
1271
+ let new_scalar = try_cast_literal_to_type (
1272
+ & ScalarValue :: TimestampSecond ( Some ( i64:: MAX ) , None ) ,
1273
+ & DataType :: Timestamp ( TimeUnit :: Millisecond , None ) ,
1274
+ )
1275
+ . unwrap ( )
1276
+ . unwrap ( ) ;
1277
+ assert_eq ! ( new_scalar, ScalarValue :: TimestampMillisecond ( None , None ) ) ;
1278
+ }
1073
1279
}
0 commit comments