From 33f3e9c13550d6d390f6128b8d94836938f652b3 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 10 Jul 2024 13:32:40 -0600 Subject: [PATCH 01/68] feat: Create new `datafusion-comet-spark-expr` crate containing Spark-compatible DataFusion expressions (#638) * convert into workspace project * update GitHub actions * update Makefile * fix regression * update target path * update protobuf path in pom.xml * update more paths * add new datafusion-comet-expr crate * rename CometAbsFunc to Abs and add documentation * fix error message * improve error handling * update crate description * remove unused dep * address feedback * finish renaming crate * update README for datafusion-spark-expr * rename crate to datafusion-comet-spark-expr --- Cargo.toml | 38 +++++++++++++++++++++++ README.md | 23 ++++++++++++++ src/abs.rs | 88 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 56 ++++++++++++++++++++++++++++++++++ 4 files changed, 205 insertions(+) create mode 100644 Cargo.toml create mode 100644 README.md create mode 100644 src/abs.rs create mode 100644 src/lib.rs diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 000000000000..d10d04944b76 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,38 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-comet-spark-expr" +description = "DataFusion expressions that emulate Apache Spark's behavior" +version = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +readme = { workspace = true } +license = { workspace = true } +edition = { workspace = true } + +[dependencies] +arrow = { workspace = true } +arrow-schema = { workspace = true } +datafusion = { workspace = true } +datafusion-common = { workspace = true } +datafusion-functions = { workspace = true } + +[lib] +name = "datafusion_comet_spark_expr" +path = "src/lib.rs" diff --git a/README.md b/README.md new file mode 100644 index 000000000000..a7ee7536328e --- /dev/null +++ b/README.md @@ -0,0 +1,23 @@ + + +# datafusion-comet-spark-expr: Spark-compatible Expressions + +This crate provides Apache Spark-compatible expressions for use with DataFusion and is maintained as part of the +[Apache DataFusion Comet](https://github.com/apache/datafusion-comet/) subproject. \ No newline at end of file diff --git a/src/abs.rs b/src/abs.rs new file mode 100644 index 000000000000..198a96e571f3 --- /dev/null +++ b/src/abs.rs @@ -0,0 +1,88 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Spark-compatible implementation of abs function + +use std::{any::Any, sync::Arc}; + +use arrow::datatypes::DataType; +use arrow_schema::ArrowError; + +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature}; +use datafusion_common::DataFusionError; +use datafusion_functions::math; + +use super::{EvalMode, SparkError}; + +/// Spark-compatible ABS expression +#[derive(Debug)] +pub struct Abs { + inner_abs_func: Arc, + eval_mode: EvalMode, + data_type_name: String, +} + +impl Abs { + pub fn new(eval_mode: EvalMode, data_type_name: String) -> Result { + if let EvalMode::Legacy | EvalMode::Ansi = eval_mode { + Ok(Self { + inner_abs_func: math::abs().inner().clone(), + eval_mode, + data_type_name, + }) + } else { + Err(DataFusionError::Execution(format!( + "Invalid EvalMode: \"{:?}\"", + eval_mode + ))) + } + } +} + +impl ScalarUDFImpl for Abs { + fn as_any(&self) -> &dyn Any { + self + } + fn name(&self) -> &str { + "abs" + } + + fn signature(&self) -> &Signature { + self.inner_abs_func.signature() + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + self.inner_abs_func.return_type(arg_types) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + match self.inner_abs_func.invoke(args) { + Err(DataFusionError::ArrowError(ArrowError::ComputeError(msg), _)) + if msg.contains("overflow") => + { + if self.eval_mode == EvalMode::Legacy { + Ok(args[0].clone()) + } else { + Err(DataFusionError::External(Box::new( + SparkError::ArithmeticOverflow(self.data_type_name.clone()), + ))) + } + } + other => other, + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 000000000000..3873754be5b0 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::error::Error; +use std::fmt::{Display, Formatter}; + +pub mod abs; + +/// Spark supports three evaluation modes when evaluating expressions, which affect +/// the behavior when processing input values that are invalid or would result in an +/// error, such as divide by zero errors, and also affects behavior when converting +/// between types. +#[derive(Debug, Hash, PartialEq, Clone, Copy)] +pub enum EvalMode { + /// Legacy is the default behavior in Spark prior to Spark 4.0. This mode silently ignores + /// or replaces errors during SQL operations. Operations resulting in errors (like + /// division by zero) will produce NULL values instead of failing. Legacy mode also + /// enables implicit type conversions. + Legacy, + /// Adheres to the ANSI SQL standard for error handling by throwing exceptions for + /// operations that result in errors. Does not perform implicit type conversions. + Ansi, + /// Same as Ansi mode, except that it converts errors to NULL values without + /// failing the entire query. + Try, +} + +#[derive(Debug)] +pub enum SparkError { + ArithmeticOverflow(String), +} + +impl Error for SparkError {} + +impl Display for SparkError { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + match self { + Self::ArithmeticOverflow(data_type) => + write!(f, "[ARITHMETIC_OVERFLOW] {} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.", data_type) + } + } +} From 96a2f41dd4b6eca1cb0d6bab977fde596c9e22c9 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Thu, 11 Jul 2024 05:39:58 -0600 Subject: [PATCH 02/68] feat: Move `IfExpr` to `spark-expr` crate (#653) --- Cargo.toml | 2 + src/if_expr.rs | 231 +++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 6 +- 3 files changed, 238 insertions(+), 1 deletion(-) create mode 100644 src/if_expr.rs diff --git a/Cargo.toml b/Cargo.toml index d10d04944b76..8bf76dff6e25 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,6 +32,8 @@ arrow-schema = { workspace = true } datafusion = { workspace = true } datafusion-common = { workspace = true } datafusion-functions = { workspace = true } +datafusion-physical-expr = { workspace = true } +datafusion-comet-utils = { workspace = true } [lib] name = "datafusion_comet_spark_expr" diff --git a/src/if_expr.rs b/src/if_expr.rs new file mode 100644 index 000000000000..c04494ec4ffb --- /dev/null +++ b/src/if_expr.rs @@ -0,0 +1,231 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + any::Any, + hash::{Hash, Hasher}, + sync::Arc, +}; + +use arrow::{ + array::*, + compute::{and, is_null, kernels::zip::zip, not, or_kleene}, + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{cast::as_boolean_array, Result}; +use datafusion_physical_expr::PhysicalExpr; + +use datafusion_comet_utils::down_cast_any_ref; + +#[derive(Debug, Hash)] +pub struct IfExpr { + if_expr: Arc, + true_expr: Arc, + false_expr: Arc, +} + +impl std::fmt::Display for IfExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "If [if: {}, true_expr: {}, false_expr: {}]", + self.if_expr, self.true_expr, self.false_expr + ) + } +} + +impl IfExpr { + /// Create a new IF expression + pub fn new( + if_expr: Arc, + true_expr: Arc, + false_expr: Arc, + ) -> Self { + Self { + if_expr, + true_expr, + false_expr, + } + } +} + +impl PhysicalExpr for IfExpr { + /// Return a reference to Any that can be used for down-casting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + let data_type = self.true_expr.data_type(input_schema)?; + Ok(data_type) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + if self.true_expr.nullable(_input_schema)? || self.true_expr.nullable(_input_schema)? { + Ok(true) + } else { + Ok(false) + } + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]); + + // evaluate if condition on batch + let if_value = self.if_expr.evaluate_selection(batch, &remainder)?; + let if_value = if_value.into_array(batch.num_rows())?; + let if_value = + as_boolean_array(&if_value).expect("if expression did not return a BooleanArray"); + + let true_value = self.true_expr.evaluate_selection(batch, if_value)?; + let true_value = true_value.into_array(batch.num_rows())?; + + remainder = and( + &remainder, + &or_kleene(¬(if_value)?, &is_null(if_value)?)?, + )?; + + let false_value = self + .false_expr + .evaluate_selection(batch, &remainder)? + .into_array(batch.num_rows())?; + let current_value = zip(&remainder, &false_value, &true_value)?; + + Ok(ColumnarValue::Array(current_value)) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.if_expr, &self.true_expr, &self.false_expr] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(IfExpr::new( + children[0].clone(), + children[1].clone(), + children[2].clone(), + ))) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.if_expr.hash(&mut s); + self.true_expr.hash(&mut s); + self.false_expr.hash(&mut s); + self.hash(&mut s); + } +} + +impl PartialEq for IfExpr { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.if_expr.eq(&x.if_expr) + && self.true_expr.eq(&x.true_expr) + && self.false_expr.eq(&x.false_expr) + }) + .unwrap_or(false) + } +} + +#[cfg(test)] +mod tests { + use arrow::{array::StringArray, datatypes::*}; + use datafusion::logical_expr::Operator; + use datafusion_common::cast::as_int32_array; + use datafusion_physical_expr::expressions::{binary, col, lit}; + + use super::*; + + /// Create an If expression + fn if_fn( + if_expr: Arc, + true_expr: Arc, + false_expr: Arc, + ) -> Result> { + Ok(Arc::new(IfExpr::new(if_expr, true_expr, false_expr))) + } + + #[test] + fn test_if_1() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + let schema_ref = batch.schema(); + + // if a = 'foo' 123 else 999 + let if_expr = binary( + col("a", &schema_ref)?, + Operator::Eq, + lit("foo"), + &schema_ref, + )?; + let true_expr = lit(123i32); + let false_expr = lit(999i32); + + let expr = if_fn(if_expr, true_expr, false_expr); + let result = expr?.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_int32_array(&result)?; + + let expected = &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(999)]); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn test_if_2() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + let schema_ref = batch.schema(); + + // if a >=1 123 else 999 + let if_expr = binary(col("a", &schema_ref)?, Operator::GtEq, lit(1), &schema_ref)?; + let true_expr = lit(123i32); + let false_expr = lit(999i32); + + let expr = if_fn(if_expr, true_expr, false_expr); + let result = expr?.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_int32_array(&result)?; + + let expected = &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(123)]); + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn test_if_children() { + let if_expr = lit(true); + let true_expr = lit(123i32); + let false_expr = lit(999i32); + + let expr = if_fn(if_expr, true_expr, false_expr).unwrap(); + let children = expr.children(); + assert_eq!(children.len(), 3); + assert_eq!(children[0].to_string(), "true"); + assert_eq!(children[1].to_string(), "123"); + assert_eq!(children[2].to_string(), "999"); + } +} diff --git a/src/lib.rs b/src/lib.rs index 3873754be5b0..c36e8855edf8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -18,7 +18,11 @@ use std::error::Error; use std::fmt::{Display, Formatter}; -pub mod abs; +mod abs; +mod if_expr; + +pub use abs::Abs; +pub use if_expr::IfExpr; /// Spark supports three evaluation modes when evaluating expressions, which affect /// the behavior when processing input values that are invalid or would result in an From 2f22a4dff765bf6e5e77ebed4c2cdc7baf02276a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 12 Jul 2024 05:52:16 -0600 Subject: [PATCH 03/68] chore: Refactoring of CometError/SparkError (#655) --- Cargo.toml | 1 + src/abs.rs | 7 ++--- src/error.rs | 73 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 21 ++------------- 4 files changed, 80 insertions(+), 22 deletions(-) create mode 100644 src/error.rs diff --git a/Cargo.toml b/Cargo.toml index 8bf76dff6e25..4a9b94087321 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ datafusion-common = { workspace = true } datafusion-functions = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-comet-utils = { workspace = true } +thiserror = { workspace = true } [lib] name = "datafusion_comet_spark_expr" diff --git a/src/abs.rs b/src/abs.rs index 198a96e571f3..fa25a7775ae7 100644 --- a/src/abs.rs +++ b/src/abs.rs @@ -77,9 +77,10 @@ impl ScalarUDFImpl for Abs { if self.eval_mode == EvalMode::Legacy { Ok(args[0].clone()) } else { - Err(DataFusionError::External(Box::new( - SparkError::ArithmeticOverflow(self.data_type_name.clone()), - ))) + Err(SparkError::ArithmeticOverflow { + from_type: self.data_type_name.clone(), + } + .into()) } } other => other, diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 000000000000..728a35a9d2e0 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::ArrowError; +use datafusion_common::DataFusionError; + +#[derive(thiserror::Error, Debug)] +pub enum SparkError { + // Note that this message format is based on Spark 3.4 and is more detailed than the message + // returned by Spark 3.3 + #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + because it is malformed. Correct the value as per the syntax, or change its target type. \ + Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastInvalidValue { + value: String, + from_type: String, + to_type: String, + }, + + #[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] + NumericValueOutOfRange { + value: String, + precision: u8, + scale: i8, + }, + + #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastOverFlow { + value: String, + from_type: String, + to_type: String, + }, + + #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + ArithmeticOverflow { from_type: String }, + + #[error("ArrowError: {0}.")] + Arrow(ArrowError), + + #[error("InternalError: {0}.")] + Internal(String), +} + +pub type SparkResult = Result; + +impl From for SparkError { + fn from(value: ArrowError) -> Self { + SparkError::Arrow(value) + } +} + +impl From for DataFusionError { + fn from(value: SparkError) -> Self { + DataFusionError::External(Box::new(value)) + } +} diff --git a/src/lib.rs b/src/lib.rs index c36e8855edf8..57da56f9aca6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. -use std::error::Error; -use std::fmt::{Display, Formatter}; - mod abs; +mod error; mod if_expr; pub use abs::Abs; +pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; /// Spark supports three evaluation modes when evaluating expressions, which affect @@ -42,19 +41,3 @@ pub enum EvalMode { /// failing the entire query. Try, } - -#[derive(Debug)] -pub enum SparkError { - ArithmeticOverflow(String), -} - -impl Error for SparkError {} - -impl Display for SparkError { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - match self { - Self::ArithmeticOverflow(data_type) => - write!(f, "[ARITHMETIC_OVERFLOW] {} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.", data_type) - } - } -} From 11138bb47ee9f15523c8da457d5966bb468653fc Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 12 Jul 2024 13:21:50 -0600 Subject: [PATCH 04/68] chore: Move `cast` to `spark-expr` crate (#654) * refactor in preparation for moving cast to spark-expr crate * errors * move cast to spark-expr crate * machete * refactor errors * clean up imports --- Cargo.toml | 5 + src/cast.rs | 2016 +++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 1 + 3 files changed, 2022 insertions(+) create mode 100644 src/cast.rs diff --git a/Cargo.toml b/Cargo.toml index 4a9b94087321..220417fe8b05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,12 +28,17 @@ edition = { workspace = true } [dependencies] arrow = { workspace = true } +arrow-array = { workspace = true } arrow-schema = { workspace = true } +chrono = { workspace = true } datafusion = { workspace = true } datafusion-common = { workspace = true } datafusion-functions = { workspace = true } +datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-comet-utils = { workspace = true } +num = { workspace = true } +regex = { workspace = true } thiserror = { workspace = true } [lib] diff --git a/src/cast.rs b/src/cast.rs new file mode 100644 index 000000000000..b9cf2790b5d2 --- /dev/null +++ b/src/cast.rs @@ -0,0 +1,2016 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + hash::{Hash, Hasher}, + num::Wrapping, + sync::Arc, +}; + +use arrow::{ + array::{ + cast::AsArray, + types::{Date32Type, Int16Type, Int32Type, Int8Type}, + Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array, + GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait, + PrimitiveArray, + }, + compute::{cast_with_options, unary, CastOptions}, + datatypes::{ + ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type, Int64Type, + TimestampMicrosecondType, + }, + error::ArrowError, + record_batch::RecordBatch, + util::display::FormatOptions, +}; +use arrow_schema::{DataType, Schema}; + +use datafusion_common::{ + cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue, +}; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr::PhysicalExpr; + +use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; +use num::{ + cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num, + ToPrimitive, +}; +use regex::Regex; + +use datafusion_comet_utils::{array_with_timezone, down_cast_any_ref}; + +use crate::{EvalMode, SparkError, SparkResult}; + +static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); + +const MICROS_PER_SECOND: i64 = 1000000; + +static CAST_OPTIONS: CastOptions = CastOptions { + safe: true, + format_options: FormatOptions::new() + .with_timestamp_tz_format(TIMESTAMP_FORMAT) + .with_timestamp_format(TIMESTAMP_FORMAT), +}; + +#[derive(Debug, Hash)] +pub struct Cast { + pub child: Arc, + pub data_type: DataType, + pub eval_mode: EvalMode, + + /// When cast from/to timezone related types, we need timezone, which will be resolved with + /// session local timezone by an analyzer in Spark. + pub timezone: String, +} + +macro_rules! cast_utf8_to_int { + ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ + let len = $array.len(); + let mut cast_array = PrimitiveArray::<$array_type>::builder(len); + for i in 0..len { + if $array.is_null(i) { + cast_array.append_null() + } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + let result: SparkResult = Ok(Arc::new(cast_array.finish()) as ArrayRef); + result + }}; +} + +macro_rules! cast_utf8_to_timestamp { + ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ + let len = $array.len(); + let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC"); + for i in 0..len { + if $array.is_null(i) { + cast_array.append_null() + } else if let Ok(Some(cast_value)) = $cast_method($array.value(i).trim(), $eval_mode) { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef; + result + }}; +} + +macro_rules! cast_float_to_string { + ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{ + + fn cast( + from: &dyn Array, + _eval_mode: EvalMode, + ) -> SparkResult + where + OffsetSize: OffsetSizeTrait, { + let array = from.as_any().downcast_ref::<$output_type>().unwrap(); + + // If the absolute number is less than 10,000,000 and greater or equal than 0.001, the + // result is expressed without scientific notation with at least one digit on either side of + // the decimal point. Otherwise, Spark uses a mantissa followed by E and an + // exponent. The mantissa has an optional leading minus sign followed by one digit to the + // left of the decimal point, and the minimal number of digits greater than zero to the + // right. The exponent has and optional leading minus sign. + // source: https://docs.databricks.com/en/sql/language-manual/functions/cast.html + + const LOWER_SCIENTIFIC_BOUND: $type = 0.001; + const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0; + + let output_array = array + .iter() + .map(|value| match value { + Some(value) if value == <$type>::INFINITY => Ok(Some("Infinity".to_string())), + Some(value) if value == <$type>::NEG_INFINITY => Ok(Some("-Infinity".to_string())), + Some(value) + if (value.abs() < UPPER_SCIENTIFIC_BOUND + && value.abs() >= LOWER_SCIENTIFIC_BOUND) + || value.abs() == 0.0 => + { + let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" }; + + Ok(Some(format!("{value}{trailing_zero}"))) + } + Some(value) + if value.abs() >= UPPER_SCIENTIFIC_BOUND + || value.abs() < LOWER_SCIENTIFIC_BOUND => + { + let formatted = format!("{value:E}"); + + if formatted.contains(".") { + Ok(Some(formatted)) + } else { + // `formatted` is already in scientific notation and can be split up by E + // in order to add the missing trailing 0 which gets removed for numbers with a fraction of 0.0 + let prepare_number: Vec<&str> = formatted.split("E").collect(); + + let coefficient = prepare_number[0]; + + let exponent = prepare_number[1]; + + Ok(Some(format!("{coefficient}.0E{exponent}"))) + } + } + Some(value) => Ok(Some(value.to_string())), + _ => Ok(None), + }) + .collect::, SparkError>>()?; + + Ok(Arc::new(output_array)) + } + + cast::<$offset_type>($from, $eval_mode) + }}; +} + +macro_rules! cast_int_to_int_macro { + ( + $array: expr, + $eval_mode:expr, + $from_arrow_primitive_type: ty, + $to_arrow_primitive_type: ty, + $from_data_type: expr, + $to_native_type: ty, + $spark_from_data_type_name: expr, + $spark_to_data_type_name: expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + let spark_int_literal_suffix = match $from_data_type { + &DataType::Int64 => "L", + &DataType::Int16 => "S", + &DataType::Int8 => "T", + _ => "", + }; + + let output_array = match $eval_mode { + EvalMode::Legacy => cast_array + .iter() + .map(|value| match value { + Some(value) => { + Ok::, SparkError>(Some(value as $to_native_type)) + } + _ => Ok(None), + }) + .collect::, _>>(), + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let res = <$to_native_type>::try_from(value); + if res.is_err() { + Err(cast_overflow( + &(value.to_string() + spark_int_literal_suffix), + $spark_from_data_type_name, + $spark_to_data_type_name, + )) + } else { + Ok::, SparkError>(Some(res.unwrap())) + } + } + _ => Ok(None), + }) + .collect::, _>>(), + }?; + let result: SparkResult = Ok(Arc::new(output_array) as ArrayRef); + result + }}; +} + +// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short. +// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast, +// this can cause unexpected Short/Byte cast results. Replicate this behavior. +macro_rules! cast_float_to_int16_down { + ( + $array:expr, + $eval_mode:expr, + $src_array_type:ty, + $dest_array_type:ty, + $rust_src_type:ty, + $rust_dest_type:ty, + $src_type_str:expr, + $dest_type_str:expr, + $format_str:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::<$src_array_type>() + .expect(concat!("Expected a ", stringify!($src_array_type))); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let is_overflow = value.is_nan() || value.abs() as i32 == i32::MAX; + if is_overflow { + return Err(cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + )); + } + let i32_value = value as i32; + <$rust_dest_type>::try_from(i32_value) + .map_err(|_| { + cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + ) + }) + .map(Some) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let i32_value = value as i32; + Ok::, SparkError>(Some( + i32_value as $rust_dest_type, + )) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +macro_rules! cast_float_to_int32_up { + ( + $array:expr, + $eval_mode:expr, + $src_array_type:ty, + $dest_array_type:ty, + $rust_src_type:ty, + $rust_dest_type:ty, + $src_type_str:expr, + $dest_type_str:expr, + $max_dest_val:expr, + $format_str:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::<$src_array_type>() + .expect(concat!("Expected a ", stringify!($src_array_type))); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let is_overflow = + value.is_nan() || value.abs() as $rust_dest_type == $max_dest_val; + if is_overflow { + return Err(cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + )); + } + Ok(Some(value as $rust_dest_type)) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + Ok::, SparkError>(Some(value as $rust_dest_type)) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short. +// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast, +// this can cause unexpected Short/Byte cast results. Replicate this behavior. +macro_rules! cast_decimal_to_int16_down { + ( + $array:expr, + $eval_mode:expr, + $dest_array_type:ty, + $rust_dest_type:ty, + $dest_type_str:expr, + $precision:expr, + $scale:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::() + .expect(concat!("Expected a Decimal128ArrayType")); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let (truncated, decimal) = (value / divisor, (value % divisor).abs()); + let is_overflow = truncated.abs() > i32::MAX.into(); + if is_overflow { + return Err(cast_overflow( + &format!("{}.{}BD", truncated, decimal), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + )); + } + let i32_value = truncated as i32; + <$rust_dest_type>::try_from(i32_value) + .map_err(|_| { + cast_overflow( + &format!("{}.{}BD", truncated, decimal), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + ) + }) + .map(Some) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let i32_value = (value / divisor) as i32; + Ok::, SparkError>(Some( + i32_value as $rust_dest_type, + )) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +macro_rules! cast_decimal_to_int32_up { + ( + $array:expr, + $eval_mode:expr, + $dest_array_type:ty, + $rust_dest_type:ty, + $dest_type_str:expr, + $max_dest_val:expr, + $precision:expr, + $scale:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::() + .expect(concat!("Expected a Decimal128ArrayType")); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let (truncated, decimal) = (value / divisor, (value % divisor).abs()); + let is_overflow = truncated.abs() > $max_dest_val.into(); + if is_overflow { + return Err(cast_overflow( + &format!("{}.{}BD", truncated, decimal), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + )); + } + Ok(Some(truncated as $rust_dest_type)) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let truncated = value / divisor; + Ok::, SparkError>(Some( + truncated as $rust_dest_type, + )) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +impl Cast { + pub fn new( + child: Arc, + data_type: DataType, + eval_mode: EvalMode, + timezone: String, + ) -> Self { + Self { + child, + data_type, + timezone, + eval_mode, + } + } + + pub fn new_without_timezone( + child: Arc, + data_type: DataType, + eval_mode: EvalMode, + ) -> Self { + Self { + child, + data_type, + timezone: "".to_string(), + eval_mode, + } + } + + fn cast_array(&self, array: ArrayRef) -> DataFusionResult { + let to_type = &self.data_type; + let array = array_with_timezone(array, self.timezone.clone(), Some(to_type))?; + let from_type = array.data_type().clone(); + + // unpack dictionary string arrays first + // TODO: we are unpacking a dictionary-encoded array and then performing + // the cast. We could potentially improve performance here by casting the + // dictionary values directly without unpacking the array first, although this + // would add more complexity to the code + let array = match &from_type { + DataType::Dictionary(key_type, value_type) + if key_type.as_ref() == &DataType::Int32 + && (value_type.as_ref() == &DataType::Utf8 + || value_type.as_ref() == &DataType::LargeUtf8) => + { + cast_with_options(&array, value_type.as_ref(), &CAST_OPTIONS)? + } + _ => array, + }; + let from_type = array.data_type(); + + let cast_result = match (from_type, to_type) { + (DataType::Utf8, DataType::Boolean) => { + Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode) + } + (DataType::LargeUtf8, DataType::Boolean) => { + Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode) + } + (DataType::Utf8, DataType::Timestamp(_, _)) => { + Self::cast_string_to_timestamp(&array, to_type, self.eval_mode) + } + (DataType::Utf8, DataType::Date32) => { + Self::cast_string_to_date(&array, to_type, self.eval_mode) + } + (DataType::Int64, DataType::Int32) + | (DataType::Int64, DataType::Int16) + | (DataType::Int64, DataType::Int8) + | (DataType::Int32, DataType::Int16) + | (DataType::Int32, DataType::Int8) + | (DataType::Int16, DataType::Int8) + if self.eval_mode != EvalMode::Try => + { + Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type) + } + ( + DataType::Utf8, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, + ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode), + ( + DataType::LargeUtf8, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, + ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode), + (DataType::Float64, DataType::Utf8) => { + Self::spark_cast_float64_to_utf8::(&array, self.eval_mode) + } + (DataType::Float64, DataType::LargeUtf8) => { + Self::spark_cast_float64_to_utf8::(&array, self.eval_mode) + } + (DataType::Float32, DataType::Utf8) => { + Self::spark_cast_float32_to_utf8::(&array, self.eval_mode) + } + (DataType::Float32, DataType::LargeUtf8) => { + Self::spark_cast_float32_to_utf8::(&array, self.eval_mode) + } + (DataType::Float32, DataType::Decimal128(precision, scale)) => { + Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode) + } + (DataType::Float64, DataType::Decimal128(precision, scale)) => { + Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode) + } + (DataType::Float32, DataType::Int8) + | (DataType::Float32, DataType::Int16) + | (DataType::Float32, DataType::Int32) + | (DataType::Float32, DataType::Int64) + | (DataType::Float64, DataType::Int8) + | (DataType::Float64, DataType::Int16) + | (DataType::Float64, DataType::Int32) + | (DataType::Float64, DataType::Int64) + | (DataType::Decimal128(_, _), DataType::Int8) + | (DataType::Decimal128(_, _), DataType::Int16) + | (DataType::Decimal128(_, _), DataType::Int32) + | (DataType::Decimal128(_, _), DataType::Int64) + if self.eval_mode != EvalMode::Try => + { + Self::spark_cast_nonintegral_numeric_to_integral( + &array, + self.eval_mode, + from_type, + to_type, + ) + } + _ if Self::is_datafusion_spark_compatible(from_type, to_type) => { + // use DataFusion cast only when we know that it is compatible with Spark + Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) + } + _ => { + // we should never reach this code because the Scala code should be checking + // for supported cast operations and falling back to Spark for anything that + // is not yet supported + Err(SparkError::Internal(format!( + "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}" + ))) + } + }; + Ok(spark_cast(cast_result?, from_type, to_type)) + } + + /// Determines if DataFusion supports the given cast in a way that is + /// compatible with Spark + fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { + if from_type == to_type { + return true; + } + match from_type { + DataType::Boolean => matches!( + to_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + ), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + // note that the cast from Int32/Int64 -> Decimal128 here is actually + // not compatible with Spark (no overflow checks) but we have tests that + // rely on this cast working so we have to leave it here for now + matches!( + to_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Utf8 + ) + } + DataType::Float32 | DataType::Float64 => matches!( + to_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ), + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!( + to_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + ), + DataType::Utf8 => matches!(to_type, DataType::Binary), + DataType::Date32 => matches!(to_type, DataType::Utf8), + DataType::Timestamp(_, _) => { + matches!( + to_type, + DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) + ) + } + DataType::Binary => { + // note that this is not completely Spark compatible because + // DataFusion only supports binary data containing valid UTF-8 strings + matches!(to_type, DataType::Utf8) + } + _ => false, + } + } + + fn cast_string_to_int( + to_type: &DataType, + array: &ArrayRef, + eval_mode: EvalMode, + ) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("cast_string_to_int expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Int8 => { + cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)? + } + DataType::Int16 => { + cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)? + } + DataType::Int32 => { + cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)? + } + DataType::Int64 => { + cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)? + } + dt => unreachable!( + "{}", + format!("invalid integer type {dt} in cast from string") + ), + }; + Ok(cast_array) + } + + fn cast_string_to_date( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, + ) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Date32 => { + let len = string_array.len(); + let mut cast_array = PrimitiveArray::::builder(len); + for i in 0..len { + if !string_array.is_null(i) { + match date_parser(string_array.value(i), eval_mode) { + Ok(Some(cast_value)) => cast_array.append_value(cast_value), + Ok(None) => cast_array.append_null(), + Err(e) => return Err(e), + } + } else { + cast_array.append_null() + } + } + Arc::new(cast_array.finish()) as ArrayRef + } + _ => unreachable!("Invalid data type {:?} in cast from string", to_type), + }; + Ok(cast_array) + } + + fn cast_string_to_timestamp( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, + ) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Timestamp(_, _) => { + cast_utf8_to_timestamp!( + string_array, + eval_mode, + TimestampMicrosecondType, + timestamp_parser + ) + } + _ => unreachable!("Invalid data type {:?} in cast from string", to_type), + }; + Ok(cast_array) + } + + fn cast_float64_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, + ) -> SparkResult { + Self::cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) + } + + fn cast_float32_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, + ) -> SparkResult { + Self::cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) + } + + fn cast_floating_point_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, + ) -> SparkResult + where + ::Native: AsPrimitive, + { + let input = array.as_any().downcast_ref::>().unwrap(); + let mut cast_array = PrimitiveArray::::builder(input.len()); + + let mul = 10_f64.powi(scale as i32); + + for i in 0..input.len() { + if input.is_null(i) { + cast_array.append_null(); + } else { + let input_value = input.value(i).as_(); + let value = (input_value * mul).round().to_i128(); + + match value { + Some(v) => { + if Decimal128Type::validate_decimal_precision(v, precision).is_err() { + if eval_mode == EvalMode::Ansi { + return Err(SparkError::NumericValueOutOfRange { + value: input_value.to_string(), + precision, + scale, + }); + } else { + cast_array.append_null(); + } + } + cast_array.append_value(v); + } + None => { + if eval_mode == EvalMode::Ansi { + return Err(SparkError::NumericValueOutOfRange { + value: input_value.to_string(), + precision, + scale, + }); + } else { + cast_array.append_null(); + } + } + } + } + } + + let res = Arc::new( + cast_array + .with_precision_and_scale(precision, scale)? + .finish(), + ) as ArrayRef; + Ok(res) + } + + fn spark_cast_float64_to_utf8( + from: &dyn Array, + _eval_mode: EvalMode, + ) -> SparkResult + where + OffsetSize: OffsetSizeTrait, + { + cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize) + } + + fn spark_cast_float32_to_utf8( + from: &dyn Array, + _eval_mode: EvalMode, + ) -> SparkResult + where + OffsetSize: OffsetSizeTrait, + { + cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize) + } + + fn spark_cast_int_to_int( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, + ) -> SparkResult { + match (from_type, to_type) { + (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT" + ), + (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT" + ), + (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT" + ), + (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!( + array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT" + ), + (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT" + ), + (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT" + ), + _ => unreachable!( + "{}", + format!("invalid integer type {to_type} in cast from {from_type}") + ), + } + } + + fn spark_cast_utf8_to_boolean( + from: &dyn Array, + eval_mode: EvalMode, + ) -> SparkResult + where + OffsetSize: OffsetSizeTrait, + { + let array = from + .as_any() + .downcast_ref::>() + .unwrap(); + + let output_array = array + .iter() + .map(|value| match value { + Some(value) => match value.to_ascii_lowercase().trim() { + "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)), + "f" | "false" | "n" | "no" | "0" => Ok(Some(false)), + _ if eval_mode == EvalMode::Ansi => Err(SparkError::CastInvalidValue { + value: value.to_string(), + from_type: "STRING".to_string(), + to_type: "BOOLEAN".to_string(), + }), + _ => Ok(None), + }, + _ => Ok(None), + }) + .collect::>()?; + + Ok(Arc::new(output_array)) + } + + fn spark_cast_nonintegral_numeric_to_integral( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, + ) -> SparkResult { + match (from_type, to_type) { + (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!( + array, + eval_mode, + Float32Array, + Int8Array, + f32, + i8, + "FLOAT", + "TINYINT", + "{:e}" + ), + (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!( + array, + eval_mode, + Float32Array, + Int16Array, + f32, + i16, + "FLOAT", + "SMALLINT", + "{:e}" + ), + (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!( + array, + eval_mode, + Float32Array, + Int32Array, + f32, + i32, + "FLOAT", + "INT", + i32::MAX, + "{:e}" + ), + (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!( + array, + eval_mode, + Float32Array, + Int64Array, + f32, + i64, + "FLOAT", + "BIGINT", + i64::MAX, + "{:e}" + ), + (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!( + array, + eval_mode, + Float64Array, + Int8Array, + f64, + i8, + "DOUBLE", + "TINYINT", + "{:e}D" + ), + (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!( + array, + eval_mode, + Float64Array, + Int16Array, + f64, + i16, + "DOUBLE", + "SMALLINT", + "{:e}D" + ), + (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!( + array, + eval_mode, + Float64Array, + Int32Array, + f64, + i32, + "DOUBLE", + "INT", + i32::MAX, + "{:e}D" + ), + (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!( + array, + eval_mode, + Float64Array, + Int64Array, + f64, + i64, + "DOUBLE", + "BIGINT", + i64::MAX, + "{:e}D" + ), + (DataType::Decimal128(precision, scale), DataType::Int8) => { + cast_decimal_to_int16_down!( + array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int16) => { + cast_decimal_to_int16_down!( + array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int32) => { + cast_decimal_to_int32_up!( + array, + eval_mode, + Int32Array, + i32, + "INT", + i32::MAX, + *precision, + *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int64) => { + cast_decimal_to_int32_up!( + array, + eval_mode, + Int64Array, + i64, + "BIGINT", + i64::MAX, + *precision, + *scale + ) + } + _ => unreachable!( + "{}", + format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}") + ), + } + } +} + +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte +fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult> { + Ok(cast_string_to_int_with_range_check( + str, + eval_mode, + "TINYINT", + i8::MIN as i32, + i8::MAX as i32, + )? + .map(|v| v as i8)) +} + +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort +fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult> { + Ok(cast_string_to_int_with_range_check( + str, + eval_mode, + "SMALLINT", + i16::MIN as i32, + i16::MAX as i32, + )? + .map(|v| v as i16)) +} + +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper) +fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult> { + do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN) +} + +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper intWrapper) +fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> SparkResult> { + do_cast_string_to_int::(str, eval_mode, "BIGINT", i64::MIN) +} + +fn cast_string_to_int_with_range_check( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min: i32, + max: i32, +) -> SparkResult> { + match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? { + None => Ok(None), + Some(v) if v >= min && v <= max => Ok(Some(v)), + _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), + _ => Ok(None), + } +} + +/// Equivalent to +/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal) +/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal) +fn do_cast_string_to_int< + T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From + Copy, +>( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min_value: T, +) -> SparkResult> { + let trimmed_str = str.trim(); + if trimmed_str.is_empty() { + return none_or_err(eval_mode, type_name, str); + } + let len = trimmed_str.len(); + let mut result: T = T::zero(); + let mut negative = false; + let radix = T::from(10); + let stop_value = min_value / radix; + let mut parse_sign_and_digits = true; + + for (i, ch) in trimmed_str.char_indices() { + if parse_sign_and_digits { + if i == 0 { + negative = ch == '-'; + let positive = ch == '+'; + if negative || positive { + if i + 1 == len { + // input string is just "+" or "-" + return none_or_err(eval_mode, type_name, str); + } + // consume this char + continue; + } + } + + if ch == '.' { + if eval_mode == EvalMode::Legacy { + // truncate decimal in legacy mode + parse_sign_and_digits = false; + continue; + } else { + return none_or_err(eval_mode, type_name, str); + } + } + + let digit = if ch.is_ascii_digit() { + (ch as u32) - ('0' as u32) + } else { + return none_or_err(eval_mode, type_name, str); + }; + + // We are going to process the new digit and accumulate the result. However, before + // doing this, if the result is already smaller than the + // stopValue(Integer.MIN_VALUE / radix), then result * 10 will definitely be + // smaller than minValue, and we can stop + if result < stop_value { + return none_or_err(eval_mode, type_name, str); + } + + // Since the previous result is greater than or equal to stopValue(Integer.MIN_VALUE / + // radix), we can just use `result > 0` to check overflow. If result + // overflows, we should stop + let v = result * radix; + let digit = (digit as i32).into(); + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => { + return none_or_err(eval_mode, type_name, str); + } + } + } else { + // make sure fractional digits are valid digits but ignore them + if !ch.is_ascii_digit() { + return none_or_err(eval_mode, type_name, str); + } + } + } + + if !negative { + if let Some(neg) = result.checked_neg() { + if neg < T::zero() { + return none_or_err(eval_mode, type_name, str); + } + result = neg; + } else { + return none_or_err(eval_mode, type_name, str); + } + } + + Ok(Some(result)) +} + +/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode +#[inline] +fn none_or_err(eval_mode: EvalMode, type_name: &str, str: &str) -> SparkResult> { + match eval_mode { + EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), + _ => Ok(None), + } +} + +#[inline] +fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError { + SparkError::CastInvalidValue { + value: value.to_string(), + from_type: from_type.to_string(), + to_type: to_type.to_string(), + } +} + +#[inline] +fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError { + SparkError::CastOverFlow { + value: value.to_string(), + from_type: from_type.to_string(), + to_type: to_type.to_string(), + } +} + +impl Display for Cast { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Cast [data_type: {}, timezone: {}, child: {}, eval_mode: {:?}]", + self.data_type, self.timezone, self.child, &self.eval_mode + ) + } +} + +impl PartialEq for Cast { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.child.eq(&x.child) + && self.timezone.eq(&x.timezone) + && self.data_type.eq(&x.data_type) + && self.eval_mode.eq(&x.eval_mode) + }) + .unwrap_or(false) + } +} + +impl PhysicalExpr for Cast { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _: &Schema) -> DataFusionResult { + Ok(self.data_type.clone()) + } + + fn nullable(&self, _: &Schema) -> DataFusionResult { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(self.cast_array(array)?)), + ColumnarValue::Scalar(scalar) => { + // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for + // some cases e.g., scalar subquery, Spark will not fold it, so we need to handle it + // here. + let array = scalar.to_array()?; + let scalar = ScalarValue::try_from_array(&self.cast_array(array)?, 0)?; + Ok(ColumnarValue::Scalar(scalar)) + } + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + match children.len() { + 1 => Ok(Arc::new(Cast::new( + children[0].clone(), + self.data_type.clone(), + self.eval_mode, + self.timezone.clone(), + ))), + _ => internal_err!("Cast should have exactly one child"), + } + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.child.hash(&mut s); + self.data_type.hash(&mut s); + self.timezone.hash(&mut s); + self.eval_mode.hash(&mut s); + self.hash(&mut s); + } +} + +fn timestamp_parser(value: &str, eval_mode: EvalMode) -> SparkResult> { + let value = value.trim(); + if value.is_empty() { + return Ok(None); + } + // Define regex patterns and corresponding parsing functions + let patterns = &[ + ( + Regex::new(r"^\d{4}$").unwrap(), + parse_str_to_year_timestamp as fn(&str) -> SparkResult>, + ), + ( + Regex::new(r"^\d{4}-\d{2}$").unwrap(), + parse_str_to_month_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}$").unwrap(), + parse_str_to_day_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{1,2}$").unwrap(), + parse_str_to_hour_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(), + parse_str_to_minute_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(), + parse_str_to_second_timestamp, + ), + ( + Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(), + parse_str_to_microsecond_timestamp, + ), + ( + Regex::new(r"^T\d{1,2}$").unwrap(), + parse_str_to_time_only_timestamp, + ), + ]; + + let mut timestamp = None; + + // Iterate through patterns and try matching + for (pattern, parse_func) in patterns { + if pattern.is_match(value) { + timestamp = parse_func(value)?; + break; + } + } + + if timestamp.is_none() { + return if eval_mode == EvalMode::Ansi { + Err(SparkError::CastInvalidValue { + value: value.to_string(), + from_type: "STRING".to_string(), + to_type: "TIMESTAMP".to_string(), + }) + } else { + Ok(None) + }; + } + + match timestamp { + Some(ts) => Ok(Some(ts)), + None => Err(SparkError::Internal( + "Failed to parse timestamp".to_string(), + )), + } +} + +fn parse_ymd_timestamp(year: i32, month: u32, day: u32) -> SparkResult> { + let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, 0, 0, 0); + + // Check if datetime is not None + let utc_datetime = match datetime.single() { + Some(dt) => dt.with_timezone(&chrono::Utc), + None => { + return Err(SparkError::Internal( + "Failed to parse timestamp".to_string(), + )); + } + }; + + Ok(Some(utc_datetime.timestamp_micros())) +} + +fn parse_hms_timestamp( + year: i32, + month: u32, + day: u32, + hour: u32, + minute: u32, + second: u32, + microsecond: u32, +) -> SparkResult> { + let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, hour, minute, second); + + // Check if datetime is not None + let utc_datetime = match datetime.single() { + Some(dt) => dt + .with_timezone(&chrono::Utc) + .with_nanosecond(microsecond * 1000), + None => { + return Err(SparkError::Internal( + "Failed to parse timestamp".to_string(), + )); + } + }; + + let result = match utc_datetime { + Some(dt) => dt.timestamp_micros(), + None => { + return Err(SparkError::Internal( + "Failed to parse timestamp".to_string(), + )); + } + }; + + Ok(Some(result)) +} + +fn get_timestamp_values(value: &str, timestamp_type: &str) -> SparkResult> { + let values: Vec<_> = value + .split(|c| c == 'T' || c == '-' || c == ':' || c == '.') + .collect(); + let year = values[0].parse::().unwrap_or_default(); + let month = values.get(1).map_or(1, |m| m.parse::().unwrap_or(1)); + let day = values.get(2).map_or(1, |d| d.parse::().unwrap_or(1)); + let hour = values.get(3).map_or(0, |h| h.parse::().unwrap_or(0)); + let minute = values.get(4).map_or(0, |m| m.parse::().unwrap_or(0)); + let second = values.get(5).map_or(0, |s| s.parse::().unwrap_or(0)); + let microsecond = values.get(6).map_or(0, |ms| ms.parse::().unwrap_or(0)); + + match timestamp_type { + "year" => parse_ymd_timestamp(year, 1, 1), + "month" => parse_ymd_timestamp(year, month, 1), + "day" => parse_ymd_timestamp(year, month, day), + "hour" => parse_hms_timestamp(year, month, day, hour, 0, 0, 0), + "minute" => parse_hms_timestamp(year, month, day, hour, minute, 0, 0), + "second" => parse_hms_timestamp(year, month, day, hour, minute, second, 0), + "microsecond" => parse_hms_timestamp(year, month, day, hour, minute, second, microsecond), + _ => Err(SparkError::CastInvalidValue { + value: value.to_string(), + from_type: "STRING".to_string(), + to_type: "TIMESTAMP".to_string(), + }), + } +} + +fn parse_str_to_year_timestamp(value: &str) -> SparkResult> { + get_timestamp_values(value, "year") +} + +fn parse_str_to_month_timestamp(value: &str) -> SparkResult> { + get_timestamp_values(value, "month") +} + +fn parse_str_to_day_timestamp(value: &str) -> SparkResult> { + get_timestamp_values(value, "day") +} + +fn parse_str_to_hour_timestamp(value: &str) -> SparkResult> { + get_timestamp_values(value, "hour") +} + +fn parse_str_to_minute_timestamp(value: &str) -> SparkResult> { + get_timestamp_values(value, "minute") +} + +fn parse_str_to_second_timestamp(value: &str) -> SparkResult> { + get_timestamp_values(value, "second") +} + +fn parse_str_to_microsecond_timestamp(value: &str) -> SparkResult> { + get_timestamp_values(value, "microsecond") +} + +fn parse_str_to_time_only_timestamp(value: &str) -> SparkResult> { + let values: Vec<&str> = value.split('T').collect(); + let time_values: Vec = values[1] + .split(':') + .map(|v| v.parse::().unwrap_or(0)) + .collect(); + + let datetime = chrono::Utc::now(); + let timestamp = datetime + .with_hour(time_values.first().copied().unwrap_or_default()) + .and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0))) + .and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0))) + .and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 1_000)) + .map(|dt| dt.to_utc().timestamp_micros()) + .unwrap_or_default(); + + Ok(Some(timestamp)) +} + +//a string to date parser - port of spark's SparkDateTimeUtils#stringToDate. +fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult> { + // local functions + fn get_trimmed_start(bytes: &[u8]) -> usize { + let mut start = 0; + while start < bytes.len() && is_whitespace_or_iso_control(bytes[start]) { + start += 1; + } + start + } + + fn get_trimmed_end(start: usize, bytes: &[u8]) -> usize { + let mut end = bytes.len() - 1; + while end > start && is_whitespace_or_iso_control(bytes[end]) { + end -= 1; + } + end + 1 + } + + fn is_whitespace_or_iso_control(byte: u8) -> bool { + byte.is_ascii_whitespace() || byte.is_ascii_control() + } + + fn is_valid_digits(segment: i32, digits: usize) -> bool { + // An integer is able to represent a date within [+-]5 million years. + let max_digits_year = 7; + //year (segment 0) can be between 4 to 7 digits, + //month and day (segment 1 and 2) can be between 1 to 2 digits + (segment == 0 && digits >= 4 && digits <= max_digits_year) + || (segment != 0 && digits > 0 && digits <= 2) + } + + fn return_result(date_str: &str, eval_mode: EvalMode) -> SparkResult> { + if eval_mode == EvalMode::Ansi { + Err(SparkError::CastInvalidValue { + value: date_str.to_string(), + from_type: "STRING".to_string(), + to_type: "DATE".to_string(), + }) + } else { + Ok(None) + } + } + // end local functions + + if date_str.is_empty() { + return return_result(date_str, eval_mode); + } + + //values of date segments year, month and day defaulting to 1 + let mut date_segments = [1, 1, 1]; + let mut sign = 1; + let mut current_segment = 0; + let mut current_segment_value = Wrapping(0); + let mut current_segment_digits = 0; + let bytes = date_str.as_bytes(); + + let mut j = get_trimmed_start(bytes); + let str_end_trimmed = get_trimmed_end(j, bytes); + + if j == str_end_trimmed { + return return_result(date_str, eval_mode); + } + + //assign a sign to the date + if bytes[j] == b'-' || bytes[j] == b'+' { + sign = if bytes[j] == b'-' { -1 } else { 1 }; + j += 1; + } + + //loop to the end of string until we have processed 3 segments, + //exit loop on encountering any space ' ' or 'T' after the 3rd segment + while j < str_end_trimmed && (current_segment < 3 && !(bytes[j] == b' ' || bytes[j] == b'T')) { + let b = bytes[j]; + if current_segment < 2 && b == b'-' { + //check for validity of year and month segments if current byte is separator + if !is_valid_digits(current_segment, current_segment_digits) { + return return_result(date_str, eval_mode); + } + //if valid update corresponding segment with the current segment value. + date_segments[current_segment as usize] = current_segment_value.0; + current_segment_value = Wrapping(0); + current_segment_digits = 0; + current_segment += 1; + } else if !b.is_ascii_digit() { + return return_result(date_str, eval_mode); + } else { + //increment value of current segment by the next digit + let parsed_value = Wrapping((b - b'0') as i32); + current_segment_value = current_segment_value * Wrapping(10) + parsed_value; + current_segment_digits += 1; + } + j += 1; + } + + //check for validity of last segment + if !is_valid_digits(current_segment, current_segment_digits) { + return return_result(date_str, eval_mode); + } + + if current_segment < 2 && j < str_end_trimmed { + // For the `yyyy` and `yyyy-[m]m` formats, entire input must be consumed. + return return_result(date_str, eval_mode); + } + + date_segments[current_segment as usize] = current_segment_value.0; + + match NaiveDate::from_ymd_opt( + sign * date_segments[0], + date_segments[1] as u32, + date_segments[2] as u32, + ) { + Some(date) => { + let duration_since_epoch = date + .signed_duration_since(NaiveDateTime::UNIX_EPOCH.date()) + .num_days(); + Ok(Some(duration_since_epoch.to_i32().unwrap())) + } + None => Ok(None), + } +} + +/// This takes for special casting cases of Spark. E.g., Timestamp to Long. +/// This function runs as a post process of the DataFusion cast(). By the time it arrives here, +/// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify +/// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in +/// expressions/cast.rs, so it can be still Dictionary. +fn spark_cast(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef { + match (from_type, to_type) { + (DataType::Timestamp(_, _), DataType::Int64) => { + // See Spark's `Cast` expression + unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap() + } + (DataType::Dictionary(_, value_type), DataType::Int64) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + // See Spark's `Cast` expression + unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)).unwrap() + } + (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array), + (DataType::Dictionary(_, value_type), DataType::Utf8) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + remove_trailing_zeroes(array) + } + _ => array, + } +} + +/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated +fn unary_dyn(array: &ArrayRef, op: F) -> Result +where + T: ArrowPrimitiveType, + F: Fn(T::Native) -> T::Native, +{ + if let Some(d) = array.as_any_dictionary_opt() { + let new_values = unary_dyn::(d.values(), op)?; + return Ok(Arc::new(d.with_values(Arc::new(new_values)))); + } + + match array.as_primitive_opt::() { + Some(a) if PrimitiveArray::::is_compatible(a.data_type()) => { + Ok(Arc::new(unary::( + array.as_any().downcast_ref::>().unwrap(), + op, + ))) + } + _ => Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation of type {} on array of type {}", + T::DATA_TYPE, + array.data_type() + ))), + } +} + +/// Remove any trailing zeroes in the string if they occur after in the fractional seconds, +/// to match Spark behavior +/// example: +/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9" +/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99" +/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999" +/// "1970-01-01 05:30:00" => "1970-01-01 05:30:00" +/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001" +fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef { + let string_array = as_generic_string_array::(&array).unwrap(); + let result = string_array + .iter() + .map(|s| s.map(trim_end)) + .collect::>(); + Arc::new(result) as ArrayRef +} + +fn trim_end(s: &str) -> &str { + if s.rfind('.').is_some() { + s.trim_end_matches('0') + } else { + s + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::TimestampMicrosecondType; + use arrow_array::StringArray; + use arrow_schema::TimeUnit; + + use datafusion_physical_expr::expressions::Column; + + use super::*; + + #[test] + #[cfg_attr(miri, ignore)] // test takes too long with miri + fn timestamp_parser_test() { + // write for all formats + assert_eq!( + timestamp_parser("2020", EvalMode::Legacy).unwrap(), + Some(1577836800000000) // this is in milliseconds + ); + assert_eq!( + timestamp_parser("2020-01", EvalMode::Legacy).unwrap(), + Some(1577836800000000) + ); + assert_eq!( + timestamp_parser("2020-01-01", EvalMode::Legacy).unwrap(), + Some(1577836800000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12", EvalMode::Legacy).unwrap(), + Some(1577880000000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34", EvalMode::Legacy).unwrap(), + Some(1577882040000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy).unwrap(), + Some(1577882096000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34:56.123456", EvalMode::Legacy).unwrap(), + Some(1577882096123456) + ); + // assert_eq!( + // timestamp_parser("T2", EvalMode::Legacy).unwrap(), + // Some(1714356000000000) // this value needs to change everyday. + // ); + } + + #[test] + #[cfg_attr(miri, ignore)] // test takes too long with miri + fn test_cast_string_to_timestamp() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020-01-01T12:34:56.123456"), + Some("T2"), + ])); + + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let eval_mode = EvalMode::Legacy; + let result = cast_utf8_to_timestamp!( + &string_array, + eval_mode, + TimestampMicrosecondType, + timestamp_parser + ); + + assert_eq!( + result.data_type(), + &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())) + ); + assert_eq!(result.len(), 2); + } + + #[test] + fn date_parser_test() { + for date in &[ + "2020", + "2020-01", + "2020-01-01", + "02020-01-01", + "002020-01-01", + "0002020-01-01", + "2020-1-1", + "2020-01-01 ", + "2020-01-01T", + ] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try] { + assert_eq!(date_parser(*date, *eval_mode).unwrap(), Some(18262)); + } + } + + //dates in invalid formats + for date in &[ + "abc", + "", + "not_a_date", + "3/", + "3/12", + "3/12/2020", + "3/12/2002 T", + "202", + "2020-010-01", + "2020-10-010", + "2020-10-010T", + "--262143-12-31", + "--262143-12-31 ", + ] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Try] { + assert_eq!(date_parser(*date, *eval_mode).unwrap(), None); + } + assert!(date_parser(*date, EvalMode::Ansi).is_err()); + } + + for date in &["-3638-5"] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { + assert_eq!(date_parser(*date, *eval_mode).unwrap(), Some(-2048160)); + } + } + + //Naive Date only supports years 262142 AD to 262143 BC + //returns None for dates out of range supported by Naive Date. + for date in &[ + "-262144-1-1", + "262143-01-1", + "262143-1-1", + "262143-01-1 ", + "262143-01-01T ", + "262143-1-01T 1234", + "-0973250", + ] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { + assert_eq!(date_parser(*date, *eval_mode).unwrap(), None); + } + } + } + + #[test] + fn test_cast_string_to_date() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020"), + Some("2020-01"), + Some("2020-01-01"), + Some("2020-01-01T"), + ])); + + let result = + Cast::cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap(); + + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(date32_array.len(), 4); + date32_array + .iter() + .for_each(|v| assert_eq!(v.unwrap(), 18262)); + } + + #[test] + fn test_cast_string_array_with_valid_dates() { + let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![ + Some("-262143-12-31"), + Some("\n -262143-12-31 "), + Some("-262143-12-31T \t\n"), + Some("\n\t-262143-12-31T\r"), + Some("-262143-12-31T 123123123"), + Some("\r\n-262143-12-31T \r123123123"), + Some("\n -262143-12-31T \n\t"), + ])); + + for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { + let result = + Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode) + .unwrap(); + + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result.len(), 7); + date32_array + .iter() + .for_each(|v| assert_eq!(v.unwrap(), -96464928)); + } + } + + #[test] + fn test_cast_string_array_with_invalid_dates() { + let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020"), + Some("2020-01"), + Some("2020-01-01"), + //4 invalid dates + Some("2020-010-01T"), + Some("202"), + Some(" 202 "), + Some("\n 2020-\r8 "), + Some("2020-01-01T"), + // Overflows i32 + Some("-4607172990231812908"), + ])); + + for eval_mode in &[EvalMode::Legacy, EvalMode::Try] { + let result = + Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode) + .unwrap(); + + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + date32_array.iter().collect::>(), + vec![ + Some(18262), + Some(18262), + Some(18262), + None, + None, + None, + None, + Some(18262), + None + ] + ); + } + + let result = + Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, EvalMode::Ansi); + match result { + Err(e) => assert!( + e.to_string().contains( + "[CAST_INVALID_INPUT] The value '2020-010-01T' of the type \"STRING\" cannot be cast to \"DATE\" because it is malformed") + ), + _ => panic!("Expected error"), + } + } + + #[test] + fn test_cast_string_as_i8() { + // basic + assert_eq!( + cast_string_to_i8("127", EvalMode::Legacy).unwrap(), + Some(127_i8) + ); + assert_eq!(cast_string_to_i8("128", EvalMode::Legacy).unwrap(), None); + assert!(cast_string_to_i8("128", EvalMode::Ansi).is_err()); + // decimals + assert_eq!( + cast_string_to_i8("0.2", EvalMode::Legacy).unwrap(), + Some(0_i8) + ); + assert_eq!( + cast_string_to_i8(".", EvalMode::Legacy).unwrap(), + Some(0_i8) + ); + // TRY should always return null for decimals + assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), None); + assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None); + // ANSI mode should throw error on decimal + assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err()); + assert!(cast_string_to_i8(".", EvalMode::Ansi).is_err()); + } + + #[test] + fn test_cast_unsupported_timestamp_to_date() { + // Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported + let timestamps: PrimitiveArray = vec![i64::MAX].into(); + let cast = Cast::new( + Arc::new(Column::new("a", 0)), + DataType::Date32, + EvalMode::Legacy, + "UTC".to_owned(), + ); + let result = cast.cast_array(Arc::new(timestamps.with_timezone("Europe/Copenhagen"))); + assert!(result.is_err()) + } + + #[test] + fn test_cast_invalid_timezone() { + let timestamps: PrimitiveArray = vec![i64::MAX].into(); + let cast = Cast::new( + Arc::new(Column::new("a", 0)), + DataType::Date32, + EvalMode::Legacy, + "Not a valid timezone".to_owned(), + ); + let result = cast.cast_array(Arc::new(timestamps.with_timezone("Europe/Copenhagen"))); + assert!(result.is_err()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 57da56f9aca6..93c7f249eb2e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,6 +16,7 @@ // under the License. mod abs; +pub mod cast; mod error; mod if_expr; From fb7b1981bf528481fdb43606b24e8b829c457470 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 12 Jul 2024 16:13:38 -0600 Subject: [PATCH 05/68] remove utils crate and move utils into spark-expr crate (#658) --- Cargo.toml | 3 +- src/cast.rs | 2 +- src/if_expr.rs | 2 +- src/lib.rs | 3 + src/timezone.rs | 143 +++++++++++++++++++++++++++++++++++ src/utils.rs | 196 ++++++++++++++++++++++++++++++++++++++++++++++++ 6 files changed, 346 insertions(+), 3 deletions(-) create mode 100644 src/timezone.rs create mode 100644 src/utils.rs diff --git a/Cargo.toml b/Cargo.toml index 220417fe8b05..976a1f36f354 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,7 +36,8 @@ datafusion-common = { workspace = true } datafusion-functions = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } -datafusion-comet-utils = { workspace = true } +datafusion-physical-plan = { workspace = true } +chrono-tz = { workspace = true } num = { workspace = true } regex = { workspace = true } thiserror = { workspace = true } diff --git a/src/cast.rs b/src/cast.rs index b9cf2790b5d2..7f53583e8d76 100644 --- a/src/cast.rs +++ b/src/cast.rs @@ -55,7 +55,7 @@ use num::{ }; use regex::Regex; -use datafusion_comet_utils::{array_with_timezone, down_cast_any_ref}; +use crate::utils::{array_with_timezone, down_cast_any_ref}; use crate::{EvalMode, SparkError, SparkResult}; diff --git a/src/if_expr.rs b/src/if_expr.rs index c04494ec4ffb..fa52c5d5b9b9 100644 --- a/src/if_expr.rs +++ b/src/if_expr.rs @@ -31,7 +31,7 @@ use datafusion::logical_expr::ColumnarValue; use datafusion_common::{cast::as_boolean_array, Result}; use datafusion_physical_expr::PhysicalExpr; -use datafusion_comet_utils::down_cast_any_ref; +use crate::utils::down_cast_any_ref; #[derive(Debug, Hash)] pub struct IfExpr { diff --git a/src/lib.rs b/src/lib.rs index 93c7f249eb2e..3c726f52a8e8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,9 @@ pub mod cast; mod error; mod if_expr; +pub mod timezone; +pub mod utils; + pub use abs::Abs; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; diff --git a/src/timezone.rs b/src/timezone.rs new file mode 100644 index 000000000000..7aad386aa915 --- /dev/null +++ b/src/timezone.rs @@ -0,0 +1,143 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Utils for timezone. This is basically from arrow-array::timezone (private). +use arrow_schema::ArrowError; +use chrono::{ + format::{parse, Parsed, StrftimeItems}, + offset::TimeZone, + FixedOffset, LocalResult, NaiveDate, NaiveDateTime, Offset, +}; +use std::str::FromStr; + +/// Parses a fixed offset of the form "+09:00" +fn parse_fixed_offset(tz: &str) -> Result { + let mut parsed = Parsed::new(); + + if let Ok(fixed_offset) = + parse(&mut parsed, tz, StrftimeItems::new("%:z")).and_then(|_| parsed.to_fixed_offset()) + { + return Ok(fixed_offset); + } + + if let Ok(fixed_offset) = + parse(&mut parsed, tz, StrftimeItems::new("%#z")).and_then(|_| parsed.to_fixed_offset()) + { + return Ok(fixed_offset); + } + + Err(ArrowError::ParseError(format!( + "Invalid timezone \"{}\": Expected format [+-]XX:XX, [+-]XX, or [+-]XXXX", + tz + ))) +} + +/// An [`Offset`] for [`Tz`] +#[derive(Debug, Copy, Clone)] +pub struct TzOffset { + tz: Tz, + offset: FixedOffset, +} + +impl std::fmt::Display for TzOffset { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.offset.fmt(f) + } +} + +impl Offset for TzOffset { + fn fix(&self) -> FixedOffset { + self.offset + } +} + +/// An Arrow [`TimeZone`] +#[derive(Debug, Copy, Clone)] +pub struct Tz(TzInner); + +#[derive(Debug, Copy, Clone)] +enum TzInner { + Timezone(chrono_tz::Tz), + Offset(FixedOffset), +} + +impl FromStr for Tz { + type Err = ArrowError; + + fn from_str(tz: &str) -> Result { + if tz.starts_with('+') || tz.starts_with('-') { + Ok(Self(TzInner::Offset(parse_fixed_offset(tz)?))) + } else { + Ok(Self(TzInner::Timezone(tz.parse().map_err(|e| { + ArrowError::ParseError(format!("Invalid timezone \"{}\": {}", tz, e)) + })?))) + } + } +} + +macro_rules! tz { + ($s:ident, $tz:ident, $b:block) => { + match $s.0 { + TzInner::Timezone($tz) => $b, + TzInner::Offset($tz) => $b, + } + }; +} + +impl TimeZone for Tz { + type Offset = TzOffset; + + fn from_offset(offset: &Self::Offset) -> Self { + offset.tz + } + + fn offset_from_local_date(&self, local: &NaiveDate) -> LocalResult { + tz!(self, tz, { + tz.offset_from_local_date(local).map(|x| TzOffset { + tz: *self, + offset: x.fix(), + }) + }) + } + + fn offset_from_local_datetime(&self, local: &NaiveDateTime) -> LocalResult { + tz!(self, tz, { + tz.offset_from_local_datetime(local).map(|x| TzOffset { + tz: *self, + offset: x.fix(), + }) + }) + } + + fn offset_from_utc_date(&self, utc: &NaiveDate) -> Self::Offset { + tz!(self, tz, { + TzOffset { + tz: *self, + offset: tz.offset_from_utc_date(utc).fix(), + } + }) + } + + fn offset_from_utc_datetime(&self, utc: &NaiveDateTime) -> Self::Offset { + tz!(self, tz, { + TzOffset { + tz: *self, + offset: tz.offset_from_utc_datetime(utc).fix(), + } + }) + } +} diff --git a/src/utils.rs b/src/utils.rs new file mode 100644 index 000000000000..6945e82b3e4f --- /dev/null +++ b/src/utils.rs @@ -0,0 +1,196 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{ + cast::as_primitive_array, + types::{Int32Type, TimestampMicrosecondType}, +}; +use arrow_schema::{ArrowError, DataType}; +use std::any::Any; +use std::sync::Arc; + +use crate::timezone::Tz; +use arrow::{ + array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray}, + temporal_conversions::as_datetime, +}; +use chrono::{DateTime, Offset, TimeZone}; + +use datafusion_physical_plan::PhysicalExpr; + +/// A utility function from DataFusion. It is not exposed by DataFusion. +pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any { + if any.is::>() { + any.downcast_ref::>() + .unwrap() + .as_any() + } else if any.is::>() { + any.downcast_ref::>() + .unwrap() + .as_any() + } else { + any + } +} + +/// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or +/// to apply timezone offset. +// +// We consider the following cases: +// +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Conversion | Input array | Timezone | Output array | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp -> | Array in UTC | Timezone of input | A timestamp with the timezone | +// | Utf8 or Date32 | | | offset applied and timezone | +// | | | | removed | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp -> | Array in UTC | Timezone of input | Same as input array | +// | Timestamp w/Timezone| | | | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp_ntz -> | Array in | Timezone of input | Same as input array | +// | Utf8 or Date32 | timezone | | | +// | | session local| | | +// | | timezone | | | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp_ntz -> | Array in | Timezone of input | Array in UTC and timezone | +// | Timestamp w/Timezone | session local| | specified in input | +// | | timezone | | | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp(_ntz) -> | | +// | Any other type | Not Supported | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// +pub fn array_with_timezone( + array: ArrayRef, + timezone: String, + to_type: Option<&DataType>, +) -> Result { + match array.data_type() { + DataType::Timestamp(_, None) => { + assert!(!timezone.is_empty()); + match to_type { + Some(DataType::Utf8) | Some(DataType::Date32) => Ok(array), + Some(DataType::Timestamp(_, Some(_))) => { + timestamp_ntz_to_timestamp(array, timezone.as_str(), Some(timezone.as_str())) + } + _ => { + // Not supported + panic!( + "Cannot convert from {:?} to {:?}", + array.data_type(), + to_type.unwrap() + ) + } + } + } + DataType::Timestamp(_, Some(_)) => { + assert!(!timezone.is_empty()); + let array = as_primitive_array::(&array); + let array_with_timezone = array.clone().with_timezone(timezone.clone()); + let array = Arc::new(array_with_timezone) as ArrayRef; + match to_type { + Some(DataType::Utf8) | Some(DataType::Date32) => { + pre_timestamp_cast(array, timezone) + } + _ => Ok(array), + } + } + DataType::Dictionary(_, value_type) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + let dict = as_dictionary_array::(&array); + let array = as_primitive_array::(dict.values()); + let array_with_timezone = + array_with_timezone(Arc::new(array.clone()) as ArrayRef, timezone, to_type)?; + let dict = dict.with_values(array_with_timezone); + Ok(Arc::new(dict)) + } + _ => Ok(array), + } +} + +fn datetime_cast_err(value: i64) -> ArrowError { + ArrowError::CastError(format!( + "Cannot convert TimestampMicrosecondType {value} to datetime. Comet only supports dates between Jan 1, 262145 BCE and Dec 31, 262143 CE", + )) +} + +/// Takes in a Timestamp(Microsecond, None) array and a timezone id, and returns +/// a Timestamp(Microsecond, Some<_>) array. +/// The understanding is that the input array has time in the timezone specified in the second +/// argument. +/// Parameters: +/// array - input array of timestamp without timezone +/// tz - timezone of the values in the input array +/// to_timezone - timezone to change the input values to +fn timestamp_ntz_to_timestamp( + array: ArrayRef, + tz: &str, + to_timezone: Option<&str>, +) -> Result { + assert!(!tz.is_empty()); + match array.data_type() { + DataType::Timestamp(_, None) => { + let array = as_primitive_array::(&array); + let tz: Tz = tz.parse()?; + let array: PrimitiveArray = array.try_unary(|value| { + as_datetime::(value) + .ok_or_else(|| datetime_cast_err(value)) + .map(|local_datetime| { + let datetime: DateTime = + tz.from_local_datetime(&local_datetime).unwrap(); + datetime.timestamp_micros() + }) + })?; + let array_with_tz = if let Some(to_tz) = to_timezone { + array.with_timezone(to_tz) + } else { + array + }; + Ok(Arc::new(array_with_tz)) + } + _ => Ok(array), + } +} + +/// This takes for special pre-casting cases of Spark. E.g., Timestamp to String. +fn pre_timestamp_cast(array: ArrayRef, timezone: String) -> Result { + assert!(!timezone.is_empty()); + match array.data_type() { + DataType::Timestamp(_, _) => { + // Spark doesn't output timezone while casting timestamp to string, but arrow's cast + // kernel does if timezone exists. So we need to apply offset of timezone to array + // timestamp value and remove timezone from array datatype. + let array = as_primitive_array::(&array); + + let tz: Tz = timezone.parse()?; + let array: PrimitiveArray = array.try_unary(|value| { + as_datetime::(value) + .ok_or_else(|| datetime_cast_err(value)) + .map(|datetime| { + let offset = tz.offset_from_utc_datetime(&datetime).fix(); + let datetime = datetime + offset; + datetime.and_utc().timestamp_micros() + }) + })?; + + Ok(Arc::new(array)) + } + _ => Ok(array), + } +} From d510649c789526c500ed1973655a9da54f4bbdea Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Mon, 15 Jul 2024 11:47:44 -0600 Subject: [PATCH 06/68] chore: Move temporal kernels and expressions to spark-expr crate (#660) * Move temporal expressions to spark-expr crate * reduce public api * reduce public api * update imports in benchmarks * fmt * remove unused dep --- src/kernels/mod.rs | 20 + src/kernels/temporal.rs | 1148 +++++++++++++++++++++++++++++++++++++++ src/lib.rs | 6 +- src/temporal.rs | 534 ++++++++++++++++++ 4 files changed, 1707 insertions(+), 1 deletion(-) create mode 100644 src/kernels/mod.rs create mode 100644 src/kernels/temporal.rs create mode 100644 src/temporal.rs diff --git a/src/kernels/mod.rs b/src/kernels/mod.rs new file mode 100644 index 000000000000..88aa34b1a3f8 --- /dev/null +++ b/src/kernels/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Kernels + +pub(crate) mod temporal; diff --git a/src/kernels/temporal.rs b/src/kernels/temporal.rs new file mode 100644 index 000000000000..6f2474e8d7a8 --- /dev/null +++ b/src/kernels/temporal.rs @@ -0,0 +1,1148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! temporal kernels + +use chrono::{DateTime, Datelike, Duration, NaiveDateTime, Timelike, Utc}; + +use std::sync::Arc; + +use arrow::{array::*, datatypes::DataType}; +use arrow_array::{ + downcast_dictionary_array, downcast_temporal_array, + temporal_conversions::*, + timezone::Tz, + types::{ArrowDictionaryKeyType, ArrowTemporalType, Date32Type, TimestampMicrosecondType}, + ArrowNumericType, +}; + +use arrow_schema::TimeUnit; + +use crate::SparkError; + +// Copied from arrow_arith/temporal.rs +macro_rules! return_compute_error_with { + ($msg:expr, $param:expr) => { + return { Err(SparkError::Internal(format!("{}: {:?}", $msg, $param))) } + }; +} + +// The number of days between the beginning of the proleptic gregorian calendar (0001-01-01) +// and the beginning of the Unix Epoch (1970-01-01) +const DAYS_TO_UNIX_EPOCH: i32 = 719_163; + +// Copied from arrow_arith/temporal.rs with modification to the output datatype +// Transforms a array of NaiveDate to an array of Date32 after applying an operation +fn as_datetime_with_op, T: ArrowTemporalType, F>( + iter: ArrayIter, + mut builder: PrimitiveBuilder, + op: F, +) -> Date32Array +where + F: Fn(NaiveDateTime) -> i32, + i64: From, +{ + iter.into_iter().for_each(|value| { + if let Some(value) = value { + match as_datetime::(i64::from(value)) { + Some(dt) => builder.append_value(op(dt)), + None => builder.append_null(), + } + } else { + builder.append_null(); + } + }); + + builder.finish() +} + +#[inline] +fn as_datetime_with_op_single( + value: Option, + builder: &mut PrimitiveBuilder, + op: F, +) where + F: Fn(NaiveDateTime) -> i32, +{ + if let Some(value) = value { + match as_datetime::(i64::from(value)) { + Some(dt) => builder.append_value(op(dt)), + None => builder.append_null(), + } + } else { + builder.append_null(); + } +} + +// Based on arrow_arith/temporal.rs:extract_component_from_datetime_array +// Transforms an array of DateTime to an arrayOf TimeStampMicrosecond after applying an +// operation +fn as_timestamp_tz_with_op, T: ArrowTemporalType, F>( + iter: ArrayIter, + mut builder: PrimitiveBuilder, + tz: &str, + op: F, +) -> Result +where + F: Fn(DateTime) -> i64, + i64: From, +{ + let tz: Tz = tz.parse()?; + for value in iter { + match value { + Some(value) => match as_datetime_with_timezone::(value.into(), tz) { + Some(time) => builder.append_value(op(time)), + _ => { + return Err(SparkError::Internal( + "Unable to read value as datetime".to_string(), + )); + } + }, + None => builder.append_null(), + } + } + Ok(builder.finish()) +} + +fn as_timestamp_tz_with_op_single( + value: Option, + builder: &mut PrimitiveBuilder, + tz: &Tz, + op: F, +) -> Result<(), SparkError> +where + F: Fn(DateTime) -> i64, + i64: From, +{ + match value { + Some(value) => match as_datetime_with_timezone::(value.into(), *tz) { + Some(time) => builder.append_value(op(time)), + _ => { + return Err(SparkError::Internal( + "Unable to read value as datetime".to_string(), + )); + } + }, + None => builder.append_null(), + } + Ok(()) +} + +#[inline] +fn as_days_from_unix_epoch(dt: Option) -> i32 { + dt.unwrap().num_days_from_ce() - DAYS_TO_UNIX_EPOCH +} + +// Apply the Tz to the Naive Date Time,,convert to UTC, and return as microseconds in Unix epoch +#[inline] +fn as_micros_from_unix_epoch_utc(dt: Option>) -> i64 { + dt.unwrap().with_timezone(&Utc).timestamp_micros() +} + +#[inline] +fn trunc_date_to_year(dt: T) -> Option { + Some(dt) + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)) + .and_then(|d| d.with_day0(0)) + .and_then(|d| d.with_month0(0)) +} + +/// returns the month of the beginning of the quarter +#[inline] +fn quarter_month(dt: &T) -> u32 { + 1 + 3 * ((dt.month() - 1) / 3) +} + +#[inline] +fn trunc_date_to_quarter(dt: T) -> Option { + Some(dt) + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)) + .and_then(|d| d.with_day0(0)) + .and_then(|d| d.with_month(quarter_month(&d))) +} + +#[inline] +fn trunc_date_to_month(dt: T) -> Option { + Some(dt) + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)) + .and_then(|d| d.with_day0(0)) +} + +#[inline] +fn trunc_date_to_week(dt: T) -> Option +where + T: Datelike + Timelike + std::ops::Sub + Copy, +{ + Some(dt) + .map(|d| d - Duration::try_seconds(60 * 60 * 24 * d.weekday() as i64).unwrap()) + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)) +} + +#[inline] +fn trunc_date_to_day(dt: T) -> Option { + Some(dt) + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)) +} + +#[inline] +fn trunc_date_to_hour(dt: T) -> Option { + Some(dt) + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) +} + +#[inline] +fn trunc_date_to_minute(dt: T) -> Option { + Some(dt) + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)) +} + +#[inline] +fn trunc_date_to_second(dt: T) -> Option { + Some(dt).and_then(|d| d.with_nanosecond(0)) +} + +#[inline] +fn trunc_date_to_ms(dt: T) -> Option { + Some(dt).and_then(|d| d.with_nanosecond(1_000_000 * (d.nanosecond() / 1_000_000))) +} + +#[inline] +fn trunc_date_to_microsec(dt: T) -> Option { + Some(dt).and_then(|d| d.with_nanosecond(1_000 * (d.nanosecond() / 1_000))) +} + +/// +/// Implements the spark [TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#trunc) +/// function where the specified format is a scalar value +/// +/// array is an array of Date32 values. The array may be a dictionary array. +/// +/// format is a scalar string specifying the format to apply to the timestamp value. +pub(crate) fn date_trunc_dyn(array: &dyn Array, format: String) -> Result { + match array.data_type().clone() { + DataType::Dictionary(_, _) => { + downcast_dictionary_array!( + array => { + let truncated_values = date_trunc_dyn(array.values(), format)?; + Ok(Arc::new(array.with_values(truncated_values))) + } + dt => return_compute_error_with!("date_trunc does not support", dt), + ) + } + _ => { + downcast_temporal_array!( + array => { + date_trunc(array, format) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("date_trunc does not support", dt), + ) + } + } +} + +pub(crate) fn date_trunc( + array: &PrimitiveArray, + format: String, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + let builder = Date32Builder::with_capacity(array.len()); + let iter = ArrayIter::new(array); + match array.data_type() { + DataType::Date32 => match format.to_uppercase().as_str() { + "YEAR" | "YYYY" | "YY" => Ok(as_datetime_with_op::<&PrimitiveArray, T, _>( + iter, + builder, + |dt| as_days_from_unix_epoch(trunc_date_to_year(dt)), + )), + "QUARTER" => Ok(as_datetime_with_op::<&PrimitiveArray, T, _>( + iter, + builder, + |dt| as_days_from_unix_epoch(trunc_date_to_quarter(dt)), + )), + "MONTH" | "MON" | "MM" => Ok(as_datetime_with_op::<&PrimitiveArray, T, _>( + iter, + builder, + |dt| as_days_from_unix_epoch(trunc_date_to_month(dt)), + )), + "WEEK" => Ok(as_datetime_with_op::<&PrimitiveArray, T, _>( + iter, + builder, + |dt| as_days_from_unix_epoch(trunc_date_to_week(dt)), + )), + _ => Err(SparkError::Internal(format!( + "Unsupported format: {:?} for function 'date_trunc'", + format + ))), + }, + dt => return_compute_error_with!( + "Unsupported input type '{:?}' for function 'date_trunc'", + dt + ), + } +} + +/// +/// Implements the spark [TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#trunc) +/// function where the specified format may be an array +/// +/// array is an array of Date32 values. The array may be a dictionary array. +/// +/// format is an array of strings specifying the format to apply to the corresponding date value. +/// The array may be a dictionary array. +pub(crate) fn date_trunc_array_fmt_dyn( + array: &dyn Array, + formats: &dyn Array, +) -> Result { + match (array.data_type().clone(), formats.data_type().clone()) { + (DataType::Dictionary(_, v), DataType::Dictionary(_, f)) => { + if !matches!(*v, DataType::Date32) { + return_compute_error_with!("date_trunc does not support", v) + } + if !matches!(*f, DataType::Utf8) { + return_compute_error_with!("date_trunc does not support format type ", f) + } + downcast_dictionary_array!( + formats => { + downcast_dictionary_array!( + array => { + date_trunc_array_fmt_dict_dict( + &array.downcast_dict::().unwrap(), + &formats.downcast_dict::().unwrap()) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("date_trunc does not support", dt) + ) + } + fmt => return_compute_error_with!("date_trunc does not support format type", fmt), + ) + } + (DataType::Dictionary(_, v), DataType::Utf8) => { + if !matches!(*v, DataType::Date32) { + return_compute_error_with!("date_trunc does not support", v) + } + downcast_dictionary_array!( + array => { + date_trunc_array_fmt_dict_plain( + &array.downcast_dict::().unwrap(), + formats.as_any().downcast_ref::() + .expect("Unexpected value type in formats")) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("date_trunc does not support", dt), + ) + } + (DataType::Date32, DataType::Dictionary(_, f)) => { + if !matches!(*f, DataType::Utf8) { + return_compute_error_with!("date_trunc does not support format type ", f) + } + downcast_dictionary_array!( + formats => { + downcast_temporal_array!(array => { + date_trunc_array_fmt_plain_dict( + array.as_any().downcast_ref::() + .expect("Unexpected error in casting date array"), + &formats.downcast_dict::().unwrap()) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("date_trunc does not support", dt), + ) + } + fmt => return_compute_error_with!("date_trunc does not support format type", fmt), + ) + } + (DataType::Date32, DataType::Utf8) => date_trunc_array_fmt_plain_plain( + array + .as_any() + .downcast_ref::() + .expect("Unexpected error in casting date array"), + formats + .as_any() + .downcast_ref::() + .expect("Unexpected value type in formats"), + ) + .map(|a| Arc::new(a) as ArrayRef), + (dt, fmt) => Err(SparkError::Internal(format!( + "Unsupported datatype: {:}, format: {:?} for function 'date_trunc'", + dt, fmt + ))), + } +} + +macro_rules! date_trunc_array_fmt_helper { + ($array: ident, $formats: ident, $datatype: ident) => {{ + let mut builder = Date32Builder::with_capacity($array.len()); + let iter = $array.into_iter(); + match $datatype { + DataType::Date32 => { + for (index, val) in iter.enumerate() { + let op_result = match $formats.value(index).to_uppercase().as_str() { + "YEAR" | "YYYY" | "YY" => { + Ok(as_datetime_with_op_single(val, &mut builder, |dt| { + as_days_from_unix_epoch(trunc_date_to_year(dt)) + })) + } + "QUARTER" => Ok(as_datetime_with_op_single(val, &mut builder, |dt| { + as_days_from_unix_epoch(trunc_date_to_quarter(dt)) + })), + "MONTH" | "MON" | "MM" => { + Ok(as_datetime_with_op_single(val, &mut builder, |dt| { + as_days_from_unix_epoch(trunc_date_to_month(dt)) + })) + } + "WEEK" => Ok(as_datetime_with_op_single(val, &mut builder, |dt| { + as_days_from_unix_epoch(trunc_date_to_week(dt)) + })), + _ => Err(SparkError::Internal(format!( + "Unsupported format: {:?} for function 'date_trunc'", + $formats.value(index) + ))), + }; + op_result? + } + Ok(builder.finish()) + } + dt => return_compute_error_with!( + "Unsupported input type '{:?}' for function 'date_trunc'", + dt + ), + } + }}; +} + +fn date_trunc_array_fmt_plain_plain( + array: &Date32Array, + formats: &StringArray, +) -> Result +where +{ + let data_type = array.data_type(); + date_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn date_trunc_array_fmt_plain_dict( + array: &Date32Array, + formats: &TypedDictionaryArray, +) -> Result +where + K: ArrowDictionaryKeyType, +{ + let data_type = array.data_type(); + date_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn date_trunc_array_fmt_dict_plain( + array: &TypedDictionaryArray, + formats: &StringArray, +) -> Result +where + K: ArrowDictionaryKeyType, +{ + let data_type = array.values().data_type(); + date_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn date_trunc_array_fmt_dict_dict( + array: &TypedDictionaryArray, + formats: &TypedDictionaryArray, +) -> Result +where + K: ArrowDictionaryKeyType, + F: ArrowDictionaryKeyType, +{ + let data_type = array.values().data_type(); + date_trunc_array_fmt_helper!(array, formats, data_type) +} + +/// +/// Implements the spark [DATE_TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc) +/// function where the specified format is a scalar value +/// +/// array is an array of Timestamp(Microsecond) values. Timestamp values must have a valid +/// timezone or no timezone. The array may be a dictionary array. +/// +/// format is a scalar string specifying the format to apply to the timestamp value. +pub(crate) fn timestamp_trunc_dyn( + array: &dyn Array, + format: String, +) -> Result { + match array.data_type().clone() { + DataType::Dictionary(_, _) => { + downcast_dictionary_array!( + array => { + let truncated_values = timestamp_trunc_dyn(array.values(), format)?; + Ok(Arc::new(array.with_values(truncated_values))) + } + dt => return_compute_error_with!("timestamp_trunc does not support", dt), + ) + } + _ => { + downcast_temporal_array!( + array => { + timestamp_trunc(array, format) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("timestamp_trunc does not support", dt), + ) + } + } +} + +pub(crate) fn timestamp_trunc( + array: &PrimitiveArray, + format: String, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + let builder = TimestampMicrosecondBuilder::with_capacity(array.len()); + let iter = ArrayIter::new(array); + match array.data_type() { + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { + match format.to_uppercase().as_str() { + "YEAR" | "YYYY" | "YY" => { + as_timestamp_tz_with_op::<&PrimitiveArray, T, _>(iter, builder, tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_year(dt)) + }) + } + "QUARTER" => { + as_timestamp_tz_with_op::<&PrimitiveArray, T, _>(iter, builder, tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_quarter(dt)) + }) + } + "MONTH" | "MON" | "MM" => { + as_timestamp_tz_with_op::<&PrimitiveArray, T, _>(iter, builder, tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_month(dt)) + }) + } + "WEEK" => { + as_timestamp_tz_with_op::<&PrimitiveArray, T, _>(iter, builder, tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_week(dt)) + }) + } + "DAY" | "DD" => { + as_timestamp_tz_with_op::<&PrimitiveArray, T, _>(iter, builder, tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_day(dt)) + }) + } + "HOUR" => { + as_timestamp_tz_with_op::<&PrimitiveArray, T, _>(iter, builder, tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_hour(dt)) + }) + } + "MINUTE" => { + as_timestamp_tz_with_op::<&PrimitiveArray, T, _>(iter, builder, tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_minute(dt)) + }) + } + "SECOND" => { + as_timestamp_tz_with_op::<&PrimitiveArray, T, _>(iter, builder, tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_second(dt)) + }) + } + "MILLISECOND" => { + as_timestamp_tz_with_op::<&PrimitiveArray, T, _>(iter, builder, tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_ms(dt)) + }) + } + "MICROSECOND" => { + as_timestamp_tz_with_op::<&PrimitiveArray, T, _>(iter, builder, tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_microsec(dt)) + }) + } + _ => Err(SparkError::Internal(format!( + "Unsupported format: {:?} for function 'timestamp_trunc'", + format + ))), + } + } + dt => return_compute_error_with!( + "Unsupported input type '{:?}' for function 'timestamp_trunc'", + dt + ), + } +} + +/// +/// Implements the spark [DATE_TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc) +/// function where the specified format may be an array +/// +/// array is an array of Timestamp(Microsecond) values. Timestamp values must have a valid +/// timezone or no timezone. The array may be a dictionary array. +/// +/// format is an array of strings specifying the format to apply to the corresponding timestamp +/// value. The array may be a dictionary array. +pub(crate) fn timestamp_trunc_array_fmt_dyn( + array: &dyn Array, + formats: &dyn Array, +) -> Result { + match (array.data_type().clone(), formats.data_type().clone()) { + (DataType::Dictionary(_, _), DataType::Dictionary(_, _)) => { + downcast_dictionary_array!( + formats => { + downcast_dictionary_array!( + array => { + timestamp_trunc_array_fmt_dict_dict( + &array.downcast_dict::().unwrap(), + &formats.downcast_dict::().unwrap()) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("timestamp_trunc does not support", dt) + ) + } + fmt => return_compute_error_with!("timestamp_trunc does not support format type", fmt), + ) + } + (DataType::Dictionary(_, _), DataType::Utf8) => { + downcast_dictionary_array!( + array => { + timestamp_trunc_array_fmt_dict_plain( + &array.downcast_dict::>().unwrap(), + formats.as_any().downcast_ref::() + .expect("Unexpected value type in formats")) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("timestamp_trunc does not support", dt), + ) + } + (DataType::Timestamp(TimeUnit::Microsecond, _), DataType::Dictionary(_, _)) => { + downcast_dictionary_array!( + formats => { + downcast_temporal_array!(array => { + timestamp_trunc_array_fmt_plain_dict( + array, + &formats.downcast_dict::().unwrap()) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("timestamp_trunc does not support", dt), + ) + } + fmt => return_compute_error_with!("timestamp_trunc does not support format type", fmt), + ) + } + (DataType::Timestamp(TimeUnit::Microsecond, _), DataType::Utf8) => { + downcast_temporal_array!( + array => { + timestamp_trunc_array_fmt_plain_plain(array, + formats.as_any().downcast_ref::().expect("Unexpected value type in formats")) + .map(|a| Arc::new(a) as ArrayRef) + }, + dt => return_compute_error_with!("timestamp_trunc does not support", dt), + ) + } + (dt, fmt) => Err(SparkError::Internal(format!( + "Unsupported datatype: {:}, format: {:?} for function 'timestamp_trunc'", + dt, fmt + ))), + } +} + +macro_rules! timestamp_trunc_array_fmt_helper { + ($array: ident, $formats: ident, $datatype: ident) => {{ + let mut builder = TimestampMicrosecondBuilder::with_capacity($array.len()); + let iter = $array.into_iter(); + assert_eq!( + $array.len(), + $formats.len(), + "lengths of values array and format array must be the same" + ); + match $datatype { + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { + let tz: Tz = tz.parse()?; + for (index, val) in iter.enumerate() { + let op_result = match $formats.value(index).to_uppercase().as_str() { + "YEAR" | "YYYY" | "YY" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_year(dt)) + }) + } + "QUARTER" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_quarter(dt)) + }) + } + "MONTH" | "MON" | "MM" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_month(dt)) + }) + } + "WEEK" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_week(dt)) + }) + } + "DAY" | "DD" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_day(dt)) + }) + } + "HOUR" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_hour(dt)) + }) + } + "MINUTE" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_minute(dt)) + }) + } + "SECOND" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_second(dt)) + }) + } + "MILLISECOND" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_ms(dt)) + }) + } + "MICROSECOND" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_microsec(dt)) + }) + } + _ => Err(SparkError::Internal(format!( + "Unsupported format: {:?} for function 'timestamp_trunc'", + $formats.value(index) + ))), + }; + op_result? + } + Ok(builder.finish()) + } + dt => { + return_compute_error_with!( + "Unsupported input type '{:?}' for function 'timestamp_trunc'", + dt + ) + } + } + }}; +} + +fn timestamp_trunc_array_fmt_plain_plain( + array: &PrimitiveArray, + formats: &StringArray, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + let data_type = array.data_type(); + timestamp_trunc_array_fmt_helper!(array, formats, data_type) +} +fn timestamp_trunc_array_fmt_plain_dict( + array: &PrimitiveArray, + formats: &TypedDictionaryArray, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, + K: ArrowDictionaryKeyType, +{ + let data_type = array.data_type(); + timestamp_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn timestamp_trunc_array_fmt_dict_plain( + array: &TypedDictionaryArray>, + formats: &StringArray, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, + K: ArrowDictionaryKeyType, +{ + let data_type = array.values().data_type(); + timestamp_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn timestamp_trunc_array_fmt_dict_dict( + array: &TypedDictionaryArray>, + formats: &TypedDictionaryArray, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, + K: ArrowDictionaryKeyType, + F: ArrowDictionaryKeyType, +{ + let data_type = array.values().data_type(); + timestamp_trunc_array_fmt_helper!(array, formats, data_type) +} + +#[cfg(test)] +mod tests { + use crate::kernels::temporal::{ + date_trunc, date_trunc_array_fmt_dyn, timestamp_trunc, timestamp_trunc_array_fmt_dyn, + }; + use arrow_array::{ + builder::{PrimitiveDictionaryBuilder, StringDictionaryBuilder}, + iterator::ArrayIter, + types::{Date32Type, Int32Type, TimestampMicrosecondType}, + Array, Date32Array, PrimitiveArray, StringArray, TimestampMicrosecondArray, + }; + use std::sync::Arc; + + #[test] + #[cfg_attr(miri, ignore)] // test takes too long with miri + fn test_date_trunc() { + let size = 1000; + let mut vec: Vec = Vec::with_capacity(size); + for i in 0..size { + vec.push(i as i32); + } + let array = Date32Array::from(vec); + for fmt in [ + "YEAR", "YYYY", "YY", "QUARTER", "MONTH", "MON", "MM", "WEEK", + ] { + match date_trunc(&array, fmt.to_string()) { + Ok(a) => { + for i in 0..size { + assert!(array.values().get(i) >= a.values().get(i)) + } + } + _ => assert!(false), + } + } + } + + #[test] + // This test only verifies that the various input array types work. Actually correctness to + // ensure this produces the same results as spark is verified in the JVM tests + fn test_date_trunc_array_fmt_dyn() { + let size = 10; + let formats = [ + "YEAR", "YYYY", "YY", "QUARTER", "MONTH", "MON", "MM", "WEEK", + ]; + let mut vec: Vec = Vec::with_capacity(size * formats.len()); + let mut fmt_vec: Vec<&str> = Vec::with_capacity(size * formats.len()); + for i in 0..size { + for j in 0..formats.len() { + vec.push(i as i32 * 1_000_001); + fmt_vec.push(formats[j]); + } + } + + // timestamp array + let array = Date32Array::from(vec); + + // formats array + let fmt_array = StringArray::from(fmt_vec); + + // timestamp dictionary array + let mut date_dict_builder = PrimitiveDictionaryBuilder::::new(); + for v in array.iter() { + date_dict_builder + .append(v.unwrap()) + .expect("Error in building timestamp array"); + } + let mut array_dict = date_dict_builder.finish(); + // apply timezone + array_dict = array_dict.with_values(Arc::new( + array_dict + .values() + .as_any() + .downcast_ref::() + .unwrap() + .clone(), + )); + + // formats dictionary array + let mut formats_dict_builder = StringDictionaryBuilder::::new(); + for v in fmt_array.iter() { + formats_dict_builder + .append(v.unwrap()) + .expect("Error in building formats array"); + } + let fmt_dict = formats_dict_builder.finish(); + + // verify input arrays + let iter = ArrayIter::new(&array); + let mut dict_iter = array_dict + .downcast_dict::>() + .unwrap() + .into_iter(); + for val in iter { + assert_eq!( + dict_iter + .next() + .expect("array and dictionary array do not match"), + val + ) + } + + // verify input format arrays + let fmt_iter = ArrayIter::new(&fmt_array); + let mut fmt_dict_iter = fmt_dict.downcast_dict::().unwrap().into_iter(); + for val in fmt_iter { + assert_eq!( + fmt_dict_iter + .next() + .expect("formats and dictionary formats do not match"), + val + ) + } + + // test cases + if let Ok(a) = date_trunc_array_fmt_dyn(&array, &fmt_array) { + for i in 0..array.len() { + assert!( + array.value(i) >= a.as_any().downcast_ref::().unwrap().value(i) + ) + } + } else { + assert!(false) + } + if let Ok(a) = date_trunc_array_fmt_dyn(&array_dict, &fmt_array) { + for i in 0..array.len() { + assert!( + array.value(i) >= a.as_any().downcast_ref::().unwrap().value(i) + ) + } + } else { + assert!(false) + } + if let Ok(a) = date_trunc_array_fmt_dyn(&array, &fmt_dict) { + for i in 0..array.len() { + assert!( + array.value(i) >= a.as_any().downcast_ref::().unwrap().value(i) + ) + } + } else { + assert!(false) + } + if let Ok(a) = date_trunc_array_fmt_dyn(&array_dict, &fmt_dict) { + for i in 0..array.len() { + assert!( + array.value(i) >= a.as_any().downcast_ref::().unwrap().value(i) + ) + } + } else { + assert!(false) + } + } + + #[test] + #[cfg_attr(miri, ignore)] // test takes too long with miri + fn test_timestamp_trunc() { + let size = 1000; + let mut vec: Vec = Vec::with_capacity(size); + for i in 0..size { + vec.push(i as i64); + } + let array = TimestampMicrosecondArray::from(vec).with_timezone_utc(); + for fmt in [ + "YEAR", + "YYYY", + "YY", + "QUARTER", + "MONTH", + "MON", + "MM", + "WEEK", + "DAY", + "DD", + "HOUR", + "MINUTE", + "SECOND", + "MILLISECOND", + "MICROSECOND", + ] { + match timestamp_trunc(&array, fmt.to_string()) { + Ok(a) => { + for i in 0..size { + assert!(array.values().get(i) >= a.values().get(i)) + } + } + _ => assert!(false), + } + } + } + + #[test] + // test takes too long with miri + #[cfg_attr(miri, ignore)] + // This test only verifies that the various input array types work. Actually correctness to + // ensure this produces the same results as spark is verified in the JVM tests + fn test_timestamp_trunc_array_fmt_dyn() { + let size = 10; + let formats = [ + "YEAR", + "YYYY", + "YY", + "QUARTER", + "MONTH", + "MON", + "MM", + "WEEK", + "DAY", + "DD", + "HOUR", + "MINUTE", + "SECOND", + "MILLISECOND", + "MICROSECOND", + ]; + let mut vec: Vec = Vec::with_capacity(size * formats.len()); + let mut fmt_vec: Vec<&str> = Vec::with_capacity(size * formats.len()); + for i in 0..size { + for j in 0..formats.len() { + vec.push(i as i64 * 1_000_000_001); + fmt_vec.push(formats[j]); + } + } + + // timestamp array + let array = TimestampMicrosecondArray::from(vec).with_timezone_utc(); + + // formats array + let fmt_array = StringArray::from(fmt_vec); + + // timestamp dictionary array + let mut timestamp_dict_builder = + PrimitiveDictionaryBuilder::::new(); + for v in array.iter() { + timestamp_dict_builder + .append(v.unwrap()) + .expect("Error in building timestamp array"); + } + let mut array_dict = timestamp_dict_builder.finish(); + // apply timezone + array_dict = array_dict.with_values(Arc::new( + array_dict + .values() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .with_timezone_utc(), + )); + + // formats dictionary array + let mut formats_dict_builder = StringDictionaryBuilder::::new(); + for v in fmt_array.iter() { + formats_dict_builder + .append(v.unwrap()) + .expect("Error in building formats array"); + } + let fmt_dict = formats_dict_builder.finish(); + + // verify input arrays + let iter = ArrayIter::new(&array); + let mut dict_iter = array_dict + .downcast_dict::>() + .unwrap() + .into_iter(); + for val in iter { + assert_eq!( + dict_iter + .next() + .expect("array and dictionary array do not match"), + val + ) + } + + // verify input format arrays + let fmt_iter = ArrayIter::new(&fmt_array); + let mut fmt_dict_iter = fmt_dict.downcast_dict::().unwrap().into_iter(); + for val in fmt_iter { + assert_eq!( + fmt_dict_iter + .next() + .expect("formats and dictionary formats do not match"), + val + ) + } + + // test cases + if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array, &fmt_array) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any() + .downcast_ref::() + .unwrap() + .value(i) + ) + } + } else { + assert!(false) + } + if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array_dict, &fmt_array) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any() + .downcast_ref::() + .unwrap() + .value(i) + ) + } + } else { + assert!(false) + } + if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array, &fmt_dict) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any() + .downcast_ref::() + .unwrap() + .value(i) + ) + } + } else { + assert!(false) + } + if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array_dict, &fmt_dict) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any() + .downcast_ref::() + .unwrap() + .value(i) + ) + } + } else { + assert!(false) + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 3c726f52a8e8..5168e0e80747 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -16,16 +16,20 @@ // under the License. mod abs; -pub mod cast; +mod cast; mod error; mod if_expr; +mod kernels; +mod temporal; pub mod timezone; pub mod utils; pub use abs::Abs; +pub use cast::Cast; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; +pub use temporal::{DateTruncExec, HourExec, MinuteExec, SecondExec, TimestampTruncExec}; /// Spark supports three evaluation modes when evaluating expressions, which affect /// the behavior when processing input values that are invalid or would result in an diff --git a/src/temporal.rs b/src/temporal.rs new file mode 100644 index 000000000000..ea30d3383dd5 --- /dev/null +++ b/src/temporal.rs @@ -0,0 +1,534 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + hash::{Hash, Hasher}, + sync::Arc, +}; + +use arrow::{ + compute::{date_part, DatePart}, + record_batch::RecordBatch, +}; +use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue::Utf8}; +use datafusion_physical_expr::PhysicalExpr; + +use crate::utils::{array_with_timezone, down_cast_any_ref}; + +use crate::kernels::temporal::{ + date_trunc_array_fmt_dyn, date_trunc_dyn, timestamp_trunc_array_fmt_dyn, timestamp_trunc_dyn, +}; + +#[derive(Debug, Hash)] +pub struct HourExec { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + timezone: String, +} + +impl HourExec { + pub fn new(child: Arc, timezone: String) -> Self { + HourExec { child, timezone } + } +} + +impl Display for HourExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Hour [timezone:{}, child: {}]", + self.timezone, self.child + ) + } +} + +impl PartialEq for HourExec { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.child.eq(&x.child) && self.timezone.eq(&x.timezone)) + .unwrap_or(false) + } +} + +impl PhysicalExpr for HourExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema).unwrap() { + DataType::Dictionary(key_type, _) => { + Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) + } + _ => Ok(DataType::Int32), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let array = array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?; + let result = date_part(&array, DatePart::Hour)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Hour(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(HourExec::new( + children[0].clone(), + self.timezone.clone(), + ))) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.child.hash(&mut s); + self.timezone.hash(&mut s); + self.hash(&mut s); + } +} + +#[derive(Debug, Hash)] +pub struct MinuteExec { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + timezone: String, +} + +impl MinuteExec { + pub fn new(child: Arc, timezone: String) -> Self { + MinuteExec { child, timezone } + } +} + +impl Display for MinuteExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Minute [timezone:{}, child: {}]", + self.timezone, self.child + ) + } +} + +impl PartialEq for MinuteExec { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.child.eq(&x.child) && self.timezone.eq(&x.timezone)) + .unwrap_or(false) + } +} + +impl PhysicalExpr for MinuteExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema).unwrap() { + DataType::Dictionary(key_type, _) => { + Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) + } + _ => Ok(DataType::Int32), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let array = array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?; + let result = date_part(&array, DatePart::Minute)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Minute(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(MinuteExec::new( + children[0].clone(), + self.timezone.clone(), + ))) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.child.hash(&mut s); + self.timezone.hash(&mut s); + self.hash(&mut s); + } +} + +#[derive(Debug, Hash)] +pub struct SecondExec { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + timezone: String, +} + +impl SecondExec { + pub fn new(child: Arc, timezone: String) -> Self { + SecondExec { child, timezone } + } +} + +impl Display for SecondExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Second (timezone:{}, child: {}]", + self.timezone, self.child + ) + } +} + +impl PartialEq for SecondExec { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.child.eq(&x.child) && self.timezone.eq(&x.timezone)) + .unwrap_or(false) + } +} + +impl PhysicalExpr for SecondExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema).unwrap() { + DataType::Dictionary(key_type, _) => { + Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) + } + _ => Ok(DataType::Int32), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let array = array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?; + let result = date_part(&array, DatePart::Second)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Second(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(SecondExec::new( + children[0].clone(), + self.timezone.clone(), + ))) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.child.hash(&mut s); + self.timezone.hash(&mut s); + self.hash(&mut s); + } +} + +#[derive(Debug, Hash)] +pub struct DateTruncExec { + /// An array with DataType::Date32 + child: Arc, + /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#trunc + format: Arc, +} + +impl DateTruncExec { + pub fn new(child: Arc, format: Arc) -> Self { + DateTruncExec { child, format } + } +} + +impl Display for DateTruncExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "DateTrunc [child:{}, format: {}]", + self.child, self.format + ) + } +} + +impl PartialEq for DateTruncExec { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| self.child.eq(&x.child) && self.format.eq(&x.format)) + .unwrap_or(false) + } +} + +impl PhysicalExpr for DateTruncExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + self.child.data_type(input_schema) + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let date = self.child.evaluate(batch)?; + let format = self.format.evaluate(batch)?; + match (date, format) { + (ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => { + let result = date_trunc_dyn(&date, format)?; + Ok(ColumnarValue::Array(result)) + } + (ColumnarValue::Array(date), ColumnarValue::Array(formats)) => { + let result = date_trunc_array_fmt_dyn(&date, &formats)?; + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Invalid input to function DateTrunc. Expected (PrimitiveArray, Scalar) or \ + (PrimitiveArray, StringArray)".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(DateTruncExec::new( + children[0].clone(), + self.format.clone(), + ))) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.child.hash(&mut s); + self.format.hash(&mut s); + self.hash(&mut s); + } +} + +#[derive(Debug, Hash)] +pub struct TimestampTruncExec { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc + format: Arc, + /// String containing a timezone name. The name must be found in the standard timezone + /// database (https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). The string is + /// later parsed into a chrono::TimeZone. + /// Timestamp arrays in this implementation are kept in arrays of UTC timestamps (in micros) + /// along with a single value for the associated TimeZone. The timezone offset is applied + /// just before any operations on the timestamp + timezone: String, +} + +impl TimestampTruncExec { + pub fn new( + child: Arc, + format: Arc, + timezone: String, + ) -> Self { + TimestampTruncExec { + child, + format, + timezone, + } + } +} + +impl Display for TimestampTruncExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "TimestampTrunc [child:{}, format:{}, timezone: {}]", + self.child, self.format, self.timezone + ) + } +} + +impl PartialEq for TimestampTruncExec { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.child.eq(&x.child) + && self.format.eq(&x.format) + && self.timezone.eq(&x.timezone) + }) + .unwrap_or(false) + } +} + +impl PhysicalExpr for TimestampTruncExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema)? { + DataType::Dictionary(key_type, _) => Ok(DataType::Dictionary( + key_type, + Box::new(DataType::Timestamp(Microsecond, None)), + )), + _ => Ok(DataType::Timestamp(Microsecond, None)), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let timestamp = self.child.evaluate(batch)?; + let format = self.format.evaluate(batch)?; + let tz = self.timezone.clone(); + match (timestamp, format) { + (ColumnarValue::Array(ts), ColumnarValue::Scalar(Utf8(Some(format)))) => { + let ts = array_with_timezone( + ts, + tz.clone(), + Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), + )?; + let result = timestamp_trunc_dyn(&ts, format)?; + Ok(ColumnarValue::Array(result)) + } + (ColumnarValue::Array(ts), ColumnarValue::Array(formats)) => { + let ts = array_with_timezone( + ts, + tz.clone(), + Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), + )?; + let result = timestamp_trunc_array_fmt_dyn(&ts, &formats)?; + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Invalid input to function TimestampTrunc. \ + Expected (PrimitiveArray, Scalar, String) or \ + (PrimitiveArray, StringArray, String)" + .to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(TimestampTruncExec::new( + children[0].clone(), + self.format.clone(), + self.timezone.clone(), + ))) + } + + fn dyn_hash(&self, state: &mut dyn Hasher) { + let mut s = state; + self.child.hash(&mut s); + self.format.hash(&mut s); + self.timezone.hash(&mut s); + self.hash(&mut s); + } +} From 46e8bf287a93a977181da0c3499a582231afffa3 Mon Sep 17 00:00:00 2001 From: Vipul Vaibhaw Date: Tue, 16 Jul 2024 01:03:10 +0530 Subject: [PATCH 07/68] fix: Optimize some functions to rewrite dictionary-encoded strings (#627) * dedup code * transforming the dict directly * code optimization for cast string to timestamp * minor optimizations * fmt fixes and casting to dict array without unpacking to array first * bug fixes * revert unrelated change * Added test case and code refactor * minor optimization * minor optimization again * convert the cast to array * Revert "convert the cast to array" This reverts commit 9270aedeafa12dacabc664ca9df7c85236e05d85. * bug fixes * rename the test to cast_dict_to_timestamp arr --- src/cast.rs | 98 ++++++++++++++++++++++++++++++++++++++--------------- 1 file changed, 71 insertions(+), 27 deletions(-) diff --git a/src/cast.rs b/src/cast.rs index 7f53583e8d76..8702ce7070a8 100644 --- a/src/cast.rs +++ b/src/cast.rs @@ -31,7 +31,7 @@ use arrow::{ GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, OffsetSizeTrait, PrimitiveArray, }, - compute::{cast_with_options, unary, CastOptions}, + compute::{cast_with_options, take, unary, CastOptions}, datatypes::{ ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type, Int64Type, TimestampMicrosecondType, @@ -40,6 +40,7 @@ use arrow::{ record_batch::RecordBatch, util::display::FormatOptions, }; +use arrow_array::DictionaryArray; use arrow_schema::{DataType, Schema}; use datafusion_common::{ @@ -98,7 +99,6 @@ macro_rules! cast_utf8_to_int { result }}; } - macro_rules! cast_utf8_to_timestamp { ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ let len = $array.len(); @@ -507,19 +507,27 @@ impl Cast { let to_type = &self.data_type; let array = array_with_timezone(array, self.timezone.clone(), Some(to_type))?; let from_type = array.data_type().clone(); - - // unpack dictionary string arrays first - // TODO: we are unpacking a dictionary-encoded array and then performing - // the cast. We could potentially improve performance here by casting the - // dictionary values directly without unpacking the array first, although this - // would add more complexity to the code let array = match &from_type { DataType::Dictionary(key_type, value_type) if key_type.as_ref() == &DataType::Int32 && (value_type.as_ref() == &DataType::Utf8 || value_type.as_ref() == &DataType::LargeUtf8) => { - cast_with_options(&array, value_type.as_ref(), &CAST_OPTIONS)? + let dict_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a dictionary array"); + + let casted_dictionary = DictionaryArray::::new( + dict_array.keys().clone(), + self.cast_array(dict_array.values().clone())?, + ); + + let casted_result = match to_type { + DataType::Dictionary(_, _) => Arc::new(casted_dictionary.clone()), + _ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?, + }; + return Ok(spark_cast(casted_result, &from_type, to_type)); } _ => array, }; @@ -724,26 +732,31 @@ impl Cast { .downcast_ref::>() .expect("Expected a string array"); - let cast_array: ArrayRef = match to_type { - DataType::Date32 => { - let len = string_array.len(); - let mut cast_array = PrimitiveArray::::builder(len); - for i in 0..len { - if !string_array.is_null(i) { - match date_parser(string_array.value(i), eval_mode) { - Ok(Some(cast_value)) => cast_array.append_value(cast_value), - Ok(None) => cast_array.append_null(), - Err(e) => return Err(e), - } - } else { - cast_array.append_null() - } + if to_type != &DataType::Date32 { + unreachable!("Invalid data type {:?} in cast from string", to_type); + } + + let len = string_array.len(); + let mut cast_array = PrimitiveArray::::builder(len); + + for i in 0..len { + let value = if string_array.is_null(i) { + None + } else { + match date_parser(string_array.value(i), eval_mode) { + Ok(Some(cast_value)) => Some(cast_value), + Ok(None) => None, + Err(e) => return Err(e), } - Arc::new(cast_array.finish()) as ArrayRef + }; + + match value { + Some(cast_value) => cast_array.append_value(cast_value), + None => cast_array.append_null(), } - _ => unreachable!("Invalid data type {:?} in cast from string", to_type), - }; - Ok(cast_array) + } + + Ok(Arc::new(cast_array.finish()) as ArrayRef) } fn cast_string_to_timestamp( @@ -1796,6 +1809,37 @@ mod tests { assert_eq!(result.len(), 2); } + #[test] + fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> { + // prepare input data + let keys = Int32Array::from(vec![0, 1]); + let values: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020-01-01T12:34:56.123456"), + Some("T2"), + ])); + let dict_array = Arc::new(DictionaryArray::new(keys, values)); + + // prepare cast expression + let timezone = "UTC".to_string(); + let expr = Arc::new(Column::new("a", 0)); // this is not used by the test + let cast = Cast::new( + expr, + DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())), + EvalMode::Legacy, + timezone.clone(), + ); + + // test casting string dictionary array to timestamp array + let result = cast.cast_array(dict_array)?; + assert_eq!( + *result.data_type(), + DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into())) + ); + assert_eq!(result.len(), 2); + + Ok(()) + } + #[test] fn date_parser_test() { for date in &[ From 21793315c46434e9c60967de0a4ea7f9a29c30be Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Tue, 16 Jul 2024 13:17:46 -0600 Subject: [PATCH 08/68] Change suffix on some expressions from Exec to Expr (#673) --- src/lib.rs | 2 +- src/temporal.rs | 70 ++++++++++++++++++++++++------------------------- 2 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 5168e0e80747..91d61f70a14d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,7 +29,7 @@ pub use abs::Abs; pub use cast::Cast; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; -pub use temporal::{DateTruncExec, HourExec, MinuteExec, SecondExec, TimestampTruncExec}; +pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr}; /// Spark supports three evaluation modes when evaluating expressions, which affect /// the behavior when processing input values that are invalid or would result in an diff --git a/src/temporal.rs b/src/temporal.rs index ea30d3383dd5..34b71a284a4e 100644 --- a/src/temporal.rs +++ b/src/temporal.rs @@ -38,19 +38,19 @@ use crate::kernels::temporal::{ }; #[derive(Debug, Hash)] -pub struct HourExec { +pub struct HourExpr { /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) child: Arc, timezone: String, } -impl HourExec { +impl HourExpr { pub fn new(child: Arc, timezone: String) -> Self { - HourExec { child, timezone } + HourExpr { child, timezone } } } -impl Display for HourExec { +impl Display for HourExpr { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( f, @@ -60,7 +60,7 @@ impl Display for HourExec { } } -impl PartialEq for HourExec { +impl PartialEq for HourExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() @@ -69,7 +69,7 @@ impl PartialEq for HourExec { } } -impl PhysicalExpr for HourExec { +impl PhysicalExpr for HourExpr { fn as_any(&self) -> &dyn Any { self } @@ -117,7 +117,7 @@ impl PhysicalExpr for HourExec { self: Arc, children: Vec>, ) -> Result, DataFusionError> { - Ok(Arc::new(HourExec::new( + Ok(Arc::new(HourExpr::new( children[0].clone(), self.timezone.clone(), ))) @@ -132,19 +132,19 @@ impl PhysicalExpr for HourExec { } #[derive(Debug, Hash)] -pub struct MinuteExec { +pub struct MinuteExpr { /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) child: Arc, timezone: String, } -impl MinuteExec { +impl MinuteExpr { pub fn new(child: Arc, timezone: String) -> Self { - MinuteExec { child, timezone } + MinuteExpr { child, timezone } } } -impl Display for MinuteExec { +impl Display for MinuteExpr { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( f, @@ -154,7 +154,7 @@ impl Display for MinuteExec { } } -impl PartialEq for MinuteExec { +impl PartialEq for MinuteExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() @@ -163,7 +163,7 @@ impl PartialEq for MinuteExec { } } -impl PhysicalExpr for MinuteExec { +impl PhysicalExpr for MinuteExpr { fn as_any(&self) -> &dyn Any { self } @@ -211,7 +211,7 @@ impl PhysicalExpr for MinuteExec { self: Arc, children: Vec>, ) -> Result, DataFusionError> { - Ok(Arc::new(MinuteExec::new( + Ok(Arc::new(MinuteExpr::new( children[0].clone(), self.timezone.clone(), ))) @@ -226,19 +226,19 @@ impl PhysicalExpr for MinuteExec { } #[derive(Debug, Hash)] -pub struct SecondExec { +pub struct SecondExpr { /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) child: Arc, timezone: String, } -impl SecondExec { +impl SecondExpr { pub fn new(child: Arc, timezone: String) -> Self { - SecondExec { child, timezone } + SecondExpr { child, timezone } } } -impl Display for SecondExec { +impl Display for SecondExpr { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( f, @@ -248,7 +248,7 @@ impl Display for SecondExec { } } -impl PartialEq for SecondExec { +impl PartialEq for SecondExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() @@ -257,7 +257,7 @@ impl PartialEq for SecondExec { } } -impl PhysicalExpr for SecondExec { +impl PhysicalExpr for SecondExpr { fn as_any(&self) -> &dyn Any { self } @@ -305,7 +305,7 @@ impl PhysicalExpr for SecondExec { self: Arc, children: Vec>, ) -> Result, DataFusionError> { - Ok(Arc::new(SecondExec::new( + Ok(Arc::new(SecondExpr::new( children[0].clone(), self.timezone.clone(), ))) @@ -320,20 +320,20 @@ impl PhysicalExpr for SecondExec { } #[derive(Debug, Hash)] -pub struct DateTruncExec { +pub struct DateTruncExpr { /// An array with DataType::Date32 child: Arc, /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#trunc format: Arc, } -impl DateTruncExec { +impl DateTruncExpr { pub fn new(child: Arc, format: Arc) -> Self { - DateTruncExec { child, format } + DateTruncExpr { child, format } } } -impl Display for DateTruncExec { +impl Display for DateTruncExpr { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( f, @@ -343,7 +343,7 @@ impl Display for DateTruncExec { } } -impl PartialEq for DateTruncExec { +impl PartialEq for DateTruncExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() @@ -352,7 +352,7 @@ impl PartialEq for DateTruncExec { } } -impl PhysicalExpr for DateTruncExec { +impl PhysicalExpr for DateTruncExpr { fn as_any(&self) -> &dyn Any { self } @@ -392,7 +392,7 @@ impl PhysicalExpr for DateTruncExec { self: Arc, children: Vec>, ) -> Result, DataFusionError> { - Ok(Arc::new(DateTruncExec::new( + Ok(Arc::new(DateTruncExpr::new( children[0].clone(), self.format.clone(), ))) @@ -407,7 +407,7 @@ impl PhysicalExpr for DateTruncExec { } #[derive(Debug, Hash)] -pub struct TimestampTruncExec { +pub struct TimestampTruncExpr { /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) child: Arc, /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc @@ -421,13 +421,13 @@ pub struct TimestampTruncExec { timezone: String, } -impl TimestampTruncExec { +impl TimestampTruncExpr { pub fn new( child: Arc, format: Arc, timezone: String, ) -> Self { - TimestampTruncExec { + TimestampTruncExpr { child, format, timezone, @@ -435,7 +435,7 @@ impl TimestampTruncExec { } } -impl Display for TimestampTruncExec { +impl Display for TimestampTruncExpr { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!( f, @@ -445,7 +445,7 @@ impl Display for TimestampTruncExec { } } -impl PartialEq for TimestampTruncExec { +impl PartialEq for TimestampTruncExpr { fn eq(&self, other: &dyn Any) -> bool { down_cast_any_ref(other) .downcast_ref::() @@ -458,7 +458,7 @@ impl PartialEq for TimestampTruncExec { } } -impl PhysicalExpr for TimestampTruncExec { +impl PhysicalExpr for TimestampTruncExpr { fn as_any(&self) -> &dyn Any { self } @@ -517,7 +517,7 @@ impl PhysicalExpr for TimestampTruncExec { self: Arc, children: Vec>, ) -> Result, DataFusionError> { - Ok(Arc::new(TimestampTruncExec::new( + Ok(Arc::new(TimestampTruncExpr::new( children[0].clone(), self.format.clone(), self.timezone.clone(), From 01e21a931947326d9eb620e76c91068b4fdf1495 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Sat, 20 Jul 2024 14:08:00 -0600 Subject: [PATCH 09/68] chore: Disable abs and signum because they return incorrect results (#695) --- Cargo.toml | 1 - src/abs.rs | 89 ------------------------------------------------------ src/lib.rs | 2 -- 3 files changed, 92 deletions(-) delete mode 100644 src/abs.rs diff --git a/Cargo.toml b/Cargo.toml index 976a1f36f354..192ed102b7f6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,7 +33,6 @@ arrow-schema = { workspace = true } chrono = { workspace = true } datafusion = { workspace = true } datafusion-common = { workspace = true } -datafusion-functions = { workspace = true } datafusion-expr = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-plan = { workspace = true } diff --git a/src/abs.rs b/src/abs.rs deleted file mode 100644 index fa25a7775ae7..000000000000 --- a/src/abs.rs +++ /dev/null @@ -1,89 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -//! Spark-compatible implementation of abs function - -use std::{any::Any, sync::Arc}; - -use arrow::datatypes::DataType; -use arrow_schema::ArrowError; - -use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature}; -use datafusion_common::DataFusionError; -use datafusion_functions::math; - -use super::{EvalMode, SparkError}; - -/// Spark-compatible ABS expression -#[derive(Debug)] -pub struct Abs { - inner_abs_func: Arc, - eval_mode: EvalMode, - data_type_name: String, -} - -impl Abs { - pub fn new(eval_mode: EvalMode, data_type_name: String) -> Result { - if let EvalMode::Legacy | EvalMode::Ansi = eval_mode { - Ok(Self { - inner_abs_func: math::abs().inner().clone(), - eval_mode, - data_type_name, - }) - } else { - Err(DataFusionError::Execution(format!( - "Invalid EvalMode: \"{:?}\"", - eval_mode - ))) - } - } -} - -impl ScalarUDFImpl for Abs { - fn as_any(&self) -> &dyn Any { - self - } - fn name(&self) -> &str { - "abs" - } - - fn signature(&self) -> &Signature { - self.inner_abs_func.signature() - } - - fn return_type(&self, arg_types: &[DataType]) -> Result { - self.inner_abs_func.return_type(arg_types) - } - - fn invoke(&self, args: &[ColumnarValue]) -> Result { - match self.inner_abs_func.invoke(args) { - Err(DataFusionError::ArrowError(ArrowError::ComputeError(msg), _)) - if msg.contains("overflow") => - { - if self.eval_mode == EvalMode::Legacy { - Ok(args[0].clone()) - } else { - Err(SparkError::ArithmeticOverflow { - from_type: self.data_type_name.clone(), - } - .into()) - } - } - other => other, - } - } -} diff --git a/src/lib.rs b/src/lib.rs index 91d61f70a14d..336201f4846e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,7 +15,6 @@ // specific language governing permissions and limitations // under the License. -mod abs; mod cast; mod error; mod if_expr; @@ -25,7 +24,6 @@ mod temporal; pub mod timezone; pub mod utils; -pub use abs::Abs; pub use cast::Cast; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; From 01362b50b22e2629851e6ad363ecaef4e32b33ab Mon Sep 17 00:00:00 2001 From: Xuanwo Date: Mon, 22 Jul 2024 02:51:08 +0800 Subject: [PATCH 10/68] chore: Make rust clippy happy (#701) * chore: Make rust clippy happy Signed-off-by: Xuanwo * Format code Signed-off-by: Xuanwo --------- Signed-off-by: Xuanwo --- src/cast.rs | 10 +++++----- src/kernels/temporal.rs | 28 ++++++++++++++-------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/src/cast.rs b/src/cast.rs index 8702ce7070a8..9a47cc87334e 100644 --- a/src/cast.rs +++ b/src/cast.rs @@ -1854,7 +1854,7 @@ mod tests { "2020-01-01T", ] { for eval_mode in &[EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try] { - assert_eq!(date_parser(*date, *eval_mode).unwrap(), Some(18262)); + assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(18262)); } } @@ -1875,14 +1875,14 @@ mod tests { "--262143-12-31 ", ] { for eval_mode in &[EvalMode::Legacy, EvalMode::Try] { - assert_eq!(date_parser(*date, *eval_mode).unwrap(), None); + assert_eq!(date_parser(date, *eval_mode).unwrap(), None); } - assert!(date_parser(*date, EvalMode::Ansi).is_err()); + assert!(date_parser(date, EvalMode::Ansi).is_err()); } for date in &["-3638-5"] { for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { - assert_eq!(date_parser(*date, *eval_mode).unwrap(), Some(-2048160)); + assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(-2048160)); } } @@ -1898,7 +1898,7 @@ mod tests { "-0973250", ] { for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { - assert_eq!(date_parser(*date, *eval_mode).unwrap(), None); + assert_eq!(date_parser(date, *eval_mode).unwrap(), None); } } } diff --git a/src/kernels/temporal.rs b/src/kernels/temporal.rs index 6f2474e8d7a8..cda4bef5d184 100644 --- a/src/kernels/temporal.rs +++ b/src/kernels/temporal.rs @@ -838,7 +838,7 @@ mod tests { assert!(array.values().get(i) >= a.values().get(i)) } } - _ => assert!(false), + _ => unreachable!(), } } } @@ -854,9 +854,9 @@ mod tests { let mut vec: Vec = Vec::with_capacity(size * formats.len()); let mut fmt_vec: Vec<&str> = Vec::with_capacity(size * formats.len()); for i in 0..size { - for j in 0..formats.len() { + for fmt_value in &formats { vec.push(i as i32 * 1_000_001); - fmt_vec.push(formats[j]); + fmt_vec.push(fmt_value); } } @@ -928,7 +928,7 @@ mod tests { ) } } else { - assert!(false) + unreachable!() } if let Ok(a) = date_trunc_array_fmt_dyn(&array_dict, &fmt_array) { for i in 0..array.len() { @@ -937,7 +937,7 @@ mod tests { ) } } else { - assert!(false) + unreachable!() } if let Ok(a) = date_trunc_array_fmt_dyn(&array, &fmt_dict) { for i in 0..array.len() { @@ -946,7 +946,7 @@ mod tests { ) } } else { - assert!(false) + unreachable!() } if let Ok(a) = date_trunc_array_fmt_dyn(&array_dict, &fmt_dict) { for i in 0..array.len() { @@ -955,7 +955,7 @@ mod tests { ) } } else { - assert!(false) + unreachable!() } } @@ -991,7 +991,7 @@ mod tests { assert!(array.values().get(i) >= a.values().get(i)) } } - _ => assert!(false), + _ => unreachable!(), } } } @@ -1023,9 +1023,9 @@ mod tests { let mut vec: Vec = Vec::with_capacity(size * formats.len()); let mut fmt_vec: Vec<&str> = Vec::with_capacity(size * formats.len()); for i in 0..size { - for j in 0..formats.len() { + for fmt_value in &formats { vec.push(i as i64 * 1_000_000_001); - fmt_vec.push(formats[j]); + fmt_vec.push(fmt_value); } } @@ -1103,7 +1103,7 @@ mod tests { ) } } else { - assert!(false) + unreachable!() } if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array_dict, &fmt_array) { for i in 0..array.len() { @@ -1116,7 +1116,7 @@ mod tests { ) } } else { - assert!(false) + unreachable!() } if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array, &fmt_dict) { for i in 0..array.len() { @@ -1129,7 +1129,7 @@ mod tests { ) } } else { - assert!(false) + unreachable!() } if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array_dict, &fmt_dict) { for i in 0..array.len() { @@ -1142,7 +1142,7 @@ mod tests { ) } } else { - assert!(false) + unreachable!() } } } From e2d838ec3abdd7701198738799e40bb54abe5814 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 24 Jul 2024 15:11:21 -0600 Subject: [PATCH 11/68] perf: Optimize IfExpr by delegating to CaseExpr (#681) * Unify IF and CASE expressions * revert test changes * fix --- Cargo.toml | 17 +++++ benches/cast_from_string.rs | 91 +++++++++++++++++++++++ benches/cast_numeric.rs | 79 ++++++++++++++++++++ benches/conditional.rs | 139 ++++++++++++++++++++++++++++++++++++ src/if_expr.rs | 44 ++++-------- 5 files changed, 340 insertions(+), 30 deletions(-) create mode 100644 benches/cast_from_string.rs create mode 100644 benches/cast_numeric.rs create mode 100644 benches/conditional.rs diff --git a/Cargo.toml b/Cargo.toml index 192ed102b7f6..aa4fcfc5f022 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ chrono = { workspace = true } datafusion = { workspace = true } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } +datafusion-physical-expr-common = { workspace = true } datafusion-physical-expr = { workspace = true } datafusion-physical-plan = { workspace = true } chrono-tz = { workspace = true } @@ -41,6 +42,22 @@ num = { workspace = true } regex = { workspace = true } thiserror = { workspace = true } +[dev-dependencies] +criterion = "0.5.1" +rand = "0.8.5" + [lib] name = "datafusion_comet_spark_expr" path = "src/lib.rs" + +[[bench]] +name = "cast_from_string" +harness = false + +[[bench]] +name = "cast_numeric" +harness = false + +[[bench]] +name = "conditional" +harness = false \ No newline at end of file diff --git a/benches/cast_from_string.rs b/benches/cast_from_string.rs new file mode 100644 index 000000000000..51410a68ad90 --- /dev/null +++ b/benches/cast_from_string.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{builder::StringBuilder, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_comet_spark_expr::{Cast, EvalMode}; +use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let batch = create_utf8_batch(); + let expr = Arc::new(Column::new("a", 0)); + let timezone = "".to_string(); + let cast_string_to_i8 = Cast::new( + expr.clone(), + DataType::Int8, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_string_to_i16 = Cast::new( + expr.clone(), + DataType::Int16, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_string_to_i32 = Cast::new( + expr.clone(), + DataType::Int32, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_string_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone); + + let mut group = c.benchmark_group("cast_string_to_int"); + group.bench_function("cast_string_to_i8", |b| { + b.iter(|| cast_string_to_i8.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i16", |b| { + b.iter(|| cast_string_to_i16.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i32", |b| { + b.iter(|| cast_string_to_i32.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i64", |b| { + b.iter(|| cast_string_to_i64.evaluate(&batch).unwrap()); + }); +} + +// Create UTF8 batch with strings representing ints, floats, nulls +fn create_utf8_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let mut b = StringBuilder::new(); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else if i % 2 == 0 { + b.append_value(format!("{}", rand::random::())); + } else { + b.append_value(format!("{}", rand::random::())); + } + } + let array = b.finish(); + + RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap() +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/benches/cast_numeric.rs b/benches/cast_numeric.rs new file mode 100644 index 000000000000..dc0ceea79ad1 --- /dev/null +++ b/benches/cast_numeric.rs @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{builder::Int32Builder, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_comet_spark_expr::{Cast, EvalMode}; +use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let batch = create_int32_batch(); + let expr = Arc::new(Column::new("a", 0)); + let timezone = "".to_string(); + let cast_i32_to_i8 = Cast::new( + expr.clone(), + DataType::Int8, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_i32_to_i16 = Cast::new( + expr.clone(), + DataType::Int16, + EvalMode::Legacy, + timezone.clone(), + ); + let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, EvalMode::Legacy, timezone); + + let mut group = c.benchmark_group("cast_int_to_int"); + group.bench_function("cast_i32_to_i8", |b| { + b.iter(|| cast_i32_to_i8.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_i32_to_i16", |b| { + b.iter(|| cast_i32_to_i16.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_i32_to_i64", |b| { + b.iter(|| cast_i32_to_i64.evaluate(&batch).unwrap()); + }); +} + +fn create_int32_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let mut b = Int32Builder::new(); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(rand::random::()); + } + } + let array = b.finish(); + + RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap() +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/benches/conditional.rs b/benches/conditional.rs new file mode 100644 index 000000000000..d86ef76f82ee --- /dev/null +++ b/benches/conditional.rs @@ -0,0 +1,139 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{Field, Schema}; +use arrow::record_batch::RecordBatch; +use arrow_array::builder::{Int32Builder, StringBuilder}; +use arrow_schema::DataType; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_comet_spark_expr::IfExpr; +use datafusion_common::ScalarValue; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr}; +use datafusion_physical_expr_common::expressions::column::Column; +use datafusion_physical_expr_common::expressions::Literal; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use std::sync::Arc; + +fn make_col(name: &str, index: usize) -> Arc { + Arc::new(Column::new(name, index)) +} + +fn make_lit_i32(n: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) +} + +fn make_null_lit() -> Arc { + Arc::new(Literal::new(ScalarValue::Utf8(None))) +} + +fn criterion_benchmark(c: &mut Criterion) { + // create input data + let mut c1 = Int32Builder::new(); + let mut c2 = StringBuilder::new(); + let mut c3 = StringBuilder::new(); + for i in 0..1000 { + c1.append_value(i); + if i % 7 == 0 { + c2.append_null(); + } else { + c2.append_value(&format!("string {i}")); + } + if i % 9 == 0 { + c3.append_null(); + } else { + c3.append_value(&format!("other string {i}")); + } + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + let c3 = Arc::new(c3.finish()); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + Field::new("c3", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap(); + + // use same predicate for all benchmarks + let predicate = Arc::new(BinaryExpr::new( + make_col("c1", 0), + Operator::LtEq, + make_lit_i32(500), + )); + + // CASE WHEN c1 <= 500 THEN 1 ELSE 0 END + c.bench_function("case_when: scalar or scalar", |b| { + let expr = Arc::new( + CaseExpr::try_new( + None, + vec![(predicate.clone(), make_lit_i32(1))], + Some(make_lit_i32(0)), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + c.bench_function("if: scalar or scalar", |b| { + let expr = Arc::new(IfExpr::new( + predicate.clone(), + make_lit_i32(1), + make_lit_i32(0), + )); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + // CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END + c.bench_function("case_when: column or null", |b| { + let expr = Arc::new( + CaseExpr::try_new(None, vec![(predicate.clone(), make_col("c2", 1))], None).unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + c.bench_function("if: column or null", |b| { + let expr = Arc::new(IfExpr::new( + predicate.clone(), + make_col("c2", 1), + make_null_lit(), + )); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + // CASE WHEN c1 <= 500 THEN c2 ELSE c3 END + c.bench_function("case_when: expr or expr", |b| { + let expr = Arc::new( + CaseExpr::try_new( + None, + vec![(predicate.clone(), make_col("c2", 1))], + Some(make_col("c3", 2)), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + c.bench_function("if: expr or expr", |b| { + let expr = Arc::new(IfExpr::new( + predicate.clone(), + make_col("c2", 1), + make_col("c3", 2), + )); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/src/if_expr.rs b/src/if_expr.rs index fa52c5d5b9b9..a5344140bb8a 100644 --- a/src/if_expr.rs +++ b/src/if_expr.rs @@ -22,22 +22,24 @@ use std::{ }; use arrow::{ - array::*, - compute::{and, is_null, kernels::zip::zip, not, or_kleene}, datatypes::{DataType, Schema}, record_batch::RecordBatch, }; use datafusion::logical_expr::ColumnarValue; -use datafusion_common::{cast::as_boolean_array, Result}; -use datafusion_physical_expr::PhysicalExpr; +use datafusion_common::Result; +use datafusion_physical_expr::{expressions::CaseExpr, PhysicalExpr}; use crate::utils::down_cast_any_ref; +/// IfExpr is a wrapper around CaseExpr, because `IF(a, b, c)` is semantically equivalent to +/// `CASE WHEN a THEN b ELSE c END`. #[derive(Debug, Hash)] pub struct IfExpr { if_expr: Arc, true_expr: Arc, false_expr: Arc, + // we delegate to case_expr for evaluation + case_expr: Arc, } impl std::fmt::Display for IfExpr { @@ -58,9 +60,12 @@ impl IfExpr { false_expr: Arc, ) -> Self { Self { - if_expr, - true_expr, - false_expr, + if_expr: if_expr.clone(), + true_expr: true_expr.clone(), + false_expr: false_expr.clone(), + case_expr: Arc::new( + CaseExpr::try_new(None, vec![(if_expr, true_expr)], Some(false_expr)).unwrap(), + ), } } } @@ -85,29 +90,7 @@ impl PhysicalExpr for IfExpr { } fn evaluate(&self, batch: &RecordBatch) -> Result { - let mut remainder = BooleanArray::from(vec![true; batch.num_rows()]); - - // evaluate if condition on batch - let if_value = self.if_expr.evaluate_selection(batch, &remainder)?; - let if_value = if_value.into_array(batch.num_rows())?; - let if_value = - as_boolean_array(&if_value).expect("if expression did not return a BooleanArray"); - - let true_value = self.true_expr.evaluate_selection(batch, if_value)?; - let true_value = true_value.into_array(batch.num_rows())?; - - remainder = and( - &remainder, - &or_kleene(¬(if_value)?, &is_null(if_value)?)?, - )?; - - let false_value = self - .false_expr - .evaluate_selection(batch, &remainder)? - .into_array(batch.num_rows())?; - let current_value = zip(&remainder, &false_value, &true_value)?; - - Ok(ColumnarValue::Array(current_value)) + self.case_expr.evaluate(batch) } fn children(&self) -> Vec<&Arc> { @@ -150,6 +133,7 @@ impl PartialEq for IfExpr { #[cfg(test)] mod tests { use arrow::{array::StringArray, datatypes::*}; + use arrow_array::Int32Array; use datafusion::logical_expr::Operator; use datafusion_common::cast::as_int32_array; use datafusion_physical_expr::expressions::{binary, col, lit}; From 5dcf7138be3f3e5e86568866566882cc49c6f811 Mon Sep 17 00:00:00 2001 From: Arttu Date: Sat, 27 Jul 2024 14:38:05 +0200 Subject: [PATCH 12/68] chore: make Cast's logic reusable for other projects (#716) --- src/cast.rs | 1099 +++++++++++++++++++++++++-------------------------- src/lib.rs | 2 +- 2 files changed, 550 insertions(+), 551 deletions(-) diff --git a/src/cast.rs b/src/cast.rs index 9a47cc87334e..ae0818970f03 100644 --- a/src/cast.rs +++ b/src/cast.rs @@ -502,158 +502,166 @@ impl Cast { eval_mode, } } +} - fn cast_array(&self, array: ArrayRef) -> DataFusionResult { - let to_type = &self.data_type; - let array = array_with_timezone(array, self.timezone.clone(), Some(to_type))?; - let from_type = array.data_type().clone(); - let array = match &from_type { - DataType::Dictionary(key_type, value_type) - if key_type.as_ref() == &DataType::Int32 - && (value_type.as_ref() == &DataType::Utf8 - || value_type.as_ref() == &DataType::LargeUtf8) => - { - let dict_array = array - .as_any() - .downcast_ref::>() - .expect("Expected a dictionary array"); - - let casted_dictionary = DictionaryArray::::new( - dict_array.keys().clone(), - self.cast_array(dict_array.values().clone())?, - ); - - let casted_result = match to_type { - DataType::Dictionary(_, _) => Arc::new(casted_dictionary.clone()), - _ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?, - }; - return Ok(spark_cast(casted_result, &from_type, to_type)); - } - _ => array, - }; - let from_type = array.data_type(); - - let cast_result = match (from_type, to_type) { - (DataType::Utf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode) - } - (DataType::LargeUtf8, DataType::Boolean) => { - Self::spark_cast_utf8_to_boolean::(&array, self.eval_mode) - } - (DataType::Utf8, DataType::Timestamp(_, _)) => { - Self::cast_string_to_timestamp(&array, to_type, self.eval_mode) - } - (DataType::Utf8, DataType::Date32) => { - Self::cast_string_to_date(&array, to_type, self.eval_mode) - } - (DataType::Int64, DataType::Int32) - | (DataType::Int64, DataType::Int16) - | (DataType::Int64, DataType::Int8) - | (DataType::Int32, DataType::Int16) - | (DataType::Int32, DataType::Int8) - | (DataType::Int16, DataType::Int8) - if self.eval_mode != EvalMode::Try => - { - Self::spark_cast_int_to_int(&array, self.eval_mode, from_type, to_type) - } - ( - DataType::Utf8, - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode), - ( - DataType::LargeUtf8, - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, - ) => Self::cast_string_to_int::(to_type, &array, self.eval_mode), - (DataType::Float64, DataType::Utf8) => { - Self::spark_cast_float64_to_utf8::(&array, self.eval_mode) - } - (DataType::Float64, DataType::LargeUtf8) => { - Self::spark_cast_float64_to_utf8::(&array, self.eval_mode) - } - (DataType::Float32, DataType::Utf8) => { - Self::spark_cast_float32_to_utf8::(&array, self.eval_mode) - } - (DataType::Float32, DataType::LargeUtf8) => { - Self::spark_cast_float32_to_utf8::(&array, self.eval_mode) - } - (DataType::Float32, DataType::Decimal128(precision, scale)) => { - Self::cast_float32_to_decimal128(&array, *precision, *scale, self.eval_mode) - } - (DataType::Float64, DataType::Decimal128(precision, scale)) => { - Self::cast_float64_to_decimal128(&array, *precision, *scale, self.eval_mode) - } - (DataType::Float32, DataType::Int8) - | (DataType::Float32, DataType::Int16) - | (DataType::Float32, DataType::Int32) - | (DataType::Float32, DataType::Int64) - | (DataType::Float64, DataType::Int8) - | (DataType::Float64, DataType::Int16) - | (DataType::Float64, DataType::Int32) - | (DataType::Float64, DataType::Int64) - | (DataType::Decimal128(_, _), DataType::Int8) - | (DataType::Decimal128(_, _), DataType::Int16) - | (DataType::Decimal128(_, _), DataType::Int32) - | (DataType::Decimal128(_, _), DataType::Int64) - if self.eval_mode != EvalMode::Try => - { - Self::spark_cast_nonintegral_numeric_to_integral( - &array, - self.eval_mode, - from_type, - to_type, - ) - } - _ if Self::is_datafusion_spark_compatible(from_type, to_type) => { - // use DataFusion cast only when we know that it is compatible with Spark - Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) - } - _ => { - // we should never reach this code because the Scala code should be checking - // for supported cast operations and falling back to Spark for anything that - // is not yet supported - Err(SparkError::Internal(format!( - "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}" - ))) - } - }; - Ok(spark_cast(cast_result?, from_type, to_type)) +/// Spark-compatible cast implementation. Defers to DataFusion's cast where that is known +/// to be compatible, and returns an error when a not supported and not DF-compatible cast +/// is requested. +pub fn spark_cast( + arg: ColumnarValue, + data_type: &DataType, + eval_mode: EvalMode, + timezone: String, +) -> DataFusionResult { + match arg { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array( + array, + data_type, + eval_mode, + timezone.to_owned(), + )?)), + ColumnarValue::Scalar(scalar) => { + // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for + // some cases e.g., scalar subquery, Spark will not fold it, so we need to handle it + // here. + let array = scalar.to_array()?; + let scalar = ScalarValue::try_from_array( + &cast_array(array, data_type, eval_mode, timezone.to_owned())?, + 0, + )?; + Ok(ColumnarValue::Scalar(scalar)) + } } +} - /// Determines if DataFusion supports the given cast in a way that is - /// compatible with Spark - fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { - if from_type == to_type { - return true; +fn cast_array( + array: ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, + timezone: String, +) -> DataFusionResult { + let array = array_with_timezone(array, timezone.clone(), Some(to_type))?; + let from_type = array.data_type().clone(); + let array = match &from_type { + DataType::Dictionary(key_type, value_type) + if key_type.as_ref() == &DataType::Int32 + && (value_type.as_ref() == &DataType::Utf8 + || value_type.as_ref() == &DataType::LargeUtf8) => + { + let dict_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a dictionary array"); + + let casted_dictionary = DictionaryArray::::new( + dict_array.keys().clone(), + cast_array(dict_array.values().clone(), to_type, eval_mode, timezone)?, + ); + + let casted_result = match to_type { + DataType::Dictionary(_, _) => Arc::new(casted_dictionary.clone()), + _ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?, + }; + return Ok(spark_cast_postprocess(casted_result, &from_type, to_type)); } - match from_type { - DataType::Boolean => matches!( - to_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Utf8 - ), - DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { - // note that the cast from Int32/Int64 -> Decimal128 here is actually - // not compatible with Spark (no overflow checks) but we have tests that - // rely on this cast working so we have to leave it here for now - matches!( - to_type, - DataType::Boolean - | DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 - | DataType::Decimal128(_, _) - | DataType::Utf8 - ) - } - DataType::Float32 | DataType::Float64 => matches!( + _ => array, + }; + let from_type = array.data_type(); + + let cast_result = match (from_type, to_type) { + (DataType::Utf8, DataType::Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), + (DataType::LargeUtf8, DataType::Boolean) => { + spark_cast_utf8_to_boolean::(&array, eval_mode) + } + (DataType::Utf8, DataType::Timestamp(_, _)) => { + cast_string_to_timestamp(&array, to_type, eval_mode) + } + (DataType::Utf8, DataType::Date32) => cast_string_to_date(&array, to_type, eval_mode), + (DataType::Int64, DataType::Int32) + | (DataType::Int64, DataType::Int16) + | (DataType::Int64, DataType::Int8) + | (DataType::Int32, DataType::Int16) + | (DataType::Int32, DataType::Int8) + | (DataType::Int16, DataType::Int8) + if eval_mode != EvalMode::Try => + { + spark_cast_int_to_int(&array, eval_mode, from_type, to_type) + } + (DataType::Utf8, DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64) => { + cast_string_to_int::(to_type, &array, eval_mode) + } + ( + DataType::LargeUtf8, + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64, + ) => cast_string_to_int::(to_type, &array, eval_mode), + (DataType::Float64, DataType::Utf8) => spark_cast_float64_to_utf8::(&array, eval_mode), + (DataType::Float64, DataType::LargeUtf8) => { + spark_cast_float64_to_utf8::(&array, eval_mode) + } + (DataType::Float32, DataType::Utf8) => spark_cast_float32_to_utf8::(&array, eval_mode), + (DataType::Float32, DataType::LargeUtf8) => { + spark_cast_float32_to_utf8::(&array, eval_mode) + } + (DataType::Float32, DataType::Decimal128(precision, scale)) => { + cast_float32_to_decimal128(&array, *precision, *scale, eval_mode) + } + (DataType::Float64, DataType::Decimal128(precision, scale)) => { + cast_float64_to_decimal128(&array, *precision, *scale, eval_mode) + } + (DataType::Float32, DataType::Int8) + | (DataType::Float32, DataType::Int16) + | (DataType::Float32, DataType::Int32) + | (DataType::Float32, DataType::Int64) + | (DataType::Float64, DataType::Int8) + | (DataType::Float64, DataType::Int16) + | (DataType::Float64, DataType::Int32) + | (DataType::Float64, DataType::Int64) + | (DataType::Decimal128(_, _), DataType::Int8) + | (DataType::Decimal128(_, _), DataType::Int16) + | (DataType::Decimal128(_, _), DataType::Int32) + | (DataType::Decimal128(_, _), DataType::Int64) + if eval_mode != EvalMode::Try => + { + spark_cast_nonintegral_numeric_to_integral(&array, eval_mode, from_type, to_type) + } + _ if is_datafusion_spark_compatible(from_type, to_type) => { + // use DataFusion cast only when we know that it is compatible with Spark + Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) + } + _ => { + // we should never reach this code because the Scala code should be checking + // for supported cast operations and falling back to Spark for anything that + // is not yet supported + Err(SparkError::Internal(format!( + "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}" + ))) + } + }; + Ok(spark_cast_postprocess(cast_result?, from_type, to_type)) +} + +/// Determines if DataFusion supports the given cast in a way that is +/// compatible with Spark +fn is_datafusion_spark_compatible(from_type: &DataType, to_type: &DataType) -> bool { + if from_type == to_type { + return true; + } + match from_type { + DataType::Boolean => matches!( + to_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + ), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + // note that the cast from Int32/Int64 -> Decimal128 here is actually + // not compatible with Spark (no overflow checks) but we have tests that + // rely on this cast working so we have to leave it here for now + matches!( to_type, DataType::Boolean | DataType::Int8 @@ -662,182 +670,180 @@ impl Cast { | DataType::Int64 | DataType::Float32 | DataType::Float64 - ), - DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!( - to_type, - DataType::Int8 - | DataType::Int16 - | DataType::Int32 - | DataType::Int64 - | DataType::Float32 - | DataType::Float64 | DataType::Decimal128(_, _) - | DataType::Decimal256(_, _) - ), - DataType::Utf8 => matches!(to_type, DataType::Binary), - DataType::Date32 => matches!(to_type, DataType::Utf8), - DataType::Timestamp(_, _) => { - matches!( - to_type, - DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) - ) - } - DataType::Binary => { - // note that this is not completely Spark compatible because - // DataFusion only supports binary data containing valid UTF-8 strings - matches!(to_type, DataType::Utf8) - } - _ => false, + | DataType::Utf8 + ) + } + DataType::Float32 | DataType::Float64 => matches!( + to_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ), + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!( + to_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + ), + DataType::Utf8 => matches!(to_type, DataType::Binary), + DataType::Date32 => matches!(to_type, DataType::Utf8), + DataType::Timestamp(_, _) => { + matches!( + to_type, + DataType::Int64 | DataType::Date32 | DataType::Utf8 | DataType::Timestamp(_, _) + ) } + DataType::Binary => { + // note that this is not completely Spark compatible because + // DataFusion only supports binary data containing valid UTF-8 strings + matches!(to_type, DataType::Utf8) + } + _ => false, } +} - fn cast_string_to_int( - to_type: &DataType, - array: &ArrayRef, - eval_mode: EvalMode, - ) -> SparkResult { - let string_array = array - .as_any() - .downcast_ref::>() - .expect("cast_string_to_int expected a string array"); +fn cast_string_to_int( + to_type: &DataType, + array: &ArrayRef, + eval_mode: EvalMode, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("cast_string_to_int expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Int8 => cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)?, + DataType::Int16 => { + cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)? + } + DataType::Int32 => { + cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)? + } + DataType::Int64 => { + cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)? + } + dt => unreachable!( + "{}", + format!("invalid integer type {dt} in cast from string") + ), + }; + Ok(cast_array) +} - let cast_array: ArrayRef = match to_type { - DataType::Int8 => { - cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)? - } - DataType::Int16 => { - cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)? - } - DataType::Int32 => { - cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)? - } - DataType::Int64 => { - cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)? - } - dt => unreachable!( - "{}", - format!("invalid integer type {dt} in cast from string") - ), - }; - Ok(cast_array) +fn cast_string_to_date( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + if to_type != &DataType::Date32 { + unreachable!("Invalid data type {:?} in cast from string", to_type); } - fn cast_string_to_date( - array: &ArrayRef, - to_type: &DataType, - eval_mode: EvalMode, - ) -> SparkResult { - let string_array = array - .as_any() - .downcast_ref::>() - .expect("Expected a string array"); + let len = string_array.len(); + let mut cast_array = PrimitiveArray::::builder(len); - if to_type != &DataType::Date32 { - unreachable!("Invalid data type {:?} in cast from string", to_type); - } + for i in 0..len { + let value = if string_array.is_null(i) { + None + } else { + match date_parser(string_array.value(i), eval_mode) { + Ok(Some(cast_value)) => Some(cast_value), + Ok(None) => None, + Err(e) => return Err(e), + } + }; - let len = string_array.len(); - let mut cast_array = PrimitiveArray::::builder(len); + match value { + Some(cast_value) => cast_array.append_value(cast_value), + None => cast_array.append_null(), + } + } - for i in 0..len { - let value = if string_array.is_null(i) { - None - } else { - match date_parser(string_array.value(i), eval_mode) { - Ok(Some(cast_value)) => Some(cast_value), - Ok(None) => None, - Err(e) => return Err(e), - } - }; + Ok(Arc::new(cast_array.finish()) as ArrayRef) +} - match value { - Some(cast_value) => cast_array.append_value(cast_value), - None => cast_array.append_null(), - } +fn cast_string_to_timestamp( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Timestamp(_, _) => { + cast_utf8_to_timestamp!( + string_array, + eval_mode, + TimestampMicrosecondType, + timestamp_parser + ) } + _ => unreachable!("Invalid data type {:?} in cast from string", to_type), + }; + Ok(cast_array) +} - Ok(Arc::new(cast_array.finish()) as ArrayRef) - } +fn cast_float64_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> SparkResult { + cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) +} - fn cast_string_to_timestamp( - array: &ArrayRef, - to_type: &DataType, - eval_mode: EvalMode, - ) -> SparkResult { - let string_array = array - .as_any() - .downcast_ref::>() - .expect("Expected a string array"); +fn cast_float32_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> SparkResult { + cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) +} - let cast_array: ArrayRef = match to_type { - DataType::Timestamp(_, _) => { - cast_utf8_to_timestamp!( - string_array, - eval_mode, - TimestampMicrosecondType, - timestamp_parser - ) - } - _ => unreachable!("Invalid data type {:?} in cast from string", to_type), - }; - Ok(cast_array) - } +fn cast_floating_point_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> SparkResult +where + ::Native: AsPrimitive, +{ + let input = array.as_any().downcast_ref::>().unwrap(); + let mut cast_array = PrimitiveArray::::builder(input.len()); - fn cast_float64_to_decimal128( - array: &dyn Array, - precision: u8, - scale: i8, - eval_mode: EvalMode, - ) -> SparkResult { - Self::cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) - } + let mul = 10_f64.powi(scale as i32); - fn cast_float32_to_decimal128( - array: &dyn Array, - precision: u8, - scale: i8, - eval_mode: EvalMode, - ) -> SparkResult { - Self::cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) - } + for i in 0..input.len() { + if input.is_null(i) { + cast_array.append_null(); + } else { + let input_value = input.value(i).as_(); + let value = (input_value * mul).round().to_i128(); - fn cast_floating_point_to_decimal128( - array: &dyn Array, - precision: u8, - scale: i8, - eval_mode: EvalMode, - ) -> SparkResult - where - ::Native: AsPrimitive, - { - let input = array.as_any().downcast_ref::>().unwrap(); - let mut cast_array = PrimitiveArray::::builder(input.len()); - - let mul = 10_f64.powi(scale as i32); - - for i in 0..input.len() { - if input.is_null(i) { - cast_array.append_null(); - } else { - let input_value = input.value(i).as_(); - let value = (input_value * mul).round().to_i128(); - - match value { - Some(v) => { - if Decimal128Type::validate_decimal_precision(v, precision).is_err() { - if eval_mode == EvalMode::Ansi { - return Err(SparkError::NumericValueOutOfRange { - value: input_value.to_string(), - precision, - scale, - }); - } else { - cast_array.append_null(); - } - } - cast_array.append_value(v); - } - None => { + match value { + Some(v) => { + if Decimal128Type::validate_decimal_precision(v, precision).is_err() { if eval_mode == EvalMode::Ansi { return Err(SparkError::NumericValueOutOfRange { value: input_value.to_string(), @@ -848,240 +854,252 @@ impl Cast { cast_array.append_null(); } } + cast_array.append_value(v); + } + None => { + if eval_mode == EvalMode::Ansi { + return Err(SparkError::NumericValueOutOfRange { + value: input_value.to_string(), + precision, + scale, + }); + } else { + cast_array.append_null(); + } } } } - - let res = Arc::new( - cast_array - .with_precision_and_scale(precision, scale)? - .finish(), - ) as ArrayRef; - Ok(res) } - fn spark_cast_float64_to_utf8( - from: &dyn Array, - _eval_mode: EvalMode, - ) -> SparkResult - where - OffsetSize: OffsetSizeTrait, - { - cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize) - } + let res = Arc::new( + cast_array + .with_precision_and_scale(precision, scale)? + .finish(), + ) as ArrayRef; + Ok(res) +} - fn spark_cast_float32_to_utf8( - from: &dyn Array, - _eval_mode: EvalMode, - ) -> SparkResult - where - OffsetSize: OffsetSizeTrait, - { - cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize) - } +fn spark_cast_float64_to_utf8( + from: &dyn Array, + _eval_mode: EvalMode, +) -> SparkResult +where + OffsetSize: OffsetSizeTrait, +{ + cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize) +} - fn spark_cast_int_to_int( - array: &dyn Array, - eval_mode: EvalMode, - from_type: &DataType, - to_type: &DataType, - ) -> SparkResult { - match (from_type, to_type) { - (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!( - array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT" - ), - (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!( - array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT" - ), - (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!( - array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT" - ), - (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!( - array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT" - ), - (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!( - array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT" - ), - (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!( - array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT" - ), - _ => unreachable!( - "{}", - format!("invalid integer type {to_type} in cast from {from_type}") - ), - } +fn spark_cast_float32_to_utf8( + from: &dyn Array, + _eval_mode: EvalMode, +) -> SparkResult +where + OffsetSize: OffsetSizeTrait, +{ + cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize) +} + +fn spark_cast_int_to_int( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, +) -> SparkResult { + match (from_type, to_type) { + (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT" + ), + (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT" + ), + (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT" + ), + (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!( + array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT" + ), + (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT" + ), + (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT" + ), + _ => unreachable!( + "{}", + format!("invalid integer type {to_type} in cast from {from_type}") + ), } +} - fn spark_cast_utf8_to_boolean( - from: &dyn Array, - eval_mode: EvalMode, - ) -> SparkResult - where - OffsetSize: OffsetSizeTrait, - { - let array = from - .as_any() - .downcast_ref::>() - .unwrap(); +fn spark_cast_utf8_to_boolean( + from: &dyn Array, + eval_mode: EvalMode, +) -> SparkResult +where + OffsetSize: OffsetSizeTrait, +{ + let array = from + .as_any() + .downcast_ref::>() + .unwrap(); - let output_array = array - .iter() - .map(|value| match value { - Some(value) => match value.to_ascii_lowercase().trim() { - "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)), - "f" | "false" | "n" | "no" | "0" => Ok(Some(false)), - _ if eval_mode == EvalMode::Ansi => Err(SparkError::CastInvalidValue { - value: value.to_string(), - from_type: "STRING".to_string(), - to_type: "BOOLEAN".to_string(), - }), - _ => Ok(None), - }, + let output_array = array + .iter() + .map(|value| match value { + Some(value) => match value.to_ascii_lowercase().trim() { + "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)), + "f" | "false" | "n" | "no" | "0" => Ok(Some(false)), + _ if eval_mode == EvalMode::Ansi => Err(SparkError::CastInvalidValue { + value: value.to_string(), + from_type: "STRING".to_string(), + to_type: "BOOLEAN".to_string(), + }), _ => Ok(None), - }) - .collect::>()?; + }, + _ => Ok(None), + }) + .collect::>()?; - Ok(Arc::new(output_array)) - } + Ok(Arc::new(output_array)) +} - fn spark_cast_nonintegral_numeric_to_integral( - array: &dyn Array, - eval_mode: EvalMode, - from_type: &DataType, - to_type: &DataType, - ) -> SparkResult { - match (from_type, to_type) { - (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!( - array, - eval_mode, - Float32Array, - Int8Array, - f32, - i8, - "FLOAT", - "TINYINT", - "{:e}" - ), - (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!( - array, - eval_mode, - Float32Array, - Int16Array, - f32, - i16, - "FLOAT", - "SMALLINT", - "{:e}" - ), - (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!( - array, - eval_mode, - Float32Array, - Int32Array, - f32, - i32, - "FLOAT", - "INT", - i32::MAX, - "{:e}" - ), - (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!( - array, - eval_mode, - Float32Array, - Int64Array, - f32, - i64, - "FLOAT", - "BIGINT", - i64::MAX, - "{:e}" - ), - (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!( - array, - eval_mode, - Float64Array, - Int8Array, - f64, - i8, - "DOUBLE", - "TINYINT", - "{:e}D" - ), - (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!( - array, - eval_mode, - Float64Array, - Int16Array, - f64, - i16, - "DOUBLE", - "SMALLINT", - "{:e}D" - ), - (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!( +fn spark_cast_nonintegral_numeric_to_integral( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, +) -> SparkResult { + match (from_type, to_type) { + (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!( + array, + eval_mode, + Float32Array, + Int8Array, + f32, + i8, + "FLOAT", + "TINYINT", + "{:e}" + ), + (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!( + array, + eval_mode, + Float32Array, + Int16Array, + f32, + i16, + "FLOAT", + "SMALLINT", + "{:e}" + ), + (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!( + array, + eval_mode, + Float32Array, + Int32Array, + f32, + i32, + "FLOAT", + "INT", + i32::MAX, + "{:e}" + ), + (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!( + array, + eval_mode, + Float32Array, + Int64Array, + f32, + i64, + "FLOAT", + "BIGINT", + i64::MAX, + "{:e}" + ), + (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!( + array, + eval_mode, + Float64Array, + Int8Array, + f64, + i8, + "DOUBLE", + "TINYINT", + "{:e}D" + ), + (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!( + array, + eval_mode, + Float64Array, + Int16Array, + f64, + i16, + "DOUBLE", + "SMALLINT", + "{:e}D" + ), + (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!( + array, + eval_mode, + Float64Array, + Int32Array, + f64, + i32, + "DOUBLE", + "INT", + i32::MAX, + "{:e}D" + ), + (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!( + array, + eval_mode, + Float64Array, + Int64Array, + f64, + i64, + "DOUBLE", + "BIGINT", + i64::MAX, + "{:e}D" + ), + (DataType::Decimal128(precision, scale), DataType::Int8) => { + cast_decimal_to_int16_down!( + array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int16) => { + cast_decimal_to_int16_down!( + array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int32) => { + cast_decimal_to_int32_up!( array, eval_mode, - Float64Array, Int32Array, - f64, i32, - "DOUBLE", "INT", i32::MAX, - "{:e}D" - ), - (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!( + *precision, + *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int64) => { + cast_decimal_to_int32_up!( array, eval_mode, - Float64Array, Int64Array, - f64, i64, - "DOUBLE", "BIGINT", i64::MAX, - "{:e}D" - ), - (DataType::Decimal128(precision, scale), DataType::Int8) => { - cast_decimal_to_int16_down!( - array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale - ) - } - (DataType::Decimal128(precision, scale), DataType::Int16) => { - cast_decimal_to_int16_down!( - array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale - ) - } - (DataType::Decimal128(precision, scale), DataType::Int32) => { - cast_decimal_to_int32_up!( - array, - eval_mode, - Int32Array, - i32, - "INT", - i32::MAX, - *precision, - *scale - ) - } - (DataType::Decimal128(precision, scale), DataType::Int64) => { - cast_decimal_to_int32_up!( - array, - eval_mode, - Int64Array, - i64, - "BIGINT", - i64::MAX, - *precision, - *scale - ) - } - _ => unreachable!( - "{}", - format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}") - ), + *precision, + *scale + ) } + _ => unreachable!( + "{}", + format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}") + ), } } @@ -1294,17 +1312,7 @@ impl PhysicalExpr for Cast { fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { let arg = self.child.evaluate(batch)?; - match arg { - ColumnarValue::Array(array) => Ok(ColumnarValue::Array(self.cast_array(array)?)), - ColumnarValue::Scalar(scalar) => { - // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for - // some cases e.g., scalar subquery, Spark will not fold it, so we need to handle it - // here. - let array = scalar.to_array()?; - let scalar = ScalarValue::try_from_array(&self.cast_array(array)?, 0)?; - Ok(ColumnarValue::Scalar(scalar)) - } - } + spark_cast(arg, &self.data_type, self.eval_mode, self.timezone.clone()) } fn children(&self) -> Vec<&Arc> { @@ -1660,7 +1668,7 @@ fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult> /// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify /// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in /// expressions/cast.rs, so it can be still Dictionary. -fn spark_cast(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef { +fn spark_cast_postprocess(array: ArrayRef, from_type: &DataType, to_type: &DataType) -> ArrayRef { match (from_type, to_type) { (DataType::Timestamp(_, _), DataType::Int64) => { // See Spark's `Cast` expression @@ -1739,8 +1747,6 @@ mod tests { use arrow_array::StringArray; use arrow_schema::TimeUnit; - use datafusion_physical_expr::expressions::Column; - use super::*; #[test] @@ -1819,18 +1825,14 @@ mod tests { ])); let dict_array = Arc::new(DictionaryArray::new(keys, values)); - // prepare cast expression let timezone = "UTC".to_string(); - let expr = Arc::new(Column::new("a", 0)); // this is not used by the test - let cast = Cast::new( - expr, - DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())), + // test casting string dictionary array to timestamp array + let result = cast_array( + dict_array, + &DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())), EvalMode::Legacy, timezone.clone(), - ); - - // test casting string dictionary array to timestamp array - let result = cast.cast_array(dict_array)?; + )?; assert_eq!( *result.data_type(), DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into())) @@ -1912,8 +1914,7 @@ mod tests { Some("2020-01-01T"), ])); - let result = - Cast::cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap(); + let result = cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap(); let date32_array = result .as_any() @@ -1939,7 +1940,7 @@ mod tests { for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { let result = - Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode) + cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode) .unwrap(); let date32_array = result @@ -1971,7 +1972,7 @@ mod tests { for eval_mode in &[EvalMode::Legacy, EvalMode::Try] { let result = - Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode) + cast_string_to_date(&array_with_invalid_date, &DataType::Date32, *eval_mode) .unwrap(); let date32_array = result @@ -1995,7 +1996,7 @@ mod tests { } let result = - Cast::cast_string_to_date(&array_with_invalid_date, &DataType::Date32, EvalMode::Ansi); + cast_string_to_date(&array_with_invalid_date, &DataType::Date32, EvalMode::Ansi); match result { Err(e) => assert!( e.to_string().contains( @@ -2035,26 +2036,24 @@ mod tests { fn test_cast_unsupported_timestamp_to_date() { // Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported let timestamps: PrimitiveArray = vec![i64::MAX].into(); - let cast = Cast::new( - Arc::new(Column::new("a", 0)), - DataType::Date32, + let result = cast_array( + Arc::new(timestamps.with_timezone("Europe/Copenhagen")), + &DataType::Date32, EvalMode::Legacy, "UTC".to_owned(), ); - let result = cast.cast_array(Arc::new(timestamps.with_timezone("Europe/Copenhagen"))); assert!(result.is_err()) } #[test] fn test_cast_invalid_timezone() { let timestamps: PrimitiveArray = vec![i64::MAX].into(); - let cast = Cast::new( - Arc::new(Column::new("a", 0)), - DataType::Date32, + let result = cast_array( + Arc::new(timestamps.with_timezone("Europe/Copenhagen")), + &DataType::Date32, EvalMode::Legacy, "Not a valid timezone".to_owned(), ); - let result = cast.cast_array(Arc::new(timestamps.with_timezone("Europe/Copenhagen"))); assert!(result.is_err()) } } diff --git a/src/lib.rs b/src/lib.rs index 336201f4846e..22628978d5b5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,7 +24,7 @@ mod temporal; pub mod timezone; pub mod utils; -pub use cast::Cast; +pub use cast::{spark_cast, Cast}; pub use error::{SparkError, SparkResult}; pub use if_expr::IfExpr; pub use temporal::{DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, TimestampTruncExpr}; From 2a4dc7b1ed6a17f21df9aae2f59de2460921714d Mon Sep 17 00:00:00 2001 From: Arttu Date: Sun, 28 Jul 2024 15:44:32 +0200 Subject: [PATCH 13/68] chore: move scalar_funcs into spark-expr (#712) --- Cargo.toml | 7 +- src/lib.rs | 3 + src/scalar_funcs.rs | 533 ++++++++++++++++++++ src/scalar_funcs/chr.rs | 125 +++++ src/scalar_funcs/hash_expressions.rs | 162 ++++++ src/scalar_funcs/hex.rs | 296 +++++++++++ src/scalar_funcs/unhex.rs | 258 ++++++++++ src/spark_hash.rs | 708 +++++++++++++++++++++++++++ src/xxhash64.rs | 190 +++++++ 9 files changed, 2280 insertions(+), 2 deletions(-) create mode 100644 src/scalar_funcs.rs create mode 100644 src/scalar_funcs/chr.rs create mode 100644 src/scalar_funcs/hash_expressions.rs create mode 100644 src/scalar_funcs/hex.rs create mode 100644 src/scalar_funcs/unhex.rs create mode 100644 src/spark_hash.rs create mode 100644 src/xxhash64.rs diff --git a/Cargo.toml b/Cargo.toml index aa4fcfc5f022..a535a2b817e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -41,10 +41,13 @@ chrono-tz = { workspace = true } num = { workspace = true } regex = { workspace = true } thiserror = { workspace = true } +unicode-segmentation = "1.11.0" [dev-dependencies] +arrow-data = {workspace = true} criterion = "0.5.1" -rand = "0.8.5" +rand = { workspace = true} +twox-hash = "1.6.3" [lib] name = "datafusion_comet_spark_expr" @@ -60,4 +63,4 @@ harness = false [[bench]] name = "conditional" -harness = false \ No newline at end of file +harness = false diff --git a/src/lib.rs b/src/lib.rs index 22628978d5b5..14ab080b466c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,9 +20,12 @@ mod error; mod if_expr; mod kernels; +pub mod scalar_funcs; +pub mod spark_hash; mod temporal; pub mod timezone; pub mod utils; +mod xxhash64; pub use cast::{spark_cast, Cast}; pub use error::{SparkError, SparkResult}; diff --git a/src/scalar_funcs.rs b/src/scalar_funcs.rs new file mode 100644 index 000000000000..c50b98bafea4 --- /dev/null +++ b/src/scalar_funcs.rs @@ -0,0 +1,533 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{cmp::min, sync::Arc}; + +use arrow::{ + array::{ + ArrayRef, AsArray, Decimal128Builder, Float32Array, Float64Array, GenericStringArray, + Int16Array, Int32Array, Int64Array, Int64Builder, Int8Array, OffsetSizeTrait, + }, + datatypes::{validate_decimal_precision, Decimal128Type, Int64Type}, +}; +use arrow_array::{Array, ArrowNativeTypeOp, BooleanArray, Decimal128Array}; +use arrow_schema::DataType; +use datafusion::{functions::math::round::round, physical_plan::ColumnarValue}; +use datafusion_common::{ + cast::as_generic_string_array, exec_err, internal_err, DataFusionError, + Result as DataFusionResult, ScalarValue, +}; +use num::{ + integer::{div_ceil, div_floor}, + BigInt, Signed, ToPrimitive, +}; +use unicode_segmentation::UnicodeSegmentation; + +mod unhex; +pub use unhex::spark_unhex; + +mod hex; +pub use hex::spark_hex; + +mod chr; +pub use chr::SparkChrFunc; + +pub mod hash_expressions; +// exposed for benchmark only +pub use hash_expressions::{spark_murmur3_hash, spark_xxhash64}; + +#[inline] +fn get_precision_scale(data_type: &DataType) -> (u8, i8) { + let DataType::Decimal128(precision, scale) = data_type else { + unreachable!() + }; + (*precision, *scale) +} + +macro_rules! downcast_compute_op { + ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => { + let res: $RESULT = + arrow::compute::kernels::arity::unary(array, |x| x.$FUNC() as i64); + Ok(Arc::new(res)) + } + _ => Err(DataFusionError::Internal(format!( + "Invalid data type for {}", + $NAME + ))), + } + }}; +} + +/// `ceil` function that simulates Spark `ceil` expression +pub fn spark_ceil( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + let value = &args[0]; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float32 => { + let result = downcast_compute_op!(array, "ceil", ceil, Float32Array, Int64Array); + Ok(ColumnarValue::Array(result?)) + } + DataType::Float64 => { + let result = downcast_compute_op!(array, "ceil", ceil, Float64Array, Int64Array); + Ok(ColumnarValue::Array(result?)) + } + DataType::Int64 => { + let result = array.as_any().downcast_ref::().unwrap(); + Ok(ColumnarValue::Array(Arc::new(result.clone()))) + } + DataType::Decimal128(_, scale) if *scale > 0 => { + let f = decimal_ceil_f(scale); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_array(array, precision, scale, &f) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function ceil", + other, + ))), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + a.map(|x| x.ceil() as i64), + ))), + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + a.map(|x| x.ceil() as i64), + ))), + ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))), + ScalarValue::Decimal128(a, _, scale) if *scale > 0 => { + let f = decimal_ceil_f(scale); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_scalar(a, precision, scale, &f) + } + _ => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function ceil", + value.data_type(), + ))), + }, + } +} + +/// `floor` function that simulates Spark `floor` expression +pub fn spark_floor( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + let value = &args[0]; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float32 => { + let result = downcast_compute_op!(array, "floor", floor, Float32Array, Int64Array); + Ok(ColumnarValue::Array(result?)) + } + DataType::Float64 => { + let result = downcast_compute_op!(array, "floor", floor, Float64Array, Int64Array); + Ok(ColumnarValue::Array(result?)) + } + DataType::Int64 => { + let result = array.as_any().downcast_ref::().unwrap(); + Ok(ColumnarValue::Array(Arc::new(result.clone()))) + } + DataType::Decimal128(_, scale) if *scale > 0 => { + let f = decimal_floor_f(scale); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_array(array, precision, scale, &f) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function floor", + other, + ))), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + a.map(|x| x.floor() as i64), + ))), + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + a.map(|x| x.floor() as i64), + ))), + ScalarValue::Int64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))), + ScalarValue::Decimal128(a, _, scale) if *scale > 0 => { + let f = decimal_floor_f(scale); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_scalar(a, precision, scale, &f) + } + _ => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function floor", + value.data_type(), + ))), + }, + } +} + +/// Spark-compatible `UnscaledValue` expression (internal to Spark optimizer) +pub fn spark_unscaled_value(args: &[ColumnarValue]) -> DataFusionResult { + match &args[0] { + ColumnarValue::Scalar(v) => match v { + ScalarValue::Decimal128(d, _, _) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + d.map(|n| n as i64), + ))), + dt => internal_err!("Expected Decimal128 but found {dt:}"), + }, + ColumnarValue::Array(a) => { + let arr = a.as_primitive::(); + let mut result = Int64Builder::new(); + for v in arr.into_iter() { + result.append_option(v.map(|v| v as i64)); + } + Ok(ColumnarValue::Array(Arc::new(result.finish()))) + } + } +} + +/// Spark-compatible `MakeDecimal` expression (internal to Spark optimizer) +pub fn spark_make_decimal( + args: &[ColumnarValue], + data_type: &DataType, +) -> DataFusionResult { + let (precision, scale) = get_precision_scale(data_type); + match &args[0] { + ColumnarValue::Scalar(v) => match v { + ScalarValue::Int64(n) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + long_to_decimal(n, precision), + precision, + scale, + ))), + sv => internal_err!("Expected Int64 but found {sv:?}"), + }, + ColumnarValue::Array(a) => { + let arr = a.as_primitive::(); + let mut result = Decimal128Builder::new(); + for v in arr.into_iter() { + result.append_option(long_to_decimal(&v, precision)) + } + let result_type = DataType::Decimal128(precision, scale); + + Ok(ColumnarValue::Array(Arc::new( + result.finish().with_data_type(result_type), + ))) + } + } +} + +/// Convert the input long to decimal with the given maximum precision. If overflows, returns null +/// instead. +#[inline] +fn long_to_decimal(v: &Option, precision: u8) -> Option { + match v { + Some(v) if validate_decimal_precision(*v as i128, precision).is_ok() => Some(*v as i128), + _ => None, + } +} + +#[inline] +fn decimal_ceil_f(scale: &i8) -> impl Fn(i128) -> i128 { + let div = 10_i128.pow_wrapping(*scale as u32); + move |x: i128| div_ceil(x, div) +} + +#[inline] +fn decimal_floor_f(scale: &i8) -> impl Fn(i128) -> i128 { + let div = 10_i128.pow_wrapping(*scale as u32); + move |x: i128| div_floor(x, div) +} + +// Spark uses BigDecimal. See RoundBase implementation in Spark. Instead, we do the same by +// 1) add the half of divisor, 2) round down by division, 3) adjust precision by multiplication +#[inline] +fn decimal_round_f(scale: &i8, point: &i64) -> Box i128> { + if *point < 0 { + if let Some(div) = 10_i128.checked_pow((-(*point) as u32) + (*scale as u32)) { + let half = div / 2; + let mul = 10_i128.pow_wrapping((-(*point)) as u32); + // i128 can hold 39 digits of a base 10 number, adding half will not cause overflow + Box::new(move |x: i128| (x + x.signum() * half) / div * mul) + } else { + Box::new(move |_: i128| 0) + } + } else { + let div = 10_i128.pow_wrapping((*scale as u32) - min(*scale as u32, *point as u32)); + let half = div / 2; + Box::new(move |x: i128| (x + x.signum() * half) / div) + } +} + +#[inline] +fn make_decimal_array( + array: &ArrayRef, + precision: u8, + scale: i8, + f: &dyn Fn(i128) -> i128, +) -> Result { + let array = array.as_primitive::(); + let result: Decimal128Array = arrow::compute::kernels::arity::unary(array, f); + let result = result.with_data_type(DataType::Decimal128(precision, scale)); + Ok(ColumnarValue::Array(Arc::new(result))) +} + +#[inline] +fn make_decimal_scalar( + a: &Option, + precision: u8, + scale: i8, + f: &dyn Fn(i128) -> i128, +) -> Result { + let result = ScalarValue::Decimal128(a.map(f), precision, scale); + Ok(ColumnarValue::Scalar(result)) +} + +macro_rules! integer_round { + ($X:expr, $DIV:expr, $HALF:expr) => {{ + let rem = $X % $DIV; + if rem <= -$HALF { + ($X - rem).sub_wrapping($DIV) + } else if rem >= $HALF { + ($X - rem).add_wrapping($DIV) + } else { + $X - rem + } + }}; +} + +macro_rules! round_integer_array { + ($ARRAY:expr, $POINT:expr, $TYPE:ty, $NATIVE:ty) => {{ + let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap(); + let ten: $NATIVE = 10; + let result: $TYPE = if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) { + let half = div / 2; + arrow::compute::kernels::arity::unary(array, |x| integer_round!(x, div, half)) + } else { + arrow::compute::kernels::arity::unary(array, |_| 0) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +macro_rules! round_integer_scalar { + ($SCALAR:expr, $POINT:expr, $TYPE:expr, $NATIVE:ty) => {{ + let ten: $NATIVE = 10; + if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) { + let half = div / 2; + Ok(ColumnarValue::Scalar($TYPE( + $SCALAR.map(|x| integer_round!(x, div, half)), + ))) + } else { + Ok(ColumnarValue::Scalar($TYPE(Some(0)))) + } + }}; +} + +/// `round` function that simulates Spark `round` expression +pub fn spark_round( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + let value = &args[0]; + let point = &args[1]; + let ColumnarValue::Scalar(ScalarValue::Int64(Some(point))) = point else { + return internal_err!("Invalid point argument for Round(): {:#?}", point); + }; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Int64 if *point < 0 => round_integer_array!(array, point, Int64Array, i64), + DataType::Int32 if *point < 0 => round_integer_array!(array, point, Int32Array, i32), + DataType::Int16 if *point < 0 => round_integer_array!(array, point, Int16Array, i16), + DataType::Int8 if *point < 0 => round_integer_array!(array, point, Int8Array, i8), + DataType::Decimal128(_, scale) if *scale > 0 => { + let f = decimal_round_f(scale, point); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_array(array, precision, scale, &f) + } + DataType::Float32 | DataType::Float64 => { + Ok(ColumnarValue::Array(round(&[array.clone()])?)) + } + dt => exec_err!("Not supported datatype for ROUND: {dt}"), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Int64(a) if *point < 0 => { + round_integer_scalar!(a, point, ScalarValue::Int64, i64) + } + ScalarValue::Int32(a) if *point < 0 => { + round_integer_scalar!(a, point, ScalarValue::Int32, i32) + } + ScalarValue::Int16(a) if *point < 0 => { + round_integer_scalar!(a, point, ScalarValue::Int16, i16) + } + ScalarValue::Int8(a) if *point < 0 => { + round_integer_scalar!(a, point, ScalarValue::Int8, i8) + } + ScalarValue::Decimal128(a, _, scale) if *scale >= 0 => { + let f = decimal_round_f(scale, point); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_scalar(a, precision, scale, &f) + } + ScalarValue::Float32(_) | ScalarValue::Float64(_) => Ok(ColumnarValue::Scalar( + ScalarValue::try_from_array(&round(&[a.to_array()?])?, 0)?, + )), + dt => exec_err!("Not supported datatype for ROUND: {dt}"), + }, + } +} + +/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length +pub fn spark_rpad(args: &[ColumnarValue]) -> Result { + match args { + [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => { + match args[0].data_type() { + DataType::Utf8 => spark_rpad_internal::(array, *length), + DataType::LargeUtf8 => spark_rpad_internal::(array, *length), + // TODO: handle Dictionary types + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function rpad", + ))), + } + } + other => Err(DataFusionError::Internal(format!( + "Unsupported arguments {other:?} for function rpad", + ))), + } +} + +fn spark_rpad_internal( + array: &ArrayRef, + length: i32, +) -> Result { + let string_array = as_generic_string_array::(array)?; + + let result = string_array + .iter() + .map(|string| match string { + Some(string) => { + let length = if length < 0 { 0 } else { length as usize }; + if length == 0 { + Ok(Some("".to_string())) + } else { + let graphemes = string.graphemes(true).collect::>(); + if length < graphemes.len() { + Ok(Some(string.to_string())) + } else { + let mut s = string.to_string(); + s.push_str(" ".repeat(length - graphemes.len()).as_str()); + Ok(Some(s)) + } + } + } + _ => Ok(None), + }) + .collect::, DataFusionError>>()?; + Ok(ColumnarValue::Array(Arc::new(result))) +} + +// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3). +// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to +// get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since +// both s2 and s3 are 38 at max., s1 is 77 at max. DataFusion division cannot handle such scale > +// Decimal256Type::MAX_SCALE. Therefore, we need to implement this decimal division using BigInt. +pub fn spark_decimal_div( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + let left = &args[0]; + let right = &args[1]; + let (p3, s3) = get_precision_scale(data_type); + + let (left, right): (ArrayRef, ArrayRef) = match (left, right) { + (ColumnarValue::Array(l), ColumnarValue::Array(r)) => (l.clone(), r.clone()), + (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => { + (l.to_array_of_size(r.len())?, r.clone()) + } + (ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => { + (l.clone(), r.to_array_of_size(l.len())?) + } + (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => (l.to_array()?, r.to_array()?), + }; + let left = left.as_primitive::(); + let right = right.as_primitive::(); + let (_, s1) = get_precision_scale(left.data_type()); + let (_, s2) = get_precision_scale(right.data_type()); + + let ten = BigInt::from(10); + let l_exp = ((s2 + s3 + 1) as u32).saturating_sub(s1 as u32); + let r_exp = (s1 as u32).saturating_sub((s2 + s3 + 1) as u32); + let l_mul = ten.pow(l_exp); + let r_mul = ten.pow(r_exp); + let five = BigInt::from(5); + let zero = BigInt::from(0); + let result: Decimal128Array = arrow::compute::kernels::arity::binary(left, right, |l, r| { + let l = BigInt::from(l) * &l_mul; + let r = BigInt::from(r) * &r_mul; + let div = if r.eq(&zero) { zero.clone() } else { &l / &r }; + let res = if div.is_negative() { + div - &five + } else { + div + &five + } / &ten; + res.to_i128().unwrap_or(i128::MAX) + })?; + let result = result.with_data_type(DataType::Decimal128(p3, s3)); + Ok(ColumnarValue::Array(Arc::new(result))) +} + +/// Spark-compatible `isnan` expression +pub fn spark_isnan(args: &[ColumnarValue]) -> Result { + fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue { + match is_nan.nulls() { + Some(nulls) => { + let is_not_null = nulls.inner(); + ColumnarValue::Array(Arc::new(BooleanArray::new( + is_nan.values() & is_not_null, + None, + ))) + } + None => ColumnarValue::Array(Arc::new(is_nan)), + } + } + let value = &args[0]; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float64 => { + let array = array.as_any().downcast_ref::().unwrap(); + let is_nan = BooleanArray::from_unary(array, |x| x.is_nan()); + Ok(set_nulls_to_false(is_nan)) + } + DataType::Float32 => { + let array = array.as_any().downcast_ref::().unwrap(); + let is_nan = BooleanArray::from_unary(array, |x| x.is_nan()); + Ok(set_nulls_to_false(is_nan)) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function isnan", + other, + ))), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( + a.map(|x| x.is_nan()).unwrap_or(false), + )))), + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean(Some( + a.map(|x| x.is_nan()).unwrap_or(false), + )))), + _ => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function isnan", + value.data_type(), + ))), + }, + } +} diff --git a/src/scalar_funcs/chr.rs b/src/scalar_funcs/chr.rs new file mode 100644 index 000000000000..5de59f9f27ca --- /dev/null +++ b/src/scalar_funcs/chr.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::{ArrayRef, StringArray}, + datatypes::{ + DataType, + DataType::{Int64, Utf8}, + }, +}; + +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{cast::as_int64_array, exec_err, Result, ScalarValue}; + +fn chr(args: &[ArrayRef]) -> Result { + let integer_array = as_int64_array(&args[0])?; + + // first map is the iterator, second is for the `Option<_>` + let result = integer_array + .iter() + .map(|integer: Option| { + integer + .map(|integer| { + if integer < 0 { + return Ok("".to_string()); // Return empty string for negative integers + } + match core::char::from_u32((integer % 256) as u32) { + Some(ch) => Ok(ch.to_string()), + None => { + exec_err!("requested character not compatible for encoding.") + } + } + }) + .transpose() + }) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +/// Spark-compatible `chr` expression +#[derive(Debug)] +pub struct SparkChrFunc { + signature: Signature, +} + +impl Default for SparkChrFunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkChrFunc { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkChrFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "chr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + spark_chr(args) + } +} + +/// Returns the ASCII character having the binary equivalent to the input expression. +/// E.g., chr(65) = 'A'. +/// Compatible with Apache Spark's Chr function +fn spark_chr(args: &[ColumnarValue]) -> Result { + let array = args[0].clone(); + match array { + ColumnarValue::Array(array) => { + let array = chr(&[array])?; + Ok(ColumnarValue::Array(array)) + } + ColumnarValue::Scalar(ScalarValue::Int64(Some(value))) => { + if value < 0 { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "".to_string(), + )))) + } else { + match core::char::from_u32((value % 256) as u32) { + Some(ch) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + ch.to_string(), + )))), + None => exec_err!("requested character was incompatible for encoding."), + } + } + } + _ => exec_err!("The argument must be an Int64 array or scalar."), + } +} diff --git a/src/scalar_funcs/hash_expressions.rs b/src/scalar_funcs/hash_expressions.rs new file mode 100644 index 000000000000..1a403b9e3db1 --- /dev/null +++ b/src/scalar_funcs/hash_expressions.rs @@ -0,0 +1,162 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::scalar_funcs::hex::hex_strings; +use crate::spark_hash::{create_murmur3_hashes, create_xxhash64_hashes}; + +use arrow_array::{ArrayRef, Int32Array, Int64Array, StringArray}; +use datafusion::functions::crypto::{sha224, sha256, sha384, sha512}; +use datafusion_common::cast::as_binary_array; +use datafusion_common::{exec_err, internal_err, DataFusionError, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarFunctionImplementation}; +use std::sync::Arc; + +/// Spark compatible murmur3 hash (just `hash` in Spark) in vectorized execution fashion +pub fn spark_murmur3_hash(args: &[ColumnarValue]) -> Result { + let length = args.len(); + let seed = &args[length - 1]; + match seed { + ColumnarValue::Scalar(ScalarValue::Int32(Some(seed))) => { + // iterate over the arguments to find out the length of the array + let num_rows = args[0..args.len() - 1] + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + let mut hashes: Vec = vec![0_u32; num_rows]; + hashes.fill(*seed as u32); + let arrays = args[0..args.len() - 1] + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => { + scalar.clone().to_array_of_size(num_rows).unwrap() + } + }) + .collect::>(); + create_murmur3_hashes(&arrays, &mut hashes)?; + if num_rows == 1 { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some( + hashes[0] as i32, + )))) + } else { + let hashes: Vec = hashes.into_iter().map(|x| x as i32).collect(); + Ok(ColumnarValue::Array(Arc::new(Int32Array::from(hashes)))) + } + } + _ => { + internal_err!( + "The seed of function murmur3_hash must be an Int32 scalar value, but got: {:?}.", + seed + ) + } + } +} + +/// Spark compatible xxhash64 in vectorized execution fashion +pub fn spark_xxhash64(args: &[ColumnarValue]) -> Result { + let length = args.len(); + let seed = &args[length - 1]; + match seed { + ColumnarValue::Scalar(ScalarValue::Int64(Some(seed))) => { + // iterate over the arguments to find out the length of the array + let num_rows = args[0..args.len() - 1] + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + let mut hashes: Vec = vec![0_u64; num_rows]; + hashes.fill(*seed as u64); + let arrays = args[0..args.len() - 1] + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => array.clone(), + ColumnarValue::Scalar(scalar) => { + scalar.clone().to_array_of_size(num_rows).unwrap() + } + }) + .collect::>(); + create_xxhash64_hashes(&arrays, &mut hashes)?; + if num_rows == 1 { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some( + hashes[0] as i64, + )))) + } else { + let hashes: Vec = hashes.into_iter().map(|x| x as i64).collect(); + Ok(ColumnarValue::Array(Arc::new(Int64Array::from(hashes)))) + } + } + _ => { + internal_err!( + "The seed of function xxhash64 must be an Int64 scalar value, but got: {:?}.", + seed + ) + } + } +} + +/// `sha224` function that simulates Spark's `sha2` expression with bit width 224 +pub fn spark_sha224(args: &[ColumnarValue]) -> Result { + wrap_digest_result_as_hex_string(args, sha224().fun()) +} + +/// `sha256` function that simulates Spark's `sha2` expression with bit width 0 or 256 +pub fn spark_sha256(args: &[ColumnarValue]) -> Result { + wrap_digest_result_as_hex_string(args, sha256().fun()) +} + +/// `sha384` function that simulates Spark's `sha2` expression with bit width 384 +pub fn spark_sha384(args: &[ColumnarValue]) -> Result { + wrap_digest_result_as_hex_string(args, sha384().fun()) +} + +/// `sha512` function that simulates Spark's `sha2` expression with bit width 512 +pub fn spark_sha512(args: &[ColumnarValue]) -> Result { + wrap_digest_result_as_hex_string(args, sha512().fun()) +} + +// Spark requires hex string as the result of sha2 functions, we have to wrap the +// result of digest functions as hex string +fn wrap_digest_result_as_hex_string( + args: &[ColumnarValue], + digest: ScalarFunctionImplementation, +) -> Result { + let value = digest(args)?; + match value { + ColumnarValue::Array(array) => { + let binary_array = as_binary_array(&array)?; + let string_array: StringArray = binary_array + .iter() + .map(|opt| opt.map(hex_strings::<_>)) + .collect(); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar( + ScalarValue::Utf8(opt.map(hex_strings::<_>)), + )), + _ => { + exec_err!( + "digest function should return binary value, but got: {:?}", + value.data_type() + ) + } + } +} diff --git a/src/scalar_funcs/hex.rs b/src/scalar_funcs/hex.rs new file mode 100644 index 000000000000..e572ba5ef39a --- /dev/null +++ b/src/scalar_funcs/hex.rs @@ -0,0 +1,296 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::{ + array::{as_dictionary_array, as_largestring_array, as_string_array}, + datatypes::Int32Type, +}; +use arrow_array::StringArray; +use arrow_schema::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{ + cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, + exec_err, DataFusionError, +}; +use std::fmt::Write; + +fn hex_int64(num: i64) -> String { + format!("{:X}", num) +} + +#[inline(always)] +fn hex_encode>(data: T, lower_case: bool) -> String { + let mut s = String::with_capacity(data.as_ref().len() * 2); + if lower_case { + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02x}").unwrap(); + } + } else { + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02X}").unwrap(); + } + } + s +} + +#[inline(always)] +pub(super) fn hex_strings>(data: T) -> String { + hex_encode(data, true) +} + +#[inline(always)] +fn hex_bytes>(bytes: T) -> Result { + let hex_string = hex_encode(bytes, false); + Ok(hex_string) +} + +/// Spark-compatible `hex` function +pub fn spark_hex(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return Err(DataFusionError::Internal( + "hex expects exactly one argument".to_string(), + )); + } + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Int64 => { + let array = as_int64_array(array)?; + + let hexed_array: StringArray = array.iter().map(|v| v.map(hex_int64)).collect(); + + Ok(ColumnarValue::Array(Arc::new(hexed_array))) + } + DataType::Utf8 => { + let array = as_string_array(array); + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::LargeUtf8 => { + let array = as_largestring_array(array); + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::Binary => { + let array = as_binary_array(array)?; + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::FixedSizeBinary(_) => { + let array = as_fixed_size_binary_array(array)?; + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::Dictionary(_, value_type) => { + let dict = as_dictionary_array::(&array); + + let values = match **value_type { + DataType::Int64 => as_int64_array(dict.values())? + .iter() + .map(|v| v.map(hex_int64)) + .collect::>(), + DataType::Utf8 => as_string_array(dict.values()) + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?, + DataType::Binary => as_binary_array(dict.values())? + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?, + _ => exec_err!( + "hex got an unexpected argument type: {:?}", + array.data_type() + )?, + }; + + let new_values: Vec> = dict + .keys() + .iter() + .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None)) + .collect(); + + let string_array_values = StringArray::from(new_values); + + Ok(ColumnarValue::Array(Arc::new(string_array_values))) + } + _ => exec_err!( + "hex got an unexpected argument type: {:?}", + array.data_type() + ), + }, + _ => exec_err!("native hex does not support scalar values at this time"), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::{ + array::{ + as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, StringBuilder, + StringDictionaryBuilder, + }, + datatypes::{Int32Type, Int64Type}, + }; + use arrow_array::{Int64Array, StringArray}; + use datafusion::logical_expr::ColumnarValue; + + #[test] + fn test_dictionary_hex_utf8() { + let mut input_builder = StringDictionaryBuilder::::new(); + input_builder.append_value("hi"); + input_builder.append_value("bye"); + input_builder.append_null(); + input_builder.append_value("rust"); + let input = input_builder.finish(); + + let mut string_builder = StringBuilder::new(); + string_builder.append_value("6869"); + string_builder.append_value("627965"); + string_builder.append_null(); + string_builder.append_value("72757374"); + let expected = string_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_dictionary_hex_int64() { + let mut input_builder = PrimitiveDictionaryBuilder::::new(); + input_builder.append_value(1); + input_builder.append_value(2); + input_builder.append_null(); + input_builder.append_value(3); + let input = input_builder.finish(); + + let mut string_builder = StringBuilder::new(); + string_builder.append_value("1"); + string_builder.append_value("2"); + string_builder.append_null(); + string_builder.append_value("3"); + let expected = string_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_dictionary_hex_binary() { + let mut input_builder = BinaryDictionaryBuilder::::new(); + input_builder.append_value("1"); + input_builder.append_value("j"); + input_builder.append_null(); + input_builder.append_value("3"); + let input = input_builder.finish(); + + let mut expected_builder = StringBuilder::new(); + expected_builder.append_value("31"); + expected_builder.append_value("6A"); + expected_builder.append_null(); + expected_builder.append_value("33"); + let expected = expected_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_hex_int64() { + let num = 1234; + let hexed = super::hex_int64(num); + assert_eq!(hexed, "4D2".to_string()); + + let num = -1; + let hexed = super::hex_int64(num); + assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string()); + } + + #[test] + fn test_spark_hex_int64() { + let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]); + let columnar_value = ColumnarValue::Array(Arc::new(int_array)); + + let result = super::spark_hex(&[columnar_value]).unwrap(); + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let string_array = as_string_array(&result); + let expected_array = StringArray::from(vec![ + Some("1".to_string()), + Some("2".to_string()), + None, + Some("3".to_string()), + ]); + + assert_eq!(string_array, &expected_array); + } +} diff --git a/src/scalar_funcs/unhex.rs b/src/scalar_funcs/unhex.rs new file mode 100644 index 000000000000..9996392b63a4 --- /dev/null +++ b/src/scalar_funcs/unhex.rs @@ -0,0 +1,258 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::OffsetSizeTrait; +use arrow_schema::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{cast::as_generic_string_array, exec_err, DataFusionError, ScalarValue}; + +/// Helper function to convert a hex digit to a binary value. +fn unhex_digit(c: u8) -> Result { + match c { + b'0'..=b'9' => Ok(c - b'0'), + b'A'..=b'F' => Ok(10 + c - b'A'), + b'a'..=b'f' => Ok(10 + c - b'a'), + _ => Err(DataFusionError::Execution( + "Input to unhex_digit is not a valid hex digit".to_string(), + )), + } +} + +/// Convert a hex string to binary and store the result in `result`. Returns an error if the input +/// is not a valid hex string. +fn unhex(hex_str: &str, result: &mut Vec) -> Result<(), DataFusionError> { + let bytes = hex_str.as_bytes(); + + let mut i = 0; + + if (bytes.len() & 0x01) != 0 { + let v = unhex_digit(bytes[0])?; + + result.push(v); + i += 1; + } + + while i < bytes.len() { + let first = unhex_digit(bytes[i])?; + let second = unhex_digit(bytes[i + 1])?; + result.push((first << 4) | second); + + i += 2; + } + + Ok(()) +} + +fn spark_unhex_inner( + array: &ColumnarValue, + fail_on_error: bool, +) -> Result { + match array { + ColumnarValue::Array(array) => { + let string_array = as_generic_string_array::(array)?; + + let mut encoded = Vec::new(); + let mut builder = arrow::array::BinaryBuilder::new(); + + for item in string_array.iter() { + if let Some(s) = item { + if unhex(s, &mut encoded).is_ok() { + builder.append_value(encoded.as_slice()); + } else if fail_on_error { + return exec_err!("Input to unhex is not a valid hex string: {s}"); + } else { + builder.append_null(); + } + encoded.clear(); + } else { + builder.append_null(); + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))) => { + let mut encoded = Vec::new(); + + if unhex(string, &mut encoded).is_ok() { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(encoded)))) + } else if fail_on_error { + exec_err!("Input to unhex is not a valid hex string: {string}") + } else { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } + _ => { + exec_err!( + "The first argument must be a string scalar or array, but got: {:?}", + array + ) + } + } +} + +/// Spark-compatible `unhex` expression +pub fn spark_unhex(args: &[ColumnarValue]) -> Result { + if args.len() > 2 { + return exec_err!("unhex takes at most 2 arguments, but got: {}", args.len()); + } + + let val_to_unhex = &args[0]; + let fail_on_error = if args.len() == 2 { + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => *fail_on_error, + _ => { + return exec_err!( + "The second argument must be boolean scalar, but got: {:?}", + args[1] + ); + } + } + } else { + false + }; + + match val_to_unhex.data_type() { + DataType::Utf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + other => exec_err!( + "The first argument must be a Utf8 or LargeUtf8: {:?}", + other + ), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{BinaryBuilder, StringBuilder}; + use arrow_array::make_array; + use arrow_data::ArrayData; + use datafusion::logical_expr::ColumnarValue; + use datafusion_common::ScalarValue; + + use super::unhex; + + #[test] + fn test_spark_unhex_null() -> Result<(), Box> { + let input = ArrayData::new_null(&arrow_schema::DataType::Utf8, 2); + let output = ArrayData::new_null(&arrow_schema::DataType::Binary, 2); + + let input = ColumnarValue::Array(Arc::new(make_array(input))); + let expected = ColumnarValue::Array(Arc::new(make_array(output))); + + let result = super::spark_unhex(&[input])?; + + match (result, expected) { + (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => { + assert_eq!(*result, *expected); + Ok(()) + } + _ => Err("Unexpected result type".into()), + } + } + + #[test] + fn test_partial_error() -> Result<(), Box> { + let mut input = StringBuilder::new(); + + input.append_value("1CGG"); // 1C is ok, but GG is invalid + input.append_value("537061726B2053514C"); // followed by valid + + let input = ColumnarValue::Array(Arc::new(input.finish())); + let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); + + let result = super::spark_unhex(&[input, fail_on_error])?; + + let mut expected = BinaryBuilder::new(); + expected.append_null(); + expected.append_value("Spark SQL".as_bytes()); + + match (result, ColumnarValue::Array(Arc::new(expected.finish()))) { + (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => { + assert_eq!(*result, *expected); + + Ok(()) + } + _ => Err("Unexpected result type".into()), + } + } + + #[test] + fn test_unhex_valid() -> Result<(), Box> { + let mut result = Vec::new(); + + unhex("537061726B2053514C", &mut result)?; + let result_str = std::str::from_utf8(&result)?; + assert_eq!(result_str, "Spark SQL"); + result.clear(); + + unhex("1C", &mut result)?; + assert_eq!(result, vec![28]); + result.clear(); + + unhex("737472696E67", &mut result)?; + assert_eq!(result, "string".as_bytes()); + result.clear(); + + unhex("1", &mut result)?; + assert_eq!(result, vec![1]); + result.clear(); + + Ok(()) + } + + #[test] + fn test_odd_length() -> Result<(), Box> { + let mut result = Vec::new(); + + unhex("A1B", &mut result)?; + assert_eq!(result, vec![10, 27]); + result.clear(); + + unhex("0A1B", &mut result)?; + assert_eq!(result, vec![10, 27]); + result.clear(); + + Ok(()) + } + + #[test] + fn test_unhex_empty() { + let mut result = Vec::new(); + + // Empty hex string + unhex("", &mut result).unwrap(); + assert!(result.is_empty()); + } + + #[test] + fn test_unhex_invalid() { + let mut result = Vec::new(); + + // Invalid hex strings + assert!(unhex("##", &mut result).is_err()); + assert!(unhex("G123", &mut result).is_err()); + assert!(unhex("hello", &mut result).is_err()); + assert!(unhex("\0", &mut result).is_err()); + } +} diff --git a/src/spark_hash.rs b/src/spark_hash.rs new file mode 100644 index 000000000000..66a103a2ae27 --- /dev/null +++ b/src/spark_hash.rs @@ -0,0 +1,708 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This includes utilities for hashing and murmur3 hashing. + +use arrow::{ + compute::take, + datatypes::{ArrowNativeTypeOp, UInt16Type, UInt32Type, UInt64Type, UInt8Type}, +}; +use std::sync::Arc; + +use datafusion::{ + arrow::{ + array::*, + datatypes::{ + ArrowDictionaryKeyType, ArrowNativeType, DataType, Int16Type, Int32Type, Int64Type, + Int8Type, TimeUnit, + }, + }, + error::{DataFusionError, Result}, +}; + +use crate::xxhash64::spark_compatible_xxhash64; + +/// Spark-compatible murmur3 hash function +#[inline] +pub fn spark_compatible_murmur3_hash>(data: T, seed: u32) -> u32 { + #[inline] + fn mix_k1(mut k1: i32) -> i32 { + k1 = k1.mul_wrapping(0xcc9e2d51u32 as i32); + k1 = k1.rotate_left(15); + k1 = k1.mul_wrapping(0x1b873593u32 as i32); + k1 + } + + #[inline] + fn mix_h1(mut h1: i32, k1: i32) -> i32 { + h1 ^= k1; + h1 = h1.rotate_left(13); + h1 = h1.mul_wrapping(5).add_wrapping(0xe6546b64u32 as i32); + h1 + } + + #[inline] + fn fmix(mut h1: i32, len: i32) -> i32 { + h1 ^= len; + h1 ^= (h1 as u32 >> 16) as i32; + h1 = h1.mul_wrapping(0x85ebca6bu32 as i32); + h1 ^= (h1 as u32 >> 13) as i32; + h1 = h1.mul_wrapping(0xc2b2ae35u32 as i32); + h1 ^= (h1 as u32 >> 16) as i32; + h1 + } + + #[inline] + unsafe fn hash_bytes_by_int(data: &[u8], seed: u32) -> i32 { + // safety: data length must be aligned to 4 bytes + let mut h1 = seed as i32; + for i in (0..data.len()).step_by(4) { + let ints = data.as_ptr().add(i) as *const i32; + let mut half_word = ints.read_unaligned(); + if cfg!(target_endian = "big") { + half_word = half_word.reverse_bits(); + } + h1 = mix_h1(h1, mix_k1(half_word)); + } + h1 + } + let data = data.as_ref(); + let len = data.len(); + let len_aligned = len - len % 4; + + // safety: + // avoid boundary checking in performance critical codes. + // all operations are guaranteed to be safe + // data is &[u8] so we do not need to check for proper alignment + unsafe { + let mut h1 = if len_aligned > 0 { + hash_bytes_by_int(&data[0..len_aligned], seed) + } else { + seed as i32 + }; + + for i in len_aligned..len { + let half_word = *data.get_unchecked(i) as i8 as i32; + h1 = mix_h1(h1, mix_k1(half_word)); + } + fmix(h1, len as i32) as u32 + } +} + +macro_rules! hash_array { + ($array_type: ident, $column: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + if array.null_count() == 0 { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = $hash_method(&array.value(i), *hash); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = $hash_method(&array.value(i), *hash); + } + } + } + }; +} + +macro_rules! hash_array_boolean { + ($array_type: ident, $column: ident, $hash_input_type: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + if array.null_count() == 0 { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = $hash_method($hash_input_type::from(array.value(i)).to_le_bytes(), *hash); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = + $hash_method($hash_input_type::from(array.value(i)).to_le_bytes(), *hash); + } + } + } + }; +} + +macro_rules! hash_array_primitive { + ($array_type: ident, $column: ident, $ty: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + let values = array.values(); + + if array.null_count() == 0 { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); + } + } else { + for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() { + if !array.is_null(i) { + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); + } + } + } + }; +} + +macro_rules! hash_array_primitive_float { + ($array_type: ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + let values = array.values(); + + if array.null_count() == 0 { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression. + if *value == 0.0 && value.is_sign_negative() { + *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash); + } else { + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); + } + } + } else { + for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() { + if !array.is_null(i) { + // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression. + if *value == 0.0 && value.is_sign_negative() { + *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash); + } else { + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); + } + } + } + } + }; +} + +macro_rules! hash_array_decimal { + ($array_type:ident, $column: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + + if array.null_count() == 0 { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = $hash_method(array.value(i).to_le_bytes(), *hash); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = $hash_method(array.value(i).to_le_bytes(), *hash); + } + } + } + }; +} + +/// Hash the values in a dictionary array +fn create_hashes_dictionary( + array: &ArrayRef, + hashes_buffer: &mut [u32], + first_col: bool, +) -> Result<()> { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + if !first_col { + // unpack the dictionary array as each row may have a different hash input + let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), None)?; + create_murmur3_hashes(&[unpacked], hashes_buffer)?; + } else { + // For the first column, hash each dictionary value once, and then use + // that computed hash for each key value to avoid a potentially + // expensive redundant hashing for large dictionary elements (e.g. strings) + let dict_values = Arc::clone(dict_array.values()); + // same initial seed as Spark + let mut dict_hashes = vec![42; dict_values.len()]; + create_murmur3_hashes(&[dict_values], &mut dict_hashes)?; + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { + if let Some(key) = key { + let idx = key.to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert key value {:?} to usize in dictionary of type {:?}", + key, + dict_array.data_type() + )) + })?; + *hash = dict_hashes[idx] + } // no update for Null, consistent with other hashes + } + } + Ok(()) +} + +// Hash the values in a dictionary array using xxhash64 +fn create_xxhash64_hashes_dictionary( + array: &ArrayRef, + hashes_buffer: &mut [u64], + first_col: bool, +) -> Result<()> { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + if !first_col { + let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), None)?; + create_xxhash64_hashes(&[unpacked], hashes_buffer)?; + } else { + // Hash each dictionary value once, and then use that computed + // hash for each key value to avoid a potentially expensive + // redundant hashing for large dictionary elements (e.g. strings) + let dict_values = Arc::clone(dict_array.values()); + // same initial seed as Spark + let mut dict_hashes = vec![42u64; dict_values.len()]; + create_xxhash64_hashes(&[dict_values], &mut dict_hashes)?; + + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { + if let Some(key) = key { + let idx = key.to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert key value {:?} to usize in dictionary of type {:?}", + key, + dict_array.data_type() + )) + })?; + *hash = dict_hashes[idx] + } // no update for Null, consistent with other hashes + } + } + Ok(()) +} + +/// Creates hash values for every row, based on the values in the +/// columns. +/// +/// The number of rows to hash is determined by `hashes_buffer.len()`. +/// `hashes_buffer` should be pre-sized appropriately +/// +/// `hash_method` is the hash function to use. +/// `create_dictionary_hash_method` is the function to create hashes for dictionary arrays input. +macro_rules! create_hashes_internal { + ($arrays: ident, $hashes_buffer: ident, $hash_method: ident, $create_dictionary_hash_method: ident) => { + for (i, col) in $arrays.iter().enumerate() { + let first_col = i == 0; + match col.data_type() { + DataType::Boolean => { + hash_array_boolean!(BooleanArray, col, i32, $hashes_buffer, $hash_method); + } + DataType::Int8 => { + hash_array_primitive!(Int8Array, col, i32, $hashes_buffer, $hash_method); + } + DataType::Int16 => { + hash_array_primitive!(Int16Array, col, i32, $hashes_buffer, $hash_method); + } + DataType::Int32 => { + hash_array_primitive!(Int32Array, col, i32, $hashes_buffer, $hash_method); + } + DataType::Int64 => { + hash_array_primitive!(Int64Array, col, i64, $hashes_buffer, $hash_method); + } + DataType::Float32 => { + hash_array_primitive_float!( + Float32Array, + col, + f32, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Float64 => { + hash_array_primitive_float!( + Float64Array, + col, + f64, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Second, _) => { + hash_array_primitive!( + TimestampSecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + hash_array_primitive!( + TimestampMillisecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + hash_array_primitive!( + TimestampMicrosecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + hash_array_primitive!( + TimestampNanosecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Date32 => { + hash_array_primitive!(Date32Array, col, i32, $hashes_buffer, $hash_method); + } + DataType::Date64 => { + hash_array_primitive!(Date64Array, col, i64, $hashes_buffer, $hash_method); + } + DataType::Utf8 => { + hash_array!(StringArray, col, $hashes_buffer, $hash_method); + } + DataType::LargeUtf8 => { + hash_array!(LargeStringArray, col, $hashes_buffer, $hash_method); + } + DataType::Binary => { + hash_array!(BinaryArray, col, $hashes_buffer, $hash_method); + } + DataType::LargeBinary => { + hash_array!(LargeBinaryArray, col, $hashes_buffer, $hash_method); + } + DataType::FixedSizeBinary(_) => { + hash_array!(FixedSizeBinaryArray, col, $hashes_buffer, $hash_method); + } + DataType::Decimal128(_, _) => { + hash_array_decimal!(Decimal128Array, col, $hashes_buffer, $hash_method); + } + DataType::Dictionary(index_type, _) => match **index_type { + DataType::Int8 => { + $create_dictionary_hash_method::(col, $hashes_buffer, first_col)?; + } + DataType::Int16 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::Int32 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::Int64 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt8 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt16 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt32 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt64 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported dictionary type in hasher hashing: {}", + col.data_type(), + ))) + } + }, + _ => { + // This is internal because we should have caught this before. + return Err(DataFusionError::Internal(format!( + "Unsupported data type in hasher: {}", + col.data_type() + ))); + } + } + } + }; +} + +/// Creates hash values for every row, based on the values in the +/// columns. +/// +/// The number of rows to hash is determined by `hashes_buffer.len()`. +/// `hashes_buffer` should be pre-sized appropriately +pub fn create_murmur3_hashes<'a>( + arrays: &[ArrayRef], + hashes_buffer: &'a mut [u32], +) -> Result<&'a mut [u32]> { + create_hashes_internal!( + arrays, + hashes_buffer, + spark_compatible_murmur3_hash, + create_hashes_dictionary + ); + Ok(hashes_buffer) +} + +/// Creates xxhash64 hash values for every row, based on the values in the +/// columns. +/// +/// The number of rows to hash is determined by `hashes_buffer.len()`. +/// `hashes_buffer` should be pre-sized appropriately +pub fn create_xxhash64_hashes<'a>( + arrays: &[ArrayRef], + hashes_buffer: &'a mut [u64], +) -> Result<&'a mut [u64]> { + create_hashes_internal!( + arrays, + hashes_buffer, + spark_compatible_xxhash64, + create_xxhash64_hashes_dictionary + ); + Ok(hashes_buffer) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Float32Array, Float64Array}; + use std::sync::Arc; + + use super::{create_murmur3_hashes, create_xxhash64_hashes}; + use datafusion::arrow::array::{ArrayRef, Int32Array, Int64Array, Int8Array, StringArray}; + + macro_rules! test_hashes_internal { + ($hash_method: ident, $input: expr, $initial_seeds: expr, $expected: expr) => { + let i = $input; + let mut hashes = $initial_seeds.clone(); + $hash_method(&[i], &mut hashes).unwrap(); + assert_eq!(hashes, $expected); + }; + } + + macro_rules! test_hashes_with_nulls { + ($method: ident, $t: ty, $values: ident, $expected: ident, $seed_type: ty) => { + // copied before inserting nulls + let mut input_with_nulls = $values.clone(); + let mut expected_with_nulls = $expected.clone(); + // test before inserting nulls + let len = $values.len(); + let initial_seeds = vec![42 as $seed_type; len]; + let i = Arc::new(<$t>::from($values)) as ArrayRef; + test_hashes_internal!($method, i, initial_seeds, $expected); + + // test with nulls + let median = len / 2; + input_with_nulls.insert(0, None); + input_with_nulls.insert(median, None); + expected_with_nulls.insert(0, 42 as $seed_type); + expected_with_nulls.insert(median, 42 as $seed_type); + let len_with_nulls = len + 2; + let initial_seeds_with_nulls = vec![42 as $seed_type; len_with_nulls]; + let nullable_input = Arc::new(<$t>::from(input_with_nulls)) as ArrayRef; + test_hashes_internal!( + $method, + nullable_input, + initial_seeds_with_nulls, + expected_with_nulls + ); + }; + } + + fn test_murmur3_hash>> + 'static>( + values: Vec>, + expected: Vec, + ) { + test_hashes_with_nulls!(create_murmur3_hashes, T, values, expected, u32); + } + + fn test_xxhash64_hash>> + 'static>( + values: Vec>, + expected: Vec, + ) { + test_hashes_with_nulls!(create_xxhash64_hashes, T, values, expected, u64); + } + + #[test] + fn test_i8() { + test_murmur3_hash::( + vec![Some(1), Some(0), Some(-1), Some(i8::MAX), Some(i8::MIN)], + vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x43b4d8ed, 0x422a1365], + ); + test_xxhash64_hash::( + vec![Some(1), Some(0), Some(-1), Some(i8::MAX), Some(i8::MIN)], + vec![ + 0xa309b38455455929, + 0x3229fbc4681e48f3, + 0x1bfdda8861c06e45, + 0x77cc15d9f9f2cdc2, + 0x39bc22b9e94d81d0, + ], + ); + } + + #[test] + fn test_i32() { + test_murmur3_hash::( + vec![Some(1), Some(0), Some(-1), Some(i32::MAX), Some(i32::MIN)], + vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x07fb67e7, 0x2b1f0fc6], + ); + test_xxhash64_hash::( + vec![Some(1), Some(0), Some(-1), Some(i32::MAX), Some(i32::MIN)], + vec![ + 0xa309b38455455929, + 0x3229fbc4681e48f3, + 0x1bfdda8861c06e45, + 0x14f0ac009c21721c, + 0x1cc7cb8d034769cd, + ], + ); + } + + #[test] + fn test_i64() { + test_murmur3_hash::( + vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)], + vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 0xa05b5d7b, 0xcd1e64fb], + ); + test_xxhash64_hash::( + vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)], + vec![ + 0x9ed50fd59358d232, + 0xb71b47ebda15746c, + 0x358ae035bfb46fd2, + 0xd2f1c616ae7eb306, + 0x88608019c494c1f4, + ], + ); + } + + #[test] + fn test_f32() { + test_murmur3_hash::( + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999), + ], + vec![ + 0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 0xcbdc340f, 0xc0361c86, + ], + ); + test_xxhash64_hash::( + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999), + ], + vec![ + 0x9b92689757fcdbd, + 0x3229fbc4681e48f3, + 0x3229fbc4681e48f3, + 0xa2becc0e61bb3823, + 0x8f20ab82d4f3687f, + 0xdce4982d97f7ac4, + ], + ) + } + + #[test] + fn test_f64() { + test_murmur3_hash::( + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999), + ], + vec![ + 0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 0xb87e1595, 0xa0eef9f9, + ], + ); + + test_xxhash64_hash::( + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999), + ], + vec![ + 0xe1fd6e07fee8ad53, + 0xb71b47ebda15746c, + 0xb71b47ebda15746c, + 0x8cdde022746f8f1f, + 0x793c5c88d313eac7, + 0xc5e60e7b75d9b232, + ], + ) + } + + #[test] + fn test_str() { + let input = [ + "hello", "bar", "", "😁", "天地", "a", "ab", "abc", "abcd", "abcde", + ] + .iter() + .map(|s| Some(s.to_string())) + .collect::>>(); + let expected: Vec = vec![ + 3286402344, 2486176763, 142593372, 885025535, 2395000894, 1485273170, 0xfa37157b, + 1322437556, 0xe860e5cc, 814637928, + ]; + + test_murmur3_hash::(input.clone(), expected); + test_xxhash64_hash::( + input, + vec![ + 0xc3629e6318d53932, + 0xe7097b6a54378d8a, + 0x98b1582b0977e704, + 0xa80d9d5a6a523bd5, + 0xfcba5f61ac666c61, + 0x88e4fe59adf7b0cc, + 0x259dd873209a3fe3, + 0x13c1d910702770e6, + 0xa17b5eb5dc364dff, + 0xf241303e4a90f299, + ], + ) + } +} diff --git a/src/xxhash64.rs b/src/xxhash64.rs new file mode 100644 index 000000000000..f5a11f66cd7d --- /dev/null +++ b/src/xxhash64.rs @@ -0,0 +1,190 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! xxhash64 implementation + +const CHUNK_SIZE: usize = 32; + +const PRIME_1: u64 = 11_400_714_785_074_694_791; +const PRIME_2: u64 = 14_029_467_366_897_019_727; +const PRIME_3: u64 = 1_609_587_929_392_839_161; +const PRIME_4: u64 = 9_650_029_242_287_828_579; +const PRIME_5: u64 = 2_870_177_450_012_600_261; + +/// Custom implementation of xxhash64 based on code from https://github.com/shepmaster/twox-hash +/// but optimized for our use case by removing any intermediate buffering, which is +/// not required because we are operating on data that is already in memory. +#[inline] +pub(crate) fn spark_compatible_xxhash64>(data: T, seed: u64) -> u64 { + let data: &[u8] = data.as_ref(); + let length_bytes = data.len(); + + let mut v1 = seed.wrapping_add(PRIME_1).wrapping_add(PRIME_2); + let mut v2 = seed.wrapping_add(PRIME_2); + let mut v3 = seed; + let mut v4 = seed.wrapping_sub(PRIME_1); + + // process chunks of 32 bytes + let mut offset_u64_4 = 0; + let ptr_u64 = data.as_ptr() as *const u64; + unsafe { + while offset_u64_4 * CHUNK_SIZE + CHUNK_SIZE <= length_bytes { + v1 = ingest_one_number(v1, ptr_u64.add(offset_u64_4 * 4).read_unaligned().to_le()); + v2 = ingest_one_number( + v2, + ptr_u64.add(offset_u64_4 * 4 + 1).read_unaligned().to_le(), + ); + v3 = ingest_one_number( + v3, + ptr_u64.add(offset_u64_4 * 4 + 2).read_unaligned().to_le(), + ); + v4 = ingest_one_number( + v4, + ptr_u64.add(offset_u64_4 * 4 + 3).read_unaligned().to_le(), + ); + offset_u64_4 += 1; + } + } + + let mut hash = if length_bytes >= CHUNK_SIZE { + // We have processed at least one full chunk + let mut hash = v1.rotate_left(1); + hash = hash.wrapping_add(v2.rotate_left(7)); + hash = hash.wrapping_add(v3.rotate_left(12)); + hash = hash.wrapping_add(v4.rotate_left(18)); + + hash = mix_one(hash, v1); + hash = mix_one(hash, v2); + hash = mix_one(hash, v3); + hash = mix_one(hash, v4); + + hash + } else { + seed.wrapping_add(PRIME_5) + }; + + hash = hash.wrapping_add(length_bytes as u64); + + // process u64s + let mut offset_u64 = offset_u64_4 * 4; + while offset_u64 * 8 + 8 <= length_bytes { + let mut k1 = unsafe { + ptr_u64 + .add(offset_u64) + .read_unaligned() + .to_le() + .wrapping_mul(PRIME_2) + }; + k1 = k1.rotate_left(31); + k1 = k1.wrapping_mul(PRIME_1); + hash ^= k1; + hash = hash.rotate_left(27); + hash = hash.wrapping_mul(PRIME_1); + hash = hash.wrapping_add(PRIME_4); + offset_u64 += 1; + } + + // process u32s + let data = &data[offset_u64 * 8..]; + let ptr_u32 = data.as_ptr() as *const u32; + let length_bytes = length_bytes - offset_u64 * 8; + let mut offset_u32 = 0; + while offset_u32 * 4 + 4 <= length_bytes { + let k1 = unsafe { + u64::from(ptr_u32.add(offset_u32).read_unaligned().to_le()).wrapping_mul(PRIME_1) + }; + hash ^= k1; + hash = hash.rotate_left(23); + hash = hash.wrapping_mul(PRIME_2); + hash = hash.wrapping_add(PRIME_3); + offset_u32 += 1; + } + + // process u8s + let data = &data[offset_u32 * 4..]; + let length_bytes = length_bytes - offset_u32 * 4; + let mut offset_u8 = 0; + while offset_u8 < length_bytes { + let k1 = u64::from(data[offset_u8]).wrapping_mul(PRIME_5); + hash ^= k1; + hash = hash.rotate_left(11); + hash = hash.wrapping_mul(PRIME_1); + offset_u8 += 1; + } + + // The final intermixing + hash ^= hash >> 33; + hash = hash.wrapping_mul(PRIME_2); + hash ^= hash >> 29; + hash = hash.wrapping_mul(PRIME_3); + hash ^= hash >> 32; + + hash +} + +#[inline(always)] +fn ingest_one_number(mut current_value: u64, mut value: u64) -> u64 { + value = value.wrapping_mul(PRIME_2); + current_value = current_value.wrapping_add(value); + current_value = current_value.rotate_left(31); + current_value.wrapping_mul(PRIME_1) +} + +#[inline(always)] +fn mix_one(mut hash: u64, mut value: u64) -> u64 { + value = value.wrapping_mul(PRIME_2); + value = value.rotate_left(31); + value = value.wrapping_mul(PRIME_1); + hash ^= value; + hash = hash.wrapping_mul(PRIME_1); + hash.wrapping_add(PRIME_4) +} + +#[cfg(test)] +mod test { + use super::spark_compatible_xxhash64; + use rand::Rng; + use std::hash::Hasher; + use twox_hash::XxHash64; + + #[test] + #[cfg_attr(miri, ignore)] // test takes too long with miri + fn test_xxhash64_random() { + let mut rng = rand::thread_rng(); + for len in 0..128 { + for _ in 0..10 { + let data: Vec = (0..len).map(|_| rng.gen()).collect(); + let seed = rng.gen(); + check_xxhash64(&data, seed); + } + } + } + + fn check_xxhash64(data: &[u8], seed: u64) { + let mut hasher = XxHash64::with_seed(seed); + hasher.write(data.as_ref()); + let hash1 = hasher.finish(); + let hash2 = spark_compatible_xxhash64(data, seed); + if hash1 != hash2 { + panic!("input: {} with seed {seed} produced incorrect hash (comet={hash2}, twox-hash={hash1})", + data.iter().fold(String::new(), |mut output, byte| { + output.push_str(&format!("{:02x}", byte)); + output + })) + } + } +} From 0a003250aa425fe936a5930df74d498f1f2a01bd Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Wed, 31 Jul 2024 10:57:46 -0600 Subject: [PATCH 14/68] chore: Add criterion benchmark for decimal_div (#743) --- Cargo.toml | 4 ++++ benches/decimal_div.rs | 54 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 benches/decimal_div.rs diff --git a/Cargo.toml b/Cargo.toml index a535a2b817e6..96eae39ffbd2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -64,3 +64,7 @@ harness = false [[bench]] name = "conditional" harness = false + +[[bench]] +name = "decimal_div" +harness = false diff --git a/benches/decimal_div.rs b/benches/decimal_div.rs new file mode 100644 index 000000000000..89f06e50532e --- /dev/null +++ b/benches/decimal_div.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::cast; +use arrow_array::builder::Decimal128Builder; +use arrow_schema::DataType; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_comet_spark_expr::scalar_funcs::spark_decimal_div; +use datafusion_expr::ColumnarValue; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + // create input data + let mut c1 = Decimal128Builder::new(); + let mut c2 = Decimal128Builder::new(); + for i in 0..1000 { + c1.append_value(99999999 + i); + c2.append_value(88888888 - i); + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + + let c1_type = DataType::Decimal128(10, 4); + let c1 = cast(c1.as_ref(), &c1_type).unwrap(); + let c2_type = DataType::Decimal128(10, 3); + let c2 = cast(c2.as_ref(), &c2_type).unwrap(); + + let args = [ColumnarValue::Array(c1), ColumnarValue::Array(c2)]; + c.bench_function("decimal_div", |b| { + b.iter(|| { + black_box(spark_decimal_div( + black_box(&args), + black_box(&DataType::Decimal128(10, 4)), + )) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); From abfce136d94e4a07cca77ac79c28175376be36e7 Mon Sep 17 00:00:00 2001 From: Akhil S S <88586412+akhilss99@users.noreply.github.com> Date: Thu, 1 Aug 2024 20:30:58 +0530 Subject: [PATCH 15/68] Add support for time-zone, 3 & 5 digit years: Cast from string to timestamp (#704) --- src/cast.rs | 343 ++++++++++++++++++++++++++++++++++++++-------------- 1 file changed, 251 insertions(+), 92 deletions(-) diff --git a/src/cast.rs b/src/cast.rs index ae0818970f03..e44b1c9f5db4 100644 --- a/src/cast.rs +++ b/src/cast.rs @@ -15,14 +15,6 @@ // specific language governing permissions and limitations // under the License. -use std::{ - any::Any, - fmt::{Debug, Display, Formatter}, - hash::{Hash, Hasher}, - num::Wrapping, - sync::Arc, -}; - use arrow::{ array::{ cast::AsArray, @@ -42,6 +34,14 @@ use arrow::{ }; use arrow_array::DictionaryArray; use arrow_schema::{DataType, Schema}; +use std::str::FromStr; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + hash::{Hash, Hasher}, + num::Wrapping, + sync::Arc, +}; use datafusion_common::{ cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue, @@ -56,6 +56,7 @@ use num::{ }; use regex::Regex; +use crate::timezone; use crate::utils::{array_with_timezone, down_cast_any_ref}; use crate::{EvalMode, SparkError, SparkResult}; @@ -71,6 +72,67 @@ static CAST_OPTIONS: CastOptions = CastOptions { .with_timestamp_format(TIMESTAMP_FORMAT), }; +struct TimeStampInfo { + year: i32, + month: u32, + day: u32, + hour: u32, + minute: u32, + second: u32, + microsecond: u32, +} + +impl Default for TimeStampInfo { + fn default() -> Self { + TimeStampInfo { + year: 1, + month: 1, + day: 1, + hour: 0, + minute: 0, + second: 0, + microsecond: 0, + } + } +} + +impl TimeStampInfo { + pub fn with_year(&mut self, year: i32) -> &mut Self { + self.year = year; + self + } + + pub fn with_month(&mut self, month: u32) -> &mut Self { + self.month = month; + self + } + + pub fn with_day(&mut self, day: u32) -> &mut Self { + self.day = day; + self + } + + pub fn with_hour(&mut self, hour: u32) -> &mut Self { + self.hour = hour; + self + } + + pub fn with_minute(&mut self, minute: u32) -> &mut Self { + self.minute = minute; + self + } + + pub fn with_second(&mut self, second: u32) -> &mut Self { + self.second = second; + self + } + + pub fn with_microsecond(&mut self, microsecond: u32) -> &mut Self { + self.microsecond = microsecond; + self + } +} + #[derive(Debug, Hash)] pub struct Cast { pub child: Arc, @@ -100,13 +162,15 @@ macro_rules! cast_utf8_to_int { }}; } macro_rules! cast_utf8_to_timestamp { - ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ + ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{ let len = $array.len(); let mut cast_array = PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC"); for i in 0..len { if $array.is_null(i) { cast_array.append_null() - } else if let Ok(Some(cast_value)) = $cast_method($array.value(i).trim(), $eval_mode) { + } else if let Ok(Some(cast_value)) = + $cast_method($array.value(i).trim(), $eval_mode, $tz) + { cast_array.append_value(cast_value); } else { cast_array.append_null() @@ -574,7 +638,7 @@ fn cast_array( spark_cast_utf8_to_boolean::(&array, eval_mode) } (DataType::Utf8, DataType::Timestamp(_, _)) => { - cast_string_to_timestamp(&array, to_type, eval_mode) + cast_string_to_timestamp(&array, to_type, eval_mode, &timezone) } (DataType::Utf8, DataType::Date32) => cast_string_to_date(&array, to_type, eval_mode), (DataType::Int64, DataType::Int32) @@ -782,19 +846,23 @@ fn cast_string_to_timestamp( array: &ArrayRef, to_type: &DataType, eval_mode: EvalMode, + timezone_str: &str, ) -> SparkResult { let string_array = array .as_any() .downcast_ref::>() .expect("Expected a string array"); + let tz = &timezone::Tz::from_str(timezone_str).unwrap(); + let cast_array: ArrayRef = match to_type { DataType::Timestamp(_, _) => { cast_utf8_to_timestamp!( string_array, eval_mode, TimestampMicrosecondType, - timestamp_parser + timestamp_parser, + tz ) } _ => unreachable!("Invalid data type {:?} in cast from string", to_type), @@ -1344,7 +1412,11 @@ impl PhysicalExpr for Cast { } } -fn timestamp_parser(value: &str, eval_mode: EvalMode) -> SparkResult> { +fn timestamp_parser( + value: &str, + eval_mode: EvalMode, + tz: &T, +) -> SparkResult> { let value = value.trim(); if value.is_empty() { return Ok(None); @@ -1352,31 +1424,31 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) -> SparkResult // Define regex patterns and corresponding parsing functions let patterns = &[ ( - Regex::new(r"^\d{4}$").unwrap(), - parse_str_to_year_timestamp as fn(&str) -> SparkResult>, + Regex::new(r"^\d{4,5}$").unwrap(), + parse_str_to_year_timestamp as fn(&str, &T) -> SparkResult>, ), ( - Regex::new(r"^\d{4}-\d{2}$").unwrap(), + Regex::new(r"^\d{4,5}-\d{2}$").unwrap(), parse_str_to_month_timestamp, ), ( - Regex::new(r"^\d{4}-\d{2}-\d{2}$").unwrap(), + Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap(), parse_str_to_day_timestamp, ), ( - Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{1,2}$").unwrap(), + Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap(), parse_str_to_hour_timestamp, ), ( - Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(), + Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(), parse_str_to_minute_timestamp, ), ( - Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(), + Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(), parse_str_to_second_timestamp, ), ( - Regex::new(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(), + Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(), parse_str_to_microsecond_timestamp, ), ( @@ -1390,7 +1462,7 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) -> SparkResult // Iterate through patterns and try matching for (pattern, parse_func) in patterns { if pattern.is_match(value) { - timestamp = parse_func(value)?; + timestamp = parse_func(value, tz)?; break; } } @@ -1415,38 +1487,24 @@ fn timestamp_parser(value: &str, eval_mode: EvalMode) -> SparkResult } } -fn parse_ymd_timestamp(year: i32, month: u32, day: u32) -> SparkResult> { - let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, 0, 0, 0); - - // Check if datetime is not None - let utc_datetime = match datetime.single() { - Some(dt) => dt.with_timezone(&chrono::Utc), - None => { - return Err(SparkError::Internal( - "Failed to parse timestamp".to_string(), - )); - } - }; - - Ok(Some(utc_datetime.timestamp_micros())) -} - -fn parse_hms_timestamp( - year: i32, - month: u32, - day: u32, - hour: u32, - minute: u32, - second: u32, - microsecond: u32, +fn parse_timestamp_to_micros( + timestamp_info: &TimeStampInfo, + tz: &T, ) -> SparkResult> { - let datetime = chrono::Utc.with_ymd_and_hms(year, month, day, hour, minute, second); + let datetime = tz.with_ymd_and_hms( + timestamp_info.year, + timestamp_info.month, + timestamp_info.day, + timestamp_info.hour, + timestamp_info.minute, + timestamp_info.second, + ); // Check if datetime is not None - let utc_datetime = match datetime.single() { + let tz_datetime = match datetime.single() { Some(dt) => dt - .with_timezone(&chrono::Utc) - .with_nanosecond(microsecond * 1000), + .with_timezone(tz) + .with_nanosecond(timestamp_info.microsecond * 1000), None => { return Err(SparkError::Internal( "Failed to parse timestamp".to_string(), @@ -1454,7 +1512,7 @@ fn parse_hms_timestamp( } }; - let result = match utc_datetime { + let result = match tz_datetime { Some(dt) => dt.timestamp_micros(), None => { return Err(SparkError::Internal( @@ -1466,7 +1524,11 @@ fn parse_hms_timestamp( Ok(Some(result)) } -fn get_timestamp_values(value: &str, timestamp_type: &str) -> SparkResult> { +fn get_timestamp_values( + value: &str, + timestamp_type: &str, + tz: &T, +) -> SparkResult> { let values: Vec<_> = value .split(|c| c == 'T' || c == '-' || c == ':' || c == '.') .collect(); @@ -1478,64 +1540,99 @@ fn get_timestamp_values(value: &str, timestamp_type: &str) -> SparkResult - logo - - -DataFusion is an extensible query engine written in [Rust] that -uses [Apache Arrow] as its in-memory format. - -This crate provides libraries and binaries for developers building fast and -feature rich database and analytic systems, customized to particular workloads. -See [use cases] for examples. The following related subprojects target end users: - -- [DataFusion Python](https://github.com/apache/datafusion-python/) offers a Python interface for SQL and DataFrame - queries. -- [DataFusion Ray](https://github.com/apache/datafusion-ray/) provides a distributed version of DataFusion that scales - out on Ray clusters. -- [DataFusion Comet](https://github.com/apache/datafusion-comet/) is an accelerator for Apache Spark based on - DataFusion. - -"Out of the box," -DataFusion offers [SQL] and [`Dataframe`] APIs, excellent [performance], -built-in support for CSV, Parquet, JSON, and Avro, extensive customization, and -a great community. - -DataFusion features a full query planner, a columnar, streaming, multi-threaded, -vectorized execution engine, and partitioned data sources. You can -customize DataFusion at almost all points including additional data sources, -query languages, functions, custom operators and more. -See the [Architecture] section for more details. - -[rust]: http://rustlang.org -[apache arrow]: https://arrow.apache.org -[use cases]: https://datafusion.apache.org/user-guide/introduction.html#use-cases -[python bindings]: https://github.com/apache/datafusion-python -[performance]: https://benchmark.clickhouse.com/ -[architecture]: https://datafusion.apache.org/contributor-guide/architecture.html - -Here are links to some important information - -- [Project Site](https://datafusion.apache.org/) -- [Installation](https://datafusion.apache.org/user-guide/cli/installation.html) -- [Rust Getting Started](https://datafusion.apache.org/user-guide/example-usage.html) -- [Rust DataFrame API](https://datafusion.apache.org/user-guide/dataframe.html) -- [Rust API docs](https://docs.rs/datafusion/latest/datafusion) -- [Rust Examples](https://github.com/apache/datafusion/tree/main/datafusion-examples) -- [Python DataFrame API](https://arrow.apache.org/datafusion-python/) -- [Architecture](https://docs.rs/datafusion/latest/datafusion/index.html#architecture) - -## What can you do with this crate? - -DataFusion is great for building projects such as domain specific query engines, new database platforms and data pipelines, query languages and more. -It lets you start quickly from a fully working engine, and then customize those features specific to your use. [Click Here](https://datafusion.apache.org/user-guide/introduction.html#known-users) to see a list known users. - -## Contributing to DataFusion - -Please see the [contributor guide] and [communication] pages for more information. - -[contributor guide]: https://datafusion.apache.org/contributor-guide -[communication]: https://datafusion.apache.org/contributor-guide/communication.html - -## Crate features - -This crate has several [features] which can be specified in your `Cargo.toml`. - -[features]: https://doc.rust-lang.org/cargo/reference/features.html - -Default features: - -- `nested_expressions`: functions for working with nested type function such as `array_to_string` -- `compression`: reading files compressed with `xz2`, `bzip2`, `flate2`, and `zstd` -- `crypto_expressions`: cryptographic functions such as `md5` and `sha256` -- `datetime_expressions`: date and time functions such as `to_timestamp` -- `encoding_expressions`: `encode` and `decode` functions -- `parquet`: support for reading the [Apache Parquet] format -- `regex_expressions`: regular expression functions, such as `regexp_match` -- `unicode_expressions`: Include unicode aware functions such as `character_length` -- `unparser`: enables support to reverse LogicalPlans back into SQL -- `recursive_protection`: uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection. - -Optional features: - -- `avro`: support for reading the [Apache Avro] format -- `backtrace`: include backtrace information in error messages -- `pyarrow`: conversions between PyArrow and DataFusion types -- `serde`: enable arrow-schema's `serde` feature - -[apache avro]: https://avro.apache.org/ -[apache parquet]: https://parquet.apache.org/ - -## Rust Version Compatibility Policy - -The Rust toolchain releases are tracked at [Rust Versions](https://releases.rs) and follow -[semantic versioning](https://semver.org/). A Rust toolchain release can be identified -by a version string like `1.80.0`, or more generally `major.minor.patch`. - -DataFusion's supports the last 4 stable Rust minor versions released and any such versions released within the last 4 months. - -For example, given the releases `1.78.0`, `1.79.0`, `1.80.0`, `1.80.1` and `1.81.0` DataFusion will support 1.78.0, which is 3 minor versions prior to the most minor recent `1.81`. - -Note: If a Rust hotfix is released for the current MSRV, the MSRV will be updated to the specific minor version that includes all applicable hotfixes preceding other policies. - -DataFusion enforces MSRV policy using a [MSRV CI Check](https://github.com/search?q=repo%3Aapache%2Fdatafusion+rust-version+language%3ATOML+path%3A%2F%5ECargo.toml%2F&type=code) - -## DataFusion API Evolution and Deprecation Guidelines - -Public methods in Apache DataFusion evolve over time: while we try to maintain a -stable API, we also improve the API over time. As a result, we typically -deprecate methods before removing them, according to the [deprecation guidelines]. - -[deprecation guidelines]: https://datafusion.apache.org/library-user-guide/api-health.html - -## Dependencies and a `Cargo.lock` - -`datafusion` is intended for use as a library and thus purposely does not have a -`Cargo.lock` file checked in. You can read more about the distinction in the -[Cargo book]. - -CI tests always run against the latest compatible versions of all dependencies -(the equivalent of doing `cargo update`), as suggested in the [Cargo CI guide] -and we rely on Dependabot for other upgrades. This strategy has two problems -that occasionally arise: - -1. CI failures when downstream libraries upgrade in some non compatible way -2. Local development builds that fail when DataFusion inadvertently relies on - a feature in a newer version of a dependency than declared in `Cargo.toml` - (e.g. a new method is added to a trait that we use). - -However, we think the current strategy is the best tradeoff between maintenance -overhead and user experience and ensures DataFusion always works with the latest -compatible versions of all dependencies. If you encounter either of these -problems, please open an issue or PR. - -[cargo book]: https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html - -# [cargo ci guide]: https://doc.rust-lang.org/cargo/guide/continuous-integration.html#verifying-latest-dependencies - -# datafusion-comet-spark-expr: Spark-compatible Expressions - -This crate provides Apache Spark-compatible expressions for use with DataFusion and is maintained as part of the -[Apache DataFusion Comet](https://github.com/apache/datafusion-comet/) subproject. +# datafusion-functions-spark: Spark-compatible Expressions -> > > > > > > comet/main +This crate provides Apache Spark-compatible expressions for use with DataFusion. From d7eaf68978629c764f963ed4d114e5b757623a4f Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 31 Jan 2025 14:26:19 -0700 Subject: [PATCH 66/68] fix cargo doc failures --- datafusion/functions-spark/src/datetime_funcs/date_trunc.rs | 2 +- .../functions-spark/src/datetime_funcs/timestamp_trunc.rs | 4 ++-- datafusion/functions-spark/src/predicate_funcs/rlike.rs | 2 +- datafusion/functions-spark/src/utils.rs | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/datafusion/functions-spark/src/datetime_funcs/date_trunc.rs b/datafusion/functions-spark/src/datetime_funcs/date_trunc.rs index 5c044945d04c..a3b06e6a1c0f 100644 --- a/datafusion/functions-spark/src/datetime_funcs/date_trunc.rs +++ b/datafusion/functions-spark/src/datetime_funcs/date_trunc.rs @@ -33,7 +33,7 @@ use crate::kernels::temporal::{date_trunc_array_fmt_dyn, date_trunc_dyn}; pub struct DateTruncExpr { /// An array with DataType::Date32 child: Arc, - /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#trunc + /// Scalar UTF8 string matching the valid values in Spark SQL: format: Arc, } diff --git a/datafusion/functions-spark/src/datetime_funcs/timestamp_trunc.rs b/datafusion/functions-spark/src/datetime_funcs/timestamp_trunc.rs index 349992322f9b..bca9b8e8daab 100644 --- a/datafusion/functions-spark/src/datetime_funcs/timestamp_trunc.rs +++ b/datafusion/functions-spark/src/datetime_funcs/timestamp_trunc.rs @@ -34,10 +34,10 @@ use crate::kernels::temporal::{timestamp_trunc_array_fmt_dyn, timestamp_trunc_dy pub struct TimestampTruncExpr { /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) child: Arc, - /// Scalar UTF8 string matching the valid values in Spark SQL: https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc + /// Scalar UTF8 string matching the valid values in Spark SQL: format: Arc, /// String containing a timezone name. The name must be found in the standard timezone - /// database (https://en.wikipedia.org/wiki/List_of_tz_database_time_zones). The string is + /// database (). The string is /// later parsed into a chrono::TimeZone. /// Timestamp arrays in this implementation are kept in arrays of UTC timestamps (in micros) /// along with a single value for the associated TimeZone. The timezone offset is applied diff --git a/datafusion/functions-spark/src/predicate_funcs/rlike.rs b/datafusion/functions-spark/src/predicate_funcs/rlike.rs index 7b67b0099c37..bfee0cc769cb 100644 --- a/datafusion/functions-spark/src/predicate_funcs/rlike.rs +++ b/datafusion/functions-spark/src/predicate_funcs/rlike.rs @@ -38,7 +38,7 @@ use std::sync::Arc; /// differences in whitespace handling and does not support all the features of Java's /// regular expression engine, which are documented at: /// -/// https://docs.oracle.com/javase/8/docs/api/java/util/regex/Pattern.html +/// #[derive(Debug)] pub struct RLike { child: Arc, diff --git a/datafusion/functions-spark/src/utils.rs b/datafusion/functions-spark/src/utils.rs index d6090014d05a..37d633e52549 100644 --- a/datafusion/functions-spark/src/utils.rs +++ b/datafusion/functions-spark/src/utils.rs @@ -227,7 +227,7 @@ fn pre_timestamp_cast(array: ArrayRef, timezone: String) -> Result #[inline] pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool { precision <= DECIMAL128_MAX_PRECISION From 96f21363a3b4114b2e1a99d7f0e5d2c6dad94d53 Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 31 Jan 2025 14:30:08 -0700 Subject: [PATCH 67/68] taplo fmt --- datafusion/functions-spark/Cargo.toml | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/datafusion/functions-spark/Cargo.toml b/datafusion/functions-spark/Cargo.toml index 08361c06b0dc..eaa48845820f 100644 --- a/datafusion/functions-spark/Cargo.toml +++ b/datafusion/functions-spark/Cargo.toml @@ -33,27 +33,26 @@ arrow-buffer = { workspace = true } arrow-data = { workspace = true } arrow-schema = { workspace = true } chrono = { workspace = true } +chrono-tz = "0.10.1" datafusion = { workspace = true, features = ["parquet"] } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } datafusion-expr-common = { workspace = true } datafusion-physical-expr = { workspace = true } -chrono-tz = "0.10.1" +futures = { workspace = true } num = "0.4.3" +rand = { workspace = true } regex = { workspace = true } thiserror = "2.0.11" -futures = { workspace = true } twox-hash = "2.0.0" -rand = { workspace = true } [dev-dependencies] -arrow-data = {workspace = true} -parquet = { workspace = true, features = ["arrow"] } +arrow-data = { workspace = true } criterion = "0.5.1" -rand = { workspace = true} +parquet = { workspace = true, features = ["arrow"] } +rand = { workspace = true } tokio = { version = "1", features = ["rt-multi-thread"] } - [lib] name = "datafusion_comet_spark_expr" path = "src/lib.rs" @@ -77,4 +76,3 @@ harness = false [[bench]] name = "aggregate" harness = false - From 77e7831897a4b38581c40730128ccb350e72af2a Mon Sep 17 00:00:00 2001 From: Andy Grove Date: Fri, 31 Jan 2025 16:04:07 -0700 Subject: [PATCH 68/68] rename functions-spark to spark --- Cargo.toml | 2 +- datafusion/{functions-spark => spark}/Cargo.toml | 2 +- datafusion/{functions-spark => spark}/README.md | 2 +- datafusion/{functions-spark => spark}/benches/aggregate.rs | 0 .../{functions-spark => spark}/benches/cast_from_string.rs | 0 datafusion/{functions-spark => spark}/benches/cast_numeric.rs | 0 datafusion/{functions-spark => spark}/benches/conditional.rs | 0 datafusion/{functions-spark => spark}/benches/decimal_div.rs | 0 datafusion/{functions-spark => spark}/src/agg_funcs/avg.rs | 0 .../{functions-spark => spark}/src/agg_funcs/avg_decimal.rs | 0 .../{functions-spark => spark}/src/agg_funcs/correlation.rs | 0 .../{functions-spark => spark}/src/agg_funcs/covariance.rs | 0 datafusion/{functions-spark => spark}/src/agg_funcs/mod.rs | 0 datafusion/{functions-spark => spark}/src/agg_funcs/stddev.rs | 0 .../{functions-spark => spark}/src/agg_funcs/sum_decimal.rs | 0 datafusion/{functions-spark => spark}/src/agg_funcs/variance.rs | 0 .../{functions-spark => spark}/src/array_funcs/array_insert.rs | 0 .../src/array_funcs/get_array_struct_fields.rs | 0 .../{functions-spark => spark}/src/array_funcs/list_extract.rs | 0 datafusion/{functions-spark => spark}/src/array_funcs/mod.rs | 0 .../{functions-spark => spark}/src/bitwise_funcs/bitwise_not.rs | 0 datafusion/{functions-spark => spark}/src/bitwise_funcs/mod.rs | 0 datafusion/{functions-spark => spark}/src/comet_scalar_funcs.rs | 0 .../{functions-spark => spark}/src/conditional_funcs/if_expr.rs | 0 .../{functions-spark => spark}/src/conditional_funcs/mod.rs | 0 .../{functions-spark => spark}/src/conversion_funcs/cast.rs | 0 .../{functions-spark => spark}/src/conversion_funcs/mod.rs | 0 .../src/datetime_funcs/date_arithmetic.rs | 0 .../{functions-spark => spark}/src/datetime_funcs/date_trunc.rs | 0 .../{functions-spark => spark}/src/datetime_funcs/hour.rs | 0 .../{functions-spark => spark}/src/datetime_funcs/minute.rs | 0 datafusion/{functions-spark => spark}/src/datetime_funcs/mod.rs | 0 .../{functions-spark => spark}/src/datetime_funcs/second.rs | 0 .../src/datetime_funcs/timestamp_trunc.rs | 0 datafusion/{functions-spark => spark}/src/error.rs | 0 datafusion/{functions-spark => spark}/src/hash_funcs/mod.rs | 0 datafusion/{functions-spark => spark}/src/hash_funcs/murmur3.rs | 0 datafusion/{functions-spark => spark}/src/hash_funcs/sha2.rs | 0 datafusion/{functions-spark => spark}/src/hash_funcs/utils.rs | 0 .../{functions-spark => spark}/src/hash_funcs/xxhash64.rs | 0 datafusion/{functions-spark => spark}/src/json_funcs/mod.rs | 0 datafusion/{functions-spark => spark}/src/json_funcs/to_json.rs | 0 datafusion/{functions-spark => spark}/src/kernels/mod.rs | 0 datafusion/{functions-spark => spark}/src/kernels/strings.rs | 0 datafusion/{functions-spark => spark}/src/kernels/temporal.rs | 0 datafusion/{functions-spark => spark}/src/lib.rs | 0 datafusion/{functions-spark => spark}/src/math_funcs/ceil.rs | 0 datafusion/{functions-spark => spark}/src/math_funcs/div.rs | 0 datafusion/{functions-spark => spark}/src/math_funcs/floor.rs | 0 datafusion/{functions-spark => spark}/src/math_funcs/hex.rs | 0 .../src/math_funcs/internal/checkoverflow.rs | 0 .../src/math_funcs/internal/make_decimal.rs | 0 .../{functions-spark => spark}/src/math_funcs/internal/mod.rs | 0 .../src/math_funcs/internal/normalize_nan.rs | 0 .../src/math_funcs/internal/unscaled_value.rs | 0 datafusion/{functions-spark => spark}/src/math_funcs/mod.rs | 0 .../{functions-spark => spark}/src/math_funcs/negative.rs | 0 datafusion/{functions-spark => spark}/src/math_funcs/round.rs | 0 datafusion/{functions-spark => spark}/src/math_funcs/unhex.rs | 0 datafusion/{functions-spark => spark}/src/math_funcs/utils.rs | 0 .../{functions-spark => spark}/src/predicate_funcs/is_nan.rs | 0 .../{functions-spark => spark}/src/predicate_funcs/mod.rs | 0 .../{functions-spark => spark}/src/predicate_funcs/rlike.rs | 0 .../src/static_invoke/char_varchar_utils/mod.rs | 0 .../src/static_invoke/char_varchar_utils/read_side_padding.rs | 0 datafusion/{functions-spark => spark}/src/static_invoke/mod.rs | 0 datafusion/{functions-spark => spark}/src/string_funcs/chr.rs | 0 datafusion/{functions-spark => spark}/src/string_funcs/mod.rs | 0 .../{functions-spark => spark}/src/string_funcs/prediction.rs | 0 .../{functions-spark => spark}/src/string_funcs/string_space.rs | 0 .../{functions-spark => spark}/src/string_funcs/substring.rs | 0 .../src/struct_funcs/create_named_struct.rs | 0 .../src/struct_funcs/get_struct_field.rs | 0 datafusion/{functions-spark => spark}/src/struct_funcs/mod.rs | 0 .../{functions-spark => spark}/src/test_common/file_util.rs | 0 datafusion/{functions-spark => spark}/src/test_common/mod.rs | 0 datafusion/{functions-spark => spark}/src/timezone.rs | 0 datafusion/{functions-spark => spark}/src/unbound.rs | 0 datafusion/{functions-spark => spark}/src/utils.rs | 0 79 files changed, 3 insertions(+), 3 deletions(-) rename datafusion/{functions-spark => spark}/Cargo.toml (98%) rename datafusion/{functions-spark => spark}/README.md (93%) rename datafusion/{functions-spark => spark}/benches/aggregate.rs (100%) rename datafusion/{functions-spark => spark}/benches/cast_from_string.rs (100%) rename datafusion/{functions-spark => spark}/benches/cast_numeric.rs (100%) rename datafusion/{functions-spark => spark}/benches/conditional.rs (100%) rename datafusion/{functions-spark => spark}/benches/decimal_div.rs (100%) rename datafusion/{functions-spark => spark}/src/agg_funcs/avg.rs (100%) rename datafusion/{functions-spark => spark}/src/agg_funcs/avg_decimal.rs (100%) rename datafusion/{functions-spark => spark}/src/agg_funcs/correlation.rs (100%) rename datafusion/{functions-spark => spark}/src/agg_funcs/covariance.rs (100%) rename datafusion/{functions-spark => spark}/src/agg_funcs/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/agg_funcs/stddev.rs (100%) rename datafusion/{functions-spark => spark}/src/agg_funcs/sum_decimal.rs (100%) rename datafusion/{functions-spark => spark}/src/agg_funcs/variance.rs (100%) rename datafusion/{functions-spark => spark}/src/array_funcs/array_insert.rs (100%) rename datafusion/{functions-spark => spark}/src/array_funcs/get_array_struct_fields.rs (100%) rename datafusion/{functions-spark => spark}/src/array_funcs/list_extract.rs (100%) rename datafusion/{functions-spark => spark}/src/array_funcs/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/bitwise_funcs/bitwise_not.rs (100%) rename datafusion/{functions-spark => spark}/src/bitwise_funcs/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/comet_scalar_funcs.rs (100%) rename datafusion/{functions-spark => spark}/src/conditional_funcs/if_expr.rs (100%) rename datafusion/{functions-spark => spark}/src/conditional_funcs/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/conversion_funcs/cast.rs (100%) rename datafusion/{functions-spark => spark}/src/conversion_funcs/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/datetime_funcs/date_arithmetic.rs (100%) rename datafusion/{functions-spark => spark}/src/datetime_funcs/date_trunc.rs (100%) rename datafusion/{functions-spark => spark}/src/datetime_funcs/hour.rs (100%) rename datafusion/{functions-spark => spark}/src/datetime_funcs/minute.rs (100%) rename datafusion/{functions-spark => spark}/src/datetime_funcs/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/datetime_funcs/second.rs (100%) rename datafusion/{functions-spark => spark}/src/datetime_funcs/timestamp_trunc.rs (100%) rename datafusion/{functions-spark => spark}/src/error.rs (100%) rename datafusion/{functions-spark => spark}/src/hash_funcs/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/hash_funcs/murmur3.rs (100%) rename datafusion/{functions-spark => spark}/src/hash_funcs/sha2.rs (100%) rename datafusion/{functions-spark => spark}/src/hash_funcs/utils.rs (100%) rename datafusion/{functions-spark => spark}/src/hash_funcs/xxhash64.rs (100%) rename datafusion/{functions-spark => spark}/src/json_funcs/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/json_funcs/to_json.rs (100%) rename datafusion/{functions-spark => spark}/src/kernels/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/kernels/strings.rs (100%) rename datafusion/{functions-spark => spark}/src/kernels/temporal.rs (100%) rename datafusion/{functions-spark => spark}/src/lib.rs (100%) rename datafusion/{functions-spark => spark}/src/math_funcs/ceil.rs (100%) rename datafusion/{functions-spark => spark}/src/math_funcs/div.rs (100%) rename datafusion/{functions-spark => spark}/src/math_funcs/floor.rs (100%) rename datafusion/{functions-spark => spark}/src/math_funcs/hex.rs (100%) rename datafusion/{functions-spark => spark}/src/math_funcs/internal/checkoverflow.rs (100%) rename datafusion/{functions-spark => spark}/src/math_funcs/internal/make_decimal.rs (100%) rename datafusion/{functions-spark => spark}/src/math_funcs/internal/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/math_funcs/internal/normalize_nan.rs (100%) rename datafusion/{functions-spark => spark}/src/math_funcs/internal/unscaled_value.rs (100%) rename datafusion/{functions-spark => spark}/src/math_funcs/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/math_funcs/negative.rs (100%) rename datafusion/{functions-spark => spark}/src/math_funcs/round.rs (100%) rename datafusion/{functions-spark => spark}/src/math_funcs/unhex.rs (100%) rename datafusion/{functions-spark => spark}/src/math_funcs/utils.rs (100%) rename datafusion/{functions-spark => spark}/src/predicate_funcs/is_nan.rs (100%) rename datafusion/{functions-spark => spark}/src/predicate_funcs/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/predicate_funcs/rlike.rs (100%) rename datafusion/{functions-spark => spark}/src/static_invoke/char_varchar_utils/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/static_invoke/char_varchar_utils/read_side_padding.rs (100%) rename datafusion/{functions-spark => spark}/src/static_invoke/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/string_funcs/chr.rs (100%) rename datafusion/{functions-spark => spark}/src/string_funcs/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/string_funcs/prediction.rs (100%) rename datafusion/{functions-spark => spark}/src/string_funcs/string_space.rs (100%) rename datafusion/{functions-spark => spark}/src/string_funcs/substring.rs (100%) rename datafusion/{functions-spark => spark}/src/struct_funcs/create_named_struct.rs (100%) rename datafusion/{functions-spark => spark}/src/struct_funcs/get_struct_field.rs (100%) rename datafusion/{functions-spark => spark}/src/struct_funcs/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/test_common/file_util.rs (100%) rename datafusion/{functions-spark => spark}/src/test_common/mod.rs (100%) rename datafusion/{functions-spark => spark}/src/timezone.rs (100%) rename datafusion/{functions-spark => spark}/src/unbound.rs (100%) rename datafusion/{functions-spark => spark}/src/utils.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index 63b9c0d3315e..e8f94885e79e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -32,7 +32,6 @@ members = [ "datafusion/functions-aggregate-common", "datafusion/functions-table", "datafusion/functions-nested", - "datafusion/functions-spark", "datafusion/functions-window", "datafusion/functions-window-common", "datafusion/optimizer", @@ -44,6 +43,7 @@ members = [ "datafusion/proto/gen", "datafusion/proto-common", "datafusion/proto-common/gen", + "datafusion/spark", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", diff --git a/datafusion/functions-spark/Cargo.toml b/datafusion/spark/Cargo.toml similarity index 98% rename from datafusion/functions-spark/Cargo.toml rename to datafusion/spark/Cargo.toml index eaa48845820f..be5a897d1648 100644 --- a/datafusion/functions-spark/Cargo.toml +++ b/datafusion/spark/Cargo.toml @@ -16,7 +16,7 @@ # under the License. [package] -name = "datafusion-functions-spark" +name = "datafusion-spark" description = "DataFusion expressions that emulate Apache Spark's behavior" version = { workspace = true } homepage = { workspace = true } diff --git a/datafusion/functions-spark/README.md b/datafusion/spark/README.md similarity index 93% rename from datafusion/functions-spark/README.md rename to datafusion/spark/README.md index 42c7f4d2a8e4..afd94d2c0690 100644 --- a/datafusion/functions-spark/README.md +++ b/datafusion/spark/README.md @@ -17,6 +17,6 @@ specific language governing permissions and limitations under the License. --> -# datafusion-functions-spark: Spark-compatible Expressions +# datafusion-spark: Spark-compatible Expressions This crate provides Apache Spark-compatible expressions for use with DataFusion. diff --git a/datafusion/functions-spark/benches/aggregate.rs b/datafusion/spark/benches/aggregate.rs similarity index 100% rename from datafusion/functions-spark/benches/aggregate.rs rename to datafusion/spark/benches/aggregate.rs diff --git a/datafusion/functions-spark/benches/cast_from_string.rs b/datafusion/spark/benches/cast_from_string.rs similarity index 100% rename from datafusion/functions-spark/benches/cast_from_string.rs rename to datafusion/spark/benches/cast_from_string.rs diff --git a/datafusion/functions-spark/benches/cast_numeric.rs b/datafusion/spark/benches/cast_numeric.rs similarity index 100% rename from datafusion/functions-spark/benches/cast_numeric.rs rename to datafusion/spark/benches/cast_numeric.rs diff --git a/datafusion/functions-spark/benches/conditional.rs b/datafusion/spark/benches/conditional.rs similarity index 100% rename from datafusion/functions-spark/benches/conditional.rs rename to datafusion/spark/benches/conditional.rs diff --git a/datafusion/functions-spark/benches/decimal_div.rs b/datafusion/spark/benches/decimal_div.rs similarity index 100% rename from datafusion/functions-spark/benches/decimal_div.rs rename to datafusion/spark/benches/decimal_div.rs diff --git a/datafusion/functions-spark/src/agg_funcs/avg.rs b/datafusion/spark/src/agg_funcs/avg.rs similarity index 100% rename from datafusion/functions-spark/src/agg_funcs/avg.rs rename to datafusion/spark/src/agg_funcs/avg.rs diff --git a/datafusion/functions-spark/src/agg_funcs/avg_decimal.rs b/datafusion/spark/src/agg_funcs/avg_decimal.rs similarity index 100% rename from datafusion/functions-spark/src/agg_funcs/avg_decimal.rs rename to datafusion/spark/src/agg_funcs/avg_decimal.rs diff --git a/datafusion/functions-spark/src/agg_funcs/correlation.rs b/datafusion/spark/src/agg_funcs/correlation.rs similarity index 100% rename from datafusion/functions-spark/src/agg_funcs/correlation.rs rename to datafusion/spark/src/agg_funcs/correlation.rs diff --git a/datafusion/functions-spark/src/agg_funcs/covariance.rs b/datafusion/spark/src/agg_funcs/covariance.rs similarity index 100% rename from datafusion/functions-spark/src/agg_funcs/covariance.rs rename to datafusion/spark/src/agg_funcs/covariance.rs diff --git a/datafusion/functions-spark/src/agg_funcs/mod.rs b/datafusion/spark/src/agg_funcs/mod.rs similarity index 100% rename from datafusion/functions-spark/src/agg_funcs/mod.rs rename to datafusion/spark/src/agg_funcs/mod.rs diff --git a/datafusion/functions-spark/src/agg_funcs/stddev.rs b/datafusion/spark/src/agg_funcs/stddev.rs similarity index 100% rename from datafusion/functions-spark/src/agg_funcs/stddev.rs rename to datafusion/spark/src/agg_funcs/stddev.rs diff --git a/datafusion/functions-spark/src/agg_funcs/sum_decimal.rs b/datafusion/spark/src/agg_funcs/sum_decimal.rs similarity index 100% rename from datafusion/functions-spark/src/agg_funcs/sum_decimal.rs rename to datafusion/spark/src/agg_funcs/sum_decimal.rs diff --git a/datafusion/functions-spark/src/agg_funcs/variance.rs b/datafusion/spark/src/agg_funcs/variance.rs similarity index 100% rename from datafusion/functions-spark/src/agg_funcs/variance.rs rename to datafusion/spark/src/agg_funcs/variance.rs diff --git a/datafusion/functions-spark/src/array_funcs/array_insert.rs b/datafusion/spark/src/array_funcs/array_insert.rs similarity index 100% rename from datafusion/functions-spark/src/array_funcs/array_insert.rs rename to datafusion/spark/src/array_funcs/array_insert.rs diff --git a/datafusion/functions-spark/src/array_funcs/get_array_struct_fields.rs b/datafusion/spark/src/array_funcs/get_array_struct_fields.rs similarity index 100% rename from datafusion/functions-spark/src/array_funcs/get_array_struct_fields.rs rename to datafusion/spark/src/array_funcs/get_array_struct_fields.rs diff --git a/datafusion/functions-spark/src/array_funcs/list_extract.rs b/datafusion/spark/src/array_funcs/list_extract.rs similarity index 100% rename from datafusion/functions-spark/src/array_funcs/list_extract.rs rename to datafusion/spark/src/array_funcs/list_extract.rs diff --git a/datafusion/functions-spark/src/array_funcs/mod.rs b/datafusion/spark/src/array_funcs/mod.rs similarity index 100% rename from datafusion/functions-spark/src/array_funcs/mod.rs rename to datafusion/spark/src/array_funcs/mod.rs diff --git a/datafusion/functions-spark/src/bitwise_funcs/bitwise_not.rs b/datafusion/spark/src/bitwise_funcs/bitwise_not.rs similarity index 100% rename from datafusion/functions-spark/src/bitwise_funcs/bitwise_not.rs rename to datafusion/spark/src/bitwise_funcs/bitwise_not.rs diff --git a/datafusion/functions-spark/src/bitwise_funcs/mod.rs b/datafusion/spark/src/bitwise_funcs/mod.rs similarity index 100% rename from datafusion/functions-spark/src/bitwise_funcs/mod.rs rename to datafusion/spark/src/bitwise_funcs/mod.rs diff --git a/datafusion/functions-spark/src/comet_scalar_funcs.rs b/datafusion/spark/src/comet_scalar_funcs.rs similarity index 100% rename from datafusion/functions-spark/src/comet_scalar_funcs.rs rename to datafusion/spark/src/comet_scalar_funcs.rs diff --git a/datafusion/functions-spark/src/conditional_funcs/if_expr.rs b/datafusion/spark/src/conditional_funcs/if_expr.rs similarity index 100% rename from datafusion/functions-spark/src/conditional_funcs/if_expr.rs rename to datafusion/spark/src/conditional_funcs/if_expr.rs diff --git a/datafusion/functions-spark/src/conditional_funcs/mod.rs b/datafusion/spark/src/conditional_funcs/mod.rs similarity index 100% rename from datafusion/functions-spark/src/conditional_funcs/mod.rs rename to datafusion/spark/src/conditional_funcs/mod.rs diff --git a/datafusion/functions-spark/src/conversion_funcs/cast.rs b/datafusion/spark/src/conversion_funcs/cast.rs similarity index 100% rename from datafusion/functions-spark/src/conversion_funcs/cast.rs rename to datafusion/spark/src/conversion_funcs/cast.rs diff --git a/datafusion/functions-spark/src/conversion_funcs/mod.rs b/datafusion/spark/src/conversion_funcs/mod.rs similarity index 100% rename from datafusion/functions-spark/src/conversion_funcs/mod.rs rename to datafusion/spark/src/conversion_funcs/mod.rs diff --git a/datafusion/functions-spark/src/datetime_funcs/date_arithmetic.rs b/datafusion/spark/src/datetime_funcs/date_arithmetic.rs similarity index 100% rename from datafusion/functions-spark/src/datetime_funcs/date_arithmetic.rs rename to datafusion/spark/src/datetime_funcs/date_arithmetic.rs diff --git a/datafusion/functions-spark/src/datetime_funcs/date_trunc.rs b/datafusion/spark/src/datetime_funcs/date_trunc.rs similarity index 100% rename from datafusion/functions-spark/src/datetime_funcs/date_trunc.rs rename to datafusion/spark/src/datetime_funcs/date_trunc.rs diff --git a/datafusion/functions-spark/src/datetime_funcs/hour.rs b/datafusion/spark/src/datetime_funcs/hour.rs similarity index 100% rename from datafusion/functions-spark/src/datetime_funcs/hour.rs rename to datafusion/spark/src/datetime_funcs/hour.rs diff --git a/datafusion/functions-spark/src/datetime_funcs/minute.rs b/datafusion/spark/src/datetime_funcs/minute.rs similarity index 100% rename from datafusion/functions-spark/src/datetime_funcs/minute.rs rename to datafusion/spark/src/datetime_funcs/minute.rs diff --git a/datafusion/functions-spark/src/datetime_funcs/mod.rs b/datafusion/spark/src/datetime_funcs/mod.rs similarity index 100% rename from datafusion/functions-spark/src/datetime_funcs/mod.rs rename to datafusion/spark/src/datetime_funcs/mod.rs diff --git a/datafusion/functions-spark/src/datetime_funcs/second.rs b/datafusion/spark/src/datetime_funcs/second.rs similarity index 100% rename from datafusion/functions-spark/src/datetime_funcs/second.rs rename to datafusion/spark/src/datetime_funcs/second.rs diff --git a/datafusion/functions-spark/src/datetime_funcs/timestamp_trunc.rs b/datafusion/spark/src/datetime_funcs/timestamp_trunc.rs similarity index 100% rename from datafusion/functions-spark/src/datetime_funcs/timestamp_trunc.rs rename to datafusion/spark/src/datetime_funcs/timestamp_trunc.rs diff --git a/datafusion/functions-spark/src/error.rs b/datafusion/spark/src/error.rs similarity index 100% rename from datafusion/functions-spark/src/error.rs rename to datafusion/spark/src/error.rs diff --git a/datafusion/functions-spark/src/hash_funcs/mod.rs b/datafusion/spark/src/hash_funcs/mod.rs similarity index 100% rename from datafusion/functions-spark/src/hash_funcs/mod.rs rename to datafusion/spark/src/hash_funcs/mod.rs diff --git a/datafusion/functions-spark/src/hash_funcs/murmur3.rs b/datafusion/spark/src/hash_funcs/murmur3.rs similarity index 100% rename from datafusion/functions-spark/src/hash_funcs/murmur3.rs rename to datafusion/spark/src/hash_funcs/murmur3.rs diff --git a/datafusion/functions-spark/src/hash_funcs/sha2.rs b/datafusion/spark/src/hash_funcs/sha2.rs similarity index 100% rename from datafusion/functions-spark/src/hash_funcs/sha2.rs rename to datafusion/spark/src/hash_funcs/sha2.rs diff --git a/datafusion/functions-spark/src/hash_funcs/utils.rs b/datafusion/spark/src/hash_funcs/utils.rs similarity index 100% rename from datafusion/functions-spark/src/hash_funcs/utils.rs rename to datafusion/spark/src/hash_funcs/utils.rs diff --git a/datafusion/functions-spark/src/hash_funcs/xxhash64.rs b/datafusion/spark/src/hash_funcs/xxhash64.rs similarity index 100% rename from datafusion/functions-spark/src/hash_funcs/xxhash64.rs rename to datafusion/spark/src/hash_funcs/xxhash64.rs diff --git a/datafusion/functions-spark/src/json_funcs/mod.rs b/datafusion/spark/src/json_funcs/mod.rs similarity index 100% rename from datafusion/functions-spark/src/json_funcs/mod.rs rename to datafusion/spark/src/json_funcs/mod.rs diff --git a/datafusion/functions-spark/src/json_funcs/to_json.rs b/datafusion/spark/src/json_funcs/to_json.rs similarity index 100% rename from datafusion/functions-spark/src/json_funcs/to_json.rs rename to datafusion/spark/src/json_funcs/to_json.rs diff --git a/datafusion/functions-spark/src/kernels/mod.rs b/datafusion/spark/src/kernels/mod.rs similarity index 100% rename from datafusion/functions-spark/src/kernels/mod.rs rename to datafusion/spark/src/kernels/mod.rs diff --git a/datafusion/functions-spark/src/kernels/strings.rs b/datafusion/spark/src/kernels/strings.rs similarity index 100% rename from datafusion/functions-spark/src/kernels/strings.rs rename to datafusion/spark/src/kernels/strings.rs diff --git a/datafusion/functions-spark/src/kernels/temporal.rs b/datafusion/spark/src/kernels/temporal.rs similarity index 100% rename from datafusion/functions-spark/src/kernels/temporal.rs rename to datafusion/spark/src/kernels/temporal.rs diff --git a/datafusion/functions-spark/src/lib.rs b/datafusion/spark/src/lib.rs similarity index 100% rename from datafusion/functions-spark/src/lib.rs rename to datafusion/spark/src/lib.rs diff --git a/datafusion/functions-spark/src/math_funcs/ceil.rs b/datafusion/spark/src/math_funcs/ceil.rs similarity index 100% rename from datafusion/functions-spark/src/math_funcs/ceil.rs rename to datafusion/spark/src/math_funcs/ceil.rs diff --git a/datafusion/functions-spark/src/math_funcs/div.rs b/datafusion/spark/src/math_funcs/div.rs similarity index 100% rename from datafusion/functions-spark/src/math_funcs/div.rs rename to datafusion/spark/src/math_funcs/div.rs diff --git a/datafusion/functions-spark/src/math_funcs/floor.rs b/datafusion/spark/src/math_funcs/floor.rs similarity index 100% rename from datafusion/functions-spark/src/math_funcs/floor.rs rename to datafusion/spark/src/math_funcs/floor.rs diff --git a/datafusion/functions-spark/src/math_funcs/hex.rs b/datafusion/spark/src/math_funcs/hex.rs similarity index 100% rename from datafusion/functions-spark/src/math_funcs/hex.rs rename to datafusion/spark/src/math_funcs/hex.rs diff --git a/datafusion/functions-spark/src/math_funcs/internal/checkoverflow.rs b/datafusion/spark/src/math_funcs/internal/checkoverflow.rs similarity index 100% rename from datafusion/functions-spark/src/math_funcs/internal/checkoverflow.rs rename to datafusion/spark/src/math_funcs/internal/checkoverflow.rs diff --git a/datafusion/functions-spark/src/math_funcs/internal/make_decimal.rs b/datafusion/spark/src/math_funcs/internal/make_decimal.rs similarity index 100% rename from datafusion/functions-spark/src/math_funcs/internal/make_decimal.rs rename to datafusion/spark/src/math_funcs/internal/make_decimal.rs diff --git a/datafusion/functions-spark/src/math_funcs/internal/mod.rs b/datafusion/spark/src/math_funcs/internal/mod.rs similarity index 100% rename from datafusion/functions-spark/src/math_funcs/internal/mod.rs rename to datafusion/spark/src/math_funcs/internal/mod.rs diff --git a/datafusion/functions-spark/src/math_funcs/internal/normalize_nan.rs b/datafusion/spark/src/math_funcs/internal/normalize_nan.rs similarity index 100% rename from datafusion/functions-spark/src/math_funcs/internal/normalize_nan.rs rename to datafusion/spark/src/math_funcs/internal/normalize_nan.rs diff --git a/datafusion/functions-spark/src/math_funcs/internal/unscaled_value.rs b/datafusion/spark/src/math_funcs/internal/unscaled_value.rs similarity index 100% rename from datafusion/functions-spark/src/math_funcs/internal/unscaled_value.rs rename to datafusion/spark/src/math_funcs/internal/unscaled_value.rs diff --git a/datafusion/functions-spark/src/math_funcs/mod.rs b/datafusion/spark/src/math_funcs/mod.rs similarity index 100% rename from datafusion/functions-spark/src/math_funcs/mod.rs rename to datafusion/spark/src/math_funcs/mod.rs diff --git a/datafusion/functions-spark/src/math_funcs/negative.rs b/datafusion/spark/src/math_funcs/negative.rs similarity index 100% rename from datafusion/functions-spark/src/math_funcs/negative.rs rename to datafusion/spark/src/math_funcs/negative.rs diff --git a/datafusion/functions-spark/src/math_funcs/round.rs b/datafusion/spark/src/math_funcs/round.rs similarity index 100% rename from datafusion/functions-spark/src/math_funcs/round.rs rename to datafusion/spark/src/math_funcs/round.rs diff --git a/datafusion/functions-spark/src/math_funcs/unhex.rs b/datafusion/spark/src/math_funcs/unhex.rs similarity index 100% rename from datafusion/functions-spark/src/math_funcs/unhex.rs rename to datafusion/spark/src/math_funcs/unhex.rs diff --git a/datafusion/functions-spark/src/math_funcs/utils.rs b/datafusion/spark/src/math_funcs/utils.rs similarity index 100% rename from datafusion/functions-spark/src/math_funcs/utils.rs rename to datafusion/spark/src/math_funcs/utils.rs diff --git a/datafusion/functions-spark/src/predicate_funcs/is_nan.rs b/datafusion/spark/src/predicate_funcs/is_nan.rs similarity index 100% rename from datafusion/functions-spark/src/predicate_funcs/is_nan.rs rename to datafusion/spark/src/predicate_funcs/is_nan.rs diff --git a/datafusion/functions-spark/src/predicate_funcs/mod.rs b/datafusion/spark/src/predicate_funcs/mod.rs similarity index 100% rename from datafusion/functions-spark/src/predicate_funcs/mod.rs rename to datafusion/spark/src/predicate_funcs/mod.rs diff --git a/datafusion/functions-spark/src/predicate_funcs/rlike.rs b/datafusion/spark/src/predicate_funcs/rlike.rs similarity index 100% rename from datafusion/functions-spark/src/predicate_funcs/rlike.rs rename to datafusion/spark/src/predicate_funcs/rlike.rs diff --git a/datafusion/functions-spark/src/static_invoke/char_varchar_utils/mod.rs b/datafusion/spark/src/static_invoke/char_varchar_utils/mod.rs similarity index 100% rename from datafusion/functions-spark/src/static_invoke/char_varchar_utils/mod.rs rename to datafusion/spark/src/static_invoke/char_varchar_utils/mod.rs diff --git a/datafusion/functions-spark/src/static_invoke/char_varchar_utils/read_side_padding.rs b/datafusion/spark/src/static_invoke/char_varchar_utils/read_side_padding.rs similarity index 100% rename from datafusion/functions-spark/src/static_invoke/char_varchar_utils/read_side_padding.rs rename to datafusion/spark/src/static_invoke/char_varchar_utils/read_side_padding.rs diff --git a/datafusion/functions-spark/src/static_invoke/mod.rs b/datafusion/spark/src/static_invoke/mod.rs similarity index 100% rename from datafusion/functions-spark/src/static_invoke/mod.rs rename to datafusion/spark/src/static_invoke/mod.rs diff --git a/datafusion/functions-spark/src/string_funcs/chr.rs b/datafusion/spark/src/string_funcs/chr.rs similarity index 100% rename from datafusion/functions-spark/src/string_funcs/chr.rs rename to datafusion/spark/src/string_funcs/chr.rs diff --git a/datafusion/functions-spark/src/string_funcs/mod.rs b/datafusion/spark/src/string_funcs/mod.rs similarity index 100% rename from datafusion/functions-spark/src/string_funcs/mod.rs rename to datafusion/spark/src/string_funcs/mod.rs diff --git a/datafusion/functions-spark/src/string_funcs/prediction.rs b/datafusion/spark/src/string_funcs/prediction.rs similarity index 100% rename from datafusion/functions-spark/src/string_funcs/prediction.rs rename to datafusion/spark/src/string_funcs/prediction.rs diff --git a/datafusion/functions-spark/src/string_funcs/string_space.rs b/datafusion/spark/src/string_funcs/string_space.rs similarity index 100% rename from datafusion/functions-spark/src/string_funcs/string_space.rs rename to datafusion/spark/src/string_funcs/string_space.rs diff --git a/datafusion/functions-spark/src/string_funcs/substring.rs b/datafusion/spark/src/string_funcs/substring.rs similarity index 100% rename from datafusion/functions-spark/src/string_funcs/substring.rs rename to datafusion/spark/src/string_funcs/substring.rs diff --git a/datafusion/functions-spark/src/struct_funcs/create_named_struct.rs b/datafusion/spark/src/struct_funcs/create_named_struct.rs similarity index 100% rename from datafusion/functions-spark/src/struct_funcs/create_named_struct.rs rename to datafusion/spark/src/struct_funcs/create_named_struct.rs diff --git a/datafusion/functions-spark/src/struct_funcs/get_struct_field.rs b/datafusion/spark/src/struct_funcs/get_struct_field.rs similarity index 100% rename from datafusion/functions-spark/src/struct_funcs/get_struct_field.rs rename to datafusion/spark/src/struct_funcs/get_struct_field.rs diff --git a/datafusion/functions-spark/src/struct_funcs/mod.rs b/datafusion/spark/src/struct_funcs/mod.rs similarity index 100% rename from datafusion/functions-spark/src/struct_funcs/mod.rs rename to datafusion/spark/src/struct_funcs/mod.rs diff --git a/datafusion/functions-spark/src/test_common/file_util.rs b/datafusion/spark/src/test_common/file_util.rs similarity index 100% rename from datafusion/functions-spark/src/test_common/file_util.rs rename to datafusion/spark/src/test_common/file_util.rs diff --git a/datafusion/functions-spark/src/test_common/mod.rs b/datafusion/spark/src/test_common/mod.rs similarity index 100% rename from datafusion/functions-spark/src/test_common/mod.rs rename to datafusion/spark/src/test_common/mod.rs diff --git a/datafusion/functions-spark/src/timezone.rs b/datafusion/spark/src/timezone.rs similarity index 100% rename from datafusion/functions-spark/src/timezone.rs rename to datafusion/spark/src/timezone.rs diff --git a/datafusion/functions-spark/src/unbound.rs b/datafusion/spark/src/unbound.rs similarity index 100% rename from datafusion/functions-spark/src/unbound.rs rename to datafusion/spark/src/unbound.rs diff --git a/datafusion/functions-spark/src/utils.rs b/datafusion/spark/src/utils.rs similarity index 100% rename from datafusion/functions-spark/src/utils.rs rename to datafusion/spark/src/utils.rs