Skip to content

Commit 4e38abd

Browse files
authored
unify cast_to function of ScalarValue (#13122)
1 parent d00a089 commit 4e38abd

File tree

2 files changed

+25
-35
lines changed

2 files changed

+25
-35
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ use arrow::{
5858
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano, ScalarBuffer};
5959
use arrow_schema::{UnionFields, UnionMode};
6060

61+
use crate::format::DEFAULT_CAST_OPTIONS;
6162
use half::f16;
6263
pub use struct_builder::ScalarStructBuilder;
6364

@@ -2809,22 +2810,30 @@ impl ScalarValue {
28092810

28102811
/// Try to parse `value` into a ScalarValue of type `target_type`
28112812
pub fn try_from_string(value: String, target_type: &DataType) -> Result<Self> {
2812-
let value = ScalarValue::from(value);
2813-
let cast_options = CastOptions {
2814-
safe: false,
2815-
format_options: Default::default(),
2816-
};
2817-
let cast_arr = cast_with_options(&value.to_array()?, target_type, &cast_options)?;
2818-
ScalarValue::try_from_array(&cast_arr, 0)
2813+
ScalarValue::from(value).cast_to(target_type)
28192814
}
28202815

28212816
/// Try to cast this value to a ScalarValue of type `data_type`
2822-
pub fn cast_to(&self, data_type: &DataType) -> Result<Self> {
2823-
let cast_options = CastOptions {
2824-
safe: false,
2825-
format_options: Default::default(),
2817+
pub fn cast_to(&self, target_type: &DataType) -> Result<Self> {
2818+
self.cast_to_with_options(target_type, &DEFAULT_CAST_OPTIONS)
2819+
}
2820+
2821+
/// Try to cast this value to a ScalarValue of type `data_type` with [`CastOptions`]
2822+
pub fn cast_to_with_options(
2823+
&self,
2824+
target_type: &DataType,
2825+
cast_options: &CastOptions<'static>,
2826+
) -> Result<Self> {
2827+
let scalar_array = match (self, target_type) {
2828+
(
2829+
ScalarValue::Float64(Some(float_ts)),
2830+
DataType::Timestamp(TimeUnit::Nanosecond, None),
2831+
) => ScalarValue::Int64(Some((float_ts * 1_000_000_000_f64).trunc() as i64))
2832+
.to_array()?,
2833+
_ => self.to_array()?,
28262834
};
2827-
let cast_arr = cast_with_options(&self.to_array()?, data_type, &cast_options)?;
2835+
2836+
let cast_arr = cast_with_options(&scalar_array, target_type, cast_options)?;
28282837
ScalarValue::try_from_array(&cast_arr, 0)
28292838
}
28302839

datafusion/expr-common/src/columnar_value.rs

Lines changed: 4 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
2020
use arrow::array::{Array, ArrayRef, NullArray};
2121
use arrow::compute::{kernels, CastOptions};
22-
use arrow::datatypes::{DataType, TimeUnit};
22+
use arrow::datatypes::DataType;
2323
use datafusion_common::format::DEFAULT_CAST_OPTIONS;
2424
use datafusion_common::{internal_err, Result, ScalarValue};
2525
use std::sync::Arc;
@@ -193,28 +193,9 @@ impl ColumnarValue {
193193
ColumnarValue::Array(array) => Ok(ColumnarValue::Array(
194194
kernels::cast::cast_with_options(array, cast_type, &cast_options)?,
195195
)),
196-
ColumnarValue::Scalar(scalar) => {
197-
let scalar_array =
198-
if cast_type == &DataType::Timestamp(TimeUnit::Nanosecond, None) {
199-
if let ScalarValue::Float64(Some(float_ts)) = scalar {
200-
ScalarValue::Int64(Some(
201-
(float_ts * 1_000_000_000_f64).trunc() as i64,
202-
))
203-
.to_array()?
204-
} else {
205-
scalar.to_array()?
206-
}
207-
} else {
208-
scalar.to_array()?
209-
};
210-
let cast_array = kernels::cast::cast_with_options(
211-
&scalar_array,
212-
cast_type,
213-
&cast_options,
214-
)?;
215-
let cast_scalar = ScalarValue::try_from_array(&cast_array, 0)?;
216-
Ok(ColumnarValue::Scalar(cast_scalar))
217-
}
196+
ColumnarValue::Scalar(scalar) => Ok(ColumnarValue::Scalar(
197+
scalar.cast_to_with_options(cast_type, &cast_options)?,
198+
)),
218199
}
219200
}
220201
}

0 commit comments

Comments
 (0)