Skip to content

Commit 1f8ede5

Browse files
authored
fix: cast literal to timestamp (#5517)
* fix: cast literal to timestamp * update tests for all transformation * handle cast between same type * refactor cast_between_timestamp to avoid overflow * handle overflow to None
1 parent 9464bf2 commit 1f8ede5

File tree

1 file changed

+210
-4
lines changed

1 file changed

+210
-4
lines changed

datafusion/optimizer/src/unwrap_cast_in_comparison.rs

Lines changed: 210 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,15 @@ use crate::{OptimizerConfig, OptimizerRule};
2424
use arrow::datatypes::{
2525
DataType, TimeUnit, MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION,
2626
};
27+
use arrow::temporal_conversions::{MICROSECONDS, MILLISECONDS, NANOSECONDS};
2728
use datafusion_common::{DFSchemaRef, DataFusionError, Result, ScalarValue};
2829
use datafusion_expr::expr::{BinaryExpr, Cast, TryCast};
2930
use datafusion_expr::expr_rewriter::{ExprRewriter, RewriteRecursion};
3031
use datafusion_expr::utils::from_plan;
3132
use datafusion_expr::{
3233
binary_expr, in_list, lit, Expr, ExprSchemable, LogicalPlan, Operator,
3334
};
35+
use std::cmp::Ordering;
3436
use std::sync::Arc;
3537

3638
/// [`UnwrapCastInComparison`] attempts to remove casts from
@@ -400,16 +402,36 @@ fn try_cast_literal_to_type(
400402
DataType::UInt32 => ScalarValue::UInt32(Some(value as u32)),
401403
DataType::UInt64 => ScalarValue::UInt64(Some(value as u64)),
402404
DataType::Timestamp(TimeUnit::Second, tz) => {
403-
ScalarValue::TimestampSecond(Some(value as i64), tz.clone())
405+
let value = cast_between_timestamp(
406+
lit_data_type,
407+
DataType::Timestamp(TimeUnit::Second, tz.clone()),
408+
value,
409+
);
410+
ScalarValue::TimestampSecond(value, tz.clone())
404411
}
405412
DataType::Timestamp(TimeUnit::Millisecond, tz) => {
406-
ScalarValue::TimestampMillisecond(Some(value as i64), tz.clone())
413+
let value = cast_between_timestamp(
414+
lit_data_type,
415+
DataType::Timestamp(TimeUnit::Millisecond, tz.clone()),
416+
value,
417+
);
418+
ScalarValue::TimestampMillisecond(value, tz.clone())
407419
}
408420
DataType::Timestamp(TimeUnit::Microsecond, tz) => {
409-
ScalarValue::TimestampMicrosecond(Some(value as i64), tz.clone())
421+
let value = cast_between_timestamp(
422+
lit_data_type,
423+
DataType::Timestamp(TimeUnit::Microsecond, tz.clone()),
424+
value,
425+
);
426+
ScalarValue::TimestampMicrosecond(value, tz.clone())
410427
}
411428
DataType::Timestamp(TimeUnit::Nanosecond, tz) => {
412-
ScalarValue::TimestampNanosecond(Some(value as i64), tz.clone())
429+
let value = cast_between_timestamp(
430+
lit_data_type,
431+
DataType::Timestamp(TimeUnit::Nanosecond, tz.clone()),
432+
value,
433+
);
434+
ScalarValue::TimestampNanosecond(value, tz.clone())
413435
}
414436
DataType::Decimal128(p, s) => {
415437
ScalarValue::Decimal128(Some(value), *p, *s)
@@ -428,6 +450,32 @@ fn try_cast_literal_to_type(
428450
}
429451
}
430452

453+
/// Cast a timestamp value from one unit to another
454+
fn cast_between_timestamp(from: DataType, to: DataType, value: i128) -> Option<i64> {
455+
let value = value as i64;
456+
let from_scale = match from {
457+
DataType::Timestamp(TimeUnit::Second, _) => 1,
458+
DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
459+
DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
460+
DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
461+
_ => return Some(value),
462+
};
463+
464+
let to_scale = match to {
465+
DataType::Timestamp(TimeUnit::Second, _) => 1,
466+
DataType::Timestamp(TimeUnit::Millisecond, _) => MILLISECONDS,
467+
DataType::Timestamp(TimeUnit::Microsecond, _) => MICROSECONDS,
468+
DataType::Timestamp(TimeUnit::Nanosecond, _) => NANOSECONDS,
469+
_ => return Some(value),
470+
};
471+
472+
match from_scale.cmp(&to_scale) {
473+
Ordering::Less => value.checked_mul(to_scale / from_scale),
474+
Ordering::Greater => Some(value / (from_scale / to_scale)),
475+
Ordering::Equal => Some(value),
476+
}
477+
}
478+
431479
#[cfg(test)]
432480
mod tests {
433481
use super::*;
@@ -1070,4 +1118,162 @@ mod tests {
10701118
}
10711119
}
10721120
}
1121+
1122+
#[test]
1123+
fn test_try_cast_literal_to_timestamp() {
1124+
// same timestamp
1125+
let new_scalar = try_cast_literal_to_type(
1126+
&ScalarValue::TimestampNanosecond(Some(123456), None),
1127+
&DataType::Timestamp(TimeUnit::Nanosecond, None),
1128+
)
1129+
.unwrap()
1130+
.unwrap();
1131+
1132+
assert_eq!(
1133+
new_scalar,
1134+
ScalarValue::TimestampNanosecond(Some(123456), None)
1135+
);
1136+
1137+
// TimestampNanosecond to TimestampMicrosecond
1138+
let new_scalar = try_cast_literal_to_type(
1139+
&ScalarValue::TimestampNanosecond(Some(123456), None),
1140+
&DataType::Timestamp(TimeUnit::Microsecond, None),
1141+
)
1142+
.unwrap()
1143+
.unwrap();
1144+
1145+
assert_eq!(
1146+
new_scalar,
1147+
ScalarValue::TimestampMicrosecond(Some(123), None)
1148+
);
1149+
1150+
// TimestampNanosecond to TimestampMillisecond
1151+
let new_scalar = try_cast_literal_to_type(
1152+
&ScalarValue::TimestampNanosecond(Some(123456), None),
1153+
&DataType::Timestamp(TimeUnit::Millisecond, None),
1154+
)
1155+
.unwrap()
1156+
.unwrap();
1157+
1158+
assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
1159+
1160+
// TimestampNanosecond to TimestampSecond
1161+
let new_scalar = try_cast_literal_to_type(
1162+
&ScalarValue::TimestampNanosecond(Some(123456), None),
1163+
&DataType::Timestamp(TimeUnit::Second, None),
1164+
)
1165+
.unwrap()
1166+
.unwrap();
1167+
1168+
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(0), None));
1169+
1170+
// TimestampMicrosecond to TimestampNanosecond
1171+
let new_scalar = try_cast_literal_to_type(
1172+
&ScalarValue::TimestampMicrosecond(Some(123), None),
1173+
&DataType::Timestamp(TimeUnit::Nanosecond, None),
1174+
)
1175+
.unwrap()
1176+
.unwrap();
1177+
1178+
assert_eq!(
1179+
new_scalar,
1180+
ScalarValue::TimestampNanosecond(Some(123000), None)
1181+
);
1182+
1183+
// TimestampMicrosecond to TimestampMillisecond
1184+
let new_scalar = try_cast_literal_to_type(
1185+
&ScalarValue::TimestampMicrosecond(Some(123), None),
1186+
&DataType::Timestamp(TimeUnit::Millisecond, None),
1187+
)
1188+
.unwrap()
1189+
.unwrap();
1190+
1191+
assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(Some(0), None));
1192+
1193+
// TimestampMicrosecond to TimestampSecond
1194+
let new_scalar = try_cast_literal_to_type(
1195+
&ScalarValue::TimestampMicrosecond(Some(123456789), None),
1196+
&DataType::Timestamp(TimeUnit::Second, None),
1197+
)
1198+
.unwrap()
1199+
.unwrap();
1200+
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123), None));
1201+
1202+
// TimestampMillisecond to TimestampNanosecond
1203+
let new_scalar = try_cast_literal_to_type(
1204+
&ScalarValue::TimestampMillisecond(Some(123), None),
1205+
&DataType::Timestamp(TimeUnit::Nanosecond, None),
1206+
)
1207+
.unwrap()
1208+
.unwrap();
1209+
assert_eq!(
1210+
new_scalar,
1211+
ScalarValue::TimestampNanosecond(Some(123000000), None)
1212+
);
1213+
1214+
// TimestampMillisecond to TimestampMicrosecond
1215+
let new_scalar = try_cast_literal_to_type(
1216+
&ScalarValue::TimestampMillisecond(Some(123), None),
1217+
&DataType::Timestamp(TimeUnit::Microsecond, None),
1218+
)
1219+
.unwrap()
1220+
.unwrap();
1221+
assert_eq!(
1222+
new_scalar,
1223+
ScalarValue::TimestampMicrosecond(Some(123000), None)
1224+
);
1225+
// TimestampMillisecond to TimestampSecond
1226+
let new_scalar = try_cast_literal_to_type(
1227+
&ScalarValue::TimestampMillisecond(Some(123456789), None),
1228+
&DataType::Timestamp(TimeUnit::Second, None),
1229+
)
1230+
.unwrap()
1231+
.unwrap();
1232+
assert_eq!(new_scalar, ScalarValue::TimestampSecond(Some(123456), None));
1233+
1234+
// TimestampSecond to TimestampNanosecond
1235+
let new_scalar = try_cast_literal_to_type(
1236+
&ScalarValue::TimestampSecond(Some(123), None),
1237+
&DataType::Timestamp(TimeUnit::Nanosecond, None),
1238+
)
1239+
.unwrap()
1240+
.unwrap();
1241+
assert_eq!(
1242+
new_scalar,
1243+
ScalarValue::TimestampNanosecond(Some(123000000000), None)
1244+
);
1245+
1246+
// TimestampSecond to TimestampMicrosecond
1247+
let new_scalar = try_cast_literal_to_type(
1248+
&ScalarValue::TimestampSecond(Some(123), None),
1249+
&DataType::Timestamp(TimeUnit::Microsecond, None),
1250+
)
1251+
.unwrap()
1252+
.unwrap();
1253+
assert_eq!(
1254+
new_scalar,
1255+
ScalarValue::TimestampMicrosecond(Some(123000000), None)
1256+
);
1257+
1258+
// TimestampSecond to TimestampMillisecond
1259+
let new_scalar = try_cast_literal_to_type(
1260+
&ScalarValue::TimestampSecond(Some(123), None),
1261+
&DataType::Timestamp(TimeUnit::Millisecond, None),
1262+
)
1263+
.unwrap()
1264+
.unwrap();
1265+
assert_eq!(
1266+
new_scalar,
1267+
ScalarValue::TimestampMillisecond(Some(123000), None)
1268+
);
1269+
1270+
// overflow
1271+
let new_scalar = try_cast_literal_to_type(
1272+
&ScalarValue::TimestampSecond(Some(i64::MAX), None),
1273+
&DataType::Timestamp(TimeUnit::Millisecond, None),
1274+
)
1275+
.unwrap()
1276+
.unwrap();
1277+
assert_eq!(new_scalar, ScalarValue::TimestampMillisecond(None, None));
1278+
}
10731279
}

0 commit comments

Comments
 (0)