diff --git a/Cargo.toml b/Cargo.toml index 85b26f802f05..e8f94885e79e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ members = [ "datafusion/proto/gen", "datafusion/proto-common", "datafusion/proto-common/gen", + "datafusion/spark", "datafusion/sql", "datafusion/sqllogictest", "datafusion/substrait", @@ -84,6 +85,7 @@ arrow-array = { version = "54.0.0", default-features = false, features = [ "chrono-tz", ] } arrow-buffer = { version = "54.0.0", default-features = false } +arrow-data = { version = "54.0.0", default-features = false } arrow-flight = { version = "54.0.0", features = [ "flight-sql-experimental", ] } diff --git a/datafusion/spark/Cargo.toml b/datafusion/spark/Cargo.toml new file mode 100644 index 000000000000..be5a897d1648 --- /dev/null +++ b/datafusion/spark/Cargo.toml @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-spark" +description = "DataFusion expressions that emulate Apache Spark's behavior" +version = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +authors = { workspace = true } +readme = { workspace = true } +license = { workspace = true } +edition = { workspace = true } + +[dependencies] +arrow = { workspace = true } +arrow-array = { workspace = true } +arrow-buffer = { workspace = true } +arrow-data = { workspace = true } +arrow-schema = { workspace = true } +chrono = { workspace = true } +chrono-tz = "0.10.1" +datafusion = { workspace = true, features = ["parquet"] } +datafusion-common = { workspace = true } +datafusion-expr = { workspace = true } +datafusion-expr-common = { workspace = true } +datafusion-physical-expr = { workspace = true } +futures = { workspace = true } +num = "0.4.3" +rand = { workspace = true } +regex = { workspace = true } +thiserror = "2.0.11" +twox-hash = "2.0.0" + +[dev-dependencies] +arrow-data = { workspace = true } +criterion = "0.5.1" +parquet = { workspace = true, features = ["arrow"] } +rand = { workspace = true } +tokio = { version = "1", features = ["rt-multi-thread"] } + +[lib] +name = "datafusion_comet_spark_expr" +path = "src/lib.rs" + +[[bench]] +name = "cast_from_string" +harness = false + +[[bench]] +name = "cast_numeric" +harness = false + +[[bench]] +name = "conditional" +harness = false + +[[bench]] +name = "decimal_div" +harness = false + +[[bench]] +name = "aggregate" +harness = false diff --git a/datafusion/spark/README.md b/datafusion/spark/README.md new file mode 100644 index 000000000000..afd94d2c0690 --- /dev/null +++ b/datafusion/spark/README.md @@ -0,0 +1,22 @@ + + +# datafusion-spark: Spark-compatible Expressions + +This crate provides Apache Spark-compatible expressions for use with DataFusion. diff --git a/datafusion/spark/benches/aggregate.rs b/datafusion/spark/benches/aggregate.rs new file mode 100644 index 000000000000..5791197ac13b --- /dev/null +++ b/datafusion/spark/benches/aggregate.rs @@ -0,0 +1,201 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{DataType, Field, Schema}; +use arrow_array::builder::{Decimal128Builder, StringBuilder}; +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_schema::SchemaRef; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion::execution::TaskContext; +use datafusion::functions_aggregate::average::avg_udaf; +use datafusion::functions_aggregate::sum::sum_udaf; +use datafusion::physical_expr::PhysicalExpr; +use datafusion::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::ExecutionPlan; +use datafusion_comet_spark_expr::AvgDecimal; +use datafusion_comet_spark_expr::SumDecimal; +use datafusion_expr::AggregateUDF; +use datafusion_physical_expr::aggregate::AggregateExprBuilder; +use datafusion_physical_expr::expressions::Column; +use futures::StreamExt; +use std::sync::Arc; +use std::time::Duration; +use tokio::runtime::Runtime; + +fn criterion_benchmark(c: &mut Criterion) { + let mut group = c.benchmark_group("aggregate"); + let num_rows = 8192; + let batch = create_record_batch(num_rows); + let mut batches = Vec::new(); + for _ in 0..10 { + batches.push(batch.clone()); + } + let partitions = &[batches]; + let c0: Arc = Arc::new(Column::new("c0", 0)); + let c1: Arc = Arc::new(Column::new("c1", 1)); + + let rt = Runtime::new().unwrap(); + + group.bench_function("avg_decimal_datafusion", |b| { + let datafusion_sum_decimal = avg_udaf(); + b.to_async(&rt).iter(|| { + black_box(agg_test( + partitions, + c0.clone(), + c1.clone(), + datafusion_sum_decimal.clone(), + "avg", + )) + }) + }); + + group.bench_function("avg_decimal_comet", |b| { + let comet_avg_decimal = Arc::new(AggregateUDF::new_from_impl(AvgDecimal::new( + DataType::Decimal128(38, 10), + DataType::Decimal128(38, 10), + ))); + b.to_async(&rt).iter(|| { + black_box(agg_test( + partitions, + c0.clone(), + c1.clone(), + comet_avg_decimal.clone(), + "avg", + )) + }) + }); + + group.bench_function("sum_decimal_datafusion", |b| { + let datafusion_sum_decimal = sum_udaf(); + b.to_async(&rt).iter(|| { + black_box(agg_test( + partitions, + c0.clone(), + c1.clone(), + datafusion_sum_decimal.clone(), + "sum", + )) + }) + }); + + group.bench_function("sum_decimal_comet", |b| { + let comet_sum_decimal = Arc::new(AggregateUDF::new_from_impl( + SumDecimal::try_new(DataType::Decimal128(38, 10)).unwrap(), + )); + b.to_async(&rt).iter(|| { + black_box(agg_test( + partitions, + c0.clone(), + c1.clone(), + comet_sum_decimal.clone(), + "sum", + )) + }) + }); + + group.finish(); +} + +async fn agg_test( + partitions: &[Vec], + c0: Arc, + c1: Arc, + aggregate_udf: Arc, + alias: &str, +) { + let schema = &partitions[0][0].schema(); + let scan: Arc = + Arc::new(MemoryExec::try_new(partitions, Arc::clone(schema), None).unwrap()); + let aggregate = + create_aggregate(scan, c0.clone(), c1.clone(), schema, aggregate_udf, alias); + let mut stream = aggregate + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + while let Some(batch) = stream.next().await { + let _batch = batch.unwrap(); + } +} + +fn create_aggregate( + scan: Arc, + c0: Arc, + c1: Arc, + schema: &SchemaRef, + aggregate_udf: Arc, + alias: &str, +) -> Arc { + let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) + .schema(schema.clone()) + .alias(alias) + .with_ignore_nulls(false) + .with_distinct(false) + .build() + .unwrap(); + + Arc::new( + AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]), + vec![aggr_expr.into()], + vec![None], // no filter expressions + scan, + Arc::clone(schema), + ) + .unwrap(), + ) +} + +fn create_record_batch(num_rows: usize) -> RecordBatch { + let mut decimal_builder = Decimal128Builder::with_capacity(num_rows); + let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32); + for i in 0..num_rows { + decimal_builder.append_value(i as i128); + string_builder.append_value(format!("this is string #{}", i % 1024)); + } + let decimal_array = Arc::new(decimal_builder.finish()); + let string_array = Arc::new(string_builder.finish()); + + let mut fields = vec![]; + let mut columns: Vec = vec![]; + + // string column + fields.push(Field::new("c0", DataType::Utf8, false)); + columns.push(string_array); + + // decimal column + fields.push(Field::new("c1", DataType::Decimal128(38, 10), false)); + columns.push(decimal_array); + + let schema = Schema::new(fields); + RecordBatch::try_new(Arc::new(schema), columns).unwrap() +} + +fn config() -> Criterion { + Criterion::default() + .measurement_time(Duration::from_millis(500)) + .warm_up_time(Duration::from_millis(500)) +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/datafusion/spark/benches/cast_from_string.rs b/datafusion/spark/benches/cast_from_string.rs new file mode 100644 index 000000000000..ad76abe6650f --- /dev/null +++ b/datafusion/spark/benches/cast_from_string.rs @@ -0,0 +1,79 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{builder::StringBuilder, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions}; +use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let batch = create_utf8_batch(); + let expr = Arc::new(Column::new("a", 0)); + let spark_cast_options = SparkCastOptions::new(EvalMode::Legacy, "", false); + let cast_string_to_i8 = + Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); + let cast_string_to_i16 = + Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone()); + let cast_string_to_i32 = + Cast::new(expr.clone(), DataType::Int32, spark_cast_options.clone()); + let cast_string_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options); + + let mut group = c.benchmark_group("cast_string_to_int"); + group.bench_function("cast_string_to_i8", |b| { + b.iter(|| cast_string_to_i8.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i16", |b| { + b.iter(|| cast_string_to_i16.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i32", |b| { + b.iter(|| cast_string_to_i32.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_string_to_i64", |b| { + b.iter(|| cast_string_to_i64.evaluate(&batch).unwrap()); + }); +} + +// Create UTF8 batch with strings representing ints, floats, nulls +fn create_utf8_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, true)])); + let mut b = StringBuilder::new(); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else if i % 2 == 0 { + b.append_value(format!("{}", rand::random::())); + } else { + b.append_value(format!("{}", rand::random::())); + } + } + let array = b.finish(); + + RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap() +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/datafusion/spark/benches/cast_numeric.rs b/datafusion/spark/benches/cast_numeric.rs new file mode 100644 index 000000000000..7f040d2960a8 --- /dev/null +++ b/datafusion/spark/benches/cast_numeric.rs @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{builder::Int32Builder, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use criterion::{criterion_group, criterion_main, Criterion}; +use datafusion_comet_spark_expr::{Cast, EvalMode, SparkCastOptions}; +use datafusion_physical_expr::{expressions::Column, PhysicalExpr}; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + let batch = create_int32_batch(); + let expr = Arc::new(Column::new("a", 0)); + let spark_cast_options = + SparkCastOptions::new_without_timezone(EvalMode::Legacy, false); + let cast_i32_to_i8 = + Cast::new(expr.clone(), DataType::Int8, spark_cast_options.clone()); + let cast_i32_to_i16 = + Cast::new(expr.clone(), DataType::Int16, spark_cast_options.clone()); + let cast_i32_to_i64 = Cast::new(expr, DataType::Int64, spark_cast_options); + + let mut group = c.benchmark_group("cast_int_to_int"); + group.bench_function("cast_i32_to_i8", |b| { + b.iter(|| cast_i32_to_i8.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_i32_to_i16", |b| { + b.iter(|| cast_i32_to_i16.evaluate(&batch).unwrap()); + }); + group.bench_function("cast_i32_to_i64", |b| { + b.iter(|| cast_i32_to_i64.evaluate(&batch).unwrap()); + }); +} + +fn create_int32_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + let mut b = Int32Builder::new(); + for i in 0..1000 { + if i % 10 == 0 { + b.append_null(); + } else { + b.append_value(rand::random::()); + } + } + let array = b.finish(); + + RecordBatch::try_new(schema.clone(), vec![Arc::new(array)]).unwrap() +} + +fn config() -> Criterion { + Criterion::default() +} + +criterion_group! { + name = benches; + config = config(); + targets = criterion_benchmark +} +criterion_main!(benches); diff --git a/datafusion/spark/benches/conditional.rs b/datafusion/spark/benches/conditional.rs new file mode 100644 index 000000000000..2cf51c53247b --- /dev/null +++ b/datafusion/spark/benches/conditional.rs @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::{Field, Schema}; +use arrow::record_batch::RecordBatch; +use arrow_array::builder::{Int32Builder, StringBuilder}; +use arrow_schema::DataType; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_comet_spark_expr::IfExpr; +use datafusion_common::ScalarValue; +use datafusion_expr::Operator; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::expressions::{BinaryExpr, CaseExpr}; +use datafusion_physical_expr::PhysicalExpr; +use std::sync::Arc; + +fn make_col(name: &str, index: usize) -> Arc { + Arc::new(Column::new(name, index)) +} + +fn make_lit_i32(n: i32) -> Arc { + Arc::new(Literal::new(ScalarValue::Int32(Some(n)))) +} + +fn make_null_lit() -> Arc { + Arc::new(Literal::new(ScalarValue::Utf8(None))) +} + +fn criterion_benchmark(c: &mut Criterion) { + // create input data + let mut c1 = Int32Builder::new(); + let mut c2 = StringBuilder::new(); + let mut c3 = StringBuilder::new(); + for i in 0..1000 { + c1.append_value(i); + if i % 7 == 0 { + c2.append_null(); + } else { + c2.append_value(format!("string {i}")); + } + if i % 9 == 0 { + c3.append_null(); + } else { + c3.append_value(format!("other string {i}")); + } + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + let c3 = Arc::new(c3.finish()); + let schema = Schema::new(vec![ + Field::new("c1", DataType::Int32, true), + Field::new("c2", DataType::Utf8, true), + Field::new("c3", DataType::Utf8, true), + ]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![c1, c2, c3]).unwrap(); + + // use same predicate for all benchmarks + let predicate = Arc::new(BinaryExpr::new( + make_col("c1", 0), + Operator::LtEq, + make_lit_i32(500), + )); + + // CASE WHEN c1 <= 500 THEN 1 ELSE 0 END + c.bench_function("case_when: scalar or scalar", |b| { + let expr = Arc::new( + CaseExpr::try_new( + None, + vec![(predicate.clone(), make_lit_i32(1))], + Some(make_lit_i32(0)), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + c.bench_function("if: scalar or scalar", |b| { + let expr = Arc::new(IfExpr::new( + predicate.clone(), + make_lit_i32(1), + make_lit_i32(0), + )); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + // CASE WHEN c1 <= 500 THEN c2 [ELSE NULL] END + c.bench_function("case_when: column or null", |b| { + let expr = Arc::new( + CaseExpr::try_new(None, vec![(predicate.clone(), make_col("c2", 1))], None) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + c.bench_function("if: column or null", |b| { + let expr = Arc::new(IfExpr::new( + predicate.clone(), + make_col("c2", 1), + make_null_lit(), + )); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + + // CASE WHEN c1 <= 500 THEN c2 ELSE c3 END + c.bench_function("case_when: expr or expr", |b| { + let expr = Arc::new( + CaseExpr::try_new( + None, + vec![(predicate.clone(), make_col("c2", 1))], + Some(make_col("c3", 2)), + ) + .unwrap(), + ); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); + c.bench_function("if: expr or expr", |b| { + let expr = Arc::new(IfExpr::new( + predicate.clone(), + make_col("c2", 1), + make_col("c3", 2), + )); + b.iter(|| black_box(expr.evaluate(black_box(&batch)).unwrap())) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/benches/decimal_div.rs b/datafusion/spark/benches/decimal_div.rs new file mode 100644 index 000000000000..ad527fecba41 --- /dev/null +++ b/datafusion/spark/benches/decimal_div.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::cast; +use arrow_array::builder::Decimal128Builder; +use arrow_schema::DataType; +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use datafusion_comet_spark_expr::spark_decimal_div; +use datafusion_expr::ColumnarValue; +use std::sync::Arc; + +fn criterion_benchmark(c: &mut Criterion) { + // create input data + let mut c1 = Decimal128Builder::new(); + let mut c2 = Decimal128Builder::new(); + for i in 0..1000 { + c1.append_value(99999999 + i); + c2.append_value(88888888 - i); + } + let c1 = Arc::new(c1.finish()); + let c2 = Arc::new(c2.finish()); + + let c1_type = DataType::Decimal128(10, 4); + let c1 = cast(c1.as_ref(), &c1_type).unwrap(); + let c2_type = DataType::Decimal128(10, 3); + let c2 = cast(c2.as_ref(), &c2_type).unwrap(); + + let args = [ColumnarValue::Array(c1), ColumnarValue::Array(c2)]; + c.bench_function("decimal_div", |b| { + b.iter(|| { + black_box(spark_decimal_div( + black_box(&args), + black_box(&DataType::Decimal128(10, 4)), + )) + }) + }); +} + +criterion_group!(benches, criterion_benchmark); +criterion_main!(benches); diff --git a/datafusion/spark/src/agg_funcs/avg.rs b/datafusion/spark/src/agg_funcs/avg.rs new file mode 100644 index 000000000000..0618596d1a6a --- /dev/null +++ b/datafusion/spark/src/agg_funcs/avg.rs @@ -0,0 +1,344 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::sum; +use arrow_array::{ + builder::PrimitiveBuilder, + cast::AsArray, + types::{Float64Type, Int64Type}, + Array, ArrayRef, ArrowNumericType, Int64Array, PrimitiveArray, +}; +use arrow_schema::{DataType, Field}; +use datafusion::logical_expr::{ + type_coercion::aggregates::avg_return_type, Accumulator, EmitTo, GroupsAccumulator, + Signature, +}; +use datafusion_common::{not_impl_err, Result, ScalarValue}; +use datafusion_physical_expr::expressions::format_state_name; +use std::{any::Any, sync::Arc}; + +use arrow_array::ArrowNativeTypeOp; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{AggregateUDFImpl, ReversedUDAF}; +use DataType::*; + +/// AVG aggregate expression +#[derive(Debug, Clone)] +pub struct Avg { + name: String, + signature: Signature, + // expr: Arc, + input_data_type: DataType, + result_data_type: DataType, +} + +impl Avg { + /// Create a new AVG aggregate function + pub fn new(name: impl Into, data_type: DataType) -> Self { + let result_data_type = avg_return_type("avg", &data_type).unwrap(); + + Self { + name: name.into(), + signature: Signature::user_defined(Immutable), + input_data_type: data_type, + result_data_type, + } + } +} + +impl AggregateUDFImpl for Avg { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + // instantiate specialized accumulator based for the type + match (&self.input_data_type, &self.result_data_type) { + (Float64, Float64) => Ok(Box::::default()), + _ => not_impl_err!( + "AvgAccumulator for ({} --> {})", + self.input_data_type, + self.result_data_type + ), + } + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new( + format_state_name(&self.name, "sum"), + self.input_data_type.clone(), + true, + ), + Field::new( + format_state_name(&self.name, "count"), + DataType::Int64, + true, + ), + ]) + } + + fn name(&self) -> &str { + &self.name + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + // instantiate specialized accumulator based for the type + match (&self.input_data_type, &self.result_data_type) { + (Float64, Float64) => { + Ok(Box::new(AvgGroupsAccumulator::::new( + &self.input_data_type, + |sum: f64, count: i64| Ok(sum / count as f64), + ))) + } + + _ => not_impl_err!( + "AvgGroupsAccumulator for ({} --> {})", + self.input_data_type, + self.result_data_type + ), + } + } + + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::Float64(None)) + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + avg_return_type(self.name(), &arg_types[0]) + } +} + +/// An accumulator to compute the average +#[derive(Debug, Default)] +pub struct AvgAccumulator { + sum: Option, + count: i64, +} + +impl Accumulator for AvgAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::Float64(self.sum), + ScalarValue::from(self.count), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = values[0].as_primitive::(); + self.count += (values.len() - values.null_count()) as i64; + let v = self.sum.get_or_insert(0.); + if let Some(x) = sum(values) { + *v += x; + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // counts are summed + self.count += sum(states[1].as_primitive::()).unwrap_or_default(); + + // sums are summed + if let Some(x) = sum(states[0].as_primitive::()) { + let v = self.sum.get_or_insert(0.); + *v += x; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + if self.count == 0 { + // If all input are nulls, count will be 0 and we will get null after the division. + // This is consistent with Spark Average implementation. + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64( + self.sum.map(|f| f / self.count as f64), + )) + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} + +/// An accumulator to compute the average of `[PrimitiveArray]`. +/// Stores values as native types, and does overflow checking +/// +/// F: Function that calculates the average value from a sum of +/// T::Native and a total count +#[derive(Debug)] +struct AvgGroupsAccumulator +where + T: ArrowNumericType + Send, + F: Fn(T::Native, i64) -> Result + Send, +{ + /// The type of the returned average + return_data_type: DataType, + + /// Count per group (use i64 to make Int64Array) + counts: Vec, + + /// Sums per group, stored as the native type + sums: Vec, + + /// Function that computes the final average (value / count) + avg_fn: F, +} + +impl AvgGroupsAccumulator +where + T: ArrowNumericType + Send, + F: Fn(T::Native, i64) -> Result + Send, +{ + pub fn new(return_data_type: &DataType, avg_fn: F) -> Self { + Self { + return_data_type: return_data_type.clone(), + counts: vec![], + sums: vec![], + avg_fn, + } + } +} + +impl GroupsAccumulator for AvgGroupsAccumulator +where + T: ArrowNumericType + Send, + F: Fn(T::Native, i64) -> Result + Send, +{ + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values[0].as_primitive::(); + let data = values.values(); + + // increment counts, update sums + self.counts.resize(total_num_groups, 0); + self.sums.resize(total_num_groups, T::default_value()); + + let iter = group_indices.iter().zip(data.iter()); + if values.null_count() == 0 { + for (&group_index, &value) in iter { + let sum = &mut self.sums[group_index]; + *sum = (*sum).add_wrapping(value); + self.counts[group_index] += 1; + } + } else { + for (idx, (&group_index, &value)) in iter.enumerate() { + if values.is_null(idx) { + continue; + } + let sum = &mut self.sums[group_index]; + *sum = (*sum).add_wrapping(value); + + self.counts[group_index] += 1; + } + } + + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 2, "two arguments to merge_batch"); + // first batch is partial sums, second is counts + let partial_sums = values[0].as_primitive::(); + let partial_counts = values[1].as_primitive::(); + // update counts with partial counts + self.counts.resize(total_num_groups, 0); + let iter1 = group_indices.iter().zip(partial_counts.values().iter()); + for (&group_index, &partial_count) in iter1 { + self.counts[group_index] += partial_count; + } + + // update sums + self.sums.resize(total_num_groups, T::default_value()); + let iter2 = group_indices.iter().zip(partial_sums.values().iter()); + for (&group_index, &new_value) in iter2 { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let counts = emit_to.take_needed(&mut self.counts); + let sums = emit_to.take_needed(&mut self.sums); + let mut builder = PrimitiveBuilder::::with_capacity(sums.len()); + let iter = sums.into_iter().zip(counts); + + for (sum, count) in iter { + if count != 0 { + builder.append_value((self.avg_fn)(sum, count)?) + } else { + builder.append_null(); + } + } + let array: PrimitiveArray = builder.finish(); + + Ok(Arc::new(array)) + } + + // return arrays for sums and counts + fn state(&mut self, emit_to: EmitTo) -> Result> { + let counts = emit_to.take_needed(&mut self.counts); + let counts = Int64Array::new(counts.into(), None); + + let sums = emit_to.take_needed(&mut self.sums); + let sums = PrimitiveArray::::new(sums.into(), None) + .with_data_type(self.return_data_type.clone()); + + Ok(vec![ + Arc::new(sums) as ArrayRef, + Arc::new(counts) as ArrayRef, + ]) + } + + fn size(&self) -> usize { + self.counts.capacity() * std::mem::size_of::() + + self.sums.capacity() * std::mem::size_of::() + } +} diff --git a/datafusion/spark/src/agg_funcs/avg_decimal.rs b/datafusion/spark/src/agg_funcs/avg_decimal.rs new file mode 100644 index 000000000000..e51f359decf1 --- /dev/null +++ b/datafusion/spark/src/agg_funcs/avg_decimal.rs @@ -0,0 +1,541 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{array::BooleanBufferBuilder, buffer::NullBuffer, compute::sum}; +use arrow_array::{ + builder::PrimitiveBuilder, + cast::AsArray, + types::{Decimal128Type, Int64Type}, + Array, ArrayRef, Decimal128Array, Int64Array, PrimitiveArray, +}; +use arrow_schema::{DataType, Field}; +use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator, Signature}; +use datafusion_common::{not_impl_err, Result, ScalarValue}; +use datafusion_physical_expr::expressions::format_state_name; +use std::{any::Any, sync::Arc}; + +use crate::utils::is_valid_decimal_precision; +use arrow_array::ArrowNativeTypeOp; +use arrow_data::decimal::{ + MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, +}; +use datafusion::logical_expr::Volatility::Immutable; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::avg_return_type; +use datafusion_expr::{AggregateUDFImpl, ReversedUDAF}; +use num::{integer::div_ceil, Integer}; +use DataType::*; + +/// AVG aggregate expression +#[derive(Debug, Clone)] +pub struct AvgDecimal { + signature: Signature, + sum_data_type: DataType, + result_data_type: DataType, +} + +impl AvgDecimal { + /// Create a new AVG aggregate function + pub fn new(result_type: DataType, sum_type: DataType) -> Self { + Self { + signature: Signature::user_defined(Immutable), + result_data_type: result_type, + sum_data_type: sum_type, + } + } +} + +impl AggregateUDFImpl for AvgDecimal { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + match (&self.sum_data_type, &self.result_data_type) { + ( + Decimal128(sum_precision, sum_scale), + Decimal128(target_precision, target_scale), + ) => Ok(Box::new(AvgDecimalAccumulator::new( + *sum_scale, + *sum_precision, + *target_precision, + *target_scale, + ))), + _ => not_impl_err!( + "AvgDecimalAccumulator for ({} --> {})", + self.sum_data_type, + self.result_data_type + ), + } + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new( + format_state_name(self.name(), "sum"), + self.sum_data_type.clone(), + true, + ), + Field::new( + format_state_name(self.name(), "count"), + DataType::Int64, + true, + ), + ]) + } + + fn name(&self) -> &str { + "avg" + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + // instantiate specialized accumulator based for the type + match (&self.sum_data_type, &self.result_data_type) { + ( + Decimal128(sum_precision, sum_scale), + Decimal128(target_precision, target_scale), + ) => Ok(Box::new(AvgDecimalGroupsAccumulator::new( + &self.result_data_type, + &self.sum_data_type, + *target_precision, + *target_scale, + *sum_precision, + *sum_scale, + ))), + _ => not_impl_err!( + "AvgDecimalGroupsAccumulator for ({} --> {})", + self.sum_data_type, + self.result_data_type + ), + } + } + + fn default_value(&self, _data_type: &DataType) -> Result { + match &self.result_data_type { + Decimal128(target_precision, target_scale) => { + Ok(make_decimal128(None, *target_precision, *target_scale)) + } + _ => not_impl_err!( + "The result_data_type of AvgDecimal should be Decimal128 but got{}", + self.result_data_type + ), + } + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + avg_return_type(self.name(), &arg_types[0]) + } +} + +/// An accumulator to compute the average for decimals +#[derive(Debug)] +struct AvgDecimalAccumulator { + sum: Option, + count: i64, + is_empty: bool, + is_not_null: bool, + sum_scale: i8, + sum_precision: u8, + target_precision: u8, + target_scale: i8, +} + +impl AvgDecimalAccumulator { + pub fn new( + sum_scale: i8, + sum_precision: u8, + target_precision: u8, + target_scale: i8, + ) -> Self { + Self { + sum: None, + count: 0, + is_empty: true, + is_not_null: true, + sum_scale, + sum_precision, + target_precision, + target_scale, + } + } + + fn update_single(&mut self, values: &Decimal128Array, idx: usize) { + let v = unsafe { values.value_unchecked(idx) }; + let (new_sum, is_overflow) = match self.sum { + Some(sum) => sum.overflowing_add(v), + None => (v, false), + }; + + if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) { + // Overflow: set buffer accumulator to null + self.is_not_null = false; + return; + } + + self.sum = Some(new_sum); + + if let Some(new_count) = self.count.checked_add(1) { + self.count = new_count; + } else { + self.is_not_null = false; + return; + } + + self.is_not_null = true; + } +} + +fn make_decimal128(value: Option, precision: u8, scale: i8) -> ScalarValue { + ScalarValue::Decimal128(value, precision, scale) +} + +impl Accumulator for AvgDecimalAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::Decimal128(self.sum, self.sum_precision, self.sum_scale), + ScalarValue::from(self.count), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if !self.is_empty && !self.is_not_null { + // This means there's a overflow in decimal, so we will just skip the rest + // of the computation + return Ok(()); + } + + let values = &values[0]; + let data = values.as_primitive::(); + + self.is_empty = self.is_empty && values.len() == values.null_count(); + + if values.null_count() == 0 { + for i in 0..data.len() { + self.update_single(data, i); + } + } else { + for i in 0..data.len() { + if data.is_null(i) { + continue; + } + self.update_single(data, i); + } + } + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + // counts are summed + self.count += sum(states[1].as_primitive::()).unwrap_or_default(); + + // sums are summed + if let Some(x) = sum(states[0].as_primitive::()) { + let v = self.sum.get_or_insert(0); + let (result, overflowed) = v.overflowing_add(x); + if overflowed { + // Set to None if overflow happens + self.sum = None; + } else { + *v = result; + } + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let scaler = 10_i128.pow(self.target_scale.saturating_sub(self.sum_scale) as u32); + let target_min = + MIN_DECIMAL_FOR_EACH_PRECISION[self.target_precision as usize - 1]; + let target_max = + MAX_DECIMAL_FOR_EACH_PRECISION[self.target_precision as usize - 1]; + + let result = self + .sum + .map(|v| avg(v, self.count as i128, target_min, target_max, scaler)); + + match result { + Some(value) => Ok(make_decimal128( + value, + self.target_precision, + self.target_scale, + )), + _ => Ok(make_decimal128( + None, + self.target_precision, + self.target_scale, + )), + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} + +#[derive(Debug)] +struct AvgDecimalGroupsAccumulator { + /// Tracks if the value is null + is_not_null: BooleanBufferBuilder, + + // Tracks if the value is empty + is_empty: BooleanBufferBuilder, + + /// The type of the avg return type + return_data_type: DataType, + target_precision: u8, + target_scale: i8, + + /// Count per group (use i64 to make Int64Array) + counts: Vec, + + /// Sums per group, stored as i128 + sums: Vec, + + /// The type of the sum + sum_data_type: DataType, + /// This is input_precision + 10 to be consistent with Spark + sum_precision: u8, + sum_scale: i8, +} + +impl AvgDecimalGroupsAccumulator { + pub fn new( + return_data_type: &DataType, + sum_data_type: &DataType, + target_precision: u8, + target_scale: i8, + sum_precision: u8, + sum_scale: i8, + ) -> Self { + Self { + is_not_null: BooleanBufferBuilder::new(0), + is_empty: BooleanBufferBuilder::new(0), + return_data_type: return_data_type.clone(), + target_precision, + target_scale, + sum_data_type: sum_data_type.clone(), + sum_precision, + sum_scale, + counts: vec![], + sums: vec![], + } + } + + fn is_overflow(&self, index: usize) -> bool { + !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index) + } + + fn update_single(&mut self, group_index: usize, value: i128) { + if self.is_overflow(group_index) { + // This means there's a overflow in decimal, so we will just skip the rest + // of the computation + return; + } + + self.is_empty.set_bit(group_index, false); + let (new_sum, is_overflow) = self.sums[group_index].overflowing_add(value); + self.counts[group_index] += 1; + + if is_overflow || !is_valid_decimal_precision(new_sum, self.sum_precision) { + // Overflow: set buffer accumulator to null + self.is_not_null.set_bit(group_index, false); + return; + } + + self.sums[group_index] = new_sum; + self.is_not_null.set_bit(group_index, true) + } +} + +fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) { + if builder.len() < capacity { + let additional = capacity - builder.len(); + builder.append_n(additional, true); + } +} + +impl GroupsAccumulator for AvgDecimalGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 1, "single argument to update_batch"); + let values = values[0].as_primitive::(); + let data = values.values(); + + // increment counts, update sums + self.counts.resize(total_num_groups, 0); + self.sums.resize(total_num_groups, 0); + ensure_bit_capacity(&mut self.is_empty, total_num_groups); + ensure_bit_capacity(&mut self.is_not_null, total_num_groups); + + let iter = group_indices.iter().zip(data.iter()); + if values.null_count() == 0 { + for (&group_index, &value) in iter { + self.update_single(group_index, value); + } + } else { + for (idx, (&group_index, &value)) in iter.enumerate() { + if values.is_null(idx) { + continue; + } + self.update_single(group_index, value); + } + } + Ok(()) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + _opt_filter: Option<&arrow_array::BooleanArray>, + total_num_groups: usize, + ) -> Result<()> { + assert_eq!(values.len(), 2, "two arguments to merge_batch"); + // first batch is partial sums, second is counts + let partial_sums = values[0].as_primitive::(); + let partial_counts = values[1].as_primitive::(); + // update counts with partial counts + self.counts.resize(total_num_groups, 0); + let iter1 = group_indices.iter().zip(partial_counts.values().iter()); + for (&group_index, &partial_count) in iter1 { + self.counts[group_index] += partial_count; + } + + // update sums + self.sums.resize(total_num_groups, 0); + let iter2 = group_indices.iter().zip(partial_sums.values().iter()); + for (&group_index, &new_value) in iter2 { + let sum = &mut self.sums[group_index]; + *sum = sum.add_wrapping(new_value); + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> Result { + let counts = emit_to.take_needed(&mut self.counts); + let sums = emit_to.take_needed(&mut self.sums); + + let mut builder = PrimitiveBuilder::::with_capacity(sums.len()) + .with_data_type(self.return_data_type.clone()); + let iter = sums.into_iter().zip(counts); + + let scaler = 10_i128.pow(self.target_scale.saturating_sub(self.sum_scale) as u32); + let target_min = + MIN_DECIMAL_FOR_EACH_PRECISION[self.target_precision as usize - 1]; + let target_max = + MAX_DECIMAL_FOR_EACH_PRECISION[self.target_precision as usize - 1]; + + for (sum, count) in iter { + if count != 0 { + match avg(sum, count as i128, target_min, target_max, scaler) { + Some(value) => { + builder.append_value(value); + } + _ => { + builder.append_null(); + } + } + } else { + builder.append_null(); + } + } + let array: PrimitiveArray = builder.finish(); + + Ok(Arc::new(array)) + } + + // return arrays for sums and counts + fn state(&mut self, emit_to: EmitTo) -> Result> { + let nulls = self.is_not_null.finish(); + let nulls = Some(NullBuffer::new(nulls)); + + let counts = emit_to.take_needed(&mut self.counts); + let counts = Int64Array::new(counts.into(), nulls.clone()); + + let sums = emit_to.take_needed(&mut self.sums); + let sums = Decimal128Array::new(sums.into(), nulls) + .with_data_type(self.sum_data_type.clone()); + + Ok(vec![ + Arc::new(sums) as ArrayRef, + Arc::new(counts) as ArrayRef, + ]) + } + + fn size(&self) -> usize { + self.counts.capacity() * std::mem::size_of::() + + self.sums.capacity() * std::mem::size_of::() + } +} + +/// Returns the `sum`/`count` as a i128 Decimal128 with +/// target_scale and target_precision and return None if overflows. +/// +/// * sum: The total sum value stored as Decimal128 with sum_scale +/// * count: total count, stored as a i128 (*NOT* a Decimal128 value) +/// * target_min: The minimum output value possible to represent with the target precision +/// * target_max: The maximum output value possible to represent with the target precision +/// * scaler: scale factor for avg +#[inline(always)] +fn avg( + sum: i128, + count: i128, + target_min: i128, + target_max: i128, + scaler: i128, +) -> Option { + if let Some(value) = sum.checked_mul(scaler) { + // `sum / count` with ROUND_HALF_UP + let (div, rem) = value.div_rem(&count); + let half = div_ceil(count, 2); + let half_neg = half.neg_wrapping(); + let new_value = match value >= 0 { + true if rem >= half => div.add_wrapping(1), + false if rem <= half_neg => div.sub_wrapping(1), + _ => div, + }; + if new_value >= target_min && new_value <= target_max { + Some(new_value) + } else { + None + } + } else { + None + } +} diff --git a/datafusion/spark/src/agg_funcs/correlation.rs b/datafusion/spark/src/agg_funcs/correlation.rs new file mode 100644 index 000000000000..81b042cf6c2a --- /dev/null +++ b/datafusion/spark/src/agg_funcs/correlation.rs @@ -0,0 +1,261 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::{and, filter, is_not_null}; + +use std::{any::Any, sync::Arc}; + +use crate::agg_funcs::covariance::CovarianceAccumulator; +use crate::agg_funcs::stddev::StddevAccumulator; +use arrow::{ + array::ArrayRef, + datatypes::{DataType, Field}, +}; +use datafusion::logical_expr::Accumulator; +use datafusion_common::{Result, ScalarValue}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::expressions::format_state_name; +use datafusion_physical_expr::expressions::StatsType; + +/// CORR aggregate expression +/// The implementation mostly is the same as the DataFusion's implementation. The reason +/// we have our own implementation is that DataFusion has UInt64 for state_field `count`, +/// while Spark has Double for count. Also we have added `null_on_divide_by_zero` +/// to be consistent with Spark's implementation. +#[derive(Debug)] +pub struct Correlation { + name: String, + signature: Signature, + null_on_divide_by_zero: bool, +} + +impl Correlation { + pub fn new( + name: impl Into, + data_type: DataType, + null_on_divide_by_zero: bool, + ) -> Self { + // the result of correlation just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + null_on_divide_by_zero, + } + } +} + +impl AggregateUDFImpl for Correlation { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::Float64(None)) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(CorrelationAccumulator::try_new( + self.null_on_divide_by_zero, + )?)) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new( + format_state_name(&self.name, "count"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean1"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean2"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "algo_const"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "m2_1"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "m2_2"), + DataType::Float64, + true, + ), + ]) + } +} + +/// An accumulator to compute correlation +#[derive(Debug)] +pub struct CorrelationAccumulator { + covar: CovarianceAccumulator, + stddev1: StddevAccumulator, + stddev2: StddevAccumulator, + null_on_divide_by_zero: bool, +} + +impl CorrelationAccumulator { + /// Creates a new `CorrelationAccumulator` + pub fn try_new(null_on_divide_by_zero: bool) -> Result { + Ok(Self { + covar: CovarianceAccumulator::try_new( + StatsType::Population, + null_on_divide_by_zero, + )?, + stddev1: StddevAccumulator::try_new( + StatsType::Population, + null_on_divide_by_zero, + )?, + stddev2: StddevAccumulator::try_new( + StatsType::Population, + null_on_divide_by_zero, + )?, + null_on_divide_by_zero, + }) + } +} + +impl Accumulator for CorrelationAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.covar.get_count()), + ScalarValue::from(self.covar.get_mean1()), + ScalarValue::from(self.covar.get_mean2()), + ScalarValue::from(self.covar.get_algo_const()), + ScalarValue::from(self.stddev1.get_m2()), + ScalarValue::from(self.stddev2.get_m2()), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { + let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; + let values1 = filter(&values[0], &mask)?; + let values2 = filter(&values[1], &mask)?; + + vec![values1, values2] + } else { + values.to_vec() + }; + + if !values[0].is_empty() && !values[1].is_empty() { + self.covar.update_batch(&values)?; + self.stddev1.update_batch(&values[0..1])?; + self.stddev2.update_batch(&values[1..2])?; + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values = if values[0].null_count() != 0 || values[1].null_count() != 0 { + let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?; + let values1 = filter(&values[0], &mask)?; + let values2 = filter(&values[1], &mask)?; + + vec![values1, values2] + } else { + values.to_vec() + }; + + self.covar.retract_batch(&values)?; + self.stddev1.retract_batch(&values[0..1])?; + self.stddev2.retract_batch(&values[1..2])?; + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let states_c = [ + Arc::clone(&states[0]), + Arc::clone(&states[1]), + Arc::clone(&states[2]), + Arc::clone(&states[3]), + ]; + let states_s1 = [ + Arc::clone(&states[0]), + Arc::clone(&states[1]), + Arc::clone(&states[4]), + ]; + let states_s2 = [ + Arc::clone(&states[0]), + Arc::clone(&states[2]), + Arc::clone(&states[5]), + ]; + + if states[0].len() > 0 && states[1].len() > 0 && states[2].len() > 0 { + self.covar.merge_batch(&states_c)?; + self.stddev1.merge_batch(&states_s1)?; + self.stddev2.merge_batch(&states_s2)?; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let covar = self.covar.evaluate()?; + let stddev1 = self.stddev1.evaluate()?; + let stddev2 = self.stddev2.evaluate()?; + + match (covar, stddev1, stddev2) { + ( + ScalarValue::Float64(Some(c)), + ScalarValue::Float64(Some(s1)), + ScalarValue::Float64(Some(s2)), + ) if s1 != 0.0 && s2 != 0.0 => Ok(ScalarValue::Float64(Some(c / (s1 * s2)))), + _ if self.null_on_divide_by_zero => Ok(ScalarValue::Float64(None)), + _ => { + if self.covar.get_count() == 1.0 { + return Ok(ScalarValue::Float64(Some(f64::NAN))); + } + Ok(ScalarValue::Float64(None)) + } + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) - std::mem::size_of_val(&self.covar) + + self.covar.size() + - std::mem::size_of_val(&self.stddev1) + + self.stddev1.size() + - std::mem::size_of_val(&self.stddev2) + + self.stddev2.size() + } +} diff --git a/datafusion/spark/src/agg_funcs/covariance.rs b/datafusion/spark/src/agg_funcs/covariance.rs new file mode 100644 index 000000000000..8ef6afe54f91 --- /dev/null +++ b/datafusion/spark/src/agg_funcs/covariance.rs @@ -0,0 +1,307 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use arrow::{ + array::{ArrayRef, Float64Array}, + compute::cast, + datatypes::{DataType, Field}, +}; +use datafusion::logical_expr::Accumulator; +use datafusion_common::{ + downcast_value, unwrap_or_internal_err, DataFusionError, Result, ScalarValue, +}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::type_coercion::aggregates::NUMERICS; +use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::expressions::format_state_name; +use datafusion_physical_expr::expressions::StatsType; + +/// COVAR_SAMP and COVAR_POP aggregate expression +/// The implementation mostly is the same as the DataFusion's implementation. The reason +/// we have our own implementation is that DataFusion has UInt64 for state_field count, +/// while Spark has Double for count. +#[derive(Debug, Clone)] +pub struct Covariance { + name: String, + signature: Signature, + stats_type: StatsType, + null_on_divide_by_zero: bool, +} + +impl Covariance { + /// Create a new COVAR aggregate function + pub fn new( + name: impl Into, + data_type: DataType, + stats_type: StatsType, + null_on_divide_by_zero: bool, + ) -> Self { + // the result of covariance just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable), + stats_type, + null_on_divide_by_zero, + } + } +} + +impl AggregateUDFImpl for Covariance { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::Float64(None)) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(CovarianceAccumulator::try_new( + self.stats_type, + self.null_on_divide_by_zero, + )?)) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new( + format_state_name(&self.name, "count"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean1"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean2"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "algo_const"), + DataType::Float64, + true, + ), + ]) + } +} + +/// An accumulator to compute covariance +#[derive(Debug)] +pub struct CovarianceAccumulator { + algo_const: f64, + mean1: f64, + mean2: f64, + count: f64, + stats_type: StatsType, + null_on_divide_by_zero: bool, +} + +impl CovarianceAccumulator { + /// Creates a new `CovarianceAccumulator` + pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> Result { + Ok(Self { + algo_const: 0_f64, + mean1: 0_f64, + mean2: 0_f64, + count: 0_f64, + stats_type: s_type, + null_on_divide_by_zero, + }) + } + + pub fn get_count(&self) -> f64 { + self.count + } + + pub fn get_mean1(&self) -> f64 { + self.mean1 + } + + pub fn get_mean2(&self) -> f64 { + self.mean2 + } + + pub fn get_algo_const(&self) -> f64 { + self.algo_const + } +} + +impl Accumulator for CovarianceAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.mean1), + ScalarValue::from(self.mean2), + ScalarValue::from(self.algo_const), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values1 = &cast(&values[0], &DataType::Float64)?; + let values2 = &cast(&values[1], &DataType::Float64)?; + + let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); + let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); + + for i in 0..values1.len() { + let value1 = if values1.is_valid(i) { + arr1.next() + } else { + None + }; + let value2 = if values2.is_valid(i) { + arr2.next() + } else { + None + }; + + if value1.is_none() || value2.is_none() { + continue; + } + + let value1 = unwrap_or_internal_err!(value1); + let value2 = unwrap_or_internal_err!(value2); + let new_count = self.count + 1.0; + let delta1 = value1 - self.mean1; + let new_mean1 = delta1 / new_count + self.mean1; + let delta2 = value2 - self.mean2; + let new_mean2 = delta2 / new_count + self.mean2; + let new_c = delta1 * (value2 - new_mean2) + self.algo_const; + + self.count += 1.0; + self.mean1 = new_mean1; + self.mean2 = new_mean2; + self.algo_const = new_c; + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let values1 = &cast(&values[0], &DataType::Float64)?; + let values2 = &cast(&values[1], &DataType::Float64)?; + let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten(); + let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten(); + + for i in 0..values1.len() { + let value1 = if values1.is_valid(i) { + arr1.next() + } else { + None + }; + let value2 = if values2.is_valid(i) { + arr2.next() + } else { + None + }; + + if value1.is_none() || value2.is_none() { + continue; + } + + let value1 = unwrap_or_internal_err!(value1); + let value2 = unwrap_or_internal_err!(value2); + + let new_count = self.count - 1.0; + let delta1 = self.mean1 - value1; + let new_mean1 = delta1 / new_count + self.mean1; + let delta2 = self.mean2 - value2; + let new_mean2 = delta2 / new_count + self.mean2; + let new_c = self.algo_const - delta1 * (new_mean2 - value2); + + self.count -= 1.0; + self.mean1 = new_mean1; + self.mean2 = new_mean2; + self.algo_const = new_c; + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], Float64Array); + let means1 = downcast_value!(states[1], Float64Array); + let means2 = downcast_value!(states[2], Float64Array); + let cs = downcast_value!(states[3], Float64Array); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0.0 { + continue; + } + let new_count = self.count + c; + let new_mean1 = + self.mean1 * self.count / new_count + means1.value(i) * c / new_count; + let new_mean2 = + self.mean2 * self.count / new_count + means2.value(i) * c / new_count; + let delta1 = self.mean1 - means1.value(i); + let delta2 = self.mean2 - means2.value(i); + let new_c = self.algo_const + + cs.value(i) + + delta1 * delta2 * self.count * c / new_count; + + self.count = new_count; + self.mean1 = new_mean1; + self.mean2 = new_mean2; + self.algo_const = new_c; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + if self.count == 0.0 { + return Ok(ScalarValue::Float64(None)); + } + + let count = match self.stats_type { + StatsType::Population => self.count, + StatsType::Sample if self.count > 1.0 => self.count - 1.0, + StatsType::Sample => { + // self.count == 1.0 + return if self.null_on_divide_by_zero { + Ok(ScalarValue::Float64(None)) + } else { + Ok(ScalarValue::Float64(Some(f64::NAN))) + }; + } + }; + + Ok(ScalarValue::Float64(Some(self.algo_const / count))) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} diff --git a/datafusion/spark/src/agg_funcs/mod.rs b/datafusion/spark/src/agg_funcs/mod.rs new file mode 100644 index 000000000000..252da788900d --- /dev/null +++ b/datafusion/spark/src/agg_funcs/mod.rs @@ -0,0 +1,32 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod avg; +mod avg_decimal; +mod correlation; +mod covariance; +mod stddev; +mod sum_decimal; +mod variance; + +pub use avg::Avg; +pub use avg_decimal::AvgDecimal; +pub use correlation::Correlation; +pub use covariance::Covariance; +pub use stddev::Stddev; +pub use sum_decimal::SumDecimal; +pub use variance::Variance; diff --git a/datafusion/spark/src/agg_funcs/stddev.rs b/datafusion/spark/src/agg_funcs/stddev.rs new file mode 100644 index 000000000000..3694ec302f87 --- /dev/null +++ b/datafusion/spark/src/agg_funcs/stddev.rs @@ -0,0 +1,181 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use crate::agg_funcs::variance::VarianceAccumulator; +use arrow::{ + array::ArrayRef, + datatypes::{DataType, Field}, +}; +use datafusion::logical_expr::Accumulator; +use datafusion_common::types::NativeType; +use datafusion_common::{internal_err, Result, ScalarValue}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::{AggregateUDFImpl, Signature, Volatility}; +use datafusion_physical_expr::expressions::format_state_name; +use datafusion_physical_expr::expressions::StatsType; + +/// STDDEV and STDDEV_SAMP (standard deviation) aggregate expression +/// The implementation mostly is the same as the DataFusion's implementation. The reason +/// we have our own implementation is that DataFusion has UInt64 for state_field `count`, +/// while Spark has Double for count. Also we have added `null_on_divide_by_zero` +/// to be consistent with Spark's implementation. +#[derive(Debug)] +pub struct Stddev { + name: String, + signature: Signature, + stats_type: StatsType, + null_on_divide_by_zero: bool, +} + +impl Stddev { + /// Create a new STDDEV aggregate function + pub fn new( + name: impl Into, + data_type: DataType, + stats_type: StatsType, + null_on_divide_by_zero: bool, + ) -> Self { + // the result of stddev just support FLOAT64. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + signature: Signature::coercible( + vec![ + datafusion_expr_common::signature::TypeSignatureClass::Native( + Arc::new(NativeType::Float64), + ), + ], + Volatility::Immutable, + ), + stats_type, + null_on_divide_by_zero, + } + } +} + +impl AggregateUDFImpl for Stddev { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(StddevAccumulator::try_new( + self.stats_type, + self.null_on_divide_by_zero, + )?)) + } + + fn create_sliding_accumulator( + &self, + _acc_args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(StddevAccumulator::try_new( + self.stats_type, + self.null_on_divide_by_zero, + )?)) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new( + format_state_name(&self.name, "count"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true), + ]) + } + + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::Float64(None)) + } +} + +/// An accumulator to compute the standard deviation +#[derive(Debug)] +pub struct StddevAccumulator { + variance: VarianceAccumulator, +} + +impl StddevAccumulator { + /// Creates a new `StddevAccumulator` + pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> Result { + Ok(Self { + variance: VarianceAccumulator::try_new(s_type, null_on_divide_by_zero)?, + }) + } + + pub fn get_m2(&self) -> f64 { + self.variance.get_m2() + } +} + +impl Accumulator for StddevAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.variance.get_count()), + ScalarValue::from(self.variance.get_mean()), + ScalarValue::from(self.variance.get_m2()), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.variance.update_batch(values) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + self.variance.retract_batch(values) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + self.variance.merge_batch(states) + } + + fn evaluate(&mut self) -> Result { + let variance = self.variance.evaluate()?; + match variance { + ScalarValue::Float64(Some(e)) => Ok(ScalarValue::Float64(Some(e.sqrt()))), + ScalarValue::Float64(None) => Ok(ScalarValue::Float64(None)), + _ => internal_err!("Variance should be f64"), + } + } + + fn size(&self) -> usize { + std::mem::align_of_val(self) - std::mem::align_of_val(&self.variance) + + self.variance.size() + } +} diff --git a/datafusion/spark/src/agg_funcs/sum_decimal.rs b/datafusion/spark/src/agg_funcs/sum_decimal.rs new file mode 100644 index 000000000000..f17f80f77d70 --- /dev/null +++ b/datafusion/spark/src/agg_funcs/sum_decimal.rs @@ -0,0 +1,557 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::{is_valid_decimal_precision, unlikely}; +use arrow::{ + array::BooleanBufferBuilder, + buffer::{BooleanBuffer, NullBuffer}, +}; +use arrow_array::{ + cast::AsArray, types::Decimal128Type, Array, ArrayRef, BooleanArray, Decimal128Array, +}; +use arrow_schema::{DataType, Field}; +use datafusion::logical_expr::{Accumulator, EmitTo, GroupsAccumulator}; +use datafusion_common::{DataFusionError, Result as DFResult, ScalarValue}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{AggregateUDFImpl, ReversedUDAF, Signature}; +use std::{any::Any, ops::BitAnd, sync::Arc}; + +#[derive(Debug)] +pub struct SumDecimal { + /// Aggregate function signature + signature: Signature, + /// The data type of the SUM result. This will always be a decimal type + /// with the same precision and scale as specified in this struct + result_type: DataType, + /// Decimal precision + precision: u8, + /// Decimal scale + scale: i8, +} + +impl SumDecimal { + pub fn try_new(data_type: DataType) -> DFResult { + // The `data_type` is the SUM result type passed from Spark side + let (precision, scale) = match data_type { + DataType::Decimal128(p, s) => (p, s), + _ => { + return Err(DataFusionError::Internal( + "Invalid data type for SumDecimal".into(), + )) + } + }; + Ok(Self { + signature: Signature::user_defined(Immutable), + result_type: data_type, + precision, + scale, + }) + } +} + +impl AggregateUDFImpl for SumDecimal { + fn as_any(&self) -> &dyn Any { + self + } + + fn accumulator(&self, _args: AccumulatorArgs) -> DFResult> { + Ok(Box::new(SumDecimalAccumulator::new( + self.precision, + self.scale, + ))) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> DFResult> { + let fields = vec![ + Field::new(self.name(), self.result_type.clone(), self.is_nullable()), + Field::new("is_empty", DataType::Boolean, false), + ]; + Ok(fields) + } + + fn name(&self) -> &str { + "sum" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> DFResult { + Ok(self.result_type.clone()) + } + + fn groups_accumulator_supported(&self, _args: AccumulatorArgs) -> bool { + true + } + + fn create_groups_accumulator( + &self, + _args: AccumulatorArgs, + ) -> DFResult> { + Ok(Box::new(SumDecimalGroupsAccumulator::new( + self.result_type.clone(), + self.precision, + ))) + } + + fn default_value(&self, _data_type: &DataType) -> DFResult { + ScalarValue::new_primitive::( + None, + &DataType::Decimal128(self.precision, self.scale), + ) + } + + fn reverse_expr(&self) -> ReversedUDAF { + ReversedUDAF::Identical + } + + fn is_nullable(&self) -> bool { + // SumDecimal is always nullable because overflows can cause null values + true + } +} + +#[derive(Debug)] +struct SumDecimalAccumulator { + sum: i128, + is_empty: bool, + is_not_null: bool, + + precision: u8, + scale: i8, +} + +impl SumDecimalAccumulator { + fn new(precision: u8, scale: i8) -> Self { + Self { + sum: 0, + is_empty: true, + is_not_null: true, + precision, + scale, + } + } + + fn update_single(&mut self, values: &Decimal128Array, idx: usize) { + let v = unsafe { values.value_unchecked(idx) }; + let (new_sum, is_overflow) = self.sum.overflowing_add(v); + + if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { + // Overflow: set buffer accumulator to null + self.is_not_null = false; + return; + } + + self.sum = new_sum; + self.is_not_null = true; + } +} + +impl Accumulator for SumDecimalAccumulator { + fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> { + assert_eq!( + values.len(), + 1, + "Expect only one element in 'values' but found {}", + values.len() + ); + + if !self.is_empty && !self.is_not_null { + // This means there's a overflow in decimal, so we will just skip the rest + // of the computation + return Ok(()); + } + + let values = &values[0]; + let data = values.as_primitive::(); + + self.is_empty = self.is_empty && values.len() == values.null_count(); + + if values.null_count() == 0 { + for i in 0..data.len() { + self.update_single(data, i); + } + } else { + for i in 0..data.len() { + if data.is_null(i) { + continue; + } + self.update_single(data, i); + } + } + + Ok(()) + } + + fn evaluate(&mut self) -> DFResult { + // For each group: + // 1. if `is_empty` is true, it means either there is no value or all values for the group + // are null, in this case we'll return null + // 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In + // non-ANSI mode Spark returns null. + if self.is_empty || !self.is_not_null { + ScalarValue::new_primitive::( + None, + &DataType::Decimal128(self.precision, self.scale), + ) + } else { + ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale) + } + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } + + fn state(&mut self) -> DFResult> { + let sum = if self.is_not_null { + ScalarValue::try_new_decimal128(self.sum, self.precision, self.scale)? + } else { + ScalarValue::new_primitive::( + None, + &DataType::Decimal128(self.precision, self.scale), + )? + }; + Ok(vec![sum, ScalarValue::from(self.is_empty)]) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> { + assert_eq!( + states.len(), + 2, + "Expect two element in 'states' but found {}", + states.len() + ); + assert_eq!(states[0].len(), 1); + assert_eq!(states[1].len(), 1); + + let that_sum = states[0].as_primitive::(); + let that_is_empty = states[1].as_any().downcast_ref::().unwrap(); + + let this_overflow = !self.is_empty && !self.is_not_null; + let that_overflow = !that_is_empty.value(0) && that_sum.is_null(0); + + self.is_not_null = !this_overflow && !that_overflow; + self.is_empty = self.is_empty && that_is_empty.value(0); + + if self.is_not_null { + self.sum += that_sum.value(0); + } + + Ok(()) + } +} + +struct SumDecimalGroupsAccumulator { + // Whether aggregate buffer for a particular group is null. True indicates it is not null. + is_not_null: BooleanBufferBuilder, + is_empty: BooleanBufferBuilder, + sum: Vec, + result_type: DataType, + precision: u8, +} + +impl SumDecimalGroupsAccumulator { + fn new(result_type: DataType, precision: u8) -> Self { + Self { + is_not_null: BooleanBufferBuilder::new(0), + is_empty: BooleanBufferBuilder::new(0), + sum: Vec::new(), + result_type, + precision, + } + } + + fn is_overflow(&self, index: usize) -> bool { + !self.is_empty.get_bit(index) && !self.is_not_null.get_bit(index) + } + + fn update_single(&mut self, group_index: usize, value: i128) { + if unlikely(self.is_overflow(group_index)) { + // This means there's a overflow in decimal, so we will just skip the rest + // of the computation + return; + } + + self.is_empty.set_bit(group_index, false); + let (new_sum, is_overflow) = self.sum[group_index].overflowing_add(value); + + if is_overflow || !is_valid_decimal_precision(new_sum, self.precision) { + // Overflow: set buffer accumulator to null + self.is_not_null.set_bit(group_index, false); + return; + } + + self.sum[group_index] = new_sum; + self.is_not_null.set_bit(group_index, true) + } +} + +fn ensure_bit_capacity(builder: &mut BooleanBufferBuilder, capacity: usize) { + if builder.len() < capacity { + let additional = capacity - builder.len(); + builder.append_n(additional, true); + } +} + +/// Build a boolean buffer from the state and reset the state, based on the emit_to +/// strategy. +fn build_bool_state(state: &mut BooleanBufferBuilder, emit_to: &EmitTo) -> BooleanBuffer { + let bool_state: BooleanBuffer = state.finish(); + + match emit_to { + EmitTo::All => bool_state, + EmitTo::First(n) => { + // split off the first N values in bool_state + let first_n_bools: BooleanBuffer = bool_state.iter().take(*n).collect(); + // reset the existing seen buffer + for seen in bool_state.iter().skip(*n) { + state.append(seen); + } + first_n_bools + } + } +} + +impl GroupsAccumulator for SumDecimalGroupsAccumulator { + fn update_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + assert_eq!(values.len(), 1); + let values = values[0].as_primitive::(); + let data = values.values(); + + // Update size for the accumulate states + self.sum.resize(total_num_groups, 0); + ensure_bit_capacity(&mut self.is_empty, total_num_groups); + ensure_bit_capacity(&mut self.is_not_null, total_num_groups); + + let iter = group_indices.iter().zip(data.iter()); + if values.null_count() == 0 { + for (&group_index, &value) in iter { + self.update_single(group_index, value); + } + } else { + for (idx, (&group_index, &value)) in iter.enumerate() { + if values.is_null(idx) { + continue; + } + self.update_single(group_index, value); + } + } + + Ok(()) + } + + fn evaluate(&mut self, emit_to: EmitTo) -> DFResult { + // For each group: + // 1. if `is_empty` is true, it means either there is no value or all values for the group + // are null, in this case we'll return null + // 2. if `is_empty` is false, but `null_state` is true, it means there's an overflow. In + // non-ANSI mode Spark returns null. + let nulls = build_bool_state(&mut self.is_not_null, &emit_to); + let is_empty = build_bool_state(&mut self.is_empty, &emit_to); + let x = (!&is_empty).bitand(&nulls); + + let result = emit_to.take_needed(&mut self.sum); + let result = Decimal128Array::new(result.into(), Some(NullBuffer::new(x))) + .with_data_type(self.result_type.clone()); + + Ok(Arc::new(result)) + } + + fn state(&mut self, emit_to: EmitTo) -> DFResult> { + let nulls = build_bool_state(&mut self.is_not_null, &emit_to); + let nulls = Some(NullBuffer::new(nulls)); + + let sum = emit_to.take_needed(&mut self.sum); + let sum = Decimal128Array::new(sum.into(), nulls.clone()) + .with_data_type(self.result_type.clone()); + + let is_empty = build_bool_state(&mut self.is_empty, &emit_to); + let is_empty = BooleanArray::new(is_empty, None); + + Ok(vec![ + Arc::new(sum) as ArrayRef, + Arc::new(is_empty) as ArrayRef, + ]) + } + + fn merge_batch( + &mut self, + values: &[ArrayRef], + group_indices: &[usize], + opt_filter: Option<&BooleanArray>, + total_num_groups: usize, + ) -> DFResult<()> { + assert_eq!( + values.len(), + 2, + "Expected two arrays: 'sum' and 'is_empty', but found {}", + values.len() + ); + assert!(opt_filter.is_none(), "opt_filter is not supported yet"); + + // Make sure we have enough capacity for the additional groups + self.sum.resize(total_num_groups, 0); + ensure_bit_capacity(&mut self.is_empty, total_num_groups); + ensure_bit_capacity(&mut self.is_not_null, total_num_groups); + + let that_sum = &values[0]; + let that_sum = that_sum.as_primitive::(); + let that_is_empty = &values[1]; + let that_is_empty = that_is_empty + .as_any() + .downcast_ref::() + .unwrap(); + + group_indices + .iter() + .enumerate() + .for_each(|(idx, &group_index)| unsafe { + let this_overflow = self.is_overflow(group_index); + let that_is_empty = that_is_empty.value_unchecked(idx); + let that_overflow = !that_is_empty && that_sum.is_null(idx); + let is_overflow = this_overflow || that_overflow; + + // This part follows the logic in Spark: + // `org.apache.spark.sql.catalyst.expressions.aggregate.Sum` + self.is_not_null.set_bit(group_index, !is_overflow); + self.is_empty.set_bit( + group_index, + self.is_empty.get_bit(group_index) && that_is_empty, + ); + if !is_overflow { + // .. otherwise, the sum value for this particular index must not be null, + // and thus we merge both values and update this sum. + self.sum[group_index] += that_sum.value_unchecked(idx); + } + }); + + Ok(()) + } + + fn size(&self) -> usize { + self.sum.capacity() * std::mem::size_of::() + + self.is_empty.capacity() / 8 + + self.is_not_null.capacity() / 8 + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::datatypes::*; + use arrow_array::builder::{Decimal128Builder, StringBuilder}; + use arrow_array::RecordBatch; + use datafusion::execution::TaskContext; + use datafusion::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, + }; + use datafusion::physical_plan::memory::MemoryExec; + use datafusion::physical_plan::ExecutionPlan; + use datafusion_common::Result; + use datafusion_expr::AggregateUDF; + use datafusion_physical_expr::aggregate::AggregateExprBuilder; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr::PhysicalExpr; + use futures::StreamExt; + + #[test] + fn invalid_data_type() { + assert!(SumDecimal::try_new(DataType::Int32).is_err()); + } + + #[tokio::test] + async fn sum_no_overflow() -> Result<()> { + let num_rows = 8192; + let batch = create_record_batch(num_rows); + let mut batches = Vec::new(); + for _ in 0..10 { + batches.push(batch.clone()); + } + let partitions = &[batches]; + let c0: Arc = Arc::new(Column::new("c0", 0)); + let c1: Arc = Arc::new(Column::new("c1", 1)); + + let data_type = DataType::Decimal128(8, 2); + let schema = Arc::clone(&partitions[0][0].schema()); + let scan: Arc = + Arc::new(MemoryExec::try_new(partitions, Arc::clone(&schema), None).unwrap()); + + let aggregate_udf = Arc::new(AggregateUDF::new_from_impl(SumDecimal::try_new( + data_type.clone(), + )?)); + + let aggr_expr = AggregateExprBuilder::new(aggregate_udf, vec![c1]) + .schema(Arc::clone(&schema)) + .alias("sum") + .with_ignore_nulls(false) + .with_distinct(false) + .build()?; + + let aggregate = Arc::new(AggregateExec::try_new( + AggregateMode::Partial, + PhysicalGroupBy::new_single(vec![(c0, "c0".to_string())]), + vec![aggr_expr.into()], + vec![None], // no filter expressions + scan, + Arc::clone(&schema), + )?); + + let mut stream = aggregate + .execute(0, Arc::new(TaskContext::default())) + .unwrap(); + while let Some(batch) = stream.next().await { + let _batch = batch?; + } + + Ok(()) + } + + fn create_record_batch(num_rows: usize) -> RecordBatch { + let mut decimal_builder = Decimal128Builder::with_capacity(num_rows); + let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 32); + for i in 0..num_rows { + decimal_builder.append_value(i as i128); + string_builder.append_value(format!("this is string #{}", i % 1024)); + } + let decimal_array = Arc::new(decimal_builder.finish()); + let string_array = Arc::new(string_builder.finish()); + + let mut fields = vec![]; + let mut columns: Vec = vec![]; + + // string column + fields.push(Field::new("c0", DataType::Utf8, false)); + columns.push(string_array); + + // decimal column + fields.push(Field::new("c1", DataType::Decimal128(38, 10), false)); + columns.push(decimal_array); + + let schema = Schema::new(fields); + RecordBatch::try_new(Arc::new(schema), columns).unwrap() + } +} diff --git a/datafusion/spark/src/agg_funcs/variance.rs b/datafusion/spark/src/agg_funcs/variance.rs new file mode 100644 index 000000000000..99332f5571a1 --- /dev/null +++ b/datafusion/spark/src/agg_funcs/variance.rs @@ -0,0 +1,252 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::any::Any; + +use arrow::{ + array::{ArrayRef, Float64Array}, + datatypes::{DataType, Field}, +}; +use datafusion::logical_expr::Accumulator; +use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue}; +use datafusion_expr::function::{AccumulatorArgs, StateFieldsArgs}; +use datafusion_expr::Volatility::Immutable; +use datafusion_expr::{AggregateUDFImpl, Signature}; +use datafusion_physical_expr::expressions::format_state_name; +use datafusion_physical_expr::expressions::StatsType; + +/// VAR_SAMP and VAR_POP aggregate expression +/// The implementation mostly is the same as the DataFusion's implementation. The reason +/// we have our own implementation is that DataFusion has UInt64 for state_field `count`, +/// while Spark has Double for count. Also we have added `null_on_divide_by_zero` +/// to be consistent with Spark's implementation. +#[derive(Debug)] +pub struct Variance { + name: String, + signature: Signature, + stats_type: StatsType, + null_on_divide_by_zero: bool, +} + +impl Variance { + /// Create a new VARIANCE aggregate function + pub fn new( + name: impl Into, + data_type: DataType, + stats_type: StatsType, + null_on_divide_by_zero: bool, + ) -> Self { + // the result of variance just support FLOAT64 data type. + assert!(matches!(data_type, DataType::Float64)); + Self { + name: name.into(), + signature: Signature::numeric(1, Immutable), + stats_type, + null_on_divide_by_zero, + } + } +} + +impl AggregateUDFImpl for Variance { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(DataType::Float64) + } + + fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new( + self.stats_type, + self.null_on_divide_by_zero, + )?)) + } + + fn create_sliding_accumulator( + &self, + _args: AccumulatorArgs, + ) -> Result> { + Ok(Box::new(VarianceAccumulator::try_new( + self.stats_type, + self.null_on_divide_by_zero, + )?)) + } + + fn state_fields(&self, _args: StateFieldsArgs) -> Result> { + Ok(vec![ + Field::new( + format_state_name(&self.name, "count"), + DataType::Float64, + true, + ), + Field::new( + format_state_name(&self.name, "mean"), + DataType::Float64, + true, + ), + Field::new(format_state_name(&self.name, "m2"), DataType::Float64, true), + ]) + } + + fn default_value(&self, _data_type: &DataType) -> Result { + Ok(ScalarValue::Float64(None)) + } +} + +/// An accumulator to compute variance +#[derive(Debug)] +pub struct VarianceAccumulator { + m2: f64, + mean: f64, + count: f64, + stats_type: StatsType, + null_on_divide_by_zero: bool, +} + +impl VarianceAccumulator { + /// Creates a new `VarianceAccumulator` + pub fn try_new(s_type: StatsType, null_on_divide_by_zero: bool) -> Result { + Ok(Self { + m2: 0_f64, + mean: 0_f64, + count: 0_f64, + stats_type: s_type, + null_on_divide_by_zero, + }) + } + + pub fn get_count(&self) -> f64 { + self.count + } + + pub fn get_mean(&self) -> f64 { + self.mean + } + + pub fn get_m2(&self) -> f64 { + self.m2 + } +} + +impl Accumulator for VarianceAccumulator { + fn state(&mut self) -> Result> { + Ok(vec![ + ScalarValue::from(self.count), + ScalarValue::from(self.mean), + ScalarValue::from(self.m2), + ]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = downcast_value!(&values[0], Float64Array).iter().flatten(); + + for value in arr { + let new_count = self.count + 1.0; + let delta1 = value - self.mean; + let new_mean = delta1 / new_count + self.mean; + let delta2 = value - new_mean; + let new_m2 = self.m2 + delta1 * delta2; + + self.count += 1.0; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + let arr = downcast_value!(&values[0], Float64Array).iter().flatten(); + + for value in arr { + let new_count = self.count - 1.0; + let delta1 = self.mean - value; + let new_mean = delta1 / new_count + self.mean; + let delta2 = new_mean - value; + let new_m2 = self.m2 - delta1 * delta2; + + self.count -= 1.0; + self.mean = new_mean; + self.m2 = new_m2; + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + let counts = downcast_value!(states[0], Float64Array); + let means = downcast_value!(states[1], Float64Array); + let m2s = downcast_value!(states[2], Float64Array); + + for i in 0..counts.len() { + let c = counts.value(i); + if c == 0_f64 { + continue; + } + let new_count = self.count + c; + let new_mean = + self.mean * self.count / new_count + means.value(i) * c / new_count; + let delta = self.mean - means.value(i); + let new_m2 = + self.m2 + m2s.value(i) + delta * delta * self.count * c / new_count; + + self.count = new_count; + self.mean = new_mean; + self.m2 = new_m2; + } + Ok(()) + } + + fn evaluate(&mut self) -> Result { + let count = match self.stats_type { + StatsType::Population => self.count, + StatsType::Sample => { + if self.count > 0.0 { + self.count - 1.0 + } else { + self.count + } + } + }; + + Ok(ScalarValue::Float64(match self.count { + count if count == 0.0 => None, + count if count == 1.0 && StatsType::Sample == self.stats_type => { + if self.null_on_divide_by_zero { + None + } else { + Some(f64::NAN) + } + } + _ => Some(self.m2 / count), + })) + } + + fn size(&self) -> usize { + std::mem::size_of_val(self) + } +} diff --git a/datafusion/spark/src/array_funcs/array_insert.rs b/datafusion/spark/src/array_funcs/array_insert.rs new file mode 100644 index 000000000000..f37e470f5bc0 --- /dev/null +++ b/datafusion/spark/src/array_funcs/array_insert.rs @@ -0,0 +1,451 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{as_primitive_array, Capacities, MutableArrayData}, + buffer::{NullBuffer, OffsetBuffer}, + datatypes::ArrowNativeType, + record_batch::RecordBatch, +}; +use arrow_array::{ + make_array, Array, ArrayRef, GenericListArray, Int32Array, OffsetSizeTrait, +}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{ + cast::{as_large_list_array, as_list_array}, + internal_err, DataFusionError, Result as DataFusionResult, +}; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +// 2147483632 == java.lang.Integer.MAX_VALUE - 15 +// It is a value of ByteArrayUtils.MAX_ROUNDED_ARRAY_LENGTH +// https://github.com/apache/spark/blob/master/common/utils/src/main/java/org/apache/spark/unsafe/array/ByteArrayUtils.java +const MAX_ROUNDED_ARRAY_LENGTH: usize = 2147483632; + +#[derive(Debug, Eq)] +pub struct ArrayInsert { + src_array_expr: Arc, + pos_expr: Arc, + item_expr: Arc, + legacy_negative_index: bool, +} + +impl Hash for ArrayInsert { + fn hash(&self, state: &mut H) { + self.src_array_expr.hash(state); + self.pos_expr.hash(state); + self.item_expr.hash(state); + self.legacy_negative_index.hash(state); + } +} +impl PartialEq for ArrayInsert { + fn eq(&self, other: &Self) -> bool { + self.src_array_expr.eq(&other.src_array_expr) + && self.pos_expr.eq(&other.pos_expr) + && self.item_expr.eq(&other.item_expr) + && self.legacy_negative_index.eq(&other.legacy_negative_index) + } +} + +impl ArrayInsert { + pub fn new( + src_array_expr: Arc, + pos_expr: Arc, + item_expr: Arc, + legacy_negative_index: bool, + ) -> Self { + Self { + src_array_expr, + pos_expr, + item_expr, + legacy_negative_index, + } + } + + pub fn array_type(&self, data_type: &DataType) -> DataFusionResult { + match data_type { + DataType::List(field) => Ok(DataType::List(Arc::clone(field))), + DataType::LargeList(field) => Ok(DataType::LargeList(Arc::clone(field))), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected src array type in ArrayInsert: {:?}", + data_type + ))), + } + } +} + +impl PhysicalExpr for ArrayInsert { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> DataFusionResult { + self.array_type(&self.src_array_expr.data_type(input_schema)?) + } + + fn nullable(&self, input_schema: &Schema) -> DataFusionResult { + self.src_array_expr.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let pos_value = self + .pos_expr + .evaluate(batch)? + .into_array(batch.num_rows())?; + + // Spark supports only IntegerType (Int32): + // https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4737 + if !matches!(pos_value.data_type(), DataType::Int32) { + return Err(DataFusionError::Internal(format!( + "Unexpected index data type in ArrayInsert: {:?}, expected type is Int32", + pos_value.data_type() + ))); + } + + // Check that src array is actually an array and get it's value type + let src_value = self + .src_array_expr + .evaluate(batch)? + .into_array(batch.num_rows())?; + + let src_element_type = match self.array_type(src_value.data_type())? { + DataType::List(field) => &field.data_type().clone(), + DataType::LargeList(field) => &field.data_type().clone(), + _ => unreachable!(), + }; + + // Check that inserted value has the same type as an array + let item_value = self + .item_expr + .evaluate(batch)? + .into_array(batch.num_rows())?; + if item_value.data_type() != src_element_type { + return Err(DataFusionError::Internal(format!( + "Type mismatch in ArrayInsert: array type is {:?} but item type is {:?}", + src_element_type, + item_value.data_type() + ))); + } + + match src_value.data_type() { + DataType::List(_) => { + let list_array = as_list_array(&src_value)?; + array_insert( + list_array, + &item_value, + &pos_value, + self.legacy_negative_index, + ) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&src_value)?; + array_insert( + list_array, + &item_value, + &pos_value, + self.legacy_negative_index, + ) + } + _ => unreachable!(), // This case is checked already + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.src_array_expr, &self.pos_expr, &self.item_expr] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> DataFusionResult> { + match children.len() { + 3 => Ok(Arc::new(ArrayInsert::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + Arc::clone(&children[2]), + self.legacy_negative_index, + ))), + _ => internal_err!("ArrayInsert should have exactly three childrens"), + } + } +} + +fn array_insert( + list_array: &GenericListArray, + items_array: &ArrayRef, + pos_array: &ArrayRef, + legacy_mode: bool, +) -> DataFusionResult { + // The code is based on the implementation of the array_append from the Apache DataFusion + // https://github.com/apache/datafusion/blob/main/datafusion/functions-nested/src/concat.rs#L513 + // + // This code is also based on the implementation of the array_insert from the Apache Spark + // https://github.com/apache/spark/blob/branch-3.5/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala#L4713 + + let values = list_array.values(); + let offsets = list_array.offsets(); + let values_data = values.to_data(); + let item_data = items_array.to_data(); + let new_capacity = Capacities::Array(values_data.len() + item_data.len()); + + let mut mutable_values = MutableArrayData::with_capacities( + vec![&values_data, &item_data], + true, + new_capacity, + ); + + let mut new_offsets = vec![O::usize_as(0)]; + let mut new_nulls = Vec::::with_capacity(list_array.len()); + + let pos_data: &Int32Array = as_primitive_array(&pos_array); // Spark supports only i32 for positions + + for (row_index, offset_window) in offsets.windows(2).enumerate() { + let pos = pos_data.values()[row_index]; + let start = offset_window[0].as_usize(); + let end = offset_window[1].as_usize(); + let is_item_null = items_array.is_null(row_index); + + if list_array.is_null(row_index) { + // In Spark if value of the array is NULL than nothing happens + mutable_values.extend_nulls(1); + new_offsets.push(new_offsets[row_index] + O::one()); + new_nulls.push(false); + continue; + } + + if pos == 0 { + return Err(DataFusionError::Internal( + "Position for array_insert should be greter or less than zero" + .to_string(), + )); + } + + if (pos > 0) || ((-pos).as_usize() < (end - start + 1)) { + let corrected_pos = if pos > 0 { + (pos - 1).as_usize() + } else { + end - start - (-pos).as_usize() + if legacy_mode { 0 } else { 1 } + }; + let new_array_len = std::cmp::max(end - start + 1, corrected_pos); + if new_array_len > MAX_ROUNDED_ARRAY_LENGTH { + return Err(DataFusionError::Internal(format!( + "Max array length in Spark is {:?}, but got {:?}", + MAX_ROUNDED_ARRAY_LENGTH, new_array_len + ))); + } + + if (start + corrected_pos) <= end { + mutable_values.extend(0, start, start + corrected_pos); + mutable_values.extend(1, row_index, row_index + 1); + mutable_values.extend(0, start + corrected_pos, end); + new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len)); + } else { + mutable_values.extend(0, start, end); + mutable_values.extend_nulls(new_array_len - (end - start)); + mutable_values.extend(1, row_index, row_index + 1); + // In that case spark actualy makes array longer than expected; + // For example, if pos is equal to 5, len is eq to 3, than resulted len will be 5 + new_offsets + .push(new_offsets[row_index] + O::usize_as(new_array_len) + O::one()); + } + } else { + // This comment is takes from the Apache Spark source code as is: + // special case- if the new position is negative but larger than the current array size + // place the new item at start of array, place the current array contents at the end + // and fill the newly created array elements inbetween with a null + let base_offset = if legacy_mode { 1 } else { 0 }; + let new_array_len = (-pos + base_offset).as_usize(); + if new_array_len > MAX_ROUNDED_ARRAY_LENGTH { + return Err(DataFusionError::Internal(format!( + "Max array length in Spark is {:?}, but got {:?}", + MAX_ROUNDED_ARRAY_LENGTH, new_array_len + ))); + } + mutable_values.extend(1, row_index, row_index + 1); + mutable_values.extend_nulls(new_array_len - (end - start + 1)); + mutable_values.extend(0, start, end); + new_offsets.push(new_offsets[row_index] + O::usize_as(new_array_len)); + } + if is_item_null { + if (start == end) || (values.is_null(row_index)) { + new_nulls.push(false) + } else { + new_nulls.push(true) + } + } else { + new_nulls.push(true) + } + } + + let data = make_array(mutable_values.freeze()); + let data_type = match list_array.data_type() { + DataType::List(field) => field.data_type(), + DataType::LargeList(field) => field.data_type(), + _ => unreachable!(), + }; + let new_array = GenericListArray::::try_new( + Arc::new(Field::new("item", data_type.clone(), true)), + OffsetBuffer::new(new_offsets.into()), + data, + Some(NullBuffer::new(new_nulls.into())), + )?; + + Ok(ColumnarValue::Array(Arc::new(new_array))) +} + +impl Display for ArrayInsert { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ArrayInsert [array: {:?}, pos: {:?}, item: {:?}]", + self.src_array_expr, self.pos_expr, self.item_expr + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow::datatypes::Int32Type; + use arrow_array::{Array, ArrayRef, Int32Array, ListArray}; + use datafusion_common::Result; + use datafusion_expr::ColumnarValue; + use std::sync::Arc; + + #[test] + fn test_array_insert() -> Result<()> { + // Test inserting an item into a list array + // Inputs and expected values are taken from the Spark results + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + Some(vec![None]), + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(1), Some(2), Some(3)]), + None, + ]); + + let positions = Int32Array::from(vec![2, 1, 1, 5, 6, 1]); + let items = Int32Array::from(vec![ + Some(10), + Some(20), + Some(30), + Some(100), + Some(100), + Some(40), + ]); + + let ColumnarValue::Array(result) = array_insert( + &list, + &(Arc::new(items) as ArrayRef), + &(Arc::new(positions) as ArrayRef), + false, + )? + else { + unreachable!() + }; + + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(10), Some(2), Some(3)]), + Some(vec![Some(20), Some(4), Some(5)]), + Some(vec![Some(30), None]), + Some(vec![Some(1), Some(2), Some(3), None, Some(100)]), + Some(vec![Some(1), Some(2), Some(3), None, None, Some(100)]), + None, + ]); + + assert_eq!(&result.to_data(), &expected.to_data()); + + Ok(()) + } + + #[test] + fn test_array_insert_negative_index() -> Result<()> { + // Test insert with negative index + // Inputs and expected values are taken from the Spark results + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + Some(vec![Some(1)]), + None, + ]); + + let positions = Int32Array::from(vec![-2, -1, -3, -1]); + let items = Int32Array::from(vec![Some(10), Some(20), Some(100), Some(30)]); + + let ColumnarValue::Array(result) = array_insert( + &list, + &(Arc::new(items) as ArrayRef), + &(Arc::new(positions) as ArrayRef), + false, + )? + else { + unreachable!() + }; + + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(10), Some(3)]), + Some(vec![Some(4), Some(5), Some(20)]), + Some(vec![Some(100), None, Some(1)]), + None, + ]); + + assert_eq!(&result.to_data(), &expected.to_data()); + + Ok(()) + } + + #[test] + fn test_array_insert_legacy_mode() -> Result<()> { + // Test the so-called "legacy" mode exisiting in the Spark + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(4), Some(5)]), + None, + ]); + + let positions = Int32Array::from(vec![-1, -1, -1]); + let items = Int32Array::from(vec![Some(10), Some(20), Some(30)]); + + let ColumnarValue::Array(result) = array_insert( + &list, + &(Arc::new(items) as ArrayRef), + &(Arc::new(positions) as ArrayRef), + true, + )? + else { + unreachable!() + }; + + let expected = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(10), Some(3)]), + Some(vec![Some(4), Some(20), Some(5)]), + None, + ]); + + assert_eq!(&result.to_data(), &expected.to_data()); + + Ok(()) + } +} diff --git a/datafusion/spark/src/array_funcs/get_array_struct_fields.rs b/datafusion/spark/src/array_funcs/get_array_struct_fields.rs new file mode 100644 index 000000000000..3f2a0c44ee95 --- /dev/null +++ b/datafusion/spark/src/array_funcs/get_array_struct_fields.rs @@ -0,0 +1,171 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::record_batch::RecordBatch; +use arrow_array::{Array, GenericListArray, OffsetSizeTrait, StructArray}; +use arrow_schema::{DataType, FieldRef, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{ + cast::{as_large_list_array, as_list_array}, + internal_err, DataFusionError, Result as DataFusionResult, +}; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct GetArrayStructFields { + child: Arc, + ordinal: usize, +} + +impl Hash for GetArrayStructFields { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.ordinal.hash(state); + } +} +impl PartialEq for GetArrayStructFields { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.ordinal.eq(&other.ordinal) + } +} + +impl GetArrayStructFields { + pub fn new(child: Arc, ordinal: usize) -> Self { + Self { child, ordinal } + } + + fn list_field(&self, input_schema: &Schema) -> DataFusionResult { + match self.child.data_type(input_schema)? { + DataType::List(field) | DataType::LargeList(field) => Ok(field), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in GetArrayStructFields: {:?}", + data_type + ))), + } + } + + fn child_field(&self, input_schema: &Schema) -> DataFusionResult { + match self.list_field(input_schema)?.data_type() { + DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in GetArrayStructFields: {:?}", + data_type + ))), + } + } +} + +impl PhysicalExpr for GetArrayStructFields { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> DataFusionResult { + let struct_field = self.child_field(input_schema)?; + match self.child.data_type(input_schema)? { + DataType::List(_) => Ok(DataType::List(struct_field)), + DataType::LargeList(_) => Ok(DataType::LargeList(struct_field)), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in GetArrayStructFields: {:?}", + data_type + ))), + } + } + + fn nullable(&self, input_schema: &Schema) -> DataFusionResult { + Ok(self.list_field(input_schema)?.is_nullable() + || self.child_field(input_schema)?.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; + + match child_value.data_type() { + DataType::List(_) => { + let list_array = as_list_array(&child_value)?; + + get_array_struct_fields(list_array, self.ordinal) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&child_value)?; + + get_array_struct_fields(list_array, self.ordinal) + } + data_type => Err(DataFusionError::Internal(format!( + "Unexpected child type for ListExtract: {:?}", + data_type + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + match children.len() { + 1 => Ok(Arc::new(GetArrayStructFields::new( + Arc::clone(&children[0]), + self.ordinal, + ))), + _ => internal_err!("GetArrayStructFields should have exactly one child"), + } + } +} + +fn get_array_struct_fields( + list_array: &GenericListArray, + ordinal: usize, +) -> DataFusionResult { + let values = list_array + .values() + .as_any() + .downcast_ref::() + .expect("A struct is expected"); + + let column = Arc::clone(values.column(ordinal)); + let field = Arc::clone(&values.fields()[ordinal]); + + let offsets = list_array.offsets(); + let array = GenericListArray::new( + field, + offsets.clone(), + column, + list_array.nulls().cloned(), + ); + + Ok(ColumnarValue::Array(Arc::new(array))) +} + +impl Display for GetArrayStructFields { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "GetArrayStructFields [child: {:?}, ordinal: {:?}]", + self.child, self.ordinal + ) + } +} diff --git a/datafusion/spark/src/array_funcs/list_extract.rs b/datafusion/spark/src/array_funcs/list_extract.rs new file mode 100644 index 000000000000..b24a1ebf74f3 --- /dev/null +++ b/datafusion/spark/src/array_funcs/list_extract.rs @@ -0,0 +1,317 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::MutableArrayData, datatypes::ArrowNativeType, record_batch::RecordBatch, +}; +use arrow_array::{Array, GenericListArray, Int32Array, OffsetSizeTrait}; +use arrow_schema::{DataType, FieldRef, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{ + cast::{as_int32_array, as_large_list_array, as_list_array}, + internal_err, DataFusionError, Result as DataFusionResult, ScalarValue, +}; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct ListExtract { + child: Arc, + ordinal: Arc, + default_value: Option>, + one_based: bool, + fail_on_error: bool, +} + +impl Hash for ListExtract { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.ordinal.hash(state); + self.default_value.hash(state); + self.one_based.hash(state); + self.fail_on_error.hash(state); + } +} +impl PartialEq for ListExtract { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) + && self.ordinal.eq(&other.ordinal) + && self.default_value.eq(&other.default_value) + && self.one_based.eq(&other.one_based) + && self.fail_on_error.eq(&other.fail_on_error) + } +} + +impl ListExtract { + pub fn new( + child: Arc, + ordinal: Arc, + default_value: Option>, + one_based: bool, + fail_on_error: bool, + ) -> Self { + Self { + child, + ordinal, + default_value, + one_based, + fail_on_error, + } + } + + fn child_field(&self, input_schema: &Schema) -> DataFusionResult { + match self.child.data_type(input_schema)? { + DataType::List(field) | DataType::LargeList(field) => Ok(field), + data_type => Err(DataFusionError::Internal(format!( + "Unexpected data type in ListExtract: {:?}", + data_type + ))), + } + } +} + +impl PhysicalExpr for ListExtract { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> DataFusionResult { + Ok(self.child_field(input_schema)?.data_type().clone()) + } + + fn nullable(&self, input_schema: &Schema) -> DataFusionResult { + // Only non-nullable if fail_on_error is enabled and the element is non-nullable + Ok(!self.fail_on_error || self.child_field(input_schema)?.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let child_value = self.child.evaluate(batch)?.into_array(batch.num_rows())?; + let ordinal_value = self.ordinal.evaluate(batch)?.into_array(batch.num_rows())?; + + let default_value = self + .default_value + .as_ref() + .map(|d| { + d.evaluate(batch).map(|value| match value { + ColumnarValue::Scalar(scalar) + if !scalar + .data_type() + .equals_datatype(child_value.data_type()) => + { + scalar.cast_to(child_value.data_type()) + } + ColumnarValue::Scalar(scalar) => Ok(scalar), + v => Err(DataFusionError::Execution(format!( + "Expected scalar default value for ListExtract, got {:?}", + v + ))), + }) + }) + .transpose()? + .unwrap_or(self.data_type(&batch.schema())?.try_into())?; + + let adjust_index = if self.one_based { + one_based_index + } else { + zero_based_index + }; + + match child_value.data_type() { + DataType::List(_) => { + let list_array = as_list_array(&child_value)?; + let index_array = as_int32_array(&ordinal_value)?; + + list_extract( + list_array, + index_array, + &default_value, + self.fail_on_error, + adjust_index, + ) + } + DataType::LargeList(_) => { + let list_array = as_large_list_array(&child_value)?; + let index_array = as_int32_array(&ordinal_value)?; + + list_extract( + list_array, + index_array, + &default_value, + self.fail_on_error, + adjust_index, + ) + } + data_type => Err(DataFusionError::Internal(format!( + "Unexpected child type for ListExtract: {:?}", + data_type + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child, &self.ordinal] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + match children.len() { + 2 => Ok(Arc::new(ListExtract::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + self.default_value.clone(), + self.one_based, + self.fail_on_error, + ))), + _ => internal_err!("ListExtract should have exactly two children"), + } + } +} + +fn one_based_index(index: i32, len: usize) -> DataFusionResult> { + if index == 0 { + return Err(DataFusionError::Execution( + "Invalid index of 0 for one-based ListExtract".to_string(), + )); + } + + let abs_index = index.abs().as_usize(); + if abs_index <= len { + if index > 0 { + Ok(Some(abs_index - 1)) + } else { + Ok(Some(len - abs_index)) + } + } else { + Ok(None) + } +} + +fn zero_based_index(index: i32, len: usize) -> DataFusionResult> { + if index < 0 { + Ok(None) + } else { + let positive_index = index.as_usize(); + if positive_index < len { + Ok(Some(positive_index)) + } else { + Ok(None) + } + } +} + +fn list_extract( + list_array: &GenericListArray, + index_array: &Int32Array, + default_value: &ScalarValue, + fail_on_error: bool, + adjust_index: impl Fn(i32, usize) -> DataFusionResult>, +) -> DataFusionResult { + let values = list_array.values(); + let offsets = list_array.offsets(); + + let data = values.to_data(); + + let default_data = default_value.to_array()?.to_data(); + + let mut mutable = + MutableArrayData::new(vec![&data, &default_data], true, index_array.len()); + + for (row, (offset_window, index)) in + offsets.windows(2).zip(index_array.values()).enumerate() + { + let start = offset_window[0].as_usize(); + let len = offset_window[1].as_usize() - start; + + if let Some(i) = adjust_index(*index, len)? { + mutable.extend(0, start + i, start + i + 1); + } else if list_array.is_null(row) { + mutable.extend_nulls(1); + } else if fail_on_error { + return Err(DataFusionError::Execution( + "Index out of bounds for array".to_string(), + )); + } else { + mutable.extend(1, 0, 1); + } + } + + let data = mutable.freeze(); + Ok(ColumnarValue::Array(arrow::array::make_array(data))) +} + +impl Display for ListExtract { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "ListExtract [child: {:?}, ordinal: {:?}, default_value: {:?}, one_based: {:?}, fail_on_error: {:?}]", + self.child, self.ordinal, self.default_value, self.one_based, self.fail_on_error + ) + } +} + +#[cfg(test)] +mod test { + use super::*; + use arrow::datatypes::Int32Type; + use arrow_array::{Array, Int32Array, ListArray}; + use datafusion_common::{Result, ScalarValue}; + use datafusion_expr::ColumnarValue; + + #[test] + fn test_list_extract_default_value() -> Result<()> { + let list = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1)]), + None, + Some(vec![]), + ]); + let indices = Int32Array::from(vec![0, 0, 0]); + + let null_default = ScalarValue::Int32(None); + + let ColumnarValue::Array(result) = + list_extract(&list, &indices, &null_default, false, zero_based_index)? + else { + unreachable!() + }; + + assert_eq!( + &result.to_data(), + &Int32Array::from(vec![Some(1), None, None]).to_data() + ); + + let zero_default = ScalarValue::Int32(Some(0)); + + let ColumnarValue::Array(result) = + list_extract(&list, &indices, &zero_default, false, zero_based_index)? + else { + unreachable!() + }; + + assert_eq!( + &result.to_data(), + &Int32Array::from(vec![Some(1), None, Some(0)]).to_data() + ); + Ok(()) + } +} diff --git a/datafusion/spark/src/array_funcs/mod.rs b/datafusion/spark/src/array_funcs/mod.rs new file mode 100644 index 000000000000..0a215f96cf85 --- /dev/null +++ b/datafusion/spark/src/array_funcs/mod.rs @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod array_insert; +mod get_array_struct_fields; +mod list_extract; + +pub use array_insert::ArrayInsert; +pub use get_array_struct_fields::GetArrayStructFields; +pub use list_extract::ListExtract; diff --git a/datafusion/spark/src/bitwise_funcs/bitwise_not.rs b/datafusion/spark/src/bitwise_funcs/bitwise_not.rs new file mode 100644 index 000000000000..668e319adf1c --- /dev/null +++ b/datafusion/spark/src/bitwise_funcs/bitwise_not.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::*, + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion::{error::DataFusionError, logical_expr::ColumnarValue}; +use datafusion_common::Result; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; + +macro_rules! compute_op { + ($OPERAND:expr, $DT:ident) => {{ + let operand = $OPERAND + .as_any() + .downcast_ref::<$DT>() + .expect("compute_op failed to downcast array"); + let result: $DT = operand.iter().map(|x| x.map(|y| !y)).collect(); + Ok(Arc::new(result)) + }}; +} + +/// BitwiseNot expression +#[derive(Debug, Eq)] +pub struct BitwiseNotExpr { + /// Input expression + arg: Arc, +} + +impl Hash for BitwiseNotExpr { + fn hash(&self, state: &mut H) { + self.arg.hash(state); + } +} + +impl PartialEq for BitwiseNotExpr { + fn eq(&self, other: &Self) -> bool { + self.arg.eq(&other.arg) + } +} + +impl BitwiseNotExpr { + /// Create new bitwise not expression + pub fn new(arg: Arc) -> Self { + Self { arg } + } + + /// Get the input expression + pub fn arg(&self) -> &Arc { + &self.arg + } +} + +impl std::fmt::Display for BitwiseNotExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "(~ {})", self.arg) + } +} + +impl PhysicalExpr for BitwiseNotExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + self.arg.data_type(input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.arg.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let arg = self.arg.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let result: Result = match array.data_type() { + DataType::Int8 => compute_op!(array, Int8Array), + DataType::Int16 => compute_op!(array, Int16Array), + DataType::Int32 => compute_op!(array, Int32Array), + DataType::Int64 => compute_op!(array, Int64Array), + _ => Err(DataFusionError::Execution(format!( + "(- '{:?}') can't be evaluated because the expression's type is {:?}, not signed int", + self, + array.data_type(), + ))), + }; + result.map(ColumnarValue::Array) + } + ColumnarValue::Scalar(_) => Err(DataFusionError::Internal( + "shouldn't go to bitwise not scalar path".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.arg] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(BitwiseNotExpr::new(Arc::clone(&children[0])))) + } +} + +pub fn bitwise_not(arg: Arc) -> Result> { + Ok(Arc::new(BitwiseNotExpr::new(arg))) +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::*; + use datafusion_common::{cast::as_int32_array, Result}; + use datafusion_physical_expr::expressions::col; + + use super::*; + + #[test] + fn bitwise_not_op() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + + let expr = bitwise_not(col("a", &schema)?)?; + + let input = Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(12345), + Some(89), + Some(-3456), + ]); + let expected = &Int32Array::from(vec![ + Some(-2), + Some(-3), + None, + Some(-12346), + Some(-90), + Some(3455), + ]); + + let batch = + RecordBatch::try_new(Arc::new(schema.clone()), vec![Arc::new(input)])?; + + let result = expr.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_int32_array(&result).expect("failed to downcast to In32Array"); + assert_eq!(result, expected); + + Ok(()) + } +} diff --git a/datafusion/spark/src/bitwise_funcs/mod.rs b/datafusion/spark/src/bitwise_funcs/mod.rs new file mode 100644 index 000000000000..9c2636331961 --- /dev/null +++ b/datafusion/spark/src/bitwise_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod bitwise_not; + +pub use bitwise_not::{bitwise_not, BitwiseNotExpr}; diff --git a/datafusion/spark/src/comet_scalar_funcs.rs b/datafusion/spark/src/comet_scalar_funcs.rs new file mode 100644 index 000000000000..ca9ae8e33e17 --- /dev/null +++ b/datafusion/spark/src/comet_scalar_funcs.rs @@ -0,0 +1,192 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::hash_funcs::*; +use crate::{ + spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div, spark_floor, + spark_hex, spark_isnan, spark_make_decimal, spark_read_side_padding, spark_round, + spark_unhex, spark_unscaled_value, SparkChrFunc, +}; +use arrow_schema::DataType; +use datafusion_common::{DataFusionError, Result as DataFusionResult}; +use datafusion_expr::registry::FunctionRegistry; +use datafusion_expr::{ + ColumnarValue, ScalarFunctionImplementation, ScalarUDF, ScalarUDFImpl, Signature, + Volatility, +}; +use std::any::Any; +use std::fmt::Debug; +use std::sync::Arc; + +macro_rules! make_comet_scalar_udf { + ($name:expr, $func:ident, $data_type:ident) => {{ + let scalar_func = CometScalarFunction::new( + $name.to_string(), + Signature::variadic_any(Volatility::Immutable), + $data_type.clone(), + Arc::new(move |args| $func(args, &$data_type)), + ); + Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func))) + }}; + ($name:expr, $func:expr, without $data_type:ident) => {{ + let scalar_func = CometScalarFunction::new( + $name.to_string(), + Signature::variadic_any(Volatility::Immutable), + $data_type, + $func, + ); + Ok(Arc::new(ScalarUDF::new_from_impl(scalar_func))) + }}; +} + +/// Create a physical scalar function. +pub fn create_comet_physical_fun( + fun_name: &str, + data_type: DataType, + registry: &dyn FunctionRegistry, +) -> Result, DataFusionError> { + match fun_name { + "ceil" => { + make_comet_scalar_udf!("ceil", spark_ceil, data_type) + } + "floor" => { + make_comet_scalar_udf!("floor", spark_floor, data_type) + } + "read_side_padding" => { + let func = Arc::new(spark_read_side_padding); + make_comet_scalar_udf!("read_side_padding", func, without data_type) + } + "round" => { + make_comet_scalar_udf!("round", spark_round, data_type) + } + "unscaled_value" => { + let func = Arc::new(spark_unscaled_value); + make_comet_scalar_udf!("unscaled_value", func, without data_type) + } + "make_decimal" => { + make_comet_scalar_udf!("make_decimal", spark_make_decimal, data_type) + } + "hex" => { + let func = Arc::new(spark_hex); + make_comet_scalar_udf!("hex", func, without data_type) + } + "unhex" => { + let func = Arc::new(spark_unhex); + make_comet_scalar_udf!("unhex", func, without data_type) + } + "decimal_div" => { + make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type) + } + "murmur3_hash" => { + let func = Arc::new(spark_murmur3_hash); + make_comet_scalar_udf!("murmur3_hash", func, without data_type) + } + "xxhash64" => { + let func = Arc::new(spark_xxhash64); + make_comet_scalar_udf!("xxhash64", func, without data_type) + } + "chr" => Ok(Arc::new(ScalarUDF::new_from_impl(SparkChrFunc::default()))), + "isnan" => { + let func = Arc::new(spark_isnan); + make_comet_scalar_udf!("isnan", func, without data_type) + } + "sha224" => { + let func = Arc::new(spark_sha224); + make_comet_scalar_udf!("sha224", func, without data_type) + } + "sha256" => { + let func = Arc::new(spark_sha256); + make_comet_scalar_udf!("sha256", func, without data_type) + } + "sha384" => { + let func = Arc::new(spark_sha384); + make_comet_scalar_udf!("sha384", func, without data_type) + } + "sha512" => { + let func = Arc::new(spark_sha512); + make_comet_scalar_udf!("sha512", func, without data_type) + } + "date_add" => { + let func = Arc::new(spark_date_add); + make_comet_scalar_udf!("date_add", func, without data_type) + } + "date_sub" => { + let func = Arc::new(spark_date_sub); + make_comet_scalar_udf!("date_sub", func, without data_type) + } + _ => registry.udf(fun_name).map_err(|e| { + DataFusionError::Execution(format!( + "Function {fun_name} not found in the registry: {e}", + )) + }), + } +} + +struct CometScalarFunction { + name: String, + signature: Signature, + data_type: DataType, + func: ScalarFunctionImplementation, +} + +impl Debug for CometScalarFunction { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("CometScalarFunction") + .field("name", &self.name) + .field("signature", &self.signature) + .field("data_type", &self.data_type) + .finish() + } +} + +impl CometScalarFunction { + fn new( + name: String, + signature: Signature, + data_type: DataType, + func: ScalarFunctionImplementation, + ) -> Self { + Self { + name, + signature, + data_type, + func, + } + } +} + +impl ScalarUDFImpl for CometScalarFunction { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + self.name.as_str() + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _: &[DataType]) -> DataFusionResult { + Ok(self.data_type.clone()) + } + + fn invoke(&self, args: &[ColumnarValue]) -> DataFusionResult { + (self.func)(args) + } +} diff --git a/datafusion/spark/src/conditional_funcs/if_expr.rs b/datafusion/spark/src/conditional_funcs/if_expr.rs new file mode 100644 index 000000000000..5924028a3929 --- /dev/null +++ b/datafusion/spark/src/conditional_funcs/if_expr.rs @@ -0,0 +1,211 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + datatypes::{DataType, Schema}, + record_batch::RecordBatch, +}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::Result; +use datafusion_physical_expr::{expressions::CaseExpr, PhysicalExpr}; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; + +/// IfExpr is a wrapper around CaseExpr, because `IF(a, b, c)` is semantically equivalent to +/// `CASE WHEN a THEN b ELSE c END`. +#[derive(Debug, Eq)] +pub struct IfExpr { + if_expr: Arc, + true_expr: Arc, + false_expr: Arc, + // we delegate to case_expr for evaluation + case_expr: Arc, +} + +impl Hash for IfExpr { + fn hash(&self, state: &mut H) { + self.if_expr.hash(state); + self.true_expr.hash(state); + self.false_expr.hash(state); + self.case_expr.hash(state); + } +} +impl PartialEq for IfExpr { + fn eq(&self, other: &Self) -> bool { + self.if_expr.eq(&other.if_expr) + && self.true_expr.eq(&other.true_expr) + && self.false_expr.eq(&other.false_expr) + && self.case_expr.eq(&other.case_expr) + } +} + +impl std::fmt::Display for IfExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "If [if: {}, true_expr: {}, false_expr: {}]", + self.if_expr, self.true_expr, self.false_expr + ) + } +} + +impl IfExpr { + /// Create a new IF expression + pub fn new( + if_expr: Arc, + true_expr: Arc, + false_expr: Arc, + ) -> Self { + Self { + if_expr: Arc::clone(&if_expr), + true_expr: Arc::clone(&true_expr), + false_expr: Arc::clone(&false_expr), + case_expr: Arc::new( + CaseExpr::try_new(None, vec![(if_expr, true_expr)], Some(false_expr)) + .unwrap(), + ), + } + } +} + +impl PhysicalExpr for IfExpr { + /// Return a reference to Any that can be used for down-casting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + let data_type = self.true_expr.data_type(input_schema)?; + Ok(data_type) + } + + fn nullable(&self, _input_schema: &Schema) -> Result { + if self.true_expr.nullable(_input_schema)? + || self.true_expr.nullable(_input_schema)? + { + Ok(true) + } else { + Ok(false) + } + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + self.case_expr.evaluate(batch) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.if_expr, &self.true_expr, &self.false_expr] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(IfExpr::new( + Arc::clone(&children[0]), + Arc::clone(&children[1]), + Arc::clone(&children[2]), + ))) + } +} + +#[cfg(test)] +mod tests { + use arrow::{array::StringArray, datatypes::*}; + use arrow_array::Int32Array; + use datafusion::logical_expr::Operator; + use datafusion_common::cast::as_int32_array; + use datafusion_physical_expr::expressions::{binary, col, lit}; + + use super::*; + + /// Create an If expression + fn if_fn( + if_expr: Arc, + true_expr: Arc, + false_expr: Arc, + ) -> Result> { + Ok(Arc::new(IfExpr::new(if_expr, true_expr, false_expr))) + } + + #[test] + fn test_if_1() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); + let a = StringArray::from(vec![Some("foo"), Some("baz"), None, Some("bar")]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + let schema_ref = batch.schema(); + + // if a = 'foo' 123 else 999 + let if_expr = binary( + col("a", &schema_ref)?, + Operator::Eq, + lit("foo"), + &schema_ref, + )?; + let true_expr = lit(123i32); + let false_expr = lit(999i32); + + let expr = if_fn(if_expr, true_expr, false_expr); + let result = expr?.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_int32_array(&result)?; + + let expected = + &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(999)]); + + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn test_if_2() -> Result<()> { + let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); + let a = Int32Array::from(vec![Some(1), Some(0), None, Some(5)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(a)])?; + let schema_ref = batch.schema(); + + // if a >=1 123 else 999 + let if_expr = + binary(col("a", &schema_ref)?, Operator::GtEq, lit(1), &schema_ref)?; + let true_expr = lit(123i32); + let false_expr = lit(999i32); + + let expr = if_fn(if_expr, true_expr, false_expr); + let result = expr?.evaluate(&batch)?.into_array(batch.num_rows())?; + let result = as_int32_array(&result)?; + + let expected = + &Int32Array::from(vec![Some(123), Some(999), Some(999), Some(123)]); + assert_eq!(expected, result); + + Ok(()) + } + + #[test] + fn test_if_children() { + let if_expr = lit(true); + let true_expr = lit(123i32); + let false_expr = lit(999i32); + + let expr = if_fn(if_expr, true_expr, false_expr).unwrap(); + let children = expr.children(); + assert_eq!(children.len(), 3); + assert_eq!(children[0].to_string(), "true"); + assert_eq!(children[1].to_string(), "123"); + assert_eq!(children[2].to_string(), "999"); + } +} diff --git a/datafusion/spark/src/conditional_funcs/mod.rs b/datafusion/spark/src/conditional_funcs/mod.rs new file mode 100644 index 000000000000..70c459ef7c08 --- /dev/null +++ b/datafusion/spark/src/conditional_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod if_expr; + +pub use if_expr::IfExpr; diff --git a/datafusion/spark/src/conversion_funcs/cast.rs b/datafusion/spark/src/conversion_funcs/cast.rs new file mode 100644 index 000000000000..ac62fc9253ad --- /dev/null +++ b/datafusion/spark/src/conversion_funcs/cast.rs @@ -0,0 +1,2727 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::timezone; +use crate::utils::array_with_timezone; +use crate::{EvalMode, SparkError, SparkResult}; +use arrow::{ + array::{ + cast::AsArray, + types::{Date32Type, Int16Type, Int32Type, Int8Type}, + Array, ArrayRef, BooleanArray, Decimal128Array, Float32Array, Float64Array, + GenericStringArray, Int16Array, Int32Array, Int64Array, Int8Array, + OffsetSizeTrait, PrimitiveArray, + }, + compute::{cast_with_options, take, unary, CastOptions}, + datatypes::{ + ArrowPrimitiveType, Decimal128Type, DecimalType, Float32Type, Float64Type, + Int64Type, TimestampMicrosecondType, + }, + error::ArrowError, + record_batch::RecordBatch, + util::display::FormatOptions, +}; +use arrow_array::builder::StringBuilder; +use arrow_array::{DictionaryArray, StringArray, StructArray}; +use arrow_schema::{DataType, Schema}; +use chrono::{NaiveDate, NaiveDateTime, TimeZone, Timelike}; +use datafusion_common::{ + cast::as_generic_string_array, internal_err, Result as DataFusionResult, ScalarValue, +}; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr::PhysicalExpr; +use num::{ + cast::AsPrimitive, integer::div_floor, traits::CheckedNeg, CheckedSub, Integer, Num, + ToPrimitive, +}; +use regex::Regex; +use std::collections::HashMap; +use std::str::FromStr; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + hash::Hash, + num::Wrapping, + sync::Arc, +}; + +static TIMESTAMP_FORMAT: Option<&str> = Some("%Y-%m-%d %H:%M:%S%.f"); + +const MICROS_PER_SECOND: i64 = 1000000; + +static CAST_OPTIONS: CastOptions = CastOptions { + safe: true, + format_options: FormatOptions::new() + .with_timestamp_tz_format(TIMESTAMP_FORMAT) + .with_timestamp_format(TIMESTAMP_FORMAT), +}; + +struct TimeStampInfo { + year: i32, + month: u32, + day: u32, + hour: u32, + minute: u32, + second: u32, + microsecond: u32, +} + +impl Default for TimeStampInfo { + fn default() -> Self { + TimeStampInfo { + year: 1, + month: 1, + day: 1, + hour: 0, + minute: 0, + second: 0, + microsecond: 0, + } + } +} + +impl TimeStampInfo { + pub fn with_year(&mut self, year: i32) -> &mut Self { + self.year = year; + self + } + + pub fn with_month(&mut self, month: u32) -> &mut Self { + self.month = month; + self + } + + pub fn with_day(&mut self, day: u32) -> &mut Self { + self.day = day; + self + } + + pub fn with_hour(&mut self, hour: u32) -> &mut Self { + self.hour = hour; + self + } + + pub fn with_minute(&mut self, minute: u32) -> &mut Self { + self.minute = minute; + self + } + + pub fn with_second(&mut self, second: u32) -> &mut Self { + self.second = second; + self + } + + pub fn with_microsecond(&mut self, microsecond: u32) -> &mut Self { + self.microsecond = microsecond; + self + } +} + +#[derive(Debug, Eq)] +pub struct Cast { + pub child: Arc, + pub data_type: DataType, + pub cast_options: SparkCastOptions, +} + +impl PartialEq for Cast { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) + && self.data_type.eq(&other.data_type) + && self.cast_options.eq(&other.cast_options) + } +} + +impl Hash for Cast { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.data_type.hash(state); + self.cast_options.hash(state); + } +} + +/// Determine if Comet supports a cast, taking options such as EvalMode and Timezone into account. +pub fn cast_supported( + from_type: &DataType, + to_type: &DataType, + options: &SparkCastOptions, +) -> bool { + use DataType::*; + + let from_type = if let Dictionary(_, dt) = from_type { + dt + } else { + from_type + }; + + let to_type = if let Dictionary(_, dt) = to_type { + dt + } else { + to_type + }; + + if from_type == to_type { + return true; + } + + match (from_type, to_type) { + (Boolean, _) => can_cast_from_boolean(to_type, options), + (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64) + if options.allow_cast_unsigned_ints => + { + true + } + (Int8, _) => can_cast_from_byte(to_type, options), + (Int16, _) => can_cast_from_short(to_type, options), + (Int32, _) => can_cast_from_int(to_type, options), + (Int64, _) => can_cast_from_long(to_type, options), + (Float32, _) => can_cast_from_float(to_type, options), + (Float64, _) => can_cast_from_double(to_type, options), + (Decimal128(p, s), _) => can_cast_from_decimal(p, s, to_type, options), + (Timestamp(_, None), _) => can_cast_from_timestamp_ntz(to_type, options), + (Timestamp(_, Some(_)), _) => can_cast_from_timestamp(to_type, options), + (Utf8 | LargeUtf8, _) => can_cast_from_string(to_type, options), + (_, Utf8 | LargeUtf8) => can_cast_to_string(from_type, options), + (Struct(from_fields), Struct(to_fields)) => from_fields + .iter() + .zip(to_fields.iter()) + .all(|(a, b)| cast_supported(a.data_type(), b.data_type(), options)), + _ => false, + } +} + +fn can_cast_from_string(to_type: &DataType, options: &SparkCastOptions) -> bool { + use DataType::*; + match to_type { + Boolean | Int8 | Int16 | Int32 | Int64 | Binary => true, + Float32 | Float64 => { + // https://github.com/apache/datafusion-comet/issues/326 + // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. + // Does not support ANSI mode. + options.allow_incompat + } + Decimal128(_, _) => { + // https://github.com/apache/datafusion-comet/issues/325 + // Does not support inputs ending with 'd' or 'f'. Does not support 'inf'. + // Does not support ANSI mode. Returns 0.0 instead of null if input contains no digits + + options.allow_incompat + } + Date32 | Date64 => { + // https://github.com/apache/datafusion-comet/issues/327 + // Only supports years between 262143 BC and 262142 AD + options.allow_incompat + } + Timestamp(_, _) if options.eval_mode == EvalMode::Ansi => { + // ANSI mode not supported + false + } + Timestamp(_, Some(tz)) if tz.as_ref() != "UTC" => { + // Cast will use UTC instead of $timeZoneId + options.allow_incompat + } + Timestamp(_, _) => { + // https://github.com/apache/datafusion-comet/issues/328 + // Not all valid formats are supported + options.allow_incompat + } + _ => false, + } +} + +fn can_cast_to_string(from_type: &DataType, options: &SparkCastOptions) -> bool { + use DataType::*; + match from_type { + Boolean | Int8 | Int16 | Int32 | Int64 | Date32 | Date64 | Timestamp(_, _) => { + true + } + Float32 | Float64 => { + // There can be differences in precision. + // For example, the input \"1.4E-45\" will produce 1.0E-45 " + + // instead of 1.4E-45")) + true + } + Decimal128(_, _) => { + // https://github.com/apache/datafusion-comet/issues/1068 + // There can be formatting differences in some case due to Spark using + // scientific notation where Comet does not + true + } + Binary => { + // https://github.com/apache/datafusion-comet/issues/377 + // Only works for binary data representing valid UTF-8 strings + options.allow_incompat + } + Struct(fields) => fields + .iter() + .all(|f| can_cast_to_string(f.data_type(), options)), + _ => false, + } +} + +fn can_cast_from_timestamp_ntz(to_type: &DataType, options: &SparkCastOptions) -> bool { + use DataType::*; + match to_type { + Timestamp(_, _) | Date32 | Date64 | Utf8 => { + // incompatible + options.allow_incompat + } + _ => { + // unsupported + false + } + } +} + +fn can_cast_from_timestamp(to_type: &DataType, _options: &SparkCastOptions) -> bool { + use DataType::*; + match to_type { + Boolean | Int8 | Int16 => { + // https://github.com/apache/datafusion-comet/issues/352 + // this seems like an edge case that isn't important for us to support + false + } + Int64 => { + // https://github.com/apache/datafusion-comet/issues/352 + true + } + Date32 | Date64 | Utf8 | Decimal128(_, _) => true, + _ => { + // unsupported + false + } + } +} + +fn can_cast_from_boolean(to_type: &DataType, _: &SparkCastOptions) -> bool { + use DataType::*; + matches!(to_type, Int8 | Int16 | Int32 | Int64 | Float32 | Float64) +} + +fn can_cast_from_byte(to_type: &DataType, _: &SparkCastOptions) -> bool { + use DataType::*; + matches!( + to_type, + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) + ) +} + +fn can_cast_from_short(to_type: &DataType, _: &SparkCastOptions) -> bool { + use DataType::*; + matches!( + to_type, + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Decimal128(_, _) + ) +} + +fn can_cast_from_int(to_type: &DataType, options: &SparkCastOptions) -> bool { + use DataType::*; + match to_type { + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 | Utf8 => true, + Decimal128(_, _) => { + // incompatible: no overflow check + options.allow_incompat + } + _ => false, + } +} + +fn can_cast_from_long(to_type: &DataType, options: &SparkCastOptions) -> bool { + use DataType::*; + match to_type { + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true, + Decimal128(_, _) => { + // incompatible: no overflow check + options.allow_incompat + } + _ => false, + } +} + +fn can_cast_from_float(to_type: &DataType, _: &SparkCastOptions) -> bool { + use DataType::*; + matches!( + to_type, + Boolean | Int8 | Int16 | Int32 | Int64 | Float64 | Decimal128(_, _) + ) +} + +fn can_cast_from_double(to_type: &DataType, _: &SparkCastOptions) -> bool { + use DataType::*; + matches!( + to_type, + Boolean | Int8 | Int16 | Int32 | Int64 | Float32 | Decimal128(_, _) + ) +} + +fn can_cast_from_decimal( + p1: &u8, + _s1: &i8, + to_type: &DataType, + options: &SparkCastOptions, +) -> bool { + use DataType::*; + match to_type { + Int8 | Int16 | Int32 | Int64 | Float32 | Float64 => true, + Decimal128(p2, _) => { + if p2 < p1 { + // https://github.com/apache/datafusion/issues/13492 + // Incompatible(Some("Casting to smaller precision is not supported")) + options.allow_incompat + } else { + true + } + } + _ => false, + } +} + +macro_rules! cast_utf8_to_int { + ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident) => {{ + let len = $array.len(); + let mut cast_array = PrimitiveArray::<$array_type>::builder(len); + for i in 0..len { + if $array.is_null(i) { + cast_array.append_null() + } else if let Some(cast_value) = $cast_method($array.value(i), $eval_mode)? { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + let result: SparkResult = Ok(Arc::new(cast_array.finish()) as ArrayRef); + result + }}; +} +macro_rules! cast_utf8_to_timestamp { + ($array:expr, $eval_mode:expr, $array_type:ty, $cast_method:ident, $tz:expr) => {{ + let len = $array.len(); + let mut cast_array = + PrimitiveArray::<$array_type>::builder(len).with_timezone("UTC"); + for i in 0..len { + if $array.is_null(i) { + cast_array.append_null() + } else if let Ok(Some(cast_value)) = + $cast_method($array.value(i).trim(), $eval_mode, $tz) + { + cast_array.append_value(cast_value); + } else { + cast_array.append_null() + } + } + let result: ArrayRef = Arc::new(cast_array.finish()) as ArrayRef; + result + }}; +} + +macro_rules! cast_float_to_string { + ($from:expr, $eval_mode:expr, $type:ty, $output_type:ty, $offset_type:ty) => {{ + + fn cast( + from: &dyn Array, + _eval_mode: EvalMode, + ) -> SparkResult + where + OffsetSize: OffsetSizeTrait, { + let array = from.as_any().downcast_ref::<$output_type>().unwrap(); + + // If the absolute number is less than 10,000,000 and greater or equal than 0.001, the + // result is expressed without scientific notation with at least one digit on either side of + // the decimal point. Otherwise, Spark uses a mantissa followed by E and an + // exponent. The mantissa has an optional leading minus sign followed by one digit to the + // left of the decimal point, and the minimal number of digits greater than zero to the + // right. The exponent has and optional leading minus sign. + // source: https://docs.databricks.com/en/sql/language-manual/functions/cast.html + + const LOWER_SCIENTIFIC_BOUND: $type = 0.001; + const UPPER_SCIENTIFIC_BOUND: $type = 10000000.0; + + let output_array = array + .iter() + .map(|value| match value { + Some(value) if value == <$type>::INFINITY => Ok(Some("Infinity".to_string())), + Some(value) if value == <$type>::NEG_INFINITY => Ok(Some("-Infinity".to_string())), + Some(value) + if (value.abs() < UPPER_SCIENTIFIC_BOUND + && value.abs() >= LOWER_SCIENTIFIC_BOUND) + || value.abs() == 0.0 => + { + let trailing_zero = if value.fract() == 0.0 { ".0" } else { "" }; + + Ok(Some(format!("{value}{trailing_zero}"))) + } + Some(value) + if value.abs() >= UPPER_SCIENTIFIC_BOUND + || value.abs() < LOWER_SCIENTIFIC_BOUND => + { + let formatted = format!("{value:E}"); + + if formatted.contains(".") { + Ok(Some(formatted)) + } else { + // `formatted` is already in scientific notation and can be split up by E + // in order to add the missing trailing 0 which gets removed for numbers with a fraction of 0.0 + let prepare_number: Vec<&str> = formatted.split("E").collect(); + + let coefficient = prepare_number[0]; + + let exponent = prepare_number[1]; + + Ok(Some(format!("{coefficient}.0E{exponent}"))) + } + } + Some(value) => Ok(Some(value.to_string())), + _ => Ok(None), + }) + .collect::, SparkError>>()?; + + Ok(Arc::new(output_array)) + } + + cast::<$offset_type>($from, $eval_mode) + }}; +} + +macro_rules! cast_int_to_int_macro { + ( + $array: expr, + $eval_mode:expr, + $from_arrow_primitive_type: ty, + $to_arrow_primitive_type: ty, + $from_data_type: expr, + $to_native_type: ty, + $spark_from_data_type_name: expr, + $spark_to_data_type_name: expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::>() + .unwrap(); + let spark_int_literal_suffix = match $from_data_type { + &DataType::Int64 => "L", + &DataType::Int16 => "S", + &DataType::Int8 => "T", + _ => "", + }; + + let output_array = match $eval_mode { + EvalMode::Legacy => cast_array + .iter() + .map(|value| match value { + Some(value) => Ok::, SparkError>(Some( + value as $to_native_type, + )), + _ => Ok(None), + }) + .collect::, _>>(), + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let res = <$to_native_type>::try_from(value); + if res.is_err() { + Err(cast_overflow( + &(value.to_string() + spark_int_literal_suffix), + $spark_from_data_type_name, + $spark_to_data_type_name, + )) + } else { + Ok::, SparkError>(Some(res.unwrap())) + } + } + _ => Ok(None), + }) + .collect::, _>>(), + }?; + let result: SparkResult = Ok(Arc::new(output_array) as ArrayRef); + result + }}; +} + +// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short. +// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast, +// this can cause unexpected Short/Byte cast results. Replicate this behavior. +macro_rules! cast_float_to_int16_down { + ( + $array:expr, + $eval_mode:expr, + $src_array_type:ty, + $dest_array_type:ty, + $rust_src_type:ty, + $rust_dest_type:ty, + $src_type_str:expr, + $dest_type_str:expr, + $format_str:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::<$src_array_type>() + .expect(concat!("Expected a ", stringify!($src_array_type))); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let is_overflow = + value.is_nan() || value.abs() as i32 == i32::MAX; + if is_overflow { + return Err(cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + )); + } + let i32_value = value as i32; + <$rust_dest_type>::try_from(i32_value) + .map_err(|_| { + cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + ) + }) + .map(Some) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let i32_value = value as i32; + Ok::, SparkError>(Some( + i32_value as $rust_dest_type, + )) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +macro_rules! cast_float_to_int32_up { + ( + $array:expr, + $eval_mode:expr, + $src_array_type:ty, + $dest_array_type:ty, + $rust_src_type:ty, + $rust_dest_type:ty, + $src_type_str:expr, + $dest_type_str:expr, + $max_dest_val:expr, + $format_str:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::<$src_array_type>() + .expect(concat!("Expected a ", stringify!($src_array_type))); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let is_overflow = value.is_nan() + || value.abs() as $rust_dest_type == $max_dest_val; + if is_overflow { + return Err(cast_overflow( + &format!($format_str, value).replace("e", "E"), + $src_type_str, + $dest_type_str, + )); + } + Ok(Some(value as $rust_dest_type)) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => Ok::, SparkError>(Some( + value as $rust_dest_type, + )), + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +// When Spark casts to Byte/Short Types, it does not cast directly to Byte/Short. +// It casts to Int first and then to Byte/Short. Because of potential overflows in the Int cast, +// this can cause unexpected Short/Byte cast results. Replicate this behavior. +macro_rules! cast_decimal_to_int16_down { + ( + $array:expr, + $eval_mode:expr, + $dest_array_type:ty, + $rust_dest_type:ty, + $dest_type_str:expr, + $precision:expr, + $scale:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::() + .expect(concat!("Expected a Decimal128ArrayType")); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let (truncated, decimal) = + (value / divisor, (value % divisor).abs()); + let is_overflow = truncated.abs() > i32::MAX.into(); + if is_overflow { + return Err(cast_overflow( + &format!("{}.{}BD", truncated, decimal), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + )); + } + let i32_value = truncated as i32; + <$rust_dest_type>::try_from(i32_value) + .map_err(|_| { + cast_overflow( + &format!("{}.{}BD", truncated, decimal), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + ) + }) + .map(Some) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let i32_value = (value / divisor) as i32; + Ok::, SparkError>(Some( + i32_value as $rust_dest_type, + )) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +macro_rules! cast_decimal_to_int32_up { + ( + $array:expr, + $eval_mode:expr, + $dest_array_type:ty, + $rust_dest_type:ty, + $dest_type_str:expr, + $max_dest_val:expr, + $precision:expr, + $scale:expr + ) => {{ + let cast_array = $array + .as_any() + .downcast_ref::() + .expect(concat!("Expected a Decimal128ArrayType")); + + let output_array = match $eval_mode { + EvalMode::Ansi => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let (truncated, decimal) = + (value / divisor, (value % divisor).abs()); + let is_overflow = truncated.abs() > $max_dest_val.into(); + if is_overflow { + return Err(cast_overflow( + &format!("{}.{}BD", truncated, decimal), + &format!("DECIMAL({},{})", $precision, $scale), + $dest_type_str, + )); + } + Ok(Some(truncated as $rust_dest_type)) + } + None => Ok(None), + }) + .collect::>()?, + _ => cast_array + .iter() + .map(|value| match value { + Some(value) => { + let divisor = 10_i128.pow($scale as u32); + let truncated = value / divisor; + Ok::, SparkError>(Some( + truncated as $rust_dest_type, + )) + } + None => Ok(None), + }) + .collect::>()?, + }; + Ok(Arc::new(output_array) as ArrayRef) + }}; +} + +impl Cast { + pub fn new( + child: Arc, + data_type: DataType, + cast_options: SparkCastOptions, + ) -> Self { + Self { + child, + data_type, + cast_options, + } + } +} + +/// Spark cast options +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub struct SparkCastOptions { + /// Spark evaluation mode + pub eval_mode: EvalMode, + /// When cast from/to timezone related types, we need timezone, which will be resolved with + /// session local timezone by an analyzer in Spark. + // TODO we should change timezone to Tz to avoid repeated parsing + pub timezone: String, + /// Allow casts that are supported but not guaranteed to be 100% compatible + pub allow_incompat: bool, + /// Support casting unsigned ints to signed ints (used by Parquet SchemaAdapter) + pub allow_cast_unsigned_ints: bool, + /// We also use the cast logic for adapting Parquet schemas, so this flag is used + /// for that use case + pub is_adapting_schema: bool, +} + +impl SparkCastOptions { + pub fn new(eval_mode: EvalMode, timezone: &str, allow_incompat: bool) -> Self { + Self { + eval_mode, + timezone: timezone.to_string(), + allow_incompat, + allow_cast_unsigned_ints: false, + is_adapting_schema: false, + } + } + + pub fn new_without_timezone(eval_mode: EvalMode, allow_incompat: bool) -> Self { + Self { + eval_mode, + timezone: "".to_string(), + allow_incompat, + allow_cast_unsigned_ints: false, + is_adapting_schema: false, + } + } +} + +/// Spark-compatible cast implementation. Defers to DataFusion's cast where that is known +/// to be compatible, and returns an error when a not supported and not DF-compatible cast +/// is requested. +pub fn spark_cast( + arg: ColumnarValue, + data_type: &DataType, + cast_options: &SparkCastOptions, +) -> DataFusionResult { + match arg { + ColumnarValue::Array(array) => Ok(ColumnarValue::Array(cast_array( + array, + data_type, + cast_options, + )?)), + ColumnarValue::Scalar(scalar) => { + // Note that normally CAST(scalar) should be fold in Spark JVM side. However, for + // some cases e.g., scalar subquery, Spark will not fold it, so we need to handle it + // here. + let array = scalar.to_array()?; + let scalar = ScalarValue::try_from_array( + &cast_array(array, data_type, cast_options)?, + 0, + )?; + Ok(ColumnarValue::Scalar(scalar)) + } + } +} + +fn cast_array( + array: ArrayRef, + to_type: &DataType, + cast_options: &SparkCastOptions, +) -> DataFusionResult { + use DataType::*; + let array = array_with_timezone(array, cast_options.timezone.clone(), Some(to_type))?; + let from_type = array.data_type().clone(); + + let array = match &from_type { + Dictionary(key_type, value_type) + if key_type.as_ref() == &Int32 + && (value_type.as_ref() == &Utf8 + || value_type.as_ref() == &LargeUtf8) => + { + let dict_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a dictionary array"); + + let casted_dictionary = DictionaryArray::::new( + dict_array.keys().clone(), + cast_array(Arc::clone(dict_array.values()), to_type, cast_options)?, + ); + + let casted_result = match to_type { + Dictionary(_, _) => Arc::new(casted_dictionary.clone()), + _ => take(casted_dictionary.values().as_ref(), dict_array.keys(), None)?, + }; + return Ok(spark_cast_postprocess(casted_result, &from_type, to_type)); + } + _ => array, + }; + let from_type = array.data_type(); + let eval_mode = cast_options.eval_mode; + + let cast_result = match (from_type, to_type) { + (Utf8, Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), + (LargeUtf8, Boolean) => spark_cast_utf8_to_boolean::(&array, eval_mode), + (Utf8, Timestamp(_, _)) => { + cast_string_to_timestamp(&array, to_type, eval_mode, &cast_options.timezone) + } + (Utf8, Date32) => cast_string_to_date(&array, to_type, eval_mode), + (Int64, Int32) + | (Int64, Int16) + | (Int64, Int8) + | (Int32, Int16) + | (Int32, Int8) + | (Int16, Int8) + if eval_mode != EvalMode::Try => + { + spark_cast_int_to_int(&array, eval_mode, from_type, to_type) + } + (Utf8, Int8 | Int16 | Int32 | Int64) => { + cast_string_to_int::(to_type, &array, eval_mode) + } + (LargeUtf8, Int8 | Int16 | Int32 | Int64) => { + cast_string_to_int::(to_type, &array, eval_mode) + } + (Float64, Utf8) => spark_cast_float64_to_utf8::(&array, eval_mode), + (Float64, LargeUtf8) => spark_cast_float64_to_utf8::(&array, eval_mode), + (Float32, Utf8) => spark_cast_float32_to_utf8::(&array, eval_mode), + (Float32, LargeUtf8) => spark_cast_float32_to_utf8::(&array, eval_mode), + (Float32, Decimal128(precision, scale)) => { + cast_float32_to_decimal128(&array, *precision, *scale, eval_mode) + } + (Float64, Decimal128(precision, scale)) => { + cast_float64_to_decimal128(&array, *precision, *scale, eval_mode) + } + (Float32, Int8) + | (Float32, Int16) + | (Float32, Int32) + | (Float32, Int64) + | (Float64, Int8) + | (Float64, Int16) + | (Float64, Int32) + | (Float64, Int64) + | (Decimal128(_, _), Int8) + | (Decimal128(_, _), Int16) + | (Decimal128(_, _), Int32) + | (Decimal128(_, _), Int64) + if eval_mode != EvalMode::Try => + { + spark_cast_nonintegral_numeric_to_integral( + &array, eval_mode, from_type, to_type, + ) + } + (Struct(_), Utf8) => Ok(casts_struct_to_string(array.as_struct(), cast_options)?), + (Struct(_), Struct(_)) => Ok(cast_struct_to_struct( + array.as_struct(), + from_type, + to_type, + cast_options, + )?), + (UInt8 | UInt16 | UInt32 | UInt64, Int8 | Int16 | Int32 | Int64) + if cast_options.allow_cast_unsigned_ints => + { + Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) + } + _ if cast_options.is_adapting_schema + || is_datafusion_spark_compatible( + from_type, + to_type, + cast_options.allow_incompat, + ) => + { + // use DataFusion cast only when we know that it is compatible with Spark + Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?) + } + _ => { + // we should never reach this code because the Scala code should be checking + // for supported cast operations and falling back to Spark for anything that + // is not yet supported + Err(SparkError::Internal(format!( + "Native cast invoked for unsupported cast from {from_type:?} to {to_type:?}" + ))) + } + }; + Ok(spark_cast_postprocess(cast_result?, from_type, to_type)) +} + +/// Determines if DataFusion supports the given cast in a way that is +/// compatible with Spark +fn is_datafusion_spark_compatible( + from_type: &DataType, + to_type: &DataType, + allow_incompat: bool, +) -> bool { + if from_type == to_type { + return true; + } + match from_type { + DataType::Null => { + matches!(to_type, DataType::List(_)) + } + DataType::Boolean => matches!( + to_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Utf8 + ), + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => { + // note that the cast from Int32/Int64 -> Decimal128 here is actually + // not compatible with Spark (no overflow checks) but we have tests that + // rely on this cast working so we have to leave it here for now + matches!( + to_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Utf8 + ) + } + DataType::Float32 | DataType::Float64 => matches!( + to_type, + DataType::Boolean + | DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + ), + DataType::Decimal128(_, _) | DataType::Decimal256(_, _) => matches!( + to_type, + DataType::Int8 + | DataType::Int16 + | DataType::Int32 + | DataType::Int64 + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + | DataType::Decimal256(_, _) + | DataType::Utf8 // note that there can be formatting differences + ), + DataType::Utf8 if allow_incompat => matches!( + to_type, + DataType::Binary + | DataType::Float32 + | DataType::Float64 + | DataType::Decimal128(_, _) + ), + DataType::Utf8 => matches!(to_type, DataType::Binary), + DataType::Date32 => matches!(to_type, DataType::Utf8), + DataType::Timestamp(_, _) => { + matches!( + to_type, + DataType::Int64 + | DataType::Date32 + | DataType::Utf8 + | DataType::Timestamp(_, _) + ) + } + DataType::Binary => { + // note that this is not completely Spark compatible because + // DataFusion only supports binary data containing valid UTF-8 strings + matches!(to_type, DataType::Utf8) + } + _ => false, + } +} + +/// Cast between struct types based on logic in +/// `org.apache.spark.sql.catalyst.expressions.Cast#castStruct`. +fn cast_struct_to_struct( + array: &StructArray, + from_type: &DataType, + to_type: &DataType, + cast_options: &SparkCastOptions, +) -> DataFusionResult { + match (from_type, to_type) { + (DataType::Struct(from_fields), DataType::Struct(to_fields)) => { + // TODO some of this logic may be specific to converting Parquet to Spark + let mut field_name_to_index_map = HashMap::new(); + for (i, field) in from_fields.iter().enumerate() { + field_name_to_index_map.insert(field.name(), i); + } + assert_eq!(field_name_to_index_map.len(), from_fields.len()); + let mut cast_fields: Vec = Vec::with_capacity(to_fields.len()); + for i in 0..to_fields.len() { + let from_index = field_name_to_index_map[to_fields[i].name()]; + let cast_field = cast_array( + Arc::clone(array.column(from_index)), + to_fields[i].data_type(), + cast_options, + )?; + cast_fields.push(cast_field); + } + Ok(Arc::new(StructArray::new( + to_fields.clone(), + cast_fields, + array.nulls().cloned(), + ))) + } + _ => unreachable!(), + } +} + +fn casts_struct_to_string( + array: &StructArray, + spark_cast_options: &SparkCastOptions, +) -> DataFusionResult { + // cast each field to a string + let string_arrays: Vec = array + .columns() + .iter() + .map(|arr| { + spark_cast( + ColumnarValue::Array(Arc::clone(arr)), + &DataType::Utf8, + spark_cast_options, + ) + .and_then(|cv| cv.into_array(arr.len())) + }) + .collect::>>()?; + let string_arrays: Vec<&StringArray> = + string_arrays.iter().map(|arr| arr.as_string()).collect(); + // build the struct string containing entries in the format `"field_name":field_value` + let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16); + let mut str = String::with_capacity(array.len() * 16); + for row_index in 0..array.len() { + if array.is_null(row_index) { + builder.append_null(); + } else { + str.clear(); + let mut any_fields_written = false; + str.push('{'); + for field in &string_arrays { + if any_fields_written { + str.push_str(", "); + } + if field.is_null(row_index) { + str.push_str("null"); + } else { + str.push_str(field.value(row_index)); + } + any_fields_written = true; + } + str.push('}'); + builder.append_value(&str); + } + } + Ok(Arc::new(builder.finish())) +} + +fn cast_string_to_int( + to_type: &DataType, + array: &ArrayRef, + eval_mode: EvalMode, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("cast_string_to_int expected a string array"); + + let cast_array: ArrayRef = match to_type { + DataType::Int8 => { + cast_utf8_to_int!(string_array, eval_mode, Int8Type, cast_string_to_i8)? + } + DataType::Int16 => { + cast_utf8_to_int!(string_array, eval_mode, Int16Type, cast_string_to_i16)? + } + DataType::Int32 => { + cast_utf8_to_int!(string_array, eval_mode, Int32Type, cast_string_to_i32)? + } + DataType::Int64 => { + cast_utf8_to_int!(string_array, eval_mode, Int64Type, cast_string_to_i64)? + } + dt => unreachable!( + "{}", + format!("invalid integer type {dt} in cast from string") + ), + }; + Ok(cast_array) +} + +fn cast_string_to_date( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + if to_type != &DataType::Date32 { + unreachable!("Invalid data type {:?} in cast from string", to_type); + } + + let len = string_array.len(); + let mut cast_array = PrimitiveArray::::builder(len); + + for i in 0..len { + let value = if string_array.is_null(i) { + None + } else { + match date_parser(string_array.value(i), eval_mode) { + Ok(Some(cast_value)) => Some(cast_value), + Ok(None) => None, + Err(e) => return Err(e), + } + }; + + match value { + Some(cast_value) => cast_array.append_value(cast_value), + None => cast_array.append_null(), + } + } + + Ok(Arc::new(cast_array.finish()) as ArrayRef) +} + +fn cast_string_to_timestamp( + array: &ArrayRef, + to_type: &DataType, + eval_mode: EvalMode, + timezone_str: &str, +) -> SparkResult { + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let tz = &timezone::Tz::from_str(timezone_str).unwrap(); + + let cast_array: ArrayRef = match to_type { + DataType::Timestamp(_, _) => { + cast_utf8_to_timestamp!( + string_array, + eval_mode, + TimestampMicrosecondType, + timestamp_parser, + tz + ) + } + _ => unreachable!("Invalid data type {:?} in cast from string", to_type), + }; + Ok(cast_array) +} + +fn cast_float64_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> SparkResult { + cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) +} + +fn cast_float32_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> SparkResult { + cast_floating_point_to_decimal128::(array, precision, scale, eval_mode) +} + +fn cast_floating_point_to_decimal128( + array: &dyn Array, + precision: u8, + scale: i8, + eval_mode: EvalMode, +) -> SparkResult +where + ::Native: AsPrimitive, +{ + let input = array.as_any().downcast_ref::>().unwrap(); + let mut cast_array = PrimitiveArray::::builder(input.len()); + + let mul = 10_f64.powi(scale as i32); + + for i in 0..input.len() { + if input.is_null(i) { + cast_array.append_null(); + } else { + let input_value = input.value(i).as_(); + let value = (input_value * mul).round().to_i128(); + + match value { + Some(v) => { + if Decimal128Type::validate_decimal_precision(v, precision).is_err() { + if eval_mode == EvalMode::Ansi { + return Err(SparkError::NumericValueOutOfRange { + value: input_value.to_string(), + precision, + scale, + }); + } else { + cast_array.append_null(); + } + } + cast_array.append_value(v); + } + None => { + if eval_mode == EvalMode::Ansi { + return Err(SparkError::NumericValueOutOfRange { + value: input_value.to_string(), + precision, + scale, + }); + } else { + cast_array.append_null(); + } + } + } + } + } + + let res = Arc::new( + cast_array + .with_precision_and_scale(precision, scale)? + .finish(), + ) as ArrayRef; + Ok(res) +} + +fn spark_cast_float64_to_utf8( + from: &dyn Array, + _eval_mode: EvalMode, +) -> SparkResult +where + OffsetSize: OffsetSizeTrait, +{ + cast_float_to_string!(from, _eval_mode, f64, Float64Array, OffsetSize) +} + +fn spark_cast_float32_to_utf8( + from: &dyn Array, + _eval_mode: EvalMode, +) -> SparkResult +where + OffsetSize: OffsetSizeTrait, +{ + cast_float_to_string!(from, _eval_mode, f32, Float32Array, OffsetSize) +} + +fn spark_cast_int_to_int( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, +) -> SparkResult { + match (from_type, to_type) { + (DataType::Int64, DataType::Int32) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int32Type, from_type, i32, "BIGINT", "INT" + ), + (DataType::Int64, DataType::Int16) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int16Type, from_type, i16, "BIGINT", "SMALLINT" + ), + (DataType::Int64, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int64Type, Int8Type, from_type, i8, "BIGINT", "TINYINT" + ), + (DataType::Int32, DataType::Int16) => cast_int_to_int_macro!( + array, eval_mode, Int32Type, Int16Type, from_type, i16, "INT", "SMALLINT" + ), + (DataType::Int32, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int32Type, Int8Type, from_type, i8, "INT", "TINYINT" + ), + (DataType::Int16, DataType::Int8) => cast_int_to_int_macro!( + array, eval_mode, Int16Type, Int8Type, from_type, i8, "SMALLINT", "TINYINT" + ), + _ => unreachable!( + "{}", + format!("invalid integer type {to_type} in cast from {from_type}") + ), + } +} + +fn spark_cast_utf8_to_boolean( + from: &dyn Array, + eval_mode: EvalMode, +) -> SparkResult +where + OffsetSize: OffsetSizeTrait, +{ + let array = from + .as_any() + .downcast_ref::>() + .unwrap(); + + let output_array = array + .iter() + .map(|value| match value { + Some(value) => match value.to_ascii_lowercase().trim() { + "t" | "true" | "y" | "yes" | "1" => Ok(Some(true)), + "f" | "false" | "n" | "no" | "0" => Ok(Some(false)), + _ if eval_mode == EvalMode::Ansi => Err(SparkError::CastInvalidValue { + value: value.to_string(), + from_type: "STRING".to_string(), + to_type: "BOOLEAN".to_string(), + }), + _ => Ok(None), + }, + _ => Ok(None), + }) + .collect::>()?; + + Ok(Arc::new(output_array)) +} + +fn spark_cast_nonintegral_numeric_to_integral( + array: &dyn Array, + eval_mode: EvalMode, + from_type: &DataType, + to_type: &DataType, +) -> SparkResult { + match (from_type, to_type) { + (DataType::Float32, DataType::Int8) => cast_float_to_int16_down!( + array, + eval_mode, + Float32Array, + Int8Array, + f32, + i8, + "FLOAT", + "TINYINT", + "{:e}" + ), + (DataType::Float32, DataType::Int16) => cast_float_to_int16_down!( + array, + eval_mode, + Float32Array, + Int16Array, + f32, + i16, + "FLOAT", + "SMALLINT", + "{:e}" + ), + (DataType::Float32, DataType::Int32) => cast_float_to_int32_up!( + array, + eval_mode, + Float32Array, + Int32Array, + f32, + i32, + "FLOAT", + "INT", + i32::MAX, + "{:e}" + ), + (DataType::Float32, DataType::Int64) => cast_float_to_int32_up!( + array, + eval_mode, + Float32Array, + Int64Array, + f32, + i64, + "FLOAT", + "BIGINT", + i64::MAX, + "{:e}" + ), + (DataType::Float64, DataType::Int8) => cast_float_to_int16_down!( + array, + eval_mode, + Float64Array, + Int8Array, + f64, + i8, + "DOUBLE", + "TINYINT", + "{:e}D" + ), + (DataType::Float64, DataType::Int16) => cast_float_to_int16_down!( + array, + eval_mode, + Float64Array, + Int16Array, + f64, + i16, + "DOUBLE", + "SMALLINT", + "{:e}D" + ), + (DataType::Float64, DataType::Int32) => cast_float_to_int32_up!( + array, + eval_mode, + Float64Array, + Int32Array, + f64, + i32, + "DOUBLE", + "INT", + i32::MAX, + "{:e}D" + ), + (DataType::Float64, DataType::Int64) => cast_float_to_int32_up!( + array, + eval_mode, + Float64Array, + Int64Array, + f64, + i64, + "DOUBLE", + "BIGINT", + i64::MAX, + "{:e}D" + ), + (DataType::Decimal128(precision, scale), DataType::Int8) => { + cast_decimal_to_int16_down!( + array, eval_mode, Int8Array, i8, "TINYINT", precision, *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int16) => { + cast_decimal_to_int16_down!( + array, eval_mode, Int16Array, i16, "SMALLINT", precision, *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int32) => { + cast_decimal_to_int32_up!( + array, + eval_mode, + Int32Array, + i32, + "INT", + i32::MAX, + *precision, + *scale + ) + } + (DataType::Decimal128(precision, scale), DataType::Int64) => { + cast_decimal_to_int32_up!( + array, + eval_mode, + Int64Array, + i64, + "BIGINT", + i64::MAX, + *precision, + *scale + ) + } + _ => unreachable!( + "{}", + format!("invalid cast from non-integral numeric type: {from_type} to integral numeric type: {to_type}") + ), + } +} + +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toByte +fn cast_string_to_i8(str: &str, eval_mode: EvalMode) -> SparkResult> { + Ok(cast_string_to_int_with_range_check( + str, + eval_mode, + "TINYINT", + i8::MIN as i32, + i8::MAX as i32, + )? + .map(|v| v as i8)) +} + +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toShort +fn cast_string_to_i16(str: &str, eval_mode: EvalMode) -> SparkResult> { + Ok(cast_string_to_int_with_range_check( + str, + eval_mode, + "SMALLINT", + i16::MIN as i32, + i16::MAX as i32, + )? + .map(|v| v as i16)) +} + +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper) +fn cast_string_to_i32(str: &str, eval_mode: EvalMode) -> SparkResult> { + do_cast_string_to_int::(str, eval_mode, "INT", i32::MIN) +} + +/// Equivalent to org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper intWrapper) +fn cast_string_to_i64(str: &str, eval_mode: EvalMode) -> SparkResult> { + do_cast_string_to_int::(str, eval_mode, "BIGINT", i64::MIN) +} + +fn cast_string_to_int_with_range_check( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min: i32, + max: i32, +) -> SparkResult> { + match do_cast_string_to_int(str, eval_mode, type_name, i32::MIN)? { + None => Ok(None), + Some(v) if v >= min && v <= max => Ok(Some(v)), + _ if eval_mode == EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), + _ => Ok(None), + } +} + +/// Equivalent to +/// - org.apache.spark.unsafe.types.UTF8String.toInt(IntWrapper intWrapper, boolean allowDecimal) +/// - org.apache.spark.unsafe.types.UTF8String.toLong(LongWrapper longWrapper, boolean allowDecimal) +fn do_cast_string_to_int< + T: Num + PartialOrd + Integer + CheckedSub + CheckedNeg + From + Copy, +>( + str: &str, + eval_mode: EvalMode, + type_name: &str, + min_value: T, +) -> SparkResult> { + let trimmed_str = str.trim(); + if trimmed_str.is_empty() { + return none_or_err(eval_mode, type_name, str); + } + let len = trimmed_str.len(); + let mut result: T = T::zero(); + let mut negative = false; + let radix = T::from(10); + let stop_value = min_value / radix; + let mut parse_sign_and_digits = true; + + for (i, ch) in trimmed_str.char_indices() { + if parse_sign_and_digits { + if i == 0 { + negative = ch == '-'; + let positive = ch == '+'; + if negative || positive { + if i + 1 == len { + // input string is just "+" or "-" + return none_or_err(eval_mode, type_name, str); + } + // consume this char + continue; + } + } + + if ch == '.' { + if eval_mode == EvalMode::Legacy { + // truncate decimal in legacy mode + parse_sign_and_digits = false; + continue; + } else { + return none_or_err(eval_mode, type_name, str); + } + } + + let digit = if ch.is_ascii_digit() { + (ch as u32) - ('0' as u32) + } else { + return none_or_err(eval_mode, type_name, str); + }; + + // We are going to process the new digit and accumulate the result. However, before + // doing this, if the result is already smaller than the + // stopValue(Integer.MIN_VALUE / radix), then result * 10 will definitely be + // smaller than minValue, and we can stop + if result < stop_value { + return none_or_err(eval_mode, type_name, str); + } + + // Since the previous result is greater than or equal to stopValue(Integer.MIN_VALUE / + // radix), we can just use `result > 0` to check overflow. If result + // overflows, we should stop + let v = result * radix; + let digit = (digit as i32).into(); + match v.checked_sub(&digit) { + Some(x) if x <= T::zero() => result = x, + _ => { + return none_or_err(eval_mode, type_name, str); + } + } + } else { + // make sure fractional digits are valid digits but ignore them + if !ch.is_ascii_digit() { + return none_or_err(eval_mode, type_name, str); + } + } + } + + if !negative { + if let Some(neg) = result.checked_neg() { + if neg < T::zero() { + return none_or_err(eval_mode, type_name, str); + } + result = neg; + } else { + return none_or_err(eval_mode, type_name, str); + } + } + + Ok(Some(result)) +} + +/// Either return Ok(None) or Err(SparkError::CastInvalidValue) depending on the evaluation mode +#[inline] +fn none_or_err( + eval_mode: EvalMode, + type_name: &str, + str: &str, +) -> SparkResult> { + match eval_mode { + EvalMode::Ansi => Err(invalid_value(str, "STRING", type_name)), + _ => Ok(None), + } +} + +#[inline] +fn invalid_value(value: &str, from_type: &str, to_type: &str) -> SparkError { + SparkError::CastInvalidValue { + value: value.to_string(), + from_type: from_type.to_string(), + to_type: to_type.to_string(), + } +} + +#[inline] +fn cast_overflow(value: &str, from_type: &str, to_type: &str) -> SparkError { + SparkError::CastOverFlow { + value: value.to_string(), + from_type: from_type.to_string(), + to_type: to_type.to_string(), + } +} + +impl Display for Cast { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Cast [data_type: {}, timezone: {}, child: {}, eval_mode: {:?}]", + self.data_type, + self.cast_options.timezone, + self.child, + &self.cast_options.eval_mode + ) + } +} + +impl PhysicalExpr for Cast { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _: &Schema) -> DataFusionResult { + Ok(self.data_type.clone()) + } + + fn nullable(&self, _: &Schema) -> DataFusionResult { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let arg = self.child.evaluate(batch)?; + spark_cast(arg, &self.data_type, &self.cast_options) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + match children.len() { + 1 => Ok(Arc::new(Cast::new( + Arc::clone(&children[0]), + self.data_type.clone(), + self.cast_options.clone(), + ))), + _ => internal_err!("Cast should have exactly one child"), + } + } +} + +fn timestamp_parser( + value: &str, + eval_mode: EvalMode, + tz: &T, +) -> SparkResult> { + let value = value.trim(); + if value.is_empty() { + return Ok(None); + } + // Define regex patterns and corresponding parsing functions + let patterns = &[ + ( + Regex::new(r"^\d{4,5}$").unwrap(), + parse_str_to_year_timestamp as fn(&str, &T) -> SparkResult>, + ), + ( + Regex::new(r"^\d{4,5}-\d{2}$").unwrap(), + parse_str_to_month_timestamp, + ), + ( + Regex::new(r"^\d{4,5}-\d{2}-\d{2}$").unwrap(), + parse_str_to_day_timestamp, + ), + ( + Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{1,2}$").unwrap(), + parse_str_to_hour_timestamp, + ), + ( + Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}$").unwrap(), + parse_str_to_minute_timestamp, + ), + ( + Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}$").unwrap(), + parse_str_to_second_timestamp, + ), + ( + Regex::new(r"^\d{4,5}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d{1,6}$").unwrap(), + parse_str_to_microsecond_timestamp, + ), + ( + Regex::new(r"^T\d{1,2}$").unwrap(), + parse_str_to_time_only_timestamp, + ), + ]; + + let mut timestamp = None; + + // Iterate through patterns and try matching + for (pattern, parse_func) in patterns { + if pattern.is_match(value) { + timestamp = parse_func(value, tz)?; + break; + } + } + + if timestamp.is_none() { + return if eval_mode == EvalMode::Ansi { + Err(SparkError::CastInvalidValue { + value: value.to_string(), + from_type: "STRING".to_string(), + to_type: "TIMESTAMP".to_string(), + }) + } else { + Ok(None) + }; + } + + match timestamp { + Some(ts) => Ok(Some(ts)), + None => Err(SparkError::Internal( + "Failed to parse timestamp".to_string(), + )), + } +} + +fn parse_timestamp_to_micros( + timestamp_info: &TimeStampInfo, + tz: &T, +) -> SparkResult> { + let datetime = tz.with_ymd_and_hms( + timestamp_info.year, + timestamp_info.month, + timestamp_info.day, + timestamp_info.hour, + timestamp_info.minute, + timestamp_info.second, + ); + + // Check if datetime is not None + let tz_datetime = match datetime.single() { + Some(dt) => dt + .with_timezone(tz) + .with_nanosecond(timestamp_info.microsecond * 1000), + None => { + return Err(SparkError::Internal( + "Failed to parse timestamp".to_string(), + )); + } + }; + + let result = match tz_datetime { + Some(dt) => dt.timestamp_micros(), + None => { + return Err(SparkError::Internal( + "Failed to parse timestamp".to_string(), + )); + } + }; + + Ok(Some(result)) +} + +fn get_timestamp_values( + value: &str, + timestamp_type: &str, + tz: &T, +) -> SparkResult> { + let values: Vec<_> = value.split(['T', '-', ':', '.']).collect(); + let year = values[0].parse::().unwrap_or_default(); + let month = values.get(1).map_or(1, |m| m.parse::().unwrap_or(1)); + let day = values.get(2).map_or(1, |d| d.parse::().unwrap_or(1)); + let hour = values.get(3).map_or(0, |h| h.parse::().unwrap_or(0)); + let minute = values.get(4).map_or(0, |m| m.parse::().unwrap_or(0)); + let second = values.get(5).map_or(0, |s| s.parse::().unwrap_or(0)); + let microsecond = values.get(6).map_or(0, |ms| ms.parse::().unwrap_or(0)); + + let mut timestamp_info = TimeStampInfo::default(); + + let timestamp_info = match timestamp_type { + "year" => timestamp_info.with_year(year), + "month" => timestamp_info.with_year(year).with_month(month), + "day" => timestamp_info + .with_year(year) + .with_month(month) + .with_day(day), + "hour" => timestamp_info + .with_year(year) + .with_month(month) + .with_day(day) + .with_hour(hour), + "minute" => timestamp_info + .with_year(year) + .with_month(month) + .with_day(day) + .with_hour(hour) + .with_minute(minute), + "second" => timestamp_info + .with_year(year) + .with_month(month) + .with_day(day) + .with_hour(hour) + .with_minute(minute) + .with_second(second), + "microsecond" => timestamp_info + .with_year(year) + .with_month(month) + .with_day(day) + .with_hour(hour) + .with_minute(minute) + .with_second(second) + .with_microsecond(microsecond), + _ => { + return Err(SparkError::CastInvalidValue { + value: value.to_string(), + from_type: "STRING".to_string(), + to_type: "TIMESTAMP".to_string(), + }) + } + }; + + parse_timestamp_to_micros(timestamp_info, tz) +} + +fn parse_str_to_year_timestamp( + value: &str, + tz: &T, +) -> SparkResult> { + get_timestamp_values(value, "year", tz) +} + +fn parse_str_to_month_timestamp( + value: &str, + tz: &T, +) -> SparkResult> { + get_timestamp_values(value, "month", tz) +} + +fn parse_str_to_day_timestamp( + value: &str, + tz: &T, +) -> SparkResult> { + get_timestamp_values(value, "day", tz) +} + +fn parse_str_to_hour_timestamp( + value: &str, + tz: &T, +) -> SparkResult> { + get_timestamp_values(value, "hour", tz) +} + +fn parse_str_to_minute_timestamp( + value: &str, + tz: &T, +) -> SparkResult> { + get_timestamp_values(value, "minute", tz) +} + +fn parse_str_to_second_timestamp( + value: &str, + tz: &T, +) -> SparkResult> { + get_timestamp_values(value, "second", tz) +} + +fn parse_str_to_microsecond_timestamp( + value: &str, + tz: &T, +) -> SparkResult> { + get_timestamp_values(value, "microsecond", tz) +} + +fn parse_str_to_time_only_timestamp( + value: &str, + tz: &T, +) -> SparkResult> { + let values: Vec<&str> = value.split('T').collect(); + let time_values: Vec = values[1] + .split(':') + .map(|v| v.parse::().unwrap_or(0)) + .collect(); + + let datetime = tz.from_utc_datetime(&chrono::Utc::now().naive_utc()); + let timestamp = datetime + .with_timezone(tz) + .with_hour(time_values.first().copied().unwrap_or_default()) + .and_then(|dt| dt.with_minute(*time_values.get(1).unwrap_or(&0))) + .and_then(|dt| dt.with_second(*time_values.get(2).unwrap_or(&0))) + .and_then(|dt| dt.with_nanosecond(*time_values.get(3).unwrap_or(&0) * 1_000)) + .map(|dt| dt.timestamp_micros()) + .unwrap_or_default(); + + Ok(Some(timestamp)) +} + +//a string to date parser - port of spark's SparkDateTimeUtils#stringToDate. +fn date_parser(date_str: &str, eval_mode: EvalMode) -> SparkResult> { + // local functions + fn get_trimmed_start(bytes: &[u8]) -> usize { + let mut start = 0; + while start < bytes.len() && is_whitespace_or_iso_control(bytes[start]) { + start += 1; + } + start + } + + fn get_trimmed_end(start: usize, bytes: &[u8]) -> usize { + let mut end = bytes.len() - 1; + while end > start && is_whitespace_or_iso_control(bytes[end]) { + end -= 1; + } + end + 1 + } + + fn is_whitespace_or_iso_control(byte: u8) -> bool { + byte.is_ascii_whitespace() || byte.is_ascii_control() + } + + fn is_valid_digits(segment: i32, digits: usize) -> bool { + // An integer is able to represent a date within [+-]5 million years. + let max_digits_year = 7; + //year (segment 0) can be between 4 to 7 digits, + //month and day (segment 1 and 2) can be between 1 to 2 digits + (segment == 0 && digits >= 4 && digits <= max_digits_year) + || (segment != 0 && digits > 0 && digits <= 2) + } + + fn return_result(date_str: &str, eval_mode: EvalMode) -> SparkResult> { + if eval_mode == EvalMode::Ansi { + Err(SparkError::CastInvalidValue { + value: date_str.to_string(), + from_type: "STRING".to_string(), + to_type: "DATE".to_string(), + }) + } else { + Ok(None) + } + } + // end local functions + + if date_str.is_empty() { + return return_result(date_str, eval_mode); + } + + //values of date segments year, month and day defaulting to 1 + let mut date_segments = [1, 1, 1]; + let mut sign = 1; + let mut current_segment = 0; + let mut current_segment_value = Wrapping(0); + let mut current_segment_digits = 0; + let bytes = date_str.as_bytes(); + + let mut j = get_trimmed_start(bytes); + let str_end_trimmed = get_trimmed_end(j, bytes); + + if j == str_end_trimmed { + return return_result(date_str, eval_mode); + } + + //assign a sign to the date + if bytes[j] == b'-' || bytes[j] == b'+' { + sign = if bytes[j] == b'-' { -1 } else { 1 }; + j += 1; + } + + //loop to the end of string until we have processed 3 segments, + //exit loop on encountering any space ' ' or 'T' after the 3rd segment + while j < str_end_trimmed + && (current_segment < 3 && !(bytes[j] == b' ' || bytes[j] == b'T')) + { + let b = bytes[j]; + if current_segment < 2 && b == b'-' { + //check for validity of year and month segments if current byte is separator + if !is_valid_digits(current_segment, current_segment_digits) { + return return_result(date_str, eval_mode); + } + //if valid update corresponding segment with the current segment value. + date_segments[current_segment as usize] = current_segment_value.0; + current_segment_value = Wrapping(0); + current_segment_digits = 0; + current_segment += 1; + } else if !b.is_ascii_digit() { + return return_result(date_str, eval_mode); + } else { + //increment value of current segment by the next digit + let parsed_value = Wrapping((b - b'0') as i32); + current_segment_value = current_segment_value * Wrapping(10) + parsed_value; + current_segment_digits += 1; + } + j += 1; + } + + //check for validity of last segment + if !is_valid_digits(current_segment, current_segment_digits) { + return return_result(date_str, eval_mode); + } + + if current_segment < 2 && j < str_end_trimmed { + // For the `yyyy` and `yyyy-[m]m` formats, entire input must be consumed. + return return_result(date_str, eval_mode); + } + + date_segments[current_segment as usize] = current_segment_value.0; + + match NaiveDate::from_ymd_opt( + sign * date_segments[0], + date_segments[1] as u32, + date_segments[2] as u32, + ) { + Some(date) => { + let duration_since_epoch = date + .signed_duration_since(NaiveDateTime::UNIX_EPOCH.date()) + .num_days(); + Ok(Some(duration_since_epoch.to_i32().unwrap())) + } + None => Ok(None), + } +} + +/// This takes for special casting cases of Spark. E.g., Timestamp to Long. +/// This function runs as a post process of the DataFusion cast(). By the time it arrives here, +/// Dictionary arrays are already unpacked by the DataFusion cast() since Spark cannot specify +/// Dictionary as to_type. The from_type is taken before the DataFusion cast() runs in +/// expressions/cast.rs, so it can be still Dictionary. +fn spark_cast_postprocess( + array: ArrayRef, + from_type: &DataType, + to_type: &DataType, +) -> ArrayRef { + match (from_type, to_type) { + (DataType::Timestamp(_, _), DataType::Int64) => { + // See Spark's `Cast` expression + unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)) + .unwrap() + } + (DataType::Dictionary(_, value_type), DataType::Int64) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + // See Spark's `Cast` expression + unary_dyn::<_, Int64Type>(&array, |v| div_floor(v, MICROS_PER_SECOND)) + .unwrap() + } + (DataType::Timestamp(_, _), DataType::Utf8) => remove_trailing_zeroes(array), + (DataType::Dictionary(_, value_type), DataType::Utf8) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + remove_trailing_zeroes(array) + } + _ => array, + } +} + +/// A fork & modified version of Arrow's `unary_dyn` which is being deprecated +fn unary_dyn(array: &ArrayRef, op: F) -> Result +where + T: ArrowPrimitiveType, + F: Fn(T::Native) -> T::Native, +{ + if let Some(d) = array.as_any_dictionary_opt() { + let new_values = unary_dyn::(d.values(), op)?; + return Ok(Arc::new(d.with_values(Arc::new(new_values)))); + } + + match array.as_primitive_opt::() { + Some(a) if PrimitiveArray::::is_compatible(a.data_type()) => { + Ok(Arc::new(unary::( + array.as_any().downcast_ref::>().unwrap(), + op, + ))) + } + _ => Err(ArrowError::NotYetImplemented(format!( + "Cannot perform unary operation of type {} on array of type {}", + T::DATA_TYPE, + array.data_type() + ))), + } +} + +/// Remove any trailing zeroes in the string if they occur after in the fractional seconds, +/// to match Spark behavior +/// example: +/// "1970-01-01 05:29:59.900" => "1970-01-01 05:29:59.9" +/// "1970-01-01 05:29:59.990" => "1970-01-01 05:29:59.99" +/// "1970-01-01 05:29:59.999" => "1970-01-01 05:29:59.999" +/// "1970-01-01 05:30:00" => "1970-01-01 05:30:00" +/// "1970-01-01 05:30:00.001" => "1970-01-01 05:30:00.001" +fn remove_trailing_zeroes(array: ArrayRef) -> ArrayRef { + let string_array = as_generic_string_array::(&array).unwrap(); + let result = string_array + .iter() + .map(|s| s.map(trim_end)) + .collect::>(); + Arc::new(result) as ArrayRef +} + +fn trim_end(s: &str) -> &str { + if s.rfind('.').is_some() { + s.trim_end_matches('0') + } else { + s + } +} + +#[cfg(test)] +mod tests { + use arrow::datatypes::TimestampMicrosecondType; + use arrow_array::StringArray; + use arrow_schema::{Field, Fields, TimeUnit}; + use std::str::FromStr; + + use super::*; + + #[test] + #[cfg_attr(miri, ignore)] // test takes too long with miri + fn timestamp_parser_test() { + let tz = &timezone::Tz::from_str("UTC").unwrap(); + // write for all formats + assert_eq!( + timestamp_parser("2020", EvalMode::Legacy, tz).unwrap(), + Some(1577836800000000) // this is in milliseconds + ); + assert_eq!( + timestamp_parser("2020-01", EvalMode::Legacy, tz).unwrap(), + Some(1577836800000000) + ); + assert_eq!( + timestamp_parser("2020-01-01", EvalMode::Legacy, tz).unwrap(), + Some(1577836800000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12", EvalMode::Legacy, tz).unwrap(), + Some(1577880000000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34", EvalMode::Legacy, tz).unwrap(), + Some(1577882040000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(), + Some(1577882096000000) + ); + assert_eq!( + timestamp_parser("2020-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(), + Some(1577882096123456) + ); + assert_eq!( + timestamp_parser("0100", EvalMode::Legacy, tz).unwrap(), + Some(-59011459200000000) + ); + assert_eq!( + timestamp_parser("0100-01", EvalMode::Legacy, tz).unwrap(), + Some(-59011459200000000) + ); + assert_eq!( + timestamp_parser("0100-01-01", EvalMode::Legacy, tz).unwrap(), + Some(-59011459200000000) + ); + assert_eq!( + timestamp_parser("0100-01-01T12", EvalMode::Legacy, tz).unwrap(), + Some(-59011416000000000) + ); + assert_eq!( + timestamp_parser("0100-01-01T12:34", EvalMode::Legacy, tz).unwrap(), + Some(-59011413960000000) + ); + assert_eq!( + timestamp_parser("0100-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(), + Some(-59011413904000000) + ); + assert_eq!( + timestamp_parser("0100-01-01T12:34:56.123456", EvalMode::Legacy, tz).unwrap(), + Some(-59011413903876544) + ); + assert_eq!( + timestamp_parser("10000", EvalMode::Legacy, tz).unwrap(), + Some(253402300800000000) + ); + assert_eq!( + timestamp_parser("10000-01", EvalMode::Legacy, tz).unwrap(), + Some(253402300800000000) + ); + assert_eq!( + timestamp_parser("10000-01-01", EvalMode::Legacy, tz).unwrap(), + Some(253402300800000000) + ); + assert_eq!( + timestamp_parser("10000-01-01T12", EvalMode::Legacy, tz).unwrap(), + Some(253402344000000000) + ); + assert_eq!( + timestamp_parser("10000-01-01T12:34", EvalMode::Legacy, tz).unwrap(), + Some(253402346040000000) + ); + assert_eq!( + timestamp_parser("10000-01-01T12:34:56", EvalMode::Legacy, tz).unwrap(), + Some(253402346096000000) + ); + assert_eq!( + timestamp_parser("10000-01-01T12:34:56.123456", EvalMode::Legacy, tz) + .unwrap(), + Some(253402346096123456) + ); + // assert_eq!( + // timestamp_parser("T2", EvalMode::Legacy).unwrap(), + // Some(1714356000000000) // this value needs to change everyday. + // ); + } + + #[test] + #[cfg_attr(miri, ignore)] // test takes too long with miri + fn test_cast_string_to_timestamp() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020-01-01T12:34:56.123456"), + Some("T2"), + Some("0100-01-01T12:34:56.123456"), + Some("10000-01-01T12:34:56.123456"), + ])); + let tz = &timezone::Tz::from_str("UTC").unwrap(); + + let string_array = array + .as_any() + .downcast_ref::>() + .expect("Expected a string array"); + + let eval_mode = EvalMode::Legacy; + let result = cast_utf8_to_timestamp!( + &string_array, + eval_mode, + TimestampMicrosecondType, + timestamp_parser, + tz + ); + + assert_eq!( + result.data_type(), + &DataType::Timestamp(TimeUnit::Microsecond, Some("UTC".into())) + ); + assert_eq!(result.len(), 4); + } + + #[test] + fn test_cast_dict_string_to_timestamp() -> DataFusionResult<()> { + // prepare input data + let keys = Int32Array::from(vec![0, 1]); + let values: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020-01-01T12:34:56.123456"), + Some("T2"), + ])); + let dict_array = Arc::new(DictionaryArray::new(keys, values)); + + let timezone = "UTC".to_string(); + // test casting string dictionary array to timestamp array + let cast_options = SparkCastOptions::new(EvalMode::Legacy, &timezone, false); + let result = cast_array( + dict_array, + &DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.clone().into())), + &cast_options, + )?; + assert_eq!( + *result.data_type(), + DataType::Timestamp(TimeUnit::Microsecond, Some(timezone.into())) + ); + assert_eq!(result.len(), 2); + + Ok(()) + } + + #[test] + fn date_parser_test() { + for date in &[ + "2020", + "2020-01", + "2020-01-01", + "02020-01-01", + "002020-01-01", + "0002020-01-01", + "2020-1-1", + "2020-01-01 ", + "2020-01-01T", + ] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Ansi, EvalMode::Try] { + assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(18262)); + } + } + + //dates in invalid formats + for date in &[ + "abc", + "", + "not_a_date", + "3/", + "3/12", + "3/12/2020", + "3/12/2002 T", + "202", + "2020-010-01", + "2020-10-010", + "2020-10-010T", + "--262143-12-31", + "--262143-12-31 ", + ] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Try] { + assert_eq!(date_parser(date, *eval_mode).unwrap(), None); + } + assert!(date_parser(date, EvalMode::Ansi).is_err()); + } + + for date in &["-3638-5"] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { + assert_eq!(date_parser(date, *eval_mode).unwrap(), Some(-2048160)); + } + } + + //Naive Date only supports years 262142 AD to 262143 BC + //returns None for dates out of range supported by Naive Date. + for date in &[ + "-262144-1-1", + "262143-01-1", + "262143-1-1", + "262143-01-1 ", + "262143-01-01T ", + "262143-1-01T 1234", + "-0973250", + ] { + for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { + assert_eq!(date_parser(date, *eval_mode).unwrap(), None); + } + } + } + + #[test] + fn test_cast_string_to_date() { + let array: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020"), + Some("2020-01"), + Some("2020-01-01"), + Some("2020-01-01T"), + ])); + + let result = + cast_string_to_date(&array, &DataType::Date32, EvalMode::Legacy).unwrap(); + + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(date32_array.len(), 4); + date32_array + .iter() + .for_each(|v| assert_eq!(v.unwrap(), 18262)); + } + + #[test] + fn test_cast_string_array_with_valid_dates() { + let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![ + Some("-262143-12-31"), + Some("\n -262143-12-31 "), + Some("-262143-12-31T \t\n"), + Some("\n\t-262143-12-31T\r"), + Some("-262143-12-31T 123123123"), + Some("\r\n-262143-12-31T \r123123123"), + Some("\n -262143-12-31T \n\t"), + ])); + + for eval_mode in &[EvalMode::Legacy, EvalMode::Try, EvalMode::Ansi] { + let result = cast_string_to_date( + &array_with_invalid_date, + &DataType::Date32, + *eval_mode, + ) + .unwrap(); + + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(result.len(), 7); + date32_array + .iter() + .for_each(|v| assert_eq!(v.unwrap(), -96464928)); + } + } + + #[test] + fn test_cast_string_array_with_invalid_dates() { + let array_with_invalid_date: ArrayRef = Arc::new(StringArray::from(vec![ + Some("2020"), + Some("2020-01"), + Some("2020-01-01"), + //4 invalid dates + Some("2020-010-01T"), + Some("202"), + Some(" 202 "), + Some("\n 2020-\r8 "), + Some("2020-01-01T"), + // Overflows i32 + Some("-4607172990231812908"), + ])); + + for eval_mode in &[EvalMode::Legacy, EvalMode::Try] { + let result = cast_string_to_date( + &array_with_invalid_date, + &DataType::Date32, + *eval_mode, + ) + .unwrap(); + + let date32_array = result + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!( + date32_array.iter().collect::>(), + vec![ + Some(18262), + Some(18262), + Some(18262), + None, + None, + None, + None, + Some(18262), + None + ] + ); + } + + let result = cast_string_to_date( + &array_with_invalid_date, + &DataType::Date32, + EvalMode::Ansi, + ); + match result { + Err(e) => assert!( + e.to_string().contains( + "[CAST_INVALID_INPUT] The value '2020-010-01T' of the type \"STRING\" cannot be cast to \"DATE\" because it is malformed") + ), + _ => panic!("Expected error"), + } + } + + #[test] + fn test_cast_string_as_i8() { + // basic + assert_eq!( + cast_string_to_i8("127", EvalMode::Legacy).unwrap(), + Some(127_i8) + ); + assert_eq!(cast_string_to_i8("128", EvalMode::Legacy).unwrap(), None); + assert!(cast_string_to_i8("128", EvalMode::Ansi).is_err()); + // decimals + assert_eq!( + cast_string_to_i8("0.2", EvalMode::Legacy).unwrap(), + Some(0_i8) + ); + assert_eq!( + cast_string_to_i8(".", EvalMode::Legacy).unwrap(), + Some(0_i8) + ); + // TRY should always return null for decimals + assert_eq!(cast_string_to_i8("0.2", EvalMode::Try).unwrap(), None); + assert_eq!(cast_string_to_i8(".", EvalMode::Try).unwrap(), None); + // ANSI mode should throw error on decimal + assert!(cast_string_to_i8("0.2", EvalMode::Ansi).is_err()); + assert!(cast_string_to_i8(".", EvalMode::Ansi).is_err()); + } + + #[test] + fn test_cast_unsupported_timestamp_to_date() { + // Since datafusion uses chrono::Datetime internally not all dates representable by TimestampMicrosecondType are supported + let timestamps: PrimitiveArray = vec![i64::MAX].into(); + let cast_options = SparkCastOptions::new(EvalMode::Legacy, "UTC", false); + let result = cast_array( + Arc::new(timestamps.with_timezone("Europe/Copenhagen")), + &DataType::Date32, + &cast_options, + ); + assert!(result.is_err()) + } + + #[test] + fn test_cast_invalid_timezone() { + let timestamps: PrimitiveArray = vec![i64::MAX].into(); + let cast_options = + SparkCastOptions::new(EvalMode::Legacy, "Not a valid timezone", false); + let result = cast_array( + Arc::new(timestamps.with_timezone("Europe/Copenhagen")), + &DataType::Date32, + &cast_options, + ); + assert!(result.is_err()) + } + + #[test] + fn test_cast_struct_to_utf8() { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(4), + Some(5), + ])); + let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); + let c: ArrayRef = Arc::new(StructArray::from(vec![ + (Arc::new(Field::new("a", DataType::Int32, true)), a), + (Arc::new(Field::new("b", DataType::Utf8, true)), b), + ])); + let string_array = cast_array( + c, + &DataType::Utf8, + &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), + ) + .unwrap(); + let string_array = string_array.as_string::(); + assert_eq!(5, string_array.len()); + assert_eq!(r#"{1, a}"#, string_array.value(0)); + assert_eq!(r#"{2, b}"#, string_array.value(1)); + assert_eq!(r#"{null, c}"#, string_array.value(2)); + assert_eq!(r#"{4, d}"#, string_array.value(3)); + assert_eq!(r#"{5, e}"#, string_array.value(4)); + } + + #[test] + fn test_cast_struct_to_struct() { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(4), + Some(5), + ])); + let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); + let c: ArrayRef = Arc::new(StructArray::from(vec![ + (Arc::new(Field::new("a", DataType::Int32, true)), a), + (Arc::new(Field::new("b", DataType::Utf8, true)), b), + ])); + // change type of "a" from Int32 to Utf8 + let fields = Fields::from(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::Utf8, true), + ]); + let cast_array = spark_cast( + ColumnarValue::Array(c), + &DataType::Struct(fields), + &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), + ) + .unwrap(); + if let ColumnarValue::Array(cast_array) = cast_array { + assert_eq!(5, cast_array.len()); + let a = cast_array.as_struct().column(0).as_string::(); + assert_eq!("1", a.value(0)); + } else { + unreachable!() + } + } + + #[test] + fn test_cast_struct_to_struct_drop_column() { + let a: ArrayRef = Arc::new(Int32Array::from(vec![ + Some(1), + Some(2), + None, + Some(4), + Some(5), + ])); + let b: ArrayRef = Arc::new(StringArray::from(vec!["a", "b", "c", "d", "e"])); + let c: ArrayRef = Arc::new(StructArray::from(vec![ + (Arc::new(Field::new("a", DataType::Int32, true)), a), + (Arc::new(Field::new("b", DataType::Utf8, true)), b), + ])); + // change type of "a" from Int32 to Utf8 and drop "b" + let fields = Fields::from(vec![Field::new("a", DataType::Utf8, true)]); + let cast_array = spark_cast( + ColumnarValue::Array(c), + &DataType::Struct(fields), + &SparkCastOptions::new(EvalMode::Legacy, "UTC", false), + ) + .unwrap(); + if let ColumnarValue::Array(cast_array) = cast_array { + assert_eq!(5, cast_array.len()); + let struct_array = cast_array.as_struct(); + assert_eq!(1, struct_array.columns().len()); + let a = struct_array.column(0).as_string::(); + assert_eq!("1", a.value(0)); + } else { + unreachable!() + } + } +} diff --git a/datafusion/spark/src/conversion_funcs/mod.rs b/datafusion/spark/src/conversion_funcs/mod.rs new file mode 100644 index 000000000000..f2c6f7ca368b --- /dev/null +++ b/datafusion/spark/src/conversion_funcs/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod cast; diff --git a/datafusion/spark/src/datetime_funcs/date_arithmetic.rs b/datafusion/spark/src/datetime_funcs/date_arithmetic.rs new file mode 100644 index 000000000000..602b2104e49c --- /dev/null +++ b/datafusion/spark/src/datetime_funcs/date_arithmetic.rs @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, AsArray}; +use arrow::compute::kernels::numeric::{add, sub}; +use arrow::datatypes::IntervalDayTime; +use arrow_array::builder::IntervalDayTimeBuilder; +use arrow_array::types::{Int16Type, Int32Type, Int8Type}; +use arrow_array::{Array, Datum}; +use arrow_schema::{ArrowError, DataType}; +use datafusion::physical_expr_common::datum; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use std::sync::Arc; + +macro_rules! scalar_date_arithmetic { + ($start:expr, $days:expr, $op:expr) => {{ + let interval = IntervalDayTime::new(*$days as i32, 0); + let interval_cv = + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(interval))); + datum::apply($start, &interval_cv, $op) + }}; +} +macro_rules! array_date_arithmetic { + ($days:expr, $interval_builder:expr, $intType:ty) => {{ + for day in $days.as_primitive::<$intType>().into_iter() { + if let Some(non_null_day) = day { + $interval_builder + .append_value(IntervalDayTime::new(non_null_day as i32, 0)); + } else { + $interval_builder.append_null(); + } + } + }}; +} + +/// Spark-compatible `date_add` and `date_sub` expressions, which assumes days for the second +/// argument, but we cannot directly add that to a Date32. We generate an IntervalDayTime from the +/// second argument and use DataFusion's interface to apply Arrow's operators. +fn spark_date_arithmetic( + args: &[ColumnarValue], + op: impl Fn(&dyn Datum, &dyn Datum) -> Result, +) -> Result { + let start = &args[0]; + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Int8(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Scalar(ScalarValue::Int16(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Scalar(ScalarValue::Int32(Some(days))) => { + scalar_date_arithmetic!(start, days, op) + } + ColumnarValue::Array(days) => { + let mut interval_builder = IntervalDayTimeBuilder::with_capacity(days.len()); + match days.data_type() { + DataType::Int8 => { + array_date_arithmetic!(days, interval_builder, Int8Type) + } + DataType::Int16 => { + array_date_arithmetic!(days, interval_builder, Int16Type) + } + DataType::Int32 => { + array_date_arithmetic!(days, interval_builder, Int32Type) + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported data types {:?} for date arithmetic.", + args, + ))) + } + } + let interval_cv = ColumnarValue::Array(Arc::new(interval_builder.finish())); + datum::apply(start, &interval_cv, op) + } + _ => Err(DataFusionError::Internal(format!( + "Unsupported data types {:?} for date arithmetic.", + args, + ))), + } +} + +pub fn spark_date_add(args: &[ColumnarValue]) -> Result { + spark_date_arithmetic(args, add) +} + +pub fn spark_date_sub(args: &[ColumnarValue]) -> Result { + spark_date_arithmetic(args, sub) +} diff --git a/datafusion/spark/src/datetime_funcs/date_trunc.rs b/datafusion/spark/src/datetime_funcs/date_trunc.rs new file mode 100644 index 000000000000..a3b06e6a1c0f --- /dev/null +++ b/datafusion/spark/src/datetime_funcs/date_trunc.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::record_batch::RecordBatch; +use arrow_schema::{DataType, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue::Utf8}; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +use crate::kernels::temporal::{date_trunc_array_fmt_dyn, date_trunc_dyn}; + +#[derive(Debug, Eq)] +pub struct DateTruncExpr { + /// An array with DataType::Date32 + child: Arc, + /// Scalar UTF8 string matching the valid values in Spark SQL: + format: Arc, +} + +impl Hash for DateTruncExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.format.hash(state); + } +} +impl PartialEq for DateTruncExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.format.eq(&other.format) + } +} + +impl DateTruncExpr { + pub fn new(child: Arc, format: Arc) -> Self { + DateTruncExpr { child, format } + } +} + +impl Display for DateTruncExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "DateTrunc [child:{}, format: {}]", + self.child, self.format + ) + } +} + +impl PhysicalExpr for DateTruncExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + self.child.data_type(input_schema) + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let date = self.child.evaluate(batch)?; + let format = self.format.evaluate(batch)?; + match (date, format) { + (ColumnarValue::Array(date), ColumnarValue::Scalar(Utf8(Some(format)))) => { + let result = date_trunc_dyn(&date, format)?; + Ok(ColumnarValue::Array(result)) + } + (ColumnarValue::Array(date), ColumnarValue::Array(formats)) => { + let result = date_trunc_array_fmt_dyn(&date, &formats)?; + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Invalid input to function DateTrunc. Expected (PrimitiveArray, Scalar) or \ + (PrimitiveArray, StringArray)".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(DateTruncExpr::new( + Arc::clone(&children[0]), + Arc::clone(&self.format), + ))) + } +} diff --git a/datafusion/spark/src/datetime_funcs/hour.rs b/datafusion/spark/src/datetime_funcs/hour.rs new file mode 100644 index 000000000000..faf9529a5130 --- /dev/null +++ b/datafusion/spark/src/datetime_funcs/hour.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::array_with_timezone; +use arrow::{ + compute::{date_part, DatePart}, + record_batch::RecordBatch, +}; +use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::DataFusionError; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct HourExpr { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + timezone: String, +} + +impl Hash for HourExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.timezone.hash(state); + } +} +impl PartialEq for HourExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.timezone.eq(&other.timezone) + } +} + +impl HourExpr { + pub fn new(child: Arc, timezone: String) -> Self { + HourExpr { child, timezone } + } +} + +impl Display for HourExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Hour [timezone:{}, child: {}]", + self.timezone, self.child + ) + } +} + +impl PhysicalExpr for HourExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema).unwrap() { + DataType::Dictionary(key_type, _) => { + Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) + } + _ => Ok(DataType::Int32), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let array = array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?; + let result = date_part(&array, DatePart::Hour)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Hour(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(HourExpr::new( + Arc::clone(&children[0]), + self.timezone.clone(), + ))) + } +} diff --git a/datafusion/spark/src/datetime_funcs/minute.rs b/datafusion/spark/src/datetime_funcs/minute.rs new file mode 100644 index 000000000000..b7facc167334 --- /dev/null +++ b/datafusion/spark/src/datetime_funcs/minute.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::array_with_timezone; +use arrow::{ + compute::{date_part, DatePart}, + record_batch::RecordBatch, +}; +use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::DataFusionError; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct MinuteExpr { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + timezone: String, +} + +impl Hash for MinuteExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.timezone.hash(state); + } +} +impl PartialEq for MinuteExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.timezone.eq(&other.timezone) + } +} + +impl MinuteExpr { + pub fn new(child: Arc, timezone: String) -> Self { + MinuteExpr { child, timezone } + } +} + +impl Display for MinuteExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Minute [timezone:{}, child: {}]", + self.timezone, self.child + ) + } +} + +impl PhysicalExpr for MinuteExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema).unwrap() { + DataType::Dictionary(key_type, _) => { + Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) + } + _ => Ok(DataType::Int32), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let array = array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?; + let result = date_part(&array, DatePart::Minute)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Minute(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(MinuteExpr::new( + Arc::clone(&children[0]), + self.timezone.clone(), + ))) + } +} diff --git a/datafusion/spark/src/datetime_funcs/mod.rs b/datafusion/spark/src/datetime_funcs/mod.rs new file mode 100644 index 000000000000..1f4d427282a3 --- /dev/null +++ b/datafusion/spark/src/datetime_funcs/mod.rs @@ -0,0 +1,30 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod date_arithmetic; +mod date_trunc; +mod hour; +mod minute; +mod second; +mod timestamp_trunc; + +pub use date_arithmetic::{spark_date_add, spark_date_sub}; +pub use date_trunc::DateTruncExpr; +pub use hour::HourExpr; +pub use minute::MinuteExpr; +pub use second::SecondExpr; +pub use timestamp_trunc::TimestampTruncExpr; diff --git a/datafusion/spark/src/datetime_funcs/second.rs b/datafusion/spark/src/datetime_funcs/second.rs new file mode 100644 index 000000000000..76a4dd9a2ca8 --- /dev/null +++ b/datafusion/spark/src/datetime_funcs/second.rs @@ -0,0 +1,122 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::array_with_timezone; +use arrow::{ + compute::{date_part, DatePart}, + record_batch::RecordBatch, +}; +use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::DataFusionError; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct SecondExpr { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + timezone: String, +} + +impl Hash for SecondExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.timezone.hash(state); + } +} +impl PartialEq for SecondExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.timezone.eq(&other.timezone) + } +} + +impl SecondExpr { + pub fn new(child: Arc, timezone: String) -> Self { + SecondExpr { child, timezone } + } +} + +impl Display for SecondExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Second (timezone:{}, child: {}]", + self.timezone, self.child + ) + } +} + +impl PhysicalExpr for SecondExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema).unwrap() { + DataType::Dictionary(key_type, _) => { + Ok(DataType::Dictionary(key_type, Box::new(DataType::Int32))) + } + _ => Ok(DataType::Int32), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let array = array_with_timezone( + array, + self.timezone.clone(), + Some(&DataType::Timestamp( + Microsecond, + Some(self.timezone.clone().into()), + )), + )?; + let result = date_part(&array, DatePart::Second)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Second(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(SecondExpr::new( + Arc::clone(&children[0]), + self.timezone.clone(), + ))) + } +} diff --git a/datafusion/spark/src/datetime_funcs/timestamp_trunc.rs b/datafusion/spark/src/datetime_funcs/timestamp_trunc.rs new file mode 100644 index 000000000000..bca9b8e8daab --- /dev/null +++ b/datafusion/spark/src/datetime_funcs/timestamp_trunc.rs @@ -0,0 +1,152 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::utils::array_with_timezone; +use arrow::record_batch::RecordBatch; +use arrow_schema::{DataType, Schema, TimeUnit::Microsecond}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue::Utf8}; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Debug, Display, Formatter}, + sync::Arc, +}; + +use crate::kernels::temporal::{timestamp_trunc_array_fmt_dyn, timestamp_trunc_dyn}; + +#[derive(Debug, Eq)] +pub struct TimestampTruncExpr { + /// An array with DataType::Timestamp(TimeUnit::Microsecond, None) + child: Arc, + /// Scalar UTF8 string matching the valid values in Spark SQL: + format: Arc, + /// String containing a timezone name. The name must be found in the standard timezone + /// database (). The string is + /// later parsed into a chrono::TimeZone. + /// Timestamp arrays in this implementation are kept in arrays of UTC timestamps (in micros) + /// along with a single value for the associated TimeZone. The timezone offset is applied + /// just before any operations on the timestamp + timezone: String, +} + +impl Hash for TimestampTruncExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.format.hash(state); + self.timezone.hash(state); + } +} +impl PartialEq for TimestampTruncExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) + && self.format.eq(&other.format) + && self.timezone.eq(&other.timezone) + } +} + +impl TimestampTruncExpr { + pub fn new( + child: Arc, + format: Arc, + timezone: String, + ) -> Self { + TimestampTruncExpr { + child, + format, + timezone, + } + } +} + +impl Display for TimestampTruncExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "TimestampTrunc [child:{}, format:{}, timezone: {}]", + self.child, self.format, self.timezone + ) + } +} + +impl PhysicalExpr for TimestampTruncExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema)? { + DataType::Dictionary(key_type, _) => Ok(DataType::Dictionary( + key_type, + Box::new(DataType::Timestamp(Microsecond, None)), + )), + _ => Ok(DataType::Timestamp(Microsecond, None)), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let timestamp = self.child.evaluate(batch)?; + let format = self.format.evaluate(batch)?; + let tz = self.timezone.clone(); + match (timestamp, format) { + (ColumnarValue::Array(ts), ColumnarValue::Scalar(Utf8(Some(format)))) => { + let ts = array_with_timezone( + ts, + tz.clone(), + Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), + )?; + let result = timestamp_trunc_dyn(&ts, format)?; + Ok(ColumnarValue::Array(result)) + } + (ColumnarValue::Array(ts), ColumnarValue::Array(formats)) => { + let ts = array_with_timezone( + ts, + tz.clone(), + Some(&DataType::Timestamp(Microsecond, Some(tz.into()))), + )?; + let result = timestamp_trunc_array_fmt_dyn(&ts, &formats)?; + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Invalid input to function TimestampTrunc. \ + Expected (PrimitiveArray, Scalar, String) or \ + (PrimitiveArray, StringArray, String)" + .to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result, DataFusionError> { + Ok(Arc::new(TimestampTruncExpr::new( + Arc::clone(&children[0]), + Arc::clone(&self.format), + self.timezone.clone(), + ))) + } +} diff --git a/datafusion/spark/src/error.rs b/datafusion/spark/src/error.rs new file mode 100644 index 000000000000..728a35a9d2e0 --- /dev/null +++ b/datafusion/spark/src/error.rs @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_schema::ArrowError; +use datafusion_common::DataFusionError; + +#[derive(thiserror::Error, Debug)] +pub enum SparkError { + // Note that this message format is based on Spark 3.4 and is more detailed than the message + // returned by Spark 3.3 + #[error("[CAST_INVALID_INPUT] The value '{value}' of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + because it is malformed. Correct the value as per the syntax, or change its target type. \ + Use `try_cast` to tolerate malformed input and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastInvalidValue { + value: String, + from_type: String, + to_type: String, + }, + + #[error("[NUMERIC_VALUE_OUT_OF_RANGE] {value} cannot be represented as Decimal({precision}, {scale}). If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error, and return NULL instead.")] + NumericValueOutOfRange { + value: String, + precision: u8, + scale: i8, + }, + + #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" cannot be cast to \"{to_type}\" \ + due to an overflow. Use `try_cast` to tolerate overflow and return NULL instead. If necessary \ + set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + CastOverFlow { + value: String, + from_type: String, + to_type: String, + }, + + #[error("[ARITHMETIC_OVERFLOW] {from_type} overflow. If necessary set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")] + ArithmeticOverflow { from_type: String }, + + #[error("ArrowError: {0}.")] + Arrow(ArrowError), + + #[error("InternalError: {0}.")] + Internal(String), +} + +pub type SparkResult = Result; + +impl From for SparkError { + fn from(value: ArrowError) -> Self { + SparkError::Arrow(value) + } +} + +impl From for DataFusionError { + fn from(value: SparkError) -> Self { + DataFusionError::External(Box::new(value)) + } +} diff --git a/datafusion/spark/src/hash_funcs/mod.rs b/datafusion/spark/src/hash_funcs/mod.rs new file mode 100644 index 000000000000..7649c4c5476f --- /dev/null +++ b/datafusion/spark/src/hash_funcs/mod.rs @@ -0,0 +1,25 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod murmur3; +mod sha2; +pub(super) mod utils; +mod xxhash64; + +pub use murmur3::spark_murmur3_hash; +pub use sha2::{spark_sha224, spark_sha256, spark_sha384, spark_sha512}; +pub use xxhash64::spark_xxhash64; diff --git a/datafusion/spark/src/hash_funcs/murmur3.rs b/datafusion/spark/src/hash_funcs/murmur3.rs new file mode 100644 index 000000000000..2590f716e5ab --- /dev/null +++ b/datafusion/spark/src/hash_funcs/murmur3.rs @@ -0,0 +1,287 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::create_hashes_internal; +use arrow::compute::take; +use arrow_array::types::ArrowDictionaryKeyType; +use arrow_array::{Array, ArrayRef, ArrowNativeTypeOp, DictionaryArray, Int32Array}; +use arrow_buffer::ArrowNativeType; +use datafusion_common::{internal_err, DataFusionError, ScalarValue}; +use datafusion_expr::ColumnarValue; +use std::sync::Arc; + +/// Spark compatible murmur3 hash (just `hash` in Spark) in vectorized execution fashion +pub fn spark_murmur3_hash( + args: &[ColumnarValue], +) -> Result { + let length = args.len(); + let seed = &args[length - 1]; + match seed { + ColumnarValue::Scalar(ScalarValue::Int32(Some(seed))) => { + // iterate over the arguments to find out the length of the array + let num_rows = args[0..args.len() - 1] + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + let mut hashes: Vec = vec![0_u32; num_rows]; + hashes.fill(*seed as u32); + let arrays = args[0..args.len() - 1] + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => Arc::clone(array), + ColumnarValue::Scalar(scalar) => { + scalar.clone().to_array_of_size(num_rows).unwrap() + } + }) + .collect::>(); + create_murmur3_hashes(&arrays, &mut hashes)?; + if num_rows == 1 { + Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some( + hashes[0] as i32, + )))) + } else { + let hashes: Vec = hashes.into_iter().map(|x| x as i32).collect(); + Ok(ColumnarValue::Array(Arc::new(Int32Array::from(hashes)))) + } + } + _ => { + internal_err!( + "The seed of function murmur3_hash must be an Int32 scalar value, but got: {:?}.", + seed + ) + } + } +} + +/// Spark-compatible murmur3 hash function +#[inline] +pub fn spark_compatible_murmur3_hash>(data: T, seed: u32) -> u32 { + #[inline] + fn mix_k1(mut k1: i32) -> i32 { + k1 = k1.mul_wrapping(0xcc9e2d51u32 as i32); + k1 = k1.rotate_left(15); + k1 = k1.mul_wrapping(0x1b873593u32 as i32); + k1 + } + + #[inline] + fn mix_h1(mut h1: i32, k1: i32) -> i32 { + h1 ^= k1; + h1 = h1.rotate_left(13); + h1 = h1.mul_wrapping(5).add_wrapping(0xe6546b64u32 as i32); + h1 + } + + #[inline] + fn fmix(mut h1: i32, len: i32) -> i32 { + h1 ^= len; + h1 ^= (h1 as u32 >> 16) as i32; + h1 = h1.mul_wrapping(0x85ebca6bu32 as i32); + h1 ^= (h1 as u32 >> 13) as i32; + h1 = h1.mul_wrapping(0xc2b2ae35u32 as i32); + h1 ^= (h1 as u32 >> 16) as i32; + h1 + } + + #[inline] + unsafe fn hash_bytes_by_int(data: &[u8], seed: u32) -> i32 { + // safety: data length must be aligned to 4 bytes + let mut h1 = seed as i32; + for i in (0..data.len()).step_by(4) { + let ints = data.as_ptr().add(i) as *const i32; + let mut half_word = ints.read_unaligned(); + if cfg!(target_endian = "big") { + half_word = half_word.reverse_bits(); + } + h1 = mix_h1(h1, mix_k1(half_word)); + } + h1 + } + let data = data.as_ref(); + let len = data.len(); + let len_aligned = len - len % 4; + + // safety: + // avoid boundary checking in performance critical codes. + // all operations are guaranteed to be safe + // data is &[u8] so we do not need to check for proper alignment + unsafe { + let mut h1 = if len_aligned > 0 { + hash_bytes_by_int(&data[0..len_aligned], seed) + } else { + seed as i32 + }; + + for i in len_aligned..len { + let half_word = *data.get_unchecked(i) as i8 as i32; + h1 = mix_h1(h1, mix_k1(half_word)); + } + fmix(h1, len as i32) as u32 + } +} + +/// Hash the values in a dictionary array +fn create_hashes_dictionary( + array: &ArrayRef, + hashes_buffer: &mut [u32], + first_col: bool, +) -> datafusion_common::Result<()> { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + if !first_col { + // unpack the dictionary array as each row may have a different hash input + let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), None)?; + create_murmur3_hashes(&[unpacked], hashes_buffer)?; + } else { + // For the first column, hash each dictionary value once, and then use + // that computed hash for each key value to avoid a potentially + // expensive redundant hashing for large dictionary elements (e.g. strings) + let dict_values = Arc::clone(dict_array.values()); + // same initial seed as Spark + let mut dict_hashes = vec![42; dict_values.len()]; + create_murmur3_hashes(&[dict_values], &mut dict_hashes)?; + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { + if let Some(key) = key { + let idx = key.to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert key value {:?} to usize in dictionary of type {:?}", + key, + dict_array.data_type() + )) + })?; + *hash = dict_hashes[idx] + } // no update for Null, consistent with other hashes + } + } + Ok(()) +} + +/// Creates hash values for every row, based on the values in the +/// columns. +/// +/// The number of rows to hash is determined by `hashes_buffer.len()`. +/// `hashes_buffer` should be pre-sized appropriately +pub fn create_murmur3_hashes<'a>( + arrays: &[ArrayRef], + hashes_buffer: &'a mut [u32], +) -> datafusion_common::Result<&'a mut [u32]> { + create_hashes_internal!( + arrays, + hashes_buffer, + spark_compatible_murmur3_hash, + create_hashes_dictionary + ); + Ok(hashes_buffer) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Float32Array, Float64Array}; + use std::sync::Arc; + + use crate::murmur3::create_murmur3_hashes; + use crate::test_hashes_with_nulls; + use datafusion::arrow::array::{ + ArrayRef, Int32Array, Int64Array, Int8Array, StringArray, + }; + + fn test_murmur3_hash< + I: Clone, + T: arrow_array::Array + From>> + 'static, + >( + values: Vec>, + expected: Vec, + ) { + test_hashes_with_nulls!(create_murmur3_hashes, T, values, expected, u32); + } + + #[test] + fn test_i8() { + test_murmur3_hash::( + vec![Some(1), Some(0), Some(-1), Some(i8::MAX), Some(i8::MIN)], + vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x43b4d8ed, 0x422a1365], + ); + } + + #[test] + fn test_i32() { + test_murmur3_hash::( + vec![Some(1), Some(0), Some(-1), Some(i32::MAX), Some(i32::MIN)], + vec![0xdea578e3, 0x379fae8f, 0xa0590e3d, 0x07fb67e7, 0x2b1f0fc6], + ); + } + + #[test] + fn test_i64() { + test_murmur3_hash::( + vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)], + vec![0x99f0149d, 0x9c67b85d, 0xc8008529, 0xa05b5d7b, 0xcd1e64fb], + ); + } + + #[test] + fn test_f32() { + test_murmur3_hash::( + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999), + ], + vec![ + 0xe434cc39, 0x379fae8f, 0x379fae8f, 0xdc0da8eb, 0xcbdc340f, 0xc0361c86, + ], + ); + } + + #[test] + fn test_f64() { + test_murmur3_hash::( + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999), + ], + vec![ + 0xe4876492, 0x9c67b85d, 0x9c67b85d, 0x13d81357, 0xb87e1595, 0xa0eef9f9, + ], + ); + } + + #[test] + fn test_str() { + let input = [ + "hello", "bar", "", "😁", "天地", "a", "ab", "abc", "abcd", "abcde", + ] + .iter() + .map(|s| Some(s.to_string())) + .collect::>>(); + let expected: Vec = vec![ + 3286402344, 2486176763, 142593372, 885025535, 2395000894, 1485273170, + 0xfa37157b, 1322437556, 0xe860e5cc, 814637928, + ]; + + test_murmur3_hash::(input.clone(), expected); + } +} diff --git a/datafusion/spark/src/hash_funcs/sha2.rs b/datafusion/spark/src/hash_funcs/sha2.rs new file mode 100644 index 000000000000..40d8def3a615 --- /dev/null +++ b/datafusion/spark/src/hash_funcs/sha2.rs @@ -0,0 +1,76 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::math_funcs::hex::hex_strings; +use arrow_array::{Array, StringArray}; +use datafusion::functions::crypto::{sha224, sha256, sha384, sha512}; +use datafusion_common::cast::as_binary_array; +use datafusion_common::{exec_err, DataFusionError, ScalarValue}; +use datafusion_expr::{ColumnarValue, ScalarUDF}; +use std::sync::Arc; + +/// `sha224` function that simulates Spark's `sha2` expression with bit width 224 +pub fn spark_sha224(args: &[ColumnarValue]) -> Result { + wrap_digest_result_as_hex_string(args, sha224()) +} + +/// `sha256` function that simulates Spark's `sha2` expression with bit width 0 or 256 +pub fn spark_sha256(args: &[ColumnarValue]) -> Result { + wrap_digest_result_as_hex_string(args, sha256()) +} + +/// `sha384` function that simulates Spark's `sha2` expression with bit width 384 +pub fn spark_sha384(args: &[ColumnarValue]) -> Result { + wrap_digest_result_as_hex_string(args, sha384()) +} + +/// `sha512` function that simulates Spark's `sha2` expression with bit width 512 +pub fn spark_sha512(args: &[ColumnarValue]) -> Result { + wrap_digest_result_as_hex_string(args, sha512()) +} + +// Spark requires hex string as the result of sha2 functions, we have to wrap the +// result of digest functions as hex string +fn wrap_digest_result_as_hex_string( + args: &[ColumnarValue], + digest: Arc, +) -> Result { + let row_count = match &args[0] { + ColumnarValue::Array(array) => array.len(), + ColumnarValue::Scalar(_) => 1, + }; + let value = digest.invoke_batch(args, row_count)?; + match value { + ColumnarValue::Array(array) => { + let binary_array = as_binary_array(&array)?; + let string_array: StringArray = binary_array + .iter() + .map(|opt| opt.map(hex_strings::<_>)) + .collect(); + Ok(ColumnarValue::Array(Arc::new(string_array))) + } + ColumnarValue::Scalar(ScalarValue::Binary(opt)) => Ok(ColumnarValue::Scalar( + ScalarValue::Utf8(opt.map(hex_strings::<_>)), + )), + _ => { + exec_err!( + "digest function should return binary value, but got: {:?}", + value.data_type() + ) + } + } +} diff --git a/datafusion/spark/src/hash_funcs/utils.rs b/datafusion/spark/src/hash_funcs/utils.rs new file mode 100644 index 000000000000..f1db9b4810e1 --- /dev/null +++ b/datafusion/spark/src/hash_funcs/utils.rs @@ -0,0 +1,428 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! This includes utilities for hashing and murmur3 hashing. + +#[macro_export] +macro_rules! hash_array { + ($array_type: ident, $column: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + if array.null_count() == 0 { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = $hash_method(&array.value(i), *hash); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = $hash_method(&array.value(i), *hash); + } + } + } + }; +} + +#[macro_export] +macro_rules! hash_array_boolean { + ($array_type: ident, $column: ident, $hash_input_type: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + if array.null_count() == 0 { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = $hash_method( + $hash_input_type::from(array.value(i)).to_le_bytes(), + *hash, + ); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = $hash_method( + $hash_input_type::from(array.value(i)).to_le_bytes(), + *hash, + ); + } + } + } + }; +} + +#[macro_export] +macro_rules! hash_array_primitive { + ($array_type: ident, $column: ident, $ty: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + let values = array.values(); + + if array.null_count() == 0 { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); + } + } else { + for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() { + if !array.is_null(i) { + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); + } + } + } + }; +} + +#[macro_export] +macro_rules! hash_array_primitive_float { + ($array_type: ident, $column: ident, $ty: ident, $ty2: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + let values = array.values(); + + if array.null_count() == 0 { + for (hash, value) in $hashes.iter_mut().zip(values.iter()) { + // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression. + if *value == 0.0 && value.is_sign_negative() { + *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash); + } else { + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); + } + } + } else { + for (i, (hash, value)) in $hashes.iter_mut().zip(values.iter()).enumerate() { + if !array.is_null(i) { + // Spark uses 0 as hash for -0.0, see `Murmur3Hash` expression. + if *value == 0.0 && value.is_sign_negative() { + *hash = $hash_method((0 as $ty2).to_le_bytes(), *hash); + } else { + *hash = $hash_method((*value as $ty).to_le_bytes(), *hash); + } + } + } + } + }; +} + +#[macro_export] +macro_rules! hash_array_small_decimal { + ($array_type:ident, $column: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + + if array.null_count() == 0 { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = $hash_method( + i64::try_from(array.value(i)).unwrap().to_le_bytes(), + *hash, + ); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = $hash_method( + i64::try_from(array.value(i)).unwrap().to_le_bytes(), + *hash, + ); + } + } + } + }; +} + +#[macro_export] +macro_rules! hash_array_decimal { + ($array_type:ident, $column: ident, $hashes: ident, $hash_method: ident) => { + let array = $column.as_any().downcast_ref::<$array_type>().unwrap(); + + if array.null_count() == 0 { + for (i, hash) in $hashes.iter_mut().enumerate() { + *hash = $hash_method(array.value(i).to_le_bytes(), *hash); + } + } else { + for (i, hash) in $hashes.iter_mut().enumerate() { + if !array.is_null(i) { + *hash = $hash_method(array.value(i).to_le_bytes(), *hash); + } + } + } + }; +} + +/// Creates hash values for every row, based on the values in the +/// columns. +/// +/// The number of rows to hash is determined by `hashes_buffer.len()`. +/// `hashes_buffer` should be pre-sized appropriately +/// +/// `hash_method` is the hash function to use. +/// `create_dictionary_hash_method` is the function to create hashes for dictionary arrays input. +#[macro_export] +macro_rules! create_hashes_internal { + ($arrays: ident, $hashes_buffer: ident, $hash_method: ident, $create_dictionary_hash_method: ident) => { + use arrow::datatypes::{DataType, TimeUnit}; + use arrow_array::{types::*, *}; + + for (i, col) in $arrays.iter().enumerate() { + let first_col = i == 0; + match col.data_type() { + DataType::Boolean => { + $crate::hash_array_boolean!( + BooleanArray, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Int8 => { + $crate::hash_array_primitive!( + Int8Array, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Int16 => { + $crate::hash_array_primitive!( + Int16Array, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Int32 => { + $crate::hash_array_primitive!( + Int32Array, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Int64 => { + $crate::hash_array_primitive!( + Int64Array, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Float32 => { + $crate::hash_array_primitive_float!( + Float32Array, + col, + f32, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Float64 => { + $crate::hash_array_primitive_float!( + Float64Array, + col, + f64, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Second, _) => { + $crate::hash_array_primitive!( + TimestampSecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + $crate::hash_array_primitive!( + TimestampMillisecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + $crate::hash_array_primitive!( + TimestampMicrosecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + $crate::hash_array_primitive!( + TimestampNanosecondArray, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Date32 => { + $crate::hash_array_primitive!( + Date32Array, + col, + i32, + $hashes_buffer, + $hash_method + ); + } + DataType::Date64 => { + $crate::hash_array_primitive!( + Date64Array, + col, + i64, + $hashes_buffer, + $hash_method + ); + } + DataType::Utf8 => { + $crate::hash_array!(StringArray, col, $hashes_buffer, $hash_method); + } + DataType::LargeUtf8 => { + $crate::hash_array!(LargeStringArray, col, $hashes_buffer, $hash_method); + } + DataType::Binary => { + $crate::hash_array!(BinaryArray, col, $hashes_buffer, $hash_method); + } + DataType::LargeBinary => { + $crate::hash_array!(LargeBinaryArray, col, $hashes_buffer, $hash_method); + } + DataType::FixedSizeBinary(_) => { + $crate::hash_array!(FixedSizeBinaryArray, col, $hashes_buffer, $hash_method); + } + // Apache Spark: if it's a small decimal, i.e. precision <= 18, turn it into long and hash it. + // Else, turn it into bytes and hash it. + DataType::Decimal128(precision, _) if *precision <= 18 => { + $crate::hash_array_small_decimal!(Decimal128Array, col, $hashes_buffer, $hash_method); + } + DataType::Decimal128(_, _) => { + $crate::hash_array_decimal!(Decimal128Array, col, $hashes_buffer, $hash_method); + } + DataType::Dictionary(index_type, _) => match **index_type { + DataType::Int8 => { + $create_dictionary_hash_method::(col, $hashes_buffer, first_col)?; + } + DataType::Int16 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::Int32 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::Int64 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt8 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt16 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt32 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + DataType::UInt64 => { + $create_dictionary_hash_method::( + col, + $hashes_buffer, + first_col, + )?; + } + _ => { + return Err(DataFusionError::Internal(format!( + "Unsupported dictionary type in hasher hashing: {}", + col.data_type(), + ))) + } + }, + _ => { + // This is internal because we should have caught this before. + return Err(DataFusionError::Internal(format!( + "Unsupported data type in hasher: {}", + col.data_type() + ))); + } + } + } + }; +} + +pub(crate) mod test_utils { + + #[macro_export] + macro_rules! test_hashes_internal { + ($hash_method: ident, $input: expr, $initial_seeds: expr, $expected: expr) => { + let i = $input; + let mut hashes = $initial_seeds.clone(); + $hash_method(&[i], &mut hashes).unwrap(); + assert_eq!(hashes, $expected); + }; + } + + #[macro_export] + macro_rules! test_hashes_with_nulls { + ($method: ident, $t: ty, $values: ident, $expected: ident, $seed_type: ty) => { + // copied before inserting nulls + let mut input_with_nulls = $values.clone(); + let mut expected_with_nulls = $expected.clone(); + // test before inserting nulls + let len = $values.len(); + let initial_seeds = vec![42 as $seed_type; len]; + let i = Arc::new(<$t>::from($values)) as ArrayRef; + $crate::test_hashes_internal!($method, i, initial_seeds, $expected); + + // test with nulls + let median = len / 2; + input_with_nulls.insert(0, None); + input_with_nulls.insert(median, None); + expected_with_nulls.insert(0, 42 as $seed_type); + expected_with_nulls.insert(median, 42 as $seed_type); + let len_with_nulls = len + 2; + let initial_seeds_with_nulls = vec![42 as $seed_type; len_with_nulls]; + let nullable_input = Arc::new(<$t>::from(input_with_nulls)) as ArrayRef; + $crate::test_hashes_internal!( + $method, + nullable_input, + initial_seeds_with_nulls, + expected_with_nulls + ); + }; + } +} diff --git a/datafusion/spark/src/hash_funcs/xxhash64.rs b/datafusion/spark/src/hash_funcs/xxhash64.rs new file mode 100644 index 000000000000..cc33c535010a --- /dev/null +++ b/datafusion/spark/src/hash_funcs/xxhash64.rs @@ -0,0 +1,269 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::compute::take; +use twox_hash::XxHash64; + +use datafusion::{ + arrow::{ + array::*, + datatypes::{ArrowDictionaryKeyType, ArrowNativeType}, + }, + common::{internal_err, ScalarValue}, + error::{DataFusionError, Result}, +}; + +use crate::create_hashes_internal; +use arrow_array::{Array, ArrayRef, Int64Array}; +use datafusion_expr::ColumnarValue; +use std::sync::Arc; + +/// Spark compatible xxhash64 in vectorized execution fashion +pub fn spark_xxhash64(args: &[ColumnarValue]) -> Result { + let length = args.len(); + let seed = &args[length - 1]; + match seed { + ColumnarValue::Scalar(ScalarValue::Int64(Some(seed))) => { + // iterate over the arguments to find out the length of the array + let num_rows = args[0..args.len() - 1] + .iter() + .find_map(|arg| match arg { + ColumnarValue::Array(array) => Some(array.len()), + ColumnarValue::Scalar(_) => None, + }) + .unwrap_or(1); + let mut hashes: Vec = vec![0_u64; num_rows]; + hashes.fill(*seed as u64); + let arrays = args[0..args.len() - 1] + .iter() + .map(|arg| match arg { + ColumnarValue::Array(array) => Arc::clone(array), + ColumnarValue::Scalar(scalar) => { + scalar.clone().to_array_of_size(num_rows).unwrap() + } + }) + .collect::>(); + create_xxhash64_hashes(&arrays, &mut hashes)?; + if num_rows == 1 { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(Some( + hashes[0] as i64, + )))) + } else { + let hashes: Vec = hashes.into_iter().map(|x| x as i64).collect(); + Ok(ColumnarValue::Array(Arc::new(Int64Array::from(hashes)))) + } + } + _ => { + internal_err!( + "The seed of function xxhash64 must be an Int64 scalar value, but got: {:?}.", + seed + ) + } + } +} + +#[inline] +fn spark_compatible_xxhash64>(data: T, seed: u64) -> u64 { + XxHash64::oneshot(seed, data.as_ref()) +} + +// Hash the values in a dictionary array using xxhash64 +fn create_xxhash64_hashes_dictionary( + array: &ArrayRef, + hashes_buffer: &mut [u64], + first_col: bool, +) -> Result<()> { + let dict_array = array.as_any().downcast_ref::>().unwrap(); + if !first_col { + let unpacked = take(dict_array.values().as_ref(), dict_array.keys(), None)?; + create_xxhash64_hashes(&[unpacked], hashes_buffer)?; + } else { + // Hash each dictionary value once, and then use that computed + // hash for each key value to avoid a potentially expensive + // redundant hashing for large dictionary elements (e.g. strings) + let dict_values = Arc::clone(dict_array.values()); + // same initial seed as Spark + let mut dict_hashes = vec![42u64; dict_values.len()]; + create_xxhash64_hashes(&[dict_values], &mut dict_hashes)?; + + for (hash, key) in hashes_buffer.iter_mut().zip(dict_array.keys().iter()) { + if let Some(key) = key { + let idx = key.to_usize().ok_or_else(|| { + DataFusionError::Internal(format!( + "Can not convert key value {:?} to usize in dictionary of type {:?}", + key, + dict_array.data_type() + )) + })?; + *hash = dict_hashes[idx] + } // no update for Null, consistent with other hashes + } + } + Ok(()) +} + +/// Creates xxhash64 hash values for every row, based on the values in the +/// columns. +/// +/// The number of rows to hash is determined by `hashes_buffer.len()`. +/// `hashes_buffer` should be pre-sized appropriately +fn create_xxhash64_hashes<'a>( + arrays: &[ArrayRef], + hashes_buffer: &'a mut [u64], +) -> Result<&'a mut [u64]> { + create_hashes_internal!( + arrays, + hashes_buffer, + spark_compatible_xxhash64, + create_xxhash64_hashes_dictionary + ); + Ok(hashes_buffer) +} + +#[cfg(test)] +mod tests { + use arrow::array::{Float32Array, Float64Array}; + use std::sync::Arc; + + use super::create_xxhash64_hashes; + use crate::test_hashes_with_nulls; + use datafusion::arrow::array::{ + ArrayRef, Int32Array, Int64Array, Int8Array, StringArray, + }; + + fn test_xxhash64_hash< + I: Clone, + T: arrow_array::Array + From>> + 'static, + >( + values: Vec>, + expected: Vec, + ) { + test_hashes_with_nulls!(create_xxhash64_hashes, T, values, expected, u64); + } + + #[test] + fn test_i8() { + test_xxhash64_hash::( + vec![Some(1), Some(0), Some(-1), Some(i8::MAX), Some(i8::MIN)], + vec![ + 0xa309b38455455929, + 0x3229fbc4681e48f3, + 0x1bfdda8861c06e45, + 0x77cc15d9f9f2cdc2, + 0x39bc22b9e94d81d0, + ], + ); + } + + #[test] + fn test_i32() { + test_xxhash64_hash::( + vec![Some(1), Some(0), Some(-1), Some(i32::MAX), Some(i32::MIN)], + vec![ + 0xa309b38455455929, + 0x3229fbc4681e48f3, + 0x1bfdda8861c06e45, + 0x14f0ac009c21721c, + 0x1cc7cb8d034769cd, + ], + ); + } + + #[test] + fn test_i64() { + test_xxhash64_hash::( + vec![Some(1), Some(0), Some(-1), Some(i64::MAX), Some(i64::MIN)], + vec![ + 0x9ed50fd59358d232, + 0xb71b47ebda15746c, + 0x358ae035bfb46fd2, + 0xd2f1c616ae7eb306, + 0x88608019c494c1f4, + ], + ); + } + + #[test] + fn test_f32() { + test_xxhash64_hash::( + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999), + ], + vec![ + 0x9b92689757fcdbd, + 0x3229fbc4681e48f3, + 0x3229fbc4681e48f3, + 0xa2becc0e61bb3823, + 0x8f20ab82d4f3687f, + 0xdce4982d97f7ac4, + ], + ) + } + + #[test] + fn test_f64() { + test_xxhash64_hash::( + vec![ + Some(1.0), + Some(0.0), + Some(-0.0), + Some(-1.0), + Some(99999999999.99999999999), + Some(-99999999999.99999999999), + ], + vec![ + 0xe1fd6e07fee8ad53, + 0xb71b47ebda15746c, + 0xb71b47ebda15746c, + 0x8cdde022746f8f1f, + 0x793c5c88d313eac7, + 0xc5e60e7b75d9b232, + ], + ) + } + + #[test] + fn test_str() { + let input = [ + "hello", "bar", "", "😁", "天地", "a", "ab", "abc", "abcd", "abcde", + ] + .iter() + .map(|s| Some(s.to_string())) + .collect::>>(); + + test_xxhash64_hash::( + input, + vec![ + 0xc3629e6318d53932, + 0xe7097b6a54378d8a, + 0x98b1582b0977e704, + 0xa80d9d5a6a523bd5, + 0xfcba5f61ac666c61, + 0x88e4fe59adf7b0cc, + 0x259dd873209a3fe3, + 0x13c1d910702770e6, + 0xa17b5eb5dc364dff, + 0xf241303e4a90f299, + ], + ) + } +} diff --git a/datafusion/spark/src/json_funcs/mod.rs b/datafusion/spark/src/json_funcs/mod.rs new file mode 100644 index 000000000000..de3037590dba --- /dev/null +++ b/datafusion/spark/src/json_funcs/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod to_json; + +pub use to_json::ToJson; diff --git a/datafusion/spark/src/json_funcs/to_json.rs b/datafusion/spark/src/json_funcs/to_json.rs new file mode 100644 index 000000000000..3389ea3a0e24 --- /dev/null +++ b/datafusion/spark/src/json_funcs/to_json.rs @@ -0,0 +1,356 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// TODO upstream this to DataFusion as long as we have a way to specify all +// of the Spark-specific compatibility features that we need (including +// being able to specify Spark-compatible cast from all types to string) + +use crate::SparkCastOptions; +use crate::{spark_cast, EvalMode}; +use arrow_array::builder::StringBuilder; +use arrow_array::{Array, ArrayRef, RecordBatch, StringArray, StructArray}; +use arrow_schema::{DataType, Schema}; +use datafusion_common::Result; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr::PhysicalExpr; +use std::any::Any; +use std::fmt::{Debug, Display, Formatter}; +use std::hash::Hash; +use std::sync::Arc; + +/// to_json function +#[derive(Debug, Eq)] +pub struct ToJson { + /// The input to convert to JSON + expr: Arc, + /// Timezone to use when converting timestamps to JSON + timezone: String, +} + +impl Hash for ToJson { + fn hash(&self, state: &mut H) { + self.expr.hash(state); + self.timezone.hash(state); + } +} +impl PartialEq for ToJson { + fn eq(&self, other: &Self) -> bool { + self.expr.eq(&other.expr) && self.timezone.eq(&other.timezone) + } +} + +impl ToJson { + pub fn new(expr: Arc, timezone: &str) -> Self { + Self { + expr, + timezone: timezone.to_owned(), + } + } +} + +impl Display for ToJson { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "to_json({}, timezone={})", self.expr, self.timezone) + } +} + +impl PartialEq for ToJson { + fn eq(&self, other: &dyn Any) -> bool { + if let Some(other) = other.downcast_ref::() { + self.expr.eq(&other.expr) && self.timezone.eq(&other.timezone) + } else { + false + } + } +} + +impl PhysicalExpr for ToJson { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _: &Schema) -> Result { + Ok(DataType::Utf8) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.expr.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let input = self.expr.evaluate(batch)?.into_array(batch.num_rows())?; + Ok(ColumnarValue::Array(array_to_json_string( + &input, + &self.timezone, + )?)) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.expr] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert!(children.len() == 1); + Ok(Arc::new(Self::new( + Arc::clone(&children[0]), + &self.timezone, + ))) + } +} + +/// Convert an array into a JSON value string representation +fn array_to_json_string(arr: &Arc, timezone: &str) -> Result { + if let Some(struct_array) = arr.as_any().downcast_ref::() { + struct_to_json(struct_array, timezone) + } else { + spark_cast( + ColumnarValue::Array(Arc::clone(arr)), + &DataType::Utf8, + &SparkCastOptions::new(EvalMode::Legacy, timezone, false), + )? + .into_array(arr.len()) + } +} + +fn escape_string(input: &str) -> String { + let mut escaped_string = String::with_capacity(input.len()); + let mut is_escaped = false; + for c in input.chars() { + match c { + '\"' | '\\' if !is_escaped => { + escaped_string.push('\\'); + escaped_string.push(c); + is_escaped = false; + } + '\t' => { + escaped_string.push('\\'); + escaped_string.push('t'); + is_escaped = false; + } + '\r' => { + escaped_string.push('\\'); + escaped_string.push('r'); + is_escaped = false; + } + '\n' => { + escaped_string.push('\\'); + escaped_string.push('n'); + is_escaped = false; + } + '\x0C' => { + escaped_string.push('\\'); + escaped_string.push('f'); + is_escaped = false; + } + '\x08' => { + escaped_string.push('\\'); + escaped_string.push('b'); + is_escaped = false; + } + '\\' => { + escaped_string.push('\\'); + is_escaped = true; + } + _ => { + escaped_string.push(c); + is_escaped = false; + } + } + } + escaped_string +} + +fn struct_to_json(array: &StructArray, timezone: &str) -> Result { + // get field names and escape any quotes + let field_names: Vec = array + .fields() + .iter() + .map(|f| escape_string(f.name().as_str())) + .collect(); + // determine which fields need to have their values quoted + let is_string: Vec = array + .fields() + .iter() + .map(|f| match f.data_type() { + DataType::Utf8 | DataType::LargeUtf8 => true, + DataType::Dictionary(_, dt) => { + matches!(dt.as_ref(), DataType::Utf8 | DataType::LargeUtf8) + } + _ => false, + }) + .collect(); + // create JSON string representation of each column + let string_arrays: Vec = array + .columns() + .iter() + .map(|arr| array_to_json_string(arr, timezone)) + .collect::>>()?; + let string_arrays: Vec<&StringArray> = string_arrays + .iter() + .map(|arr| { + arr.as_any() + .downcast_ref::() + .expect("string array") + }) + .collect(); + // build the JSON string containing entries in the format `"field_name":field_value` + let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16); + let mut json = String::with_capacity(array.len() * 16); + for row_index in 0..array.len() { + if array.is_null(row_index) { + builder.append_null(); + } else { + json.clear(); + let mut any_fields_written = false; + json.push('{'); + for col_index in 0..string_arrays.len() { + if !string_arrays[col_index].is_null(row_index) { + if any_fields_written { + json.push(','); + } + // quoted field name + json.push('"'); + json.push_str(&field_names[col_index]); + json.push_str("\":"); + // value + let string_value = string_arrays[col_index].value(row_index); + if is_string[col_index] { + json.push('"'); + json.push_str(&escape_string(string_value)); + json.push('"'); + } else { + json.push_str(string_value); + } + any_fields_written = true; + } + } + json.push('}'); + builder.append_value(&json); + } + } + Ok(Arc::new(builder.finish())) +} + +#[cfg(test)] +mod test { + use crate::json_funcs::to_json::struct_to_json; + use arrow_array::types::Int32Type; + use arrow_array::{Array, PrimitiveArray, StringArray}; + use arrow_array::{ArrayRef, BooleanArray, Int32Array, StructArray}; + use arrow_schema::{DataType, Field}; + use datafusion_common::Result; + use std::sync::Arc; + + #[test] + fn test_primitives() -> Result<()> { + let bools: ArrayRef = create_bools(); + let ints: ArrayRef = create_ints(); + let strings: ArrayRef = create_strings(); + let struct_array = StructArray::from(vec![ + (Arc::new(Field::new("a", DataType::Boolean, true)), bools), + (Arc::new(Field::new("b", DataType::Int32, true)), ints), + (Arc::new(Field::new("c", DataType::Utf8, true)), strings), + ]); + let json = struct_to_json(&struct_array, "UTC")?; + let json = json + .as_any() + .downcast_ref::() + .expect("string array"); + assert_eq!(4, json.len()); + assert_eq!(r#"{"b":123}"#, json.value(0)); + assert_eq!(r#"{"a":true,"c":"foo"}"#, json.value(1)); + assert_eq!(r#"{"a":false,"b":2147483647,"c":"bar"}"#, json.value(2)); + assert_eq!(r#"{"a":false,"b":-2147483648,"c":""}"#, json.value(3)); + Ok(()) + } + + #[test] + fn test_nested_struct() -> Result<()> { + let bools: ArrayRef = create_bools(); + let ints: ArrayRef = create_ints(); + + // create first struct array + let struct_fields = vec![ + Arc::new(Field::new("a", DataType::Boolean, true)), + Arc::new(Field::new("b", DataType::Int32, true)), + ]; + let struct_values = vec![bools, ints]; + let struct_array = StructArray::from( + struct_fields + .clone() + .into_iter() + .zip(struct_values) + .collect::>(), + ); + + // create second struct array containing the first struct array + let struct_fields2 = vec![Arc::new(Field::new( + "a", + DataType::Struct(struct_fields.into()), + true, + ))]; + let struct_values2: Vec = vec![Arc::new(struct_array.clone())]; + let struct_array2 = StructArray::from( + struct_fields2 + .into_iter() + .zip(struct_values2) + .collect::>(), + ); + + let json = struct_to_json(&struct_array2, "UTC")?; + let json = json + .as_any() + .downcast_ref::() + .expect("string array"); + assert_eq!(4, json.len()); + assert_eq!(r#"{"a":{"b":123}}"#, json.value(0)); + assert_eq!(r#"{"a":{"a":true}}"#, json.value(1)); + assert_eq!(r#"{"a":{"a":false,"b":2147483647}}"#, json.value(2)); + assert_eq!(r#"{"a":{"a":false,"b":-2147483648}}"#, json.value(3)); + Ok(()) + } + + fn create_ints() -> Arc> { + Arc::new(Int32Array::from(vec![ + Some(123), + None, + Some(i32::MAX), + Some(i32::MIN), + ])) + } + + fn create_bools() -> Arc { + Arc::new(BooleanArray::from(vec![ + None, + Some(true), + Some(false), + Some(false), + ])) + } + + fn create_strings() -> Arc { + Arc::new(StringArray::from(vec![ + None, + Some("foo"), + Some("bar"), + Some(""), + ])) + } +} diff --git a/datafusion/spark/src/kernels/mod.rs b/datafusion/spark/src/kernels/mod.rs new file mode 100644 index 000000000000..3669ff13ad0e --- /dev/null +++ b/datafusion/spark/src/kernels/mod.rs @@ -0,0 +1,21 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Kernels + +pub mod strings; +pub(crate) mod temporal; diff --git a/datafusion/spark/src/kernels/strings.rs b/datafusion/spark/src/kernels/strings.rs new file mode 100644 index 000000000000..8b0b3dea32f9 --- /dev/null +++ b/datafusion/spark/src/kernels/strings.rs @@ -0,0 +1,132 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! String kernels + +use std::sync::Arc; + +use arrow::{ + array::*, + buffer::MutableBuffer, + compute::kernels::substring::{substring as arrow_substring, substring_by_char}, + datatypes::{DataType, Int32Type}, +}; +use datafusion_common::DataFusionError; + +/// Returns an ArrayRef with a string consisting of `length` spaces. +/// +/// # Preconditions +/// +/// - elements in `length` must not be negative +pub fn string_space(length: &dyn Array) -> Result { + match length.data_type() { + DataType::Int32 => { + let array = length.as_any().downcast_ref::().unwrap(); + Ok(generic_string_space::(array)) + } + DataType::Dictionary(_, _) => { + let dict = as_dictionary_array::(length); + let values = string_space(dict.values())?; + let result = DictionaryArray::try_new(dict.keys().clone(), values)?; + Ok(Arc::new(result)) + } + dt => panic!( + "Unsupported input type for function 'string_space': {:?}", + dt + ), + } +} + +pub fn substring( + array: &dyn Array, + start: i64, + length: u64, +) -> Result { + match array.data_type() { + DataType::LargeUtf8 => substring_by_char( + array + .as_any() + .downcast_ref::() + .expect("A large string is expected"), + start, + Some(length), + ) + .map_err(|e| e.into()) + .map(|t| make_array(t.into_data())), + DataType::Utf8 => substring_by_char( + array + .as_any() + .downcast_ref::() + .expect("A string is expected"), + start, + Some(length), + ) + .map_err(|e| e.into()) + .map(|t| make_array(t.into_data())), + DataType::Binary | DataType::LargeBinary => { + arrow_substring(array, start, Some(length)).map_err(|e| e.into()) + } + DataType::Dictionary(_, _) => { + let dict = as_dictionary_array::(array); + let values = substring(dict.values(), start, length)?; + let result = DictionaryArray::try_new(dict.keys().clone(), values)?; + Ok(Arc::new(result)) + } + dt => panic!("Unsupported input type for function 'substring': {:?}", dt), + } +} + +fn generic_string_space(length: &Int32Array) -> ArrayRef { + let array_len = length.len(); + let mut offsets = + MutableBuffer::new((array_len + 1) * std::mem::size_of::()); + let mut length_so_far = OffsetSize::zero(); + + // compute null bitmap (copy) + let null_bit_buffer = length.to_data().nulls().map(|b| b.buffer().clone()); + + // Gets slice of length array to access it directly for performance. + let length_data = length.to_data(); + let lengths = length_data.buffers()[0].typed_data::(); + let total = lengths.iter().map(|l| *l as usize).sum::(); + let mut values = MutableBuffer::new(total); + + offsets.push(length_so_far); + + let blank = " ".as_bytes()[0]; + values.resize(total, blank); + + (0..array_len).for_each(|i| { + let current_len = lengths[i] as usize; + + length_so_far += OffsetSize::from_usize(current_len).unwrap(); + offsets.push(length_so_far); + }); + + let data = unsafe { + ArrayData::new_unchecked( + GenericStringArray::::DATA_TYPE, + array_len, + None, + null_bit_buffer, + 0, + vec![offsets.into(), values.into()], + vec![], + ) + }; + make_array(data) +} diff --git a/datafusion/spark/src/kernels/temporal.rs b/datafusion/spark/src/kernels/temporal.rs new file mode 100644 index 000000000000..9cfa2d04b628 --- /dev/null +++ b/datafusion/spark/src/kernels/temporal.rs @@ -0,0 +1,1177 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! temporal kernels + +use chrono::{DateTime, Datelike, Duration, NaiveDateTime, Timelike, Utc}; + +use std::sync::Arc; + +use arrow::{array::*, datatypes::DataType}; +use arrow_array::{ + downcast_dictionary_array, downcast_temporal_array, + temporal_conversions::*, + timezone::Tz, + types::{ + ArrowDictionaryKeyType, ArrowTemporalType, Date32Type, TimestampMicrosecondType, + }, + ArrowNumericType, +}; + +use arrow_schema::TimeUnit; + +use crate::SparkError; + +// Copied from arrow_arith/temporal.rs +macro_rules! return_compute_error_with { + ($msg:expr, $param:expr) => { + return { Err(SparkError::Internal(format!("{}: {:?}", $msg, $param))) } + }; +} + +// The number of days between the beginning of the proleptic gregorian calendar (0001-01-01) +// and the beginning of the Unix Epoch (1970-01-01) +const DAYS_TO_UNIX_EPOCH: i32 = 719_163; + +// Copied from arrow_arith/temporal.rs with modification to the output datatype +// Transforms a array of NaiveDate to an array of Date32 after applying an operation +fn as_datetime_with_op, T: ArrowTemporalType, F>( + iter: ArrayIter, + mut builder: PrimitiveBuilder, + op: F, +) -> Date32Array +where + F: Fn(NaiveDateTime) -> i32, + i64: From, +{ + iter.into_iter().for_each(|value| { + if let Some(value) = value { + match as_datetime::(i64::from(value)) { + Some(dt) => builder.append_value(op(dt)), + None => builder.append_null(), + } + } else { + builder.append_null(); + } + }); + + builder.finish() +} + +#[inline] +fn as_datetime_with_op_single( + value: Option, + builder: &mut PrimitiveBuilder, + op: F, +) where + F: Fn(NaiveDateTime) -> i32, +{ + if let Some(value) = value { + match as_datetime::(i64::from(value)) { + Some(dt) => builder.append_value(op(dt)), + None => builder.append_null(), + } + } else { + builder.append_null(); + } +} + +// Based on arrow_arith/temporal.rs:extract_component_from_datetime_array +// Transforms an array of DateTime to an arrayOf TimeStampMicrosecond after applying an +// operation +fn as_timestamp_tz_with_op, T: ArrowTemporalType, F>( + iter: ArrayIter, + mut builder: PrimitiveBuilder, + tz: &str, + op: F, +) -> Result +where + F: Fn(DateTime) -> i64, + i64: From, +{ + let tz: Tz = tz.parse()?; + for value in iter { + match value { + Some(value) => match as_datetime_with_timezone::(value.into(), tz) { + Some(time) => builder.append_value(op(time)), + _ => { + return Err(SparkError::Internal( + "Unable to read value as datetime".to_string(), + )); + } + }, + None => builder.append_null(), + } + } + Ok(builder.finish()) +} + +fn as_timestamp_tz_with_op_single( + value: Option, + builder: &mut PrimitiveBuilder, + tz: &Tz, + op: F, +) -> Result<(), SparkError> +where + F: Fn(DateTime) -> i64, + i64: From, +{ + match value { + Some(value) => match as_datetime_with_timezone::(value.into(), *tz) { + Some(time) => builder.append_value(op(time)), + _ => { + return Err(SparkError::Internal( + "Unable to read value as datetime".to_string(), + )); + } + }, + None => builder.append_null(), + } + Ok(()) +} + +#[inline] +fn as_days_from_unix_epoch(dt: Option) -> i32 { + dt.unwrap().num_days_from_ce() - DAYS_TO_UNIX_EPOCH +} + +// Apply the Tz to the Naive Date Time,,convert to UTC, and return as microseconds in Unix epoch +#[inline] +fn as_micros_from_unix_epoch_utc(dt: Option>) -> i64 { + dt.unwrap().with_timezone(&Utc).timestamp_micros() +} + +#[inline] +fn trunc_date_to_year(dt: T) -> Option { + Some(dt) + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)) + .and_then(|d| d.with_day0(0)) + .and_then(|d| d.with_month0(0)) +} + +/// returns the month of the beginning of the quarter +#[inline] +fn quarter_month(dt: &T) -> u32 { + 1 + 3 * ((dt.month() - 1) / 3) +} + +#[inline] +fn trunc_date_to_quarter(dt: T) -> Option { + Some(dt) + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)) + .and_then(|d| d.with_day0(0)) + .and_then(|d| d.with_month(quarter_month(&d))) +} + +#[inline] +fn trunc_date_to_month(dt: T) -> Option { + Some(dt) + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)) + .and_then(|d| d.with_day0(0)) +} + +#[inline] +fn trunc_date_to_week(dt: T) -> Option +where + T: Datelike + Timelike + std::ops::Sub + Copy, +{ + Some(dt) + .map(|d| d - Duration::try_seconds(60 * 60 * 24 * d.weekday() as i64).unwrap()) + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)) +} + +#[inline] +fn trunc_date_to_day(dt: T) -> Option { + Some(dt) + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) + .and_then(|d| d.with_hour(0)) +} + +#[inline] +fn trunc_date_to_hour(dt: T) -> Option { + Some(dt) + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)) + .and_then(|d| d.with_minute(0)) +} + +#[inline] +fn trunc_date_to_minute(dt: T) -> Option { + Some(dt) + .and_then(|d| d.with_nanosecond(0)) + .and_then(|d| d.with_second(0)) +} + +#[inline] +fn trunc_date_to_second(dt: T) -> Option { + Some(dt).and_then(|d| d.with_nanosecond(0)) +} + +#[inline] +fn trunc_date_to_ms(dt: T) -> Option { + Some(dt).and_then(|d| d.with_nanosecond(1_000_000 * (d.nanosecond() / 1_000_000))) +} + +#[inline] +fn trunc_date_to_microsec(dt: T) -> Option { + Some(dt).and_then(|d| d.with_nanosecond(1_000 * (d.nanosecond() / 1_000))) +} + +/// +/// Implements the spark [TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#trunc) +/// function where the specified format is a scalar value +/// +/// array is an array of Date32 values. The array may be a dictionary array. +/// +/// format is a scalar string specifying the format to apply to the timestamp value. +pub(crate) fn date_trunc_dyn( + array: &dyn Array, + format: String, +) -> Result { + match array.data_type().clone() { + DataType::Dictionary(_, _) => { + downcast_dictionary_array!( + array => { + let truncated_values = date_trunc_dyn(array.values(), format)?; + Ok(Arc::new(array.with_values(truncated_values))) + } + dt => return_compute_error_with!("date_trunc does not support", dt), + ) + } + _ => { + downcast_temporal_array!( + array => { + date_trunc(array, format) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("date_trunc does not support", dt), + ) + } + } +} + +pub(crate) fn date_trunc( + array: &PrimitiveArray, + format: String, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + let builder = Date32Builder::with_capacity(array.len()); + let iter = ArrayIter::new(array); + match array.data_type() { + DataType::Date32 => match format.to_uppercase().as_str() { + "YEAR" | "YYYY" | "YY" => Ok( + as_datetime_with_op::<&PrimitiveArray, T, _>(iter, builder, |dt| { + as_days_from_unix_epoch(trunc_date_to_year(dt)) + }), + ), + "QUARTER" => Ok(as_datetime_with_op::<&PrimitiveArray, T, _>( + iter, + builder, + |dt| as_days_from_unix_epoch(trunc_date_to_quarter(dt)), + )), + "MONTH" | "MON" | "MM" => Ok( + as_datetime_with_op::<&PrimitiveArray, T, _>(iter, builder, |dt| { + as_days_from_unix_epoch(trunc_date_to_month(dt)) + }), + ), + "WEEK" => Ok(as_datetime_with_op::<&PrimitiveArray, T, _>( + iter, + builder, + |dt| as_days_from_unix_epoch(trunc_date_to_week(dt)), + )), + _ => Err(SparkError::Internal(format!( + "Unsupported format: {:?} for function 'date_trunc'", + format + ))), + }, + dt => return_compute_error_with!( + "Unsupported input type '{:?}' for function 'date_trunc'", + dt + ), + } +} + +/// +/// Implements the spark [TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#trunc) +/// function where the specified format may be an array +/// +/// array is an array of Date32 values. The array may be a dictionary array. +/// +/// format is an array of strings specifying the format to apply to the corresponding date value. +/// The array may be a dictionary array. +pub(crate) fn date_trunc_array_fmt_dyn( + array: &dyn Array, + formats: &dyn Array, +) -> Result { + match (array.data_type().clone(), formats.data_type().clone()) { + (DataType::Dictionary(_, v), DataType::Dictionary(_, f)) => { + if !matches!(*v, DataType::Date32) { + return_compute_error_with!("date_trunc does not support", v) + } + if !matches!(*f, DataType::Utf8) { + return_compute_error_with!("date_trunc does not support format type ", f) + } + downcast_dictionary_array!( + formats => { + downcast_dictionary_array!( + array => { + date_trunc_array_fmt_dict_dict( + &array.downcast_dict::().unwrap(), + &formats.downcast_dict::().unwrap()) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("date_trunc does not support", dt) + ) + } + fmt => return_compute_error_with!("date_trunc does not support format type", fmt), + ) + } + (DataType::Dictionary(_, v), DataType::Utf8) => { + if !matches!(*v, DataType::Date32) { + return_compute_error_with!("date_trunc does not support", v) + } + downcast_dictionary_array!( + array => { + date_trunc_array_fmt_dict_plain( + &array.downcast_dict::().unwrap(), + formats.as_any().downcast_ref::() + .expect("Unexpected value type in formats")) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("date_trunc does not support", dt), + ) + } + (DataType::Date32, DataType::Dictionary(_, f)) => { + if !matches!(*f, DataType::Utf8) { + return_compute_error_with!("date_trunc does not support format type ", f) + } + downcast_dictionary_array!( + formats => { + downcast_temporal_array!(array => { + date_trunc_array_fmt_plain_dict( + array.as_any().downcast_ref::() + .expect("Unexpected error in casting date array"), + &formats.downcast_dict::().unwrap()) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("date_trunc does not support", dt), + ) + } + fmt => return_compute_error_with!("date_trunc does not support format type", fmt), + ) + } + (DataType::Date32, DataType::Utf8) => date_trunc_array_fmt_plain_plain( + array + .as_any() + .downcast_ref::() + .expect("Unexpected error in casting date array"), + formats + .as_any() + .downcast_ref::() + .expect("Unexpected value type in formats"), + ) + .map(|a| Arc::new(a) as ArrayRef), + (dt, fmt) => Err(SparkError::Internal(format!( + "Unsupported datatype: {:}, format: {:?} for function 'date_trunc'", + dt, fmt + ))), + } +} + +macro_rules! date_trunc_array_fmt_helper { + ($array: ident, $formats: ident, $datatype: ident) => {{ + let mut builder = Date32Builder::with_capacity($array.len()); + let iter = $array.into_iter(); + match $datatype { + DataType::Date32 => { + for (index, val) in iter.enumerate() { + let op_result = match $formats.value(index).to_uppercase().as_str() { + "YEAR" | "YYYY" | "YY" => { + Ok(as_datetime_with_op_single(val, &mut builder, |dt| { + as_days_from_unix_epoch(trunc_date_to_year(dt)) + })) + } + "QUARTER" => { + Ok(as_datetime_with_op_single(val, &mut builder, |dt| { + as_days_from_unix_epoch(trunc_date_to_quarter(dt)) + })) + } + "MONTH" | "MON" | "MM" => { + Ok(as_datetime_with_op_single(val, &mut builder, |dt| { + as_days_from_unix_epoch(trunc_date_to_month(dt)) + })) + } + "WEEK" => { + Ok(as_datetime_with_op_single(val, &mut builder, |dt| { + as_days_from_unix_epoch(trunc_date_to_week(dt)) + })) + } + _ => Err(SparkError::Internal(format!( + "Unsupported format: {:?} for function 'date_trunc'", + $formats.value(index) + ))), + }; + op_result? + } + Ok(builder.finish()) + } + dt => return_compute_error_with!( + "Unsupported input type '{:?}' for function 'date_trunc'", + dt + ), + } + }}; +} + +fn date_trunc_array_fmt_plain_plain( + array: &Date32Array, + formats: &StringArray, +) -> Result +where +{ + let data_type = array.data_type(); + date_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn date_trunc_array_fmt_plain_dict( + array: &Date32Array, + formats: &TypedDictionaryArray, +) -> Result +where + K: ArrowDictionaryKeyType, +{ + let data_type = array.data_type(); + date_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn date_trunc_array_fmt_dict_plain( + array: &TypedDictionaryArray, + formats: &StringArray, +) -> Result +where + K: ArrowDictionaryKeyType, +{ + let data_type = array.values().data_type(); + date_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn date_trunc_array_fmt_dict_dict( + array: &TypedDictionaryArray, + formats: &TypedDictionaryArray, +) -> Result +where + K: ArrowDictionaryKeyType, + F: ArrowDictionaryKeyType, +{ + let data_type = array.values().data_type(); + date_trunc_array_fmt_helper!(array, formats, data_type) +} + +/// +/// Implements the spark [DATE_TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc) +/// function where the specified format is a scalar value +/// +/// array is an array of Timestamp(Microsecond) values. Timestamp values must have a valid +/// timezone or no timezone. The array may be a dictionary array. +/// +/// format is a scalar string specifying the format to apply to the timestamp value. +pub(crate) fn timestamp_trunc_dyn( + array: &dyn Array, + format: String, +) -> Result { + match array.data_type().clone() { + DataType::Dictionary(_, _) => { + downcast_dictionary_array!( + array => { + let truncated_values = timestamp_trunc_dyn(array.values(), format)?; + Ok(Arc::new(array.with_values(truncated_values))) + } + dt => return_compute_error_with!("timestamp_trunc does not support", dt), + ) + } + _ => { + downcast_temporal_array!( + array => { + timestamp_trunc(array, format) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("timestamp_trunc does not support", dt), + ) + } + } +} + +pub(crate) fn timestamp_trunc( + array: &PrimitiveArray, + format: String, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + let builder = TimestampMicrosecondBuilder::with_capacity(array.len()); + let iter = ArrayIter::new(array); + match array.data_type() { + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { + match format.to_uppercase().as_str() { + "YEAR" | "YYYY" | "YY" => as_timestamp_tz_with_op::< + &PrimitiveArray, + T, + _, + >(iter, builder, tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_year(dt)) + }), + "QUARTER" => as_timestamp_tz_with_op::<&PrimitiveArray, T, _>( + iter, + builder, + tz, + |dt| as_micros_from_unix_epoch_utc(trunc_date_to_quarter(dt)), + ), + "MONTH" | "MON" | "MM" => as_timestamp_tz_with_op::< + &PrimitiveArray, + T, + _, + >(iter, builder, tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_month(dt)) + }), + "WEEK" => as_timestamp_tz_with_op::<&PrimitiveArray, T, _>( + iter, + builder, + tz, + |dt| as_micros_from_unix_epoch_utc(trunc_date_to_week(dt)), + ), + "DAY" | "DD" => as_timestamp_tz_with_op::<&PrimitiveArray, T, _>( + iter, + builder, + tz, + |dt| as_micros_from_unix_epoch_utc(trunc_date_to_day(dt)), + ), + "HOUR" => as_timestamp_tz_with_op::<&PrimitiveArray, T, _>( + iter, + builder, + tz, + |dt| as_micros_from_unix_epoch_utc(trunc_date_to_hour(dt)), + ), + "MINUTE" => as_timestamp_tz_with_op::<&PrimitiveArray, T, _>( + iter, + builder, + tz, + |dt| as_micros_from_unix_epoch_utc(trunc_date_to_minute(dt)), + ), + "SECOND" => as_timestamp_tz_with_op::<&PrimitiveArray, T, _>( + iter, + builder, + tz, + |dt| as_micros_from_unix_epoch_utc(trunc_date_to_second(dt)), + ), + "MILLISECOND" => as_timestamp_tz_with_op::<&PrimitiveArray, T, _>( + iter, + builder, + tz, + |dt| as_micros_from_unix_epoch_utc(trunc_date_to_ms(dt)), + ), + "MICROSECOND" => as_timestamp_tz_with_op::<&PrimitiveArray, T, _>( + iter, + builder, + tz, + |dt| as_micros_from_unix_epoch_utc(trunc_date_to_microsec(dt)), + ), + _ => Err(SparkError::Internal(format!( + "Unsupported format: {:?} for function 'timestamp_trunc'", + format + ))), + } + } + dt => return_compute_error_with!( + "Unsupported input type '{:?}' for function 'timestamp_trunc'", + dt + ), + } +} + +/// +/// Implements the spark [DATE_TRUNC](https://spark.apache.org/docs/latest/api/sql/index.html#date_trunc) +/// function where the specified format may be an array +/// +/// array is an array of Timestamp(Microsecond) values. Timestamp values must have a valid +/// timezone or no timezone. The array may be a dictionary array. +/// +/// format is an array of strings specifying the format to apply to the corresponding timestamp +/// value. The array may be a dictionary array. +pub(crate) fn timestamp_trunc_array_fmt_dyn( + array: &dyn Array, + formats: &dyn Array, +) -> Result { + match (array.data_type().clone(), formats.data_type().clone()) { + (DataType::Dictionary(_, _), DataType::Dictionary(_, _)) => { + downcast_dictionary_array!( + formats => { + downcast_dictionary_array!( + array => { + timestamp_trunc_array_fmt_dict_dict( + &array.downcast_dict::().unwrap(), + &formats.downcast_dict::().unwrap()) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("timestamp_trunc does not support", dt) + ) + } + fmt => return_compute_error_with!("timestamp_trunc does not support format type", fmt), + ) + } + (DataType::Dictionary(_, _), DataType::Utf8) => { + downcast_dictionary_array!( + array => { + timestamp_trunc_array_fmt_dict_plain( + &array.downcast_dict::>().unwrap(), + formats.as_any().downcast_ref::() + .expect("Unexpected value type in formats")) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("timestamp_trunc does not support", dt), + ) + } + (DataType::Timestamp(TimeUnit::Microsecond, _), DataType::Dictionary(_, _)) => { + downcast_dictionary_array!( + formats => { + downcast_temporal_array!(array => { + timestamp_trunc_array_fmt_plain_dict( + array, + &formats.downcast_dict::().unwrap()) + .map(|a| Arc::new(a) as ArrayRef) + } + dt => return_compute_error_with!("timestamp_trunc does not support", dt), + ) + } + fmt => return_compute_error_with!("timestamp_trunc does not support format type", fmt), + ) + } + (DataType::Timestamp(TimeUnit::Microsecond, _), DataType::Utf8) => { + downcast_temporal_array!( + array => { + timestamp_trunc_array_fmt_plain_plain(array, + formats.as_any().downcast_ref::().expect("Unexpected value type in formats")) + .map(|a| Arc::new(a) as ArrayRef) + }, + dt => return_compute_error_with!("timestamp_trunc does not support", dt), + ) + } + (dt, fmt) => Err(SparkError::Internal(format!( + "Unsupported datatype: {:}, format: {:?} for function 'timestamp_trunc'", + dt, fmt + ))), + } +} + +macro_rules! timestamp_trunc_array_fmt_helper { + ($array: ident, $formats: ident, $datatype: ident) => {{ + let mut builder = TimestampMicrosecondBuilder::with_capacity($array.len()); + let iter = $array.into_iter(); + assert_eq!( + $array.len(), + $formats.len(), + "lengths of values array and format array must be the same" + ); + match $datatype { + DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => { + let tz: Tz = tz.parse()?; + for (index, val) in iter.enumerate() { + let op_result = match $formats.value(index).to_uppercase().as_str() { + "YEAR" | "YYYY" | "YY" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_year(dt)) + }) + } + "QUARTER" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_quarter(dt)) + }) + } + "MONTH" | "MON" | "MM" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_month(dt)) + }) + } + "WEEK" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_week(dt)) + }) + } + "DAY" | "DD" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_day(dt)) + }) + } + "HOUR" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_hour(dt)) + }) + } + "MINUTE" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_minute(dt)) + }) + } + "SECOND" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_second(dt)) + }) + } + "MILLISECOND" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_ms(dt)) + }) + } + "MICROSECOND" => { + as_timestamp_tz_with_op_single::(val, &mut builder, &tz, |dt| { + as_micros_from_unix_epoch_utc(trunc_date_to_microsec(dt)) + }) + } + _ => Err(SparkError::Internal(format!( + "Unsupported format: {:?} for function 'timestamp_trunc'", + $formats.value(index) + ))), + }; + op_result? + } + Ok(builder.finish()) + } + dt => { + return_compute_error_with!( + "Unsupported input type '{:?}' for function 'timestamp_trunc'", + dt + ) + } + } + }}; +} + +fn timestamp_trunc_array_fmt_plain_plain( + array: &PrimitiveArray, + formats: &StringArray, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, +{ + let data_type = array.data_type(); + timestamp_trunc_array_fmt_helper!(array, formats, data_type) +} +fn timestamp_trunc_array_fmt_plain_dict( + array: &PrimitiveArray, + formats: &TypedDictionaryArray, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, + K: ArrowDictionaryKeyType, +{ + let data_type = array.data_type(); + timestamp_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn timestamp_trunc_array_fmt_dict_plain( + array: &TypedDictionaryArray>, + formats: &StringArray, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, + K: ArrowDictionaryKeyType, +{ + let data_type = array.values().data_type(); + timestamp_trunc_array_fmt_helper!(array, formats, data_type) +} + +fn timestamp_trunc_array_fmt_dict_dict( + array: &TypedDictionaryArray>, + formats: &TypedDictionaryArray, +) -> Result +where + T: ArrowTemporalType + ArrowNumericType, + i64: From, + K: ArrowDictionaryKeyType, + F: ArrowDictionaryKeyType, +{ + let data_type = array.values().data_type(); + timestamp_trunc_array_fmt_helper!(array, formats, data_type) +} + +#[cfg(test)] +mod tests { + use crate::kernels::temporal::{ + date_trunc, date_trunc_array_fmt_dyn, timestamp_trunc, + timestamp_trunc_array_fmt_dyn, + }; + use arrow_array::{ + builder::{PrimitiveDictionaryBuilder, StringDictionaryBuilder}, + iterator::ArrayIter, + types::{Date32Type, Int32Type, TimestampMicrosecondType}, + Array, Date32Array, PrimitiveArray, StringArray, TimestampMicrosecondArray, + }; + use std::sync::Arc; + + #[test] + #[cfg_attr(miri, ignore)] // test takes too long with miri + fn test_date_trunc() { + let size = 1000; + let mut vec: Vec = Vec::with_capacity(size); + for i in 0..size { + vec.push(i as i32); + } + let array = Date32Array::from(vec); + for fmt in [ + "YEAR", "YYYY", "YY", "QUARTER", "MONTH", "MON", "MM", "WEEK", + ] { + match date_trunc(&array, fmt.to_string()) { + Ok(a) => { + for i in 0..size { + assert!(array.values().get(i) >= a.values().get(i)) + } + } + _ => unreachable!(), + } + } + } + + #[test] + // This test only verifies that the various input array types work. Actually correctness to + // ensure this produces the same results as spark is verified in the JVM tests + fn test_date_trunc_array_fmt_dyn() { + let size = 10; + let formats = [ + "YEAR", "YYYY", "YY", "QUARTER", "MONTH", "MON", "MM", "WEEK", + ]; + let mut vec: Vec = Vec::with_capacity(size * formats.len()); + let mut fmt_vec: Vec<&str> = Vec::with_capacity(size * formats.len()); + for i in 0..size { + for fmt_value in &formats { + vec.push(i as i32 * 1_000_001); + fmt_vec.push(fmt_value); + } + } + + // timestamp array + let array = Date32Array::from(vec); + + // formats array + let fmt_array = StringArray::from(fmt_vec); + + // timestamp dictionary array + let mut date_dict_builder = + PrimitiveDictionaryBuilder::::new(); + for v in array.iter() { + date_dict_builder + .append(v.unwrap()) + .expect("Error in building timestamp array"); + } + let mut array_dict = date_dict_builder.finish(); + // apply timezone + array_dict = array_dict.with_values(Arc::new( + array_dict + .values() + .as_any() + .downcast_ref::() + .unwrap() + .clone(), + )); + + // formats dictionary array + let mut formats_dict_builder = StringDictionaryBuilder::::new(); + for v in fmt_array.iter() { + formats_dict_builder + .append(v.unwrap()) + .expect("Error in building formats array"); + } + let fmt_dict = formats_dict_builder.finish(); + + // verify input arrays + let iter = ArrayIter::new(&array); + let mut dict_iter = array_dict + .downcast_dict::>() + .unwrap() + .into_iter(); + for val in iter { + assert_eq!( + dict_iter + .next() + .expect("array and dictionary array do not match"), + val + ) + } + + // verify input format arrays + let fmt_iter = ArrayIter::new(&fmt_array); + let mut fmt_dict_iter = + fmt_dict.downcast_dict::().unwrap().into_iter(); + for val in fmt_iter { + assert_eq!( + fmt_dict_iter + .next() + .expect("formats and dictionary formats do not match"), + val + ) + } + + // test cases + if let Ok(a) = date_trunc_array_fmt_dyn(&array, &fmt_array) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any().downcast_ref::().unwrap().value(i) + ) + } + } else { + unreachable!() + } + if let Ok(a) = date_trunc_array_fmt_dyn(&array_dict, &fmt_array) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any().downcast_ref::().unwrap().value(i) + ) + } + } else { + unreachable!() + } + if let Ok(a) = date_trunc_array_fmt_dyn(&array, &fmt_dict) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any().downcast_ref::().unwrap().value(i) + ) + } + } else { + unreachable!() + } + if let Ok(a) = date_trunc_array_fmt_dyn(&array_dict, &fmt_dict) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any().downcast_ref::().unwrap().value(i) + ) + } + } else { + unreachable!() + } + } + + #[test] + #[cfg_attr(miri, ignore)] // test takes too long with miri + fn test_timestamp_trunc() { + let size = 1000; + let mut vec: Vec = Vec::with_capacity(size); + for i in 0..size { + vec.push(i as i64); + } + let array = TimestampMicrosecondArray::from(vec).with_timezone_utc(); + for fmt in [ + "YEAR", + "YYYY", + "YY", + "QUARTER", + "MONTH", + "MON", + "MM", + "WEEK", + "DAY", + "DD", + "HOUR", + "MINUTE", + "SECOND", + "MILLISECOND", + "MICROSECOND", + ] { + match timestamp_trunc(&array, fmt.to_string()) { + Ok(a) => { + for i in 0..size { + assert!(array.values().get(i) >= a.values().get(i)) + } + } + _ => unreachable!(), + } + } + } + + #[test] + // test takes too long with miri + #[cfg_attr(miri, ignore)] + // This test only verifies that the various input array types work. Actually correctness to + // ensure this produces the same results as spark is verified in the JVM tests + fn test_timestamp_trunc_array_fmt_dyn() { + let size = 10; + let formats = [ + "YEAR", + "YYYY", + "YY", + "QUARTER", + "MONTH", + "MON", + "MM", + "WEEK", + "DAY", + "DD", + "HOUR", + "MINUTE", + "SECOND", + "MILLISECOND", + "MICROSECOND", + ]; + let mut vec: Vec = Vec::with_capacity(size * formats.len()); + let mut fmt_vec: Vec<&str> = Vec::with_capacity(size * formats.len()); + for i in 0..size { + for fmt_value in &formats { + vec.push(i as i64 * 1_000_000_001); + fmt_vec.push(fmt_value); + } + } + + // timestamp array + let array = TimestampMicrosecondArray::from(vec).with_timezone_utc(); + + // formats array + let fmt_array = StringArray::from(fmt_vec); + + // timestamp dictionary array + let mut timestamp_dict_builder = + PrimitiveDictionaryBuilder::::new(); + for v in array.iter() { + timestamp_dict_builder + .append(v.unwrap()) + .expect("Error in building timestamp array"); + } + let mut array_dict = timestamp_dict_builder.finish(); + // apply timezone + array_dict = array_dict.with_values(Arc::new( + array_dict + .values() + .as_any() + .downcast_ref::() + .unwrap() + .clone() + .with_timezone_utc(), + )); + + // formats dictionary array + let mut formats_dict_builder = StringDictionaryBuilder::::new(); + for v in fmt_array.iter() { + formats_dict_builder + .append(v.unwrap()) + .expect("Error in building formats array"); + } + let fmt_dict = formats_dict_builder.finish(); + + // verify input arrays + let iter = ArrayIter::new(&array); + let mut dict_iter = array_dict + .downcast_dict::>() + .unwrap() + .into_iter(); + for val in iter { + assert_eq!( + dict_iter + .next() + .expect("array and dictionary array do not match"), + val + ) + } + + // verify input format arrays + let fmt_iter = ArrayIter::new(&fmt_array); + let mut fmt_dict_iter = + fmt_dict.downcast_dict::().unwrap().into_iter(); + for val in fmt_iter { + assert_eq!( + fmt_dict_iter + .next() + .expect("formats and dictionary formats do not match"), + val + ) + } + + // test cases + if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array, &fmt_array) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any() + .downcast_ref::() + .unwrap() + .value(i) + ) + } + } else { + unreachable!() + } + if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array_dict, &fmt_array) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any() + .downcast_ref::() + .unwrap() + .value(i) + ) + } + } else { + unreachable!() + } + if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array, &fmt_dict) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any() + .downcast_ref::() + .unwrap() + .value(i) + ) + } + } else { + unreachable!() + } + if let Ok(a) = timestamp_trunc_array_fmt_dyn(&array_dict, &fmt_dict) { + for i in 0..array.len() { + assert!( + array.value(i) + >= a.as_any() + .downcast_ref::() + .unwrap() + .value(i) + ) + } + } else { + unreachable!() + } + } +} diff --git a/datafusion/spark/src/lib.rs b/datafusion/spark/src/lib.rs new file mode 100644 index 000000000000..f87648a1c9a7 --- /dev/null +++ b/datafusion/spark/src/lib.rs @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// The clippy throws an error if the reference clone not wrapped into `Arc::clone` +// The lint makes easier for code reader/reviewer separate references clones from more heavyweight ones +#![deny(clippy::clone_on_ref_ptr)] + +mod error; + +mod kernels; +mod static_invoke; +pub use static_invoke::*; + +mod struct_funcs; +pub use struct_funcs::{CreateNamedStruct, GetStructField}; + +mod json_funcs; +pub mod test_common; +pub mod timezone; +mod unbound; +pub use unbound::UnboundColumn; +mod predicate_funcs; +pub mod utils; +pub use predicate_funcs::{spark_isnan, RLike}; + +mod agg_funcs; +mod array_funcs; +mod bitwise_funcs; +mod comet_scalar_funcs; +pub mod hash_funcs; + +mod string_funcs; + +mod datetime_funcs; +pub use agg_funcs::*; + +pub use cast::{spark_cast, Cast, SparkCastOptions}; +mod conditional_funcs; +mod conversion_funcs; +mod math_funcs; + +pub use array_funcs::*; +pub use bitwise_funcs::*; +pub use conditional_funcs::*; +pub use conversion_funcs::*; + +pub use comet_scalar_funcs::create_comet_physical_fun; +pub use datetime_funcs::{ + spark_date_add, spark_date_sub, DateTruncExpr, HourExpr, MinuteExpr, SecondExpr, + TimestampTruncExpr, +}; +pub use error::{SparkError, SparkResult}; +pub use hash_funcs::*; +pub use json_funcs::ToJson; +pub use math_funcs::{ + create_negate_expr, spark_ceil, spark_decimal_div, spark_floor, spark_hex, + spark_make_decimal, spark_round, spark_unhex, spark_unscaled_value, CheckOverflow, + NegativeExpr, NormalizeNaNAndZero, +}; +pub use string_funcs::*; + +/// Spark supports three evaluation modes when evaluating expressions, which affect +/// the behavior when processing input values that are invalid or would result in an +/// error, such as divide by zero errors, and also affects behavior when converting +/// between types. +#[derive(Debug, Hash, PartialEq, Eq, Clone, Copy)] +pub enum EvalMode { + /// Legacy is the default behavior in Spark prior to Spark 4.0. This mode silently ignores + /// or replaces errors during SQL operations. Operations resulting in errors (like + /// division by zero) will produce NULL values instead of failing. Legacy mode also + /// enables implicit type conversions. + Legacy, + /// Adheres to the ANSI SQL standard for error handling by throwing exceptions for + /// operations that result in errors. Does not perform implicit type conversions. + Ansi, + /// Same as Ansi mode, except that it converts errors to NULL values without + /// failing the entire query. + Try, +} + +pub(crate) fn arithmetic_overflow_error(from_type: &str) -> SparkError { + SparkError::ArithmeticOverflow { + from_type: from_type.to_string(), + } +} diff --git a/datafusion/spark/src/math_funcs/ceil.rs b/datafusion/spark/src/math_funcs/ceil.rs new file mode 100644 index 000000000000..923a1d0cb10f --- /dev/null +++ b/datafusion/spark/src/math_funcs/ceil.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::downcast_compute_op; +use crate::math_funcs::utils::{ + get_precision_scale, make_decimal_array, make_decimal_scalar, +}; +use arrow::array::{Float32Array, Float64Array, Int64Array}; +use arrow_array::{Array, ArrowNativeTypeOp}; +use arrow_schema::DataType; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use num::integer::div_ceil; +use std::sync::Arc; + +/// `ceil` function that simulates Spark `ceil` expression +pub fn spark_ceil( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + let value = &args[0]; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float32 => { + let result = + downcast_compute_op!(array, "ceil", ceil, Float32Array, Int64Array); + Ok(ColumnarValue::Array(result?)) + } + DataType::Float64 => { + let result = + downcast_compute_op!(array, "ceil", ceil, Float64Array, Int64Array); + Ok(ColumnarValue::Array(result?)) + } + DataType::Int64 => { + let result = array.as_any().downcast_ref::().unwrap(); + Ok(ColumnarValue::Array(Arc::new(result.clone()))) + } + DataType::Decimal128(_, scale) if *scale > 0 => { + let f = decimal_ceil_f(scale); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_array(array, precision, scale, &f) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function ceil", + other, + ))), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + a.map(|x| x.ceil() as i64), + ))), + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + a.map(|x| x.ceil() as i64), + ))), + ScalarValue::Int64(a) => { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))) + } + ScalarValue::Decimal128(a, _, scale) if *scale > 0 => { + let f = decimal_ceil_f(scale); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_scalar(a, precision, scale, &f) + } + _ => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function ceil", + value.data_type(), + ))), + }, + } +} + +#[inline] +fn decimal_ceil_f(scale: &i8) -> impl Fn(i128) -> i128 { + let div = 10_i128.pow_wrapping(*scale as u32); + move |x: i128| div_ceil(x, div) +} diff --git a/datafusion/spark/src/math_funcs/div.rs b/datafusion/spark/src/math_funcs/div.rs new file mode 100644 index 000000000000..1b6d99065a9a --- /dev/null +++ b/datafusion/spark/src/math_funcs/div.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::math_funcs::utils::get_precision_scale; +use arrow::{ + array::{ArrayRef, AsArray}, + datatypes::Decimal128Type, +}; +use arrow_array::{Array, Decimal128Array}; +use arrow_schema::{DataType, DECIMAL128_MAX_PRECISION}; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::DataFusionError; +use num::{BigInt, Signed, ToPrimitive}; +use std::sync::Arc; + +// Let Decimal(p3, s3) as return type i.e. Decimal(p1, s1) / Decimal(p2, s2) = Decimal(p3, s3). +// Conversely, Decimal(p1, s1) = Decimal(p2, s2) * Decimal(p3, s3). This means that, in order to +// get enough scale that matches with Spark behavior, it requires to widen s1 to s2 + s3 + 1. Since +// both s2 and s3 are 38 at max., s1 is 77 at max. DataFusion division cannot handle such scale > +// Decimal256Type::MAX_SCALE. Therefore, we need to implement this decimal division using BigInt. +pub fn spark_decimal_div( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + let left = &args[0]; + let right = &args[1]; + let (p3, s3) = get_precision_scale(data_type); + + let (left, right): (ArrayRef, ArrayRef) = match (left, right) { + (ColumnarValue::Array(l), ColumnarValue::Array(r)) => { + (Arc::clone(l), Arc::clone(r)) + } + (ColumnarValue::Scalar(l), ColumnarValue::Array(r)) => { + (l.to_array_of_size(r.len())?, Arc::clone(r)) + } + (ColumnarValue::Array(l), ColumnarValue::Scalar(r)) => { + (Arc::clone(l), r.to_array_of_size(l.len())?) + } + (ColumnarValue::Scalar(l), ColumnarValue::Scalar(r)) => { + (l.to_array()?, r.to_array()?) + } + }; + let left = left.as_primitive::(); + let right = right.as_primitive::(); + let (p1, s1) = get_precision_scale(left.data_type()); + let (p2, s2) = get_precision_scale(right.data_type()); + + let l_exp = ((s2 + s3 + 1) as u32).saturating_sub(s1 as u32); + let r_exp = (s1 as u32).saturating_sub((s2 + s3 + 1) as u32); + let result: Decimal128Array = if p1 as u32 + l_exp > DECIMAL128_MAX_PRECISION as u32 + || p2 as u32 + r_exp > DECIMAL128_MAX_PRECISION as u32 + { + let ten = BigInt::from(10); + let l_mul = ten.pow(l_exp); + let r_mul = ten.pow(r_exp); + let five = BigInt::from(5); + let zero = BigInt::from(0); + arrow::compute::kernels::arity::binary(left, right, |l, r| { + let l = BigInt::from(l) * &l_mul; + let r = BigInt::from(r) * &r_mul; + let div = if r.eq(&zero) { zero.clone() } else { &l / &r }; + let res = if div.is_negative() { + div - &five + } else { + div + &five + } / &ten; + res.to_i128().unwrap_or(i128::MAX) + })? + } else { + let l_mul = 10_i128.pow(l_exp); + let r_mul = 10_i128.pow(r_exp); + arrow::compute::kernels::arity::binary(left, right, |l, r| { + let l = l * l_mul; + let r = r * r_mul; + let div = if r == 0 { 0 } else { l / r }; + let res = if div.is_negative() { div - 5 } else { div + 5 } / 10; + res.to_i128().unwrap_or(i128::MAX) + })? + }; + let result = result.with_data_type(DataType::Decimal128(p3, s3)); + Ok(ColumnarValue::Array(Arc::new(result))) +} diff --git a/datafusion/spark/src/math_funcs/floor.rs b/datafusion/spark/src/math_funcs/floor.rs new file mode 100644 index 000000000000..06755493bcb0 --- /dev/null +++ b/datafusion/spark/src/math_funcs/floor.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::downcast_compute_op; +use crate::math_funcs::utils::{ + get_precision_scale, make_decimal_array, make_decimal_scalar, +}; +use arrow::array::{Float32Array, Float64Array, Int64Array}; +use arrow_array::{Array, ArrowNativeTypeOp}; +use arrow_schema::DataType; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use num::integer::div_floor; +use std::sync::Arc; + +/// `floor` function that simulates Spark `floor` expression +pub fn spark_floor( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + let value = &args[0]; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float32 => { + let result = + downcast_compute_op!(array, "floor", floor, Float32Array, Int64Array); + Ok(ColumnarValue::Array(result?)) + } + DataType::Float64 => { + let result = + downcast_compute_op!(array, "floor", floor, Float64Array, Int64Array); + Ok(ColumnarValue::Array(result?)) + } + DataType::Int64 => { + let result = array.as_any().downcast_ref::().unwrap(); + Ok(ColumnarValue::Array(Arc::new(result.clone()))) + } + DataType::Decimal128(_, scale) if *scale > 0 => { + let f = decimal_floor_f(scale); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_array(array, precision, scale, &f) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function floor", + other, + ))), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + a.map(|x| x.floor() as i64), + ))), + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Int64( + a.map(|x| x.floor() as i64), + ))), + ScalarValue::Int64(a) => { + Ok(ColumnarValue::Scalar(ScalarValue::Int64(a.map(|x| x)))) + } + ScalarValue::Decimal128(a, _, scale) if *scale > 0 => { + let f = decimal_floor_f(scale); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_scalar(a, precision, scale, &f) + } + _ => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function floor", + value.data_type(), + ))), + }, + } +} + +#[inline] +fn decimal_floor_f(scale: &i8) -> impl Fn(i128) -> i128 { + let div = 10_i128.pow_wrapping(*scale as u32); + move |x: i128| div_floor(x, div) +} diff --git a/datafusion/spark/src/math_funcs/hex.rs b/datafusion/spark/src/math_funcs/hex.rs new file mode 100644 index 000000000000..bedcfd679d16 --- /dev/null +++ b/datafusion/spark/src/math_funcs/hex.rs @@ -0,0 +1,297 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::{ + array::{as_dictionary_array, as_largestring_array, as_string_array}, + datatypes::Int32Type, +}; +use arrow_array::StringArray; +use arrow_schema::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{ + cast::{as_binary_array, as_fixed_size_binary_array, as_int64_array}, + exec_err, DataFusionError, +}; +use std::fmt::Write; + +fn hex_int64(num: i64) -> String { + format!("{:X}", num) +} + +#[inline(always)] +fn hex_encode>(data: T, lower_case: bool) -> String { + let mut s = String::with_capacity(data.as_ref().len() * 2); + if lower_case { + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02x}").unwrap(); + } + } else { + for b in data.as_ref() { + // Writing to a string never errors, so we can unwrap here. + write!(&mut s, "{b:02X}").unwrap(); + } + } + s +} + +#[inline(always)] +pub(crate) fn hex_strings>(data: T) -> String { + hex_encode(data, true) +} + +#[inline(always)] +fn hex_bytes>(bytes: T) -> Result { + let hex_string = hex_encode(bytes, false); + Ok(hex_string) +} + +/// Spark-compatible `hex` function +pub fn spark_hex(args: &[ColumnarValue]) -> Result { + if args.len() != 1 { + return Err(DataFusionError::Internal( + "hex expects exactly one argument".to_string(), + )); + } + + match &args[0] { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Int64 => { + let array = as_int64_array(array)?; + + let hexed_array: StringArray = + array.iter().map(|v| v.map(hex_int64)).collect(); + + Ok(ColumnarValue::Array(Arc::new(hexed_array))) + } + DataType::Utf8 => { + let array = as_string_array(array); + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::LargeUtf8 => { + let array = as_largestring_array(array); + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::Binary => { + let array = as_binary_array(array)?; + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::FixedSizeBinary(_) => { + let array = as_fixed_size_binary_array(array)?; + + let hexed: StringArray = array + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?; + + Ok(ColumnarValue::Array(Arc::new(hexed))) + } + DataType::Dictionary(_, value_type) => { + let dict = as_dictionary_array::(&array); + + let values = match **value_type { + DataType::Int64 => as_int64_array(dict.values())? + .iter() + .map(|v| v.map(hex_int64)) + .collect::>(), + DataType::Utf8 => as_string_array(dict.values()) + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?, + DataType::Binary => as_binary_array(dict.values())? + .iter() + .map(|v| v.map(hex_bytes).transpose()) + .collect::>()?, + _ => exec_err!( + "hex got an unexpected argument type: {:?}", + array.data_type() + )?, + }; + + let new_values: Vec> = dict + .keys() + .iter() + .map(|key| key.map(|k| values[k as usize].clone()).unwrap_or(None)) + .collect(); + + let string_array_values = StringArray::from(new_values); + + Ok(ColumnarValue::Array(Arc::new(string_array_values))) + } + _ => exec_err!( + "hex got an unexpected argument type: {:?}", + array.data_type() + ), + }, + _ => exec_err!("native hex does not support scalar values at this time"), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::{ + array::{ + as_string_array, BinaryDictionaryBuilder, PrimitiveDictionaryBuilder, + StringBuilder, StringDictionaryBuilder, + }, + datatypes::{Int32Type, Int64Type}, + }; + use arrow_array::{Int64Array, StringArray}; + use datafusion::logical_expr::ColumnarValue; + + #[test] + fn test_dictionary_hex_utf8() { + let mut input_builder = StringDictionaryBuilder::::new(); + input_builder.append_value("hi"); + input_builder.append_value("bye"); + input_builder.append_null(); + input_builder.append_value("rust"); + let input = input_builder.finish(); + + let mut string_builder = StringBuilder::new(); + string_builder.append_value("6869"); + string_builder.append_value("627965"); + string_builder.append_null(); + string_builder.append_value("72757374"); + let expected = string_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_dictionary_hex_int64() { + let mut input_builder = PrimitiveDictionaryBuilder::::new(); + input_builder.append_value(1); + input_builder.append_value(2); + input_builder.append_null(); + input_builder.append_value(3); + let input = input_builder.finish(); + + let mut string_builder = StringBuilder::new(); + string_builder.append_value("1"); + string_builder.append_value("2"); + string_builder.append_null(); + string_builder.append_value("3"); + let expected = string_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_dictionary_hex_binary() { + let mut input_builder = BinaryDictionaryBuilder::::new(); + input_builder.append_value("1"); + input_builder.append_value("j"); + input_builder.append_null(); + input_builder.append_value("3"); + let input = input_builder.finish(); + + let mut expected_builder = StringBuilder::new(); + expected_builder.append_value("31"); + expected_builder.append_value("6A"); + expected_builder.append_null(); + expected_builder.append_value("33"); + let expected = expected_builder.finish(); + + let columnar_value = ColumnarValue::Array(Arc::new(input)); + let result = super::spark_hex(&[columnar_value]).unwrap(); + + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let result = as_string_array(&result); + + assert_eq!(result, &expected); + } + + #[test] + fn test_hex_int64() { + let num = 1234; + let hexed = super::hex_int64(num); + assert_eq!(hexed, "4D2".to_string()); + + let num = -1; + let hexed = super::hex_int64(num); + assert_eq!(hexed, "FFFFFFFFFFFFFFFF".to_string()); + } + + #[test] + fn test_spark_hex_int64() { + let int_array = Int64Array::from(vec![Some(1), Some(2), None, Some(3)]); + let columnar_value = ColumnarValue::Array(Arc::new(int_array)); + + let result = super::spark_hex(&[columnar_value]).unwrap(); + let result = match result { + ColumnarValue::Array(array) => array, + _ => panic!("Expected array"), + }; + + let string_array = as_string_array(&result); + let expected_array = StringArray::from(vec![ + Some("1".to_string()), + Some("2".to_string()), + None, + Some("3".to_string()), + ]); + + assert_eq!(string_array, &expected_array); + } +} diff --git a/datafusion/spark/src/math_funcs/internal/checkoverflow.rs b/datafusion/spark/src/math_funcs/internal/checkoverflow.rs new file mode 100644 index 000000000000..2d38f5e2f39f --- /dev/null +++ b/datafusion/spark/src/math_funcs/internal/checkoverflow.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{as_primitive_array, Array, ArrayRef, Decimal128Array}, + datatypes::{Decimal128Type, DecimalType}, + record_batch::RecordBatch, +}; +use arrow_schema::{DataType, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Display, Formatter}, + sync::Arc, +}; + +/// This is from Spark `CheckOverflow` expression. Spark `CheckOverflow` expression rounds decimals +/// to given scale and check if the decimals can fit in given precision. As `cast` kernel rounds +/// decimals already, Comet `CheckOverflow` expression only checks if the decimals can fit in the +/// precision. +#[derive(Debug, Eq)] +pub struct CheckOverflow { + pub child: Arc, + pub data_type: DataType, + pub fail_on_error: bool, +} + +impl Hash for CheckOverflow { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.data_type.hash(state); + self.fail_on_error.hash(state); + } +} + +impl PartialEq for CheckOverflow { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) + && self.data_type.eq(&other.data_type) + && self.fail_on_error.eq(&other.fail_on_error) + } +} + +impl CheckOverflow { + pub fn new( + child: Arc, + data_type: DataType, + fail_on_error: bool, + ) -> Self { + Self { + child, + data_type, + fail_on_error, + } + } +} + +impl Display for CheckOverflow { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "CheckOverflow [datatype: {}, fail_on_error: {}, child: {}]", + self.data_type, self.fail_on_error, self.child + ) + } +} + +impl PhysicalExpr for CheckOverflow { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _: &Schema) -> datafusion_common::Result { + Ok(self.data_type.clone()) + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) + if matches!(array.data_type(), DataType::Decimal128(_, _)) => + { + let (precision, scale) = match &self.data_type { + DataType::Decimal128(p, s) => (p, s), + dt => { + return Err(DataFusionError::Execution(format!( + "CheckOverflow expects only Decimal128, but got {:?}", + dt + ))) + } + }; + + let decimal_array = as_primitive_array::(&array); + + let casted_array = if self.fail_on_error { + // Returning error if overflow + decimal_array.validate_decimal_precision(*precision)?; + decimal_array + } else { + // Overflowing gets null value + &decimal_array.null_if_overflow_precision(*precision) + }; + + let new_array = Decimal128Array::from(casted_array.into_data()) + .with_precision_and_scale(*precision, *scale) + .map(|a| Arc::new(a) as ArrayRef)?; + + Ok(ColumnarValue::Array(new_array)) + } + ColumnarValue::Scalar(ScalarValue::Decimal128(v, precision, scale)) => { + // `fail_on_error` is only true when ANSI is enabled, which we don't support yet + // (Java side will simply fallback to Spark when it is enabled) + assert!( + !self.fail_on_error, + "fail_on_error (ANSI mode) is not supported yet" + ); + + let new_v: Option = v.and_then(|v| { + Decimal128Type::validate_decimal_precision(v, precision) + .map(|_| v) + .ok() + }); + + Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + new_v, precision, scale, + ))) + } + v => Err(DataFusionError::Execution(format!( + "CheckOverflow's child expression should be decimal array, but found {:?}", + v + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(CheckOverflow::new( + Arc::clone(&children[0]), + self.data_type.clone(), + self.fail_on_error, + ))) + } +} diff --git a/datafusion/spark/src/math_funcs/internal/make_decimal.rs b/datafusion/spark/src/math_funcs/internal/make_decimal.rs new file mode 100644 index 000000000000..9fdb61c4d0ce --- /dev/null +++ b/datafusion/spark/src/math_funcs/internal/make_decimal.rs @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::math_funcs::utils::get_precision_scale; +use arrow::{ + array::{AsArray, Decimal128Builder}, + datatypes::{validate_decimal_precision, Int64Type}, +}; +use arrow_schema::DataType; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; +use std::sync::Arc; + +/// Spark-compatible `MakeDecimal` expression (internal to Spark optimizer) +pub fn spark_make_decimal( + args: &[ColumnarValue], + data_type: &DataType, +) -> DataFusionResult { + let (precision, scale) = get_precision_scale(data_type); + match &args[0] { + ColumnarValue::Scalar(v) => match v { + ScalarValue::Int64(n) => Ok(ColumnarValue::Scalar(ScalarValue::Decimal128( + long_to_decimal(n, precision), + precision, + scale, + ))), + sv => internal_err!("Expected Int64 but found {sv:?}"), + }, + ColumnarValue::Array(a) => { + let arr = a.as_primitive::(); + let mut result = Decimal128Builder::new(); + for v in arr.into_iter() { + result.append_option(long_to_decimal(&v, precision)) + } + let result_type = DataType::Decimal128(precision, scale); + + Ok(ColumnarValue::Array(Arc::new( + result.finish().with_data_type(result_type), + ))) + } + } +} + +/// Convert the input long to decimal with the given maximum precision. If overflows, returns null +/// instead. +#[inline] +fn long_to_decimal(v: &Option, precision: u8) -> Option { + match v { + Some(v) if validate_decimal_precision(*v as i128, precision).is_ok() => { + Some(*v as i128) + } + _ => None, + } +} diff --git a/datafusion/spark/src/math_funcs/internal/mod.rs b/datafusion/spark/src/math_funcs/internal/mod.rs new file mode 100644 index 000000000000..29295f0d524d --- /dev/null +++ b/datafusion/spark/src/math_funcs/internal/mod.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod checkoverflow; +mod make_decimal; +mod normalize_nan; +mod unscaled_value; + +pub use checkoverflow::CheckOverflow; +pub use make_decimal::spark_make_decimal; +pub use normalize_nan::NormalizeNaNAndZero; +pub use unscaled_value::spark_unscaled_value; diff --git a/datafusion/spark/src/math_funcs/internal/normalize_nan.rs b/datafusion/spark/src/math_funcs/internal/normalize_nan.rs new file mode 100644 index 000000000000..078ce4b5a4b1 --- /dev/null +++ b/datafusion/spark/src/math_funcs/internal/normalize_nan.rs @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{as_primitive_array, ArrayAccessor, ArrayIter, Float32Array, Float64Array}, + datatypes::{ArrowNativeType, Float32Type, Float64Type}, + record_batch::RecordBatch, +}; +use arrow_schema::{DataType, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_physical_expr::PhysicalExpr; +use std::hash::Hash; +use std::{ + any::Any, + fmt::{Display, Formatter}, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct NormalizeNaNAndZero { + pub data_type: DataType, + pub child: Arc, +} + +impl PartialEq for NormalizeNaNAndZero { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.data_type.eq(&other.data_type) + } +} + +impl Hash for NormalizeNaNAndZero { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.data_type.hash(state); + } +} + +impl NormalizeNaNAndZero { + pub fn new(data_type: DataType, child: Arc) -> Self { + Self { data_type, child } + } +} + +impl PhysicalExpr for NormalizeNaNAndZero { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + self.child.data_type(input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> datafusion_common::Result { + self.child.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let cv = self.child.evaluate(batch)?; + let array = cv.into_array(batch.num_rows())?; + + match &self.data_type { + DataType::Float32 => { + let v = eval_typed(as_primitive_array::(&array)); + let new_array = Float32Array::from(v); + Ok(ColumnarValue::Array(Arc::new(new_array))) + } + DataType::Float64 => { + let v = eval_typed(as_primitive_array::(&array)); + let new_array = Float64Array::from(v); + Ok(ColumnarValue::Array(Arc::new(new_array))) + } + dt => panic!("Unexpected data type {:?}", dt), + } + } + + fn children(&self) -> Vec<&Arc> { + self.child.children() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(NormalizeNaNAndZero::new( + self.data_type.clone(), + Arc::clone(&children[0]), + ))) + } +} + +fn eval_typed>(input: T) -> Vec> { + let iter = ArrayIter::new(input); + iter.map(|o| { + o.map(|v| { + if v.is_nan() { + v.nan() + } else if v.is_neg_zero() { + v.zero() + } else { + v + } + }) + }) + .collect() +} + +impl Display for NormalizeNaNAndZero { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "FloatNormalize [child: {}]", self.child) + } +} + +trait FloatDouble: ArrowNativeType { + fn is_nan(&self) -> bool; + fn nan(&self) -> Self; + fn is_neg_zero(&self) -> bool; + fn zero(&self) -> Self; +} + +impl FloatDouble for f32 { + fn is_nan(&self) -> bool { + f32::is_nan(*self) + } + fn nan(&self) -> Self { + f32::NAN + } + fn is_neg_zero(&self) -> bool { + *self == -0.0 + } + fn zero(&self) -> Self { + 0.0 + } +} +impl FloatDouble for f64 { + fn is_nan(&self) -> bool { + f64::is_nan(*self) + } + fn nan(&self) -> Self { + f64::NAN + } + fn is_neg_zero(&self) -> bool { + *self == -0.0 + } + fn zero(&self) -> Self { + 0.0 + } +} diff --git a/datafusion/spark/src/math_funcs/internal/unscaled_value.rs b/datafusion/spark/src/math_funcs/internal/unscaled_value.rs new file mode 100644 index 000000000000..f45c047ba270 --- /dev/null +++ b/datafusion/spark/src/math_funcs/internal/unscaled_value.rs @@ -0,0 +1,44 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::{ + array::{AsArray, Int64Builder}, + datatypes::Decimal128Type, +}; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{internal_err, Result as DataFusionResult, ScalarValue}; +use std::sync::Arc; + +/// Spark-compatible `UnscaledValue` expression (internal to Spark optimizer) +pub fn spark_unscaled_value(args: &[ColumnarValue]) -> DataFusionResult { + match &args[0] { + ColumnarValue::Scalar(v) => match v { + ScalarValue::Decimal128(d, _, _) => Ok(ColumnarValue::Scalar( + ScalarValue::Int64(d.map(|n| n as i64)), + )), + dt => internal_err!("Expected Decimal128 but found {dt:}"), + }, + ColumnarValue::Array(a) => { + let arr = a.as_primitive::(); + let mut result = Int64Builder::new(); + for v in arr.into_iter() { + result.append_option(v.map(|v| v as i64)); + } + Ok(ColumnarValue::Array(Arc::new(result.finish()))) + } + } +} diff --git a/datafusion/spark/src/math_funcs/mod.rs b/datafusion/spark/src/math_funcs/mod.rs new file mode 100644 index 000000000000..c559ae15c0c3 --- /dev/null +++ b/datafusion/spark/src/math_funcs/mod.rs @@ -0,0 +1,35 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod ceil; +mod div; +mod floor; +pub(crate) mod hex; +pub mod internal; +mod negative; +mod round; +pub(crate) mod unhex; +mod utils; + +pub use ceil::spark_ceil; +pub use div::spark_decimal_div; +pub use floor::spark_floor; +pub use hex::spark_hex; +pub use internal::*; +pub use negative::{create_negate_expr, NegativeExpr}; +pub use round::spark_round; +pub use unhex::spark_unhex; diff --git a/datafusion/spark/src/math_funcs/negative.rs b/datafusion/spark/src/math_funcs/negative.rs new file mode 100644 index 000000000000..b7d66ab96bcb --- /dev/null +++ b/datafusion/spark/src/math_funcs/negative.rs @@ -0,0 +1,281 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::arithmetic_overflow_error; +use crate::SparkError; +use arrow::{compute::kernels::numeric::neg_wrapping, datatypes::IntervalDayTimeType}; +use arrow_array::RecordBatch; +use arrow_buffer::IntervalDayTime; +use arrow_schema::{DataType, Schema}; +use datafusion::{ + logical_expr::{interval_arithmetic::Interval, ColumnarValue}, + physical_expr::PhysicalExpr, +}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_expr::sort_properties::ExprProperties; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; + +pub fn create_negate_expr( + expr: Arc, + fail_on_error: bool, +) -> Result, DataFusionError> { + Ok(Arc::new(NegativeExpr::new(expr, fail_on_error))) +} + +/// Negative expression +#[derive(Debug, Eq)] +pub struct NegativeExpr { + /// Input expression + arg: Arc, + fail_on_error: bool, +} + +impl Hash for NegativeExpr { + fn hash(&self, state: &mut H) { + self.arg.hash(state); + self.fail_on_error.hash(state); + } +} + +impl PartialEq for NegativeExpr { + fn eq(&self, other: &Self) -> bool { + self.arg.eq(&other.arg) && self.fail_on_error.eq(&other.fail_on_error) + } +} + +macro_rules! check_overflow { + ($array:expr, $array_type:ty, $min_val:expr, $type_name:expr) => {{ + let typed_array = $array + .as_any() + .downcast_ref::<$array_type>() + .expect(concat!(stringify!($array_type), " expected")); + for i in 0..typed_array.len() { + if typed_array.value(i) == $min_val { + if $type_name == "byte" || $type_name == "short" { + let value = format!("{:?} caused", typed_array.value(i)); + return Err(arithmetic_overflow_error(value.as_str()).into()); + } + return Err(arithmetic_overflow_error($type_name).into()); + } + } + }}; +} + +impl NegativeExpr { + /// Create new not expression + pub fn new(arg: Arc, fail_on_error: bool) -> Self { + Self { arg, fail_on_error } + } + + /// Get the input expression + pub fn arg(&self) -> &Arc { + &self.arg + } +} + +impl std::fmt::Display for NegativeExpr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "(- {})", self.arg) + } +} + +impl PhysicalExpr for NegativeExpr { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> Result { + self.arg.data_type(input_schema) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.arg.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + let arg = self.arg.evaluate(batch)?; + + // overflow checks only apply in ANSI mode + // datatypes supported are byte, short, integer, long, float, interval + match arg { + ColumnarValue::Array(array) => { + if self.fail_on_error { + match array.data_type() { + DataType::Int8 => { + check_overflow!( + array, + arrow::array::Int8Array, + i8::MIN, + "byte" + ) + } + DataType::Int16 => { + check_overflow!( + array, + arrow::array::Int16Array, + i16::MIN, + "short" + ) + } + DataType::Int32 => { + check_overflow!( + array, + arrow::array::Int32Array, + i32::MIN, + "integer" + ) + } + DataType::Int64 => { + check_overflow!( + array, + arrow::array::Int64Array, + i64::MIN, + "long" + ) + } + DataType::Interval(value) => match value { + arrow::datatypes::IntervalUnit::YearMonth => check_overflow!( + array, + arrow::array::IntervalYearMonthArray, + i32::MIN, + "interval" + ), + arrow::datatypes::IntervalUnit::DayTime => check_overflow!( + array, + arrow::array::IntervalDayTimeArray, + IntervalDayTime::MIN, + "interval" + ), + arrow::datatypes::IntervalUnit::MonthDayNano => { + // Overflow checks are not supported + } + }, + _ => { + // Overflow checks are not supported for other datatypes + } + } + } + let result = neg_wrapping(array.as_ref())?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Scalar(scalar) => { + if self.fail_on_error { + match scalar { + ScalarValue::Int8(value) => { + if value == Some(i8::MIN) { + return Err(arithmetic_overflow_error(" caused").into()); + } + } + ScalarValue::Int16(value) => { + if value == Some(i16::MIN) { + return Err(arithmetic_overflow_error(" caused").into()); + } + } + ScalarValue::Int32(value) => { + if value == Some(i32::MIN) { + return Err(arithmetic_overflow_error("integer").into()); + } + } + ScalarValue::Int64(value) => { + if value == Some(i64::MIN) { + return Err(arithmetic_overflow_error("long").into()); + } + } + ScalarValue::IntervalDayTime(value) => { + let (days, ms) = + IntervalDayTimeType::to_parts(value.unwrap_or_default()); + if days == i32::MIN || ms == i32::MIN { + return Err(arithmetic_overflow_error("interval").into()); + } + } + ScalarValue::IntervalYearMonth(value) => { + if value == Some(i32::MIN) { + return Err(arithmetic_overflow_error("interval").into()); + } + } + _ => { + // Overflow checks are not supported for other datatypes + } + } + } + Ok(ColumnarValue::Scalar((scalar.arithmetic_negate())?)) + } + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.arg] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + Ok(Arc::new(NegativeExpr::new( + Arc::clone(&children[0]), + self.fail_on_error, + ))) + } + + /// Given the child interval of a NegativeExpr, it calculates the NegativeExpr's interval. + /// It replaces the upper and lower bounds after multiplying them with -1. + /// Ex: `(a, b]` => `[-b, -a)` + fn evaluate_bounds(&self, children: &[&Interval]) -> Result { + Interval::try_new( + children[0].upper().arithmetic_negate()?, + children[0].lower().arithmetic_negate()?, + ) + } + + /// Returns a new [`Interval`] of a NegativeExpr that has the existing `interval` given that + /// given the input interval is known to be `children`. + fn propagate_constraints( + &self, + interval: &Interval, + children: &[&Interval], + ) -> Result>> { + let child_interval = children[0]; + + if child_interval.lower() == &ScalarValue::Int32(Some(i32::MIN)) + || child_interval.upper() == &ScalarValue::Int32(Some(i32::MIN)) + || child_interval.lower() == &ScalarValue::Int64(Some(i64::MIN)) + || child_interval.upper() == &ScalarValue::Int64(Some(i64::MIN)) + { + return Err(SparkError::ArithmeticOverflow { + from_type: "long".to_string(), + } + .into()); + } + + let negated_interval = Interval::try_new( + interval.upper().arithmetic_negate()?, + interval.lower().arithmetic_negate()?, + )?; + + Ok(child_interval + .intersect(negated_interval)? + .map(|result| vec![result])) + } + + /// The ordering of a [`NegativeExpr`] is simply the reverse of its child. + fn get_properties(&self, children: &[ExprProperties]) -> Result { + let properties = children[0].clone().with_order(children[0].sort_properties); + Ok(properties) + } +} diff --git a/datafusion/spark/src/math_funcs/round.rs b/datafusion/spark/src/math_funcs/round.rs new file mode 100644 index 000000000000..eae523007795 --- /dev/null +++ b/datafusion/spark/src/math_funcs/round.rs @@ -0,0 +1,151 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::math_funcs::utils::{ + get_precision_scale, make_decimal_array, make_decimal_scalar, +}; +use arrow::array::{Int16Array, Int32Array, Int64Array, Int8Array}; +use arrow_array::{Array, ArrowNativeTypeOp}; +use arrow_schema::DataType; +use datafusion::{functions::math::round::round, physical_plan::ColumnarValue}; +use datafusion_common::{exec_err, internal_err, DataFusionError, ScalarValue}; +use std::{cmp::min, sync::Arc}; + +macro_rules! integer_round { + ($X:expr, $DIV:expr, $HALF:expr) => {{ + let rem = $X % $DIV; + if rem <= -$HALF { + ($X - rem).sub_wrapping($DIV) + } else if rem >= $HALF { + ($X - rem).add_wrapping($DIV) + } else { + $X - rem + } + }}; +} + +macro_rules! round_integer_array { + ($ARRAY:expr, $POINT:expr, $TYPE:ty, $NATIVE:ty) => {{ + let array = $ARRAY.as_any().downcast_ref::<$TYPE>().unwrap(); + let ten: $NATIVE = 10; + let result: $TYPE = if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) { + let half = div / 2; + arrow::compute::kernels::arity::unary(array, |x| integer_round!(x, div, half)) + } else { + arrow::compute::kernels::arity::unary(array, |_| 0) + }; + Ok(ColumnarValue::Array(Arc::new(result))) + }}; +} + +macro_rules! round_integer_scalar { + ($SCALAR:expr, $POINT:expr, $TYPE:expr, $NATIVE:ty) => {{ + let ten: $NATIVE = 10; + if let Some(div) = ten.checked_pow((-(*$POINT)) as u32) { + let half = div / 2; + Ok(ColumnarValue::Scalar($TYPE( + $SCALAR.map(|x| integer_round!(x, div, half)), + ))) + } else { + Ok(ColumnarValue::Scalar($TYPE(Some(0)))) + } + }}; +} + +/// `round` function that simulates Spark `round` expression +pub fn spark_round( + args: &[ColumnarValue], + data_type: &DataType, +) -> Result { + let value = &args[0]; + let point = &args[1]; + let ColumnarValue::Scalar(ScalarValue::Int64(Some(point))) = point else { + return internal_err!("Invalid point argument for Round(): {:#?}", point); + }; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Int64 if *point < 0 => { + round_integer_array!(array, point, Int64Array, i64) + } + DataType::Int32 if *point < 0 => { + round_integer_array!(array, point, Int32Array, i32) + } + DataType::Int16 if *point < 0 => { + round_integer_array!(array, point, Int16Array, i16) + } + DataType::Int8 if *point < 0 => { + round_integer_array!(array, point, Int8Array, i8) + } + DataType::Decimal128(_, scale) if *scale >= 0 => { + let f = decimal_round_f(scale, point); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_array(array, precision, scale, &f) + } + DataType::Float32 | DataType::Float64 => { + Ok(ColumnarValue::Array(round(&[Arc::clone(array)])?)) + } + dt => exec_err!("Not supported datatype for ROUND: {dt}"), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Int64(a) if *point < 0 => { + round_integer_scalar!(a, point, ScalarValue::Int64, i64) + } + ScalarValue::Int32(a) if *point < 0 => { + round_integer_scalar!(a, point, ScalarValue::Int32, i32) + } + ScalarValue::Int16(a) if *point < 0 => { + round_integer_scalar!(a, point, ScalarValue::Int16, i16) + } + ScalarValue::Int8(a) if *point < 0 => { + round_integer_scalar!(a, point, ScalarValue::Int8, i8) + } + ScalarValue::Decimal128(a, _, scale) if *scale >= 0 => { + let f = decimal_round_f(scale, point); + let (precision, scale) = get_precision_scale(data_type); + make_decimal_scalar(a, precision, scale, &f) + } + ScalarValue::Float32(_) | ScalarValue::Float64(_) => { + Ok(ColumnarValue::Scalar(ScalarValue::try_from_array( + &round(&[a.to_array()?])?, + 0, + )?)) + } + dt => exec_err!("Not supported datatype for ROUND: {dt}"), + }, + } +} + +// Spark uses BigDecimal. See RoundBase implementation in Spark. Instead, we do the same by +// 1) add the half of divisor, 2) round down by division, 3) adjust precision by multiplication +#[inline] +fn decimal_round_f(scale: &i8, point: &i64) -> Box i128> { + if *point < 0 { + if let Some(div) = 10_i128.checked_pow((-(*point) as u32) + (*scale as u32)) { + let half = div / 2; + let mul = 10_i128.pow_wrapping((-(*point)) as u32); + // i128 can hold 39 digits of a base 10 number, adding half will not cause overflow + Box::new(move |x: i128| (x + x.signum() * half) / div * mul) + } else { + Box::new(move |_: i128| 0) + } + } else { + let div = + 10_i128.pow_wrapping((*scale as u32) - min(*scale as u32, *point as u32)); + let half = div / 2; + Box::new(move |x: i128| (x + x.signum() * half) / div) + } +} diff --git a/datafusion/spark/src/math_funcs/unhex.rs b/datafusion/spark/src/math_funcs/unhex.rs new file mode 100644 index 000000000000..8c1de2f4cd98 --- /dev/null +++ b/datafusion/spark/src/math_funcs/unhex.rs @@ -0,0 +1,264 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::OffsetSizeTrait; +use arrow_schema::DataType; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{ + cast::as_generic_string_array, exec_err, DataFusionError, ScalarValue, +}; + +/// Helper function to convert a hex digit to a binary value. +fn unhex_digit(c: u8) -> Result { + match c { + b'0'..=b'9' => Ok(c - b'0'), + b'A'..=b'F' => Ok(10 + c - b'A'), + b'a'..=b'f' => Ok(10 + c - b'a'), + _ => Err(DataFusionError::Execution( + "Input to unhex_digit is not a valid hex digit".to_string(), + )), + } +} + +/// Convert a hex string to binary and store the result in `result`. Returns an error if the input +/// is not a valid hex string. +fn unhex(hex_str: &str, result: &mut Vec) -> Result<(), DataFusionError> { + let bytes = hex_str.as_bytes(); + + let mut i = 0; + + if (bytes.len() & 0x01) != 0 { + let v = unhex_digit(bytes[0])?; + + result.push(v); + i += 1; + } + + while i < bytes.len() { + let first = unhex_digit(bytes[i])?; + let second = unhex_digit(bytes[i + 1])?; + result.push((first << 4) | second); + + i += 2; + } + + Ok(()) +} + +fn spark_unhex_inner( + array: &ColumnarValue, + fail_on_error: bool, +) -> Result { + match array { + ColumnarValue::Array(array) => { + let string_array = as_generic_string_array::(array)?; + + let mut encoded = Vec::new(); + let mut builder = arrow::array::BinaryBuilder::new(); + + for item in string_array.iter() { + if let Some(s) = item { + if unhex(s, &mut encoded).is_ok() { + builder.append_value(encoded.as_slice()); + } else if fail_on_error { + return exec_err!( + "Input to unhex is not a valid hex string: {s}" + ); + } else { + builder.append_null(); + } + encoded.clear(); + } else { + builder.append_null(); + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) + } + ColumnarValue::Scalar(ScalarValue::Utf8(Some(string))) => { + let mut encoded = Vec::new(); + + if unhex(string, &mut encoded).is_ok() { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(Some(encoded)))) + } else if fail_on_error { + exec_err!("Input to unhex is not a valid hex string: {string}") + } else { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } + } + ColumnarValue::Scalar(ScalarValue::Utf8(None)) => { + Ok(ColumnarValue::Scalar(ScalarValue::Binary(None))) + } + _ => { + exec_err!( + "The first argument must be a string scalar or array, but got: {:?}", + array + ) + } + } +} + +/// Spark-compatible `unhex` expression +pub fn spark_unhex(args: &[ColumnarValue]) -> Result { + if args.len() > 2 { + return exec_err!("unhex takes at most 2 arguments, but got: {}", args.len()); + } + + let val_to_unhex = &args[0]; + let fail_on_error = if args.len() == 2 { + match &args[1] { + ColumnarValue::Scalar(ScalarValue::Boolean(Some(fail_on_error))) => { + *fail_on_error + } + _ => { + return exec_err!( + "The second argument must be boolean scalar, but got: {:?}", + args[1] + ); + } + } + } else { + false + }; + + match val_to_unhex.data_type() { + DataType::Utf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + DataType::LargeUtf8 => spark_unhex_inner::(val_to_unhex, fail_on_error), + other => exec_err!( + "The first argument must be a Utf8 or LargeUtf8: {:?}", + other + ), + } +} + +#[cfg(test)] +mod test { + use std::sync::Arc; + + use arrow::array::{BinaryBuilder, StringBuilder}; + use arrow_array::make_array; + use arrow_data::ArrayData; + use datafusion::logical_expr::ColumnarValue; + use datafusion_common::ScalarValue; + + use super::unhex; + + #[test] + fn test_spark_unhex_null() -> Result<(), Box> { + let input = ArrayData::new_null(&arrow_schema::DataType::Utf8, 2); + let output = ArrayData::new_null(&arrow_schema::DataType::Binary, 2); + + let input = ColumnarValue::Array(Arc::new(make_array(input))); + let expected = ColumnarValue::Array(Arc::new(make_array(output))); + + let result = super::spark_unhex(&[input])?; + + match (result, expected) { + (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => { + assert_eq!(*result, *expected); + Ok(()) + } + _ => Err("Unexpected result type".into()), + } + } + + #[test] + fn test_partial_error() -> Result<(), Box> { + let mut input = StringBuilder::new(); + + input.append_value("1CGG"); // 1C is ok, but GG is invalid + input.append_value("537061726B2053514C"); // followed by valid + + let input = ColumnarValue::Array(Arc::new(input.finish())); + let fail_on_error = ColumnarValue::Scalar(ScalarValue::Boolean(Some(false))); + + let result = super::spark_unhex(&[input, fail_on_error])?; + + let mut expected = BinaryBuilder::new(); + expected.append_null(); + expected.append_value("Spark SQL".as_bytes()); + + match (result, ColumnarValue::Array(Arc::new(expected.finish()))) { + (ColumnarValue::Array(result), ColumnarValue::Array(expected)) => { + assert_eq!(*result, *expected); + + Ok(()) + } + _ => Err("Unexpected result type".into()), + } + } + + #[test] + fn test_unhex_valid() -> Result<(), Box> { + let mut result = Vec::new(); + + unhex("537061726B2053514C", &mut result)?; + let result_str = std::str::from_utf8(&result)?; + assert_eq!(result_str, "Spark SQL"); + result.clear(); + + unhex("1C", &mut result)?; + assert_eq!(result, vec![28]); + result.clear(); + + unhex("737472696E67", &mut result)?; + assert_eq!(result, "string".as_bytes()); + result.clear(); + + unhex("1", &mut result)?; + assert_eq!(result, vec![1]); + result.clear(); + + Ok(()) + } + + #[test] + fn test_odd_length() -> Result<(), Box> { + let mut result = Vec::new(); + + unhex("A1B", &mut result)?; + assert_eq!(result, vec![10, 27]); + result.clear(); + + unhex("0A1B", &mut result)?; + assert_eq!(result, vec![10, 27]); + result.clear(); + + Ok(()) + } + + #[test] + fn test_unhex_empty() { + let mut result = Vec::new(); + + // Empty hex string + unhex("", &mut result).unwrap(); + assert!(result.is_empty()); + } + + #[test] + fn test_unhex_invalid() { + let mut result = Vec::new(); + + // Invalid hex strings + assert!(unhex("##", &mut result).is_err()); + assert!(unhex("G123", &mut result).is_err()); + assert!(unhex("hello", &mut result).is_err()); + assert!(unhex("\0", &mut result).is_err()); + } +} diff --git a/datafusion/spark/src/math_funcs/utils.rs b/datafusion/spark/src/math_funcs/utils.rs new file mode 100644 index 000000000000..204b7139e4b8 --- /dev/null +++ b/datafusion/spark/src/math_funcs/utils.rs @@ -0,0 +1,74 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::cast::AsArray; +use arrow_array::types::Decimal128Type; +use arrow_array::{ArrayRef, Decimal128Array}; +use arrow_schema::DataType; +use datafusion_common::{DataFusionError, ScalarValue}; +use datafusion_expr_common::columnar_value::ColumnarValue; +use std::sync::Arc; + +#[macro_export] +macro_rules! downcast_compute_op { + ($ARRAY:expr, $NAME:expr, $FUNC:ident, $TYPE:ident, $RESULT:ident) => {{ + let n = $ARRAY.as_any().downcast_ref::<$TYPE>(); + match n { + Some(array) => { + let res: $RESULT = + arrow::compute::kernels::arity::unary(array, |x| x.$FUNC() as i64); + Ok(Arc::new(res)) + } + _ => Err(DataFusionError::Internal(format!( + "Invalid data type for {}", + $NAME + ))), + } + }}; +} + +#[inline] +pub(crate) fn make_decimal_scalar( + a: &Option, + precision: u8, + scale: i8, + f: &dyn Fn(i128) -> i128, +) -> Result { + let result = ScalarValue::Decimal128(a.map(f), precision, scale); + Ok(ColumnarValue::Scalar(result)) +} + +#[inline] +pub(crate) fn make_decimal_array( + array: &ArrayRef, + precision: u8, + scale: i8, + f: &dyn Fn(i128) -> i128, +) -> Result { + let array = array.as_primitive::(); + let result: Decimal128Array = arrow::compute::kernels::arity::unary(array, f); + let result = result.with_data_type(DataType::Decimal128(precision, scale)); + Ok(ColumnarValue::Array(Arc::new(result))) +} + +#[inline] +pub(crate) fn get_precision_scale(data_type: &DataType) -> (u8, i8) { + let DataType::Decimal128(precision, scale) = data_type else { + unreachable!() + }; + (*precision, *scale) +} diff --git a/datafusion/spark/src/predicate_funcs/is_nan.rs b/datafusion/spark/src/predicate_funcs/is_nan.rs new file mode 100644 index 000000000000..094079ddb04c --- /dev/null +++ b/datafusion/spark/src/predicate_funcs/is_nan.rs @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Float32Array, Float64Array}; +use arrow_array::{Array, BooleanArray}; +use arrow_schema::DataType; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue}; +use std::sync::Arc; + +/// Spark-compatible `isnan` expression +pub fn spark_isnan(args: &[ColumnarValue]) -> Result { + fn set_nulls_to_false(is_nan: BooleanArray) -> ColumnarValue { + match is_nan.nulls() { + Some(nulls) => { + let is_not_null = nulls.inner(); + ColumnarValue::Array(Arc::new(BooleanArray::new( + is_nan.values() & is_not_null, + None, + ))) + } + None => ColumnarValue::Array(Arc::new(is_nan)), + } + } + let value = &args[0]; + match value { + ColumnarValue::Array(array) => match array.data_type() { + DataType::Float64 => { + let array = array.as_any().downcast_ref::().unwrap(); + let is_nan = BooleanArray::from_unary(array, |x| x.is_nan()); + Ok(set_nulls_to_false(is_nan)) + } + DataType::Float32 => { + let array = array.as_any().downcast_ref::().unwrap(); + let is_nan = BooleanArray::from_unary(array, |x| x.is_nan()); + Ok(set_nulls_to_false(is_nan)) + } + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function isnan", + other, + ))), + }, + ColumnarValue::Scalar(a) => match a { + ScalarValue::Float64(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean( + Some(a.map(|x| x.is_nan()).unwrap_or(false)), + ))), + ScalarValue::Float32(a) => Ok(ColumnarValue::Scalar(ScalarValue::Boolean( + Some(a.map(|x| x.is_nan()).unwrap_or(false)), + ))), + _ => Err(DataFusionError::Internal(format!( + "Unsupported data type {:?} for function isnan", + value.data_type(), + ))), + }, + } +} diff --git a/datafusion/spark/src/predicate_funcs/mod.rs b/datafusion/spark/src/predicate_funcs/mod.rs new file mode 100644 index 000000000000..5f1f570c0541 --- /dev/null +++ b/datafusion/spark/src/predicate_funcs/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod is_nan; +mod rlike; + +pub use is_nan::spark_isnan; +pub use rlike::RLike; diff --git a/datafusion/spark/src/predicate_funcs/rlike.rs b/datafusion/spark/src/predicate_funcs/rlike.rs new file mode 100644 index 000000000000..bfee0cc769cb --- /dev/null +++ b/datafusion/spark/src/predicate_funcs/rlike.rs @@ -0,0 +1,170 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::SparkError; +use arrow::compute::take; +use arrow_array::builder::BooleanBuilder; +use arrow_array::types::Int32Type; +use arrow_array::{Array, BooleanArray, DictionaryArray, RecordBatch, StringArray}; +use arrow_schema::{DataType, Schema}; +use datafusion::physical_expr_common::physical_expr::DynEq; +use datafusion_common::{internal_err, Result}; +use datafusion_expr::ColumnarValue; +use datafusion_physical_expr::PhysicalExpr; +use regex::Regex; +use std::any::Any; +use std::fmt::{Display, Formatter}; +use std::hash::{Hash, Hasher}; +use std::sync::Arc; + +/// Implementation of RLIKE operator. +/// +/// Note that this implementation is not yet Spark-compatible and simply delegates to +/// the Rust regexp crate. It will match Spark behavior for some simple cases but has +/// differences in whitespace handling and does not support all the features of Java's +/// regular expression engine, which are documented at: +/// +/// +#[derive(Debug)] +pub struct RLike { + child: Arc, + // Only scalar patterns are supported + pattern_str: String, + pattern: Regex, +} + +impl Hash for RLike { + fn hash(&self, state: &mut H) { + state.write(self.pattern_str.as_bytes()); + } +} + +impl DynEq for RLike { + fn dyn_eq(&self, other: &dyn Any) -> bool { + if let Some(other) = other.downcast_ref::() { + self.pattern_str == other.pattern_str + } else { + false + } + } +} + +impl RLike { + pub fn try_new(child: Arc, pattern: &str) -> Result { + Ok(Self { + child, + pattern_str: pattern.to_string(), + pattern: Regex::new(pattern).map_err(|e| { + SparkError::Internal(format!( + "Failed to compile pattern {}: {}", + pattern, e + )) + })?, + }) + } + + fn is_match(&self, inputs: &StringArray) -> BooleanArray { + let mut builder = BooleanBuilder::with_capacity(inputs.len()); + if inputs.is_nullable() { + for i in 0..inputs.len() { + if inputs.is_null(i) { + builder.append_null(); + } else { + builder.append_value(self.pattern.is_match(inputs.value(i))); + } + } + } else { + for i in 0..inputs.len() { + builder.append_value(self.pattern.is_match(inputs.value(i))); + } + } + builder.finish() + } +} + +impl Display for RLike { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "RLike [child: {}, pattern: {}] ", + self.child, self.pattern_str + ) + } +} + +impl PhysicalExpr for RLike { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(DataType::Boolean) + } + + fn nullable(&self, input_schema: &Schema) -> Result { + self.child.nullable(input_schema) + } + + fn evaluate(&self, batch: &RecordBatch) -> Result { + match self.child.evaluate(batch)? { + ColumnarValue::Array(array) + if array.as_any().is::>() => + { + let dict_array = array + .as_any() + .downcast_ref::>() + .expect("dict array"); + let dict_values = dict_array + .values() + .as_any() + .downcast_ref::() + .expect("strings"); + // evaluate the regexp pattern against the dictionary values + let new_values = self.is_match(dict_values); + // convert to conventional (not dictionary-encoded) array + let result = take(&new_values, dict_array.keys(), None)?; + Ok(ColumnarValue::Array(result)) + } + ColumnarValue::Array(array) => { + let inputs = array + .as_any() + .downcast_ref::() + .expect("string array"); + let array = self.is_match(inputs); + Ok(ColumnarValue::Array(Arc::new(array))) + } + ColumnarValue::Scalar(_) => { + internal_err!("non scalar regexp patterns are not supported") + } + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> Result> { + assert!(children.len() == 1); + Ok(Arc::new(RLike::try_new( + Arc::clone(&children[0]), + &self.pattern_str, + )?)) + } +} diff --git a/datafusion/spark/src/static_invoke/char_varchar_utils/mod.rs b/datafusion/spark/src/static_invoke/char_varchar_utils/mod.rs new file mode 100644 index 000000000000..fff6134dab8b --- /dev/null +++ b/datafusion/spark/src/static_invoke/char_varchar_utils/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod read_side_padding; + +pub use read_side_padding::spark_read_side_padding; diff --git a/datafusion/spark/src/static_invoke/char_varchar_utils/read_side_padding.rs b/datafusion/spark/src/static_invoke/char_varchar_utils/read_side_padding.rs new file mode 100644 index 000000000000..2b334af3d699 --- /dev/null +++ b/datafusion/spark/src/static_invoke/char_varchar_utils/read_side_padding.rs @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, OffsetSizeTrait}; +use arrow_array::builder::GenericStringBuilder; +use arrow_array::Array; +use arrow_schema::DataType; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{cast::as_generic_string_array, DataFusionError, ScalarValue}; +use std::fmt::Write; +use std::sync::Arc; + +/// Similar to DataFusion `rpad`, but not to truncate when the string is already longer than length +pub fn spark_read_side_padding( + args: &[ColumnarValue], +) -> Result { + match args { + [ColumnarValue::Array(array), ColumnarValue::Scalar(ScalarValue::Int32(Some(length)))] => + { + match array.data_type() { + DataType::Utf8 => spark_read_side_padding_internal::(array, *length), + DataType::LargeUtf8 => { + spark_read_side_padding_internal::(array, *length) + } + // TODO: handle Dictionary types + other => Err(DataFusionError::Internal(format!( + "Unsupported data type {other:?} for function read_side_padding", + ))), + } + } + other => Err(DataFusionError::Internal(format!( + "Unsupported arguments {other:?} for function read_side_padding", + ))), + } +} + +fn spark_read_side_padding_internal( + array: &ArrayRef, + length: i32, +) -> Result { + let string_array = as_generic_string_array::(array)?; + let length = 0.max(length) as usize; + let space_string = " ".repeat(length); + + let mut builder = GenericStringBuilder::::with_capacity( + string_array.len(), + string_array.len() * length, + ); + + for string in string_array.iter() { + match string { + Some(string) => { + // It looks Spark's UTF8String is closer to chars rather than graphemes + // https://stackoverflow.com/a/46290728 + let char_len = string.chars().count(); + if length <= char_len { + builder.append_value(string); + } else { + // write_str updates only the value buffer, not null nor offset buffer + // This is convenient for concatenating str(s) + builder.write_str(string)?; + builder.append_value(&space_string[char_len..]); + } + } + _ => builder.append_null(), + } + } + Ok(ColumnarValue::Array(Arc::new(builder.finish()))) +} diff --git a/datafusion/spark/src/static_invoke/mod.rs b/datafusion/spark/src/static_invoke/mod.rs new file mode 100644 index 000000000000..4072e13b7075 --- /dev/null +++ b/datafusion/spark/src/static_invoke/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod char_varchar_utils; + +pub use char_varchar_utils::spark_read_side_padding; diff --git a/datafusion/spark/src/string_funcs/chr.rs b/datafusion/spark/src/string_funcs/chr.rs new file mode 100644 index 000000000000..66470b62b315 --- /dev/null +++ b/datafusion/spark/src/string_funcs/chr.rs @@ -0,0 +1,127 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{any::Any, sync::Arc}; + +use arrow::{ + array::{ArrayRef, StringArray}, + datatypes::{ + DataType, + DataType::{Int64, Utf8}, + }, +}; + +use datafusion::logical_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility}; +use datafusion_common::{cast::as_int64_array, exec_err, Result, ScalarValue}; + +fn chr(args: &[ArrayRef]) -> Result { + let integer_array = as_int64_array(&args[0])?; + + // first map is the iterator, second is for the `Option<_>` + let result = integer_array + .iter() + .map(|integer: Option| { + integer + .map(|integer| { + if integer < 0 { + return Ok("".to_string()); // Return empty string for negative integers + } + match core::char::from_u32((integer % 256) as u32) { + Some(ch) => Ok(ch.to_string()), + None => { + exec_err!("requested character not compatible for encoding.") + } + } + }) + .transpose() + }) + .collect::>()?; + + Ok(Arc::new(result) as ArrayRef) +} + +/// Spark-compatible `chr` expression +#[derive(Debug)] +pub struct SparkChrFunc { + signature: Signature, +} + +impl Default for SparkChrFunc { + fn default() -> Self { + Self::new() + } +} + +impl SparkChrFunc { + pub fn new() -> Self { + Self { + signature: Signature::uniform(1, vec![Int64], Volatility::Immutable), + } + } +} + +impl ScalarUDFImpl for SparkChrFunc { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + "chr" + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _arg_types: &[DataType]) -> Result { + Ok(Utf8) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + spark_chr(args) + } +} + +/// Returns the ASCII character having the binary equivalent to the input expression. +/// E.g., chr(65) = 'A'. +/// Compatible with Apache Spark's Chr function +fn spark_chr(args: &[ColumnarValue]) -> Result { + let array = args[0].clone(); + match array { + ColumnarValue::Array(array) => { + let array = chr(&[array])?; + Ok(ColumnarValue::Array(array)) + } + ColumnarValue::Scalar(ScalarValue::Int64(Some(value))) => { + if value < 0 { + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + "".to_string(), + )))) + } else { + match core::char::from_u32((value % 256) as u32) { + Some(ch) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some( + ch.to_string(), + )))), + None => { + exec_err!("requested character was incompatible for encoding.") + } + } + } + } + _ => exec_err!("The argument must be an Int64 array or scalar."), + } +} diff --git a/datafusion/spark/src/string_funcs/mod.rs b/datafusion/spark/src/string_funcs/mod.rs new file mode 100644 index 000000000000..d56b5662c323 --- /dev/null +++ b/datafusion/spark/src/string_funcs/mod.rs @@ -0,0 +1,26 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod chr; +mod prediction; +mod string_space; +mod substring; + +pub use chr::SparkChrFunc; +pub use prediction::*; +pub use string_space::StringSpaceExpr; +pub use substring::SubstringExpr; diff --git a/datafusion/spark/src/string_funcs/prediction.rs b/datafusion/spark/src/string_funcs/prediction.rs new file mode 100644 index 000000000000..d75e3187df99 --- /dev/null +++ b/datafusion/spark/src/string_funcs/prediction.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![allow(deprecated)] + +use arrow::{ + compute::{ + contains_dyn, contains_utf8_scalar_dyn, ends_with_dyn, ends_with_utf8_scalar_dyn, + like_dyn, like_utf8_scalar_dyn, starts_with_dyn, starts_with_utf8_scalar_dyn, + }, + record_batch::RecordBatch, +}; +use arrow_schema::{DataType, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, ScalarValue::Utf8}; +use datafusion_physical_expr::PhysicalExpr; +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::Hash, + sync::Arc, +}; + +macro_rules! make_predicate_function { + ($name: ident, $kernel: ident, $str_scalar_kernel: ident) => { + #[derive(Debug, Eq)] + pub struct $name { + left: Arc, + right: Arc, + } + + impl $name { + pub fn new( + left: Arc, + right: Arc, + ) -> Self { + Self { left, right } + } + } + + impl Display for $name { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "$name [left: {}, right: {}]", self.left, self.right) + } + } + + impl Hash for $name { + fn hash(&self, state: &mut H) { + self.left.hash(state); + self.right.hash(state); + } + } + + impl PartialEq for $name { + fn eq(&self, other: &Self) -> bool { + self.left.eq(&other.left) && self.right.eq(&other.right) + } + } + + impl PhysicalExpr for $name { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, _: &Schema) -> datafusion_common::Result { + Ok(DataType::Boolean) + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate( + &self, + batch: &RecordBatch, + ) -> datafusion_common::Result { + let left_arg = self.left.evaluate(batch)?; + let right_arg = self.right.evaluate(batch)?; + + let array = match (left_arg, right_arg) { + // array (op) scalar + ( + ColumnarValue::Array(array), + ColumnarValue::Scalar(Utf8(Some(string))), + ) => $str_scalar_kernel(&array, string.as_str()), + (ColumnarValue::Array(_), ColumnarValue::Scalar(other)) => { + return Err(DataFusionError::Execution(format!( + "Should be String but got: {:?}", + other + ))) + } + // array (op) array + (ColumnarValue::Array(array1), ColumnarValue::Array(array2)) => { + $kernel(&array1, &array2) + } + // scalar (op) scalar should be folded at Spark optimizer + _ => { + return Err(DataFusionError::Execution( + "Predicate on two literals should be folded at Spark" + .to_string(), + )) + } + }?; + + Ok(ColumnarValue::Array(Arc::new(array))) + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.left, &self.right] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new($name::new( + children[0].clone(), + children[1].clone(), + ))) + } + } + }; +} + +make_predicate_function!(Like, like_dyn, like_utf8_scalar_dyn); + +make_predicate_function!(StartsWith, starts_with_dyn, starts_with_utf8_scalar_dyn); + +make_predicate_function!(EndsWith, ends_with_dyn, ends_with_utf8_scalar_dyn); + +make_predicate_function!(Contains, contains_dyn, contains_utf8_scalar_dyn); diff --git a/datafusion/spark/src/string_funcs/string_space.rs b/datafusion/spark/src/string_funcs/string_space.rs new file mode 100644 index 000000000000..db7092905780 --- /dev/null +++ b/datafusion/spark/src/string_funcs/string_space.rs @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![allow(deprecated)] + +use crate::kernels::strings::string_space; +use arrow::record_batch::RecordBatch; +use arrow_schema::{DataType, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::DataFusionError; +use datafusion_physical_expr::PhysicalExpr; +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::Hash, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct StringSpaceExpr { + pub child: Arc, +} + +impl Hash for StringSpaceExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + } +} + +impl PartialEq for StringSpaceExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) + } +} + +impl StringSpaceExpr { + pub fn new(child: Arc) -> Self { + Self { child } + } +} + +impl Display for StringSpaceExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "StringSpace [child: {}] ", self.child) + } +} + +impl PhysicalExpr for StringSpaceExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + match self.child.data_type(input_schema)? { + DataType::Dictionary(key_type, _) => { + Ok(DataType::Dictionary(key_type, Box::new(DataType::Utf8))) + } + _ => Ok(DataType::Utf8), + } + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let result = string_space(&array)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "StringSpace(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(StringSpaceExpr::new(Arc::clone(&children[0])))) + } +} diff --git a/datafusion/spark/src/string_funcs/substring.rs b/datafusion/spark/src/string_funcs/substring.rs new file mode 100644 index 000000000000..c38001160be0 --- /dev/null +++ b/datafusion/spark/src/string_funcs/substring.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![allow(deprecated)] + +use crate::kernels::strings::substring; +use arrow::record_batch::RecordBatch; +use arrow_schema::{DataType, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::DataFusionError; +use datafusion_physical_expr::PhysicalExpr; +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::Hash, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct SubstringExpr { + pub child: Arc, + pub start: i64, + pub len: u64, +} + +impl Hash for SubstringExpr { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.start.hash(state); + self.len.hash(state); + } +} + +impl PartialEq for SubstringExpr { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) + && self.start.eq(&other.start) + && self.len.eq(&other.len) + } +} + +impl SubstringExpr { + pub fn new(child: Arc, start: i64, len: u64) -> Self { + Self { child, start, len } + } +} + +impl Display for SubstringExpr { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "StringSpace [start: {}, len: {}, child: {}]", + self.start, self.len, self.child + ) + } +} + +impl PhysicalExpr for SubstringExpr { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> datafusion_common::Result { + self.child.data_type(input_schema) + } + + fn nullable(&self, _: &Schema) -> datafusion_common::Result { + Ok(true) + } + + fn evaluate(&self, batch: &RecordBatch) -> datafusion_common::Result { + let arg = self.child.evaluate(batch)?; + match arg { + ColumnarValue::Array(array) => { + let result = substring(&array, self.start, self.len)?; + + Ok(ColumnarValue::Array(result)) + } + _ => Err(DataFusionError::Execution( + "Substring(scalar) should be fold in Spark JVM side.".to_string(), + )), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(SubstringExpr::new( + Arc::clone(&children[0]), + self.start, + self.len, + ))) + } +} diff --git a/datafusion/spark/src/struct_funcs/create_named_struct.rs b/datafusion/spark/src/struct_funcs/create_named_struct.rs new file mode 100644 index 000000000000..3212104133eb --- /dev/null +++ b/datafusion/spark/src/struct_funcs/create_named_struct.rs @@ -0,0 +1,156 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::record_batch::RecordBatch; +use arrow_array::StructArray; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::Result as DataFusionResult; +use datafusion_physical_expr::PhysicalExpr; +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::Hash, + sync::Arc, +}; + +#[derive(Debug, Hash, PartialEq, Eq)] +pub struct CreateNamedStruct { + values: Vec>, + names: Vec, +} + +impl CreateNamedStruct { + pub fn new(values: Vec>, names: Vec) -> Self { + Self { values, names } + } + + fn fields(&self, schema: &Schema) -> DataFusionResult> { + self.values + .iter() + .zip(&self.names) + .map(|(expr, name)| { + let data_type = expr.data_type(schema)?; + let nullable = expr.nullable(schema)?; + Ok(Field::new(name, data_type, nullable)) + }) + .collect() + } +} + +impl PhysicalExpr for CreateNamedStruct { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> DataFusionResult { + let fields = self.fields(input_schema)?; + Ok(DataType::Struct(fields.into())) + } + + fn nullable(&self, _input_schema: &Schema) -> DataFusionResult { + Ok(false) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let values = self + .values + .iter() + .map(|expr| expr.evaluate(batch)) + .collect::>>()?; + let arrays = ColumnarValue::values_to_arrays(&values)?; + let fields = self.fields(&batch.schema())?; + Ok(ColumnarValue::Array(Arc::new(StructArray::new( + fields.into(), + arrays, + None, + )))) + } + + fn children(&self) -> Vec<&Arc> { + self.values.iter().collect() + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(CreateNamedStruct::new( + children.clone(), + self.names.clone(), + ))) + } +} + +impl Display for CreateNamedStruct { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "CreateNamedStruct [values: {:?}, names: {:?}]", + self.values, self.names + ) + } +} + +#[cfg(test)] +mod test { + use super::CreateNamedStruct; + use arrow_array::{Array, DictionaryArray, Int32Array, RecordBatch, StringArray}; + use arrow_schema::{DataType, Field, Schema}; + use datafusion_common::Result; + use datafusion_expr::ColumnarValue; + use datafusion_physical_expr::expressions::Column; + use datafusion_physical_expr::PhysicalExpr; + use std::sync::Arc; + + #[test] + fn test_create_struct_from_dict_encoded_i32() -> Result<()> { + let keys = Int32Array::from(vec![0, 1, 2]); + let values = Int32Array::from(vec![0, 111, 233]); + let dict = DictionaryArray::try_new(keys, Arc::new(values))?; + let data_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Int32)); + let schema = Schema::new(vec![Field::new("a", data_type, false)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?; + let field_names = vec!["a".to_string()]; + let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], field_names); + let ColumnarValue::Array(x) = x.evaluate(&batch)? else { + unreachable!() + }; + assert_eq!(3, x.len()); + Ok(()) + } + + #[test] + fn test_create_struct_from_dict_encoded_string() -> Result<()> { + let keys = Int32Array::from(vec![0, 1, 2]); + let values = + StringArray::from(vec!["a".to_string(), "b".to_string(), "c".to_string()]); + let dict = DictionaryArray::try_new(keys, Arc::new(values))?; + let data_type = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(DataType::Utf8)); + let schema = Schema::new(vec![Field::new("a", data_type, false)]); + let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(dict)])?; + let field_names = vec!["a".to_string()]; + let x = CreateNamedStruct::new(vec![Arc::new(Column::new("a", 0))], field_names); + let ColumnarValue::Array(x) = x.evaluate(&batch)? else { + unreachable!() + }; + assert_eq!(3, x.len()); + Ok(()) + } +} diff --git a/datafusion/spark/src/struct_funcs/get_struct_field.rs b/datafusion/spark/src/struct_funcs/get_struct_field.rs new file mode 100644 index 000000000000..966419ee45e3 --- /dev/null +++ b/datafusion/spark/src/struct_funcs/get_struct_field.rs @@ -0,0 +1,125 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::record_batch::RecordBatch; +use arrow_array::{Array, StructArray}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion::logical_expr::ColumnarValue; +use datafusion_common::{DataFusionError, Result as DataFusionResult, ScalarValue}; +use datafusion_physical_expr::PhysicalExpr; +use std::{ + any::Any, + fmt::{Display, Formatter}, + hash::Hash, + sync::Arc, +}; + +#[derive(Debug, Eq)] +pub struct GetStructField { + child: Arc, + ordinal: usize, +} + +impl Hash for GetStructField { + fn hash(&self, state: &mut H) { + self.child.hash(state); + self.ordinal.hash(state); + } +} +impl PartialEq for GetStructField { + fn eq(&self, other: &Self) -> bool { + self.child.eq(&other.child) && self.ordinal.eq(&other.ordinal) + } +} + +impl GetStructField { + pub fn new(child: Arc, ordinal: usize) -> Self { + Self { child, ordinal } + } + + fn child_field(&self, input_schema: &Schema) -> DataFusionResult> { + match self.child.data_type(input_schema)? { + DataType::Struct(fields) => Ok(Arc::clone(&fields[self.ordinal])), + data_type => Err(DataFusionError::Plan(format!( + "Expect struct field, got {:?}", + data_type + ))), + } + } +} + +impl PhysicalExpr for GetStructField { + fn as_any(&self) -> &dyn Any { + self + } + + fn data_type(&self, input_schema: &Schema) -> DataFusionResult { + Ok(self.child_field(input_schema)?.data_type().clone()) + } + + fn nullable(&self, input_schema: &Schema) -> DataFusionResult { + Ok(self.child_field(input_schema)?.is_nullable()) + } + + fn evaluate(&self, batch: &RecordBatch) -> DataFusionResult { + let child_value = self.child.evaluate(batch)?; + + match child_value { + ColumnarValue::Array(array) => { + let struct_array = array + .as_any() + .downcast_ref::() + .expect("A struct is expected"); + + Ok(ColumnarValue::Array(Arc::clone( + struct_array.column(self.ordinal), + ))) + } + ColumnarValue::Scalar(ScalarValue::Struct(struct_array)) => Ok( + ColumnarValue::Array(Arc::clone(struct_array.column(self.ordinal))), + ), + value => Err(DataFusionError::Execution(format!( + "Expected a struct array, got {:?}", + value + ))), + } + } + + fn children(&self) -> Vec<&Arc> { + vec![&self.child] + } + + fn with_new_children( + self: Arc, + children: Vec>, + ) -> datafusion_common::Result> { + Ok(Arc::new(GetStructField::new( + Arc::clone(&children[0]), + self.ordinal, + ))) + } +} + +impl Display for GetStructField { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!( + f, + "GetStructField [child: {:?}, ordinal: {:?}]", + self.child, self.ordinal + ) + } +} diff --git a/datafusion/spark/src/struct_funcs/mod.rs b/datafusion/spark/src/struct_funcs/mod.rs new file mode 100644 index 000000000000..86edcceac918 --- /dev/null +++ b/datafusion/spark/src/struct_funcs/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod create_named_struct; +mod get_struct_field; + +pub use create_named_struct::CreateNamedStruct; +pub use get_struct_field::GetStructField; diff --git a/datafusion/spark/src/test_common/file_util.rs b/datafusion/spark/src/test_common/file_util.rs new file mode 100644 index 000000000000..78e42d29e643 --- /dev/null +++ b/datafusion/spark/src/test_common/file_util.rs @@ -0,0 +1,53 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{env, fs, io::Write, path::PathBuf}; + +/// Returns file handle for a temp file in 'target' directory with a provided content +pub fn get_temp_file(file_name: &str, content: &[u8]) -> fs::File { + // build tmp path to a file in "target/debug/testdata" + let mut path_buf = env::current_dir().unwrap(); + path_buf.push("target"); + path_buf.push("debug"); + path_buf.push("testdata"); + fs::create_dir_all(&path_buf).unwrap(); + path_buf.push(file_name); + + // write file content + let mut tmp_file = fs::File::create(path_buf.as_path()).unwrap(); + tmp_file.write_all(content).unwrap(); + tmp_file.sync_all().unwrap(); + + // return file handle for both read and write + let file = fs::OpenOptions::new() + .read(true) + .write(true) + .open(path_buf.as_path()); + assert!(file.is_ok()); + file.unwrap() +} + +pub fn get_temp_filename() -> PathBuf { + let mut path_buf = env::current_dir().unwrap(); + path_buf.push("target"); + path_buf.push("debug"); + path_buf.push("testdata"); + fs::create_dir_all(&path_buf).unwrap(); + path_buf.push(rand::random::().to_string()); + + path_buf +} diff --git a/datafusion/spark/src/test_common/mod.rs b/datafusion/spark/src/test_common/mod.rs new file mode 100644 index 000000000000..f2edb8035e03 --- /dev/null +++ b/datafusion/spark/src/test_common/mod.rs @@ -0,0 +1,18 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +pub mod file_util; diff --git a/datafusion/spark/src/timezone.rs b/datafusion/spark/src/timezone.rs new file mode 100644 index 000000000000..59bcb13a3022 --- /dev/null +++ b/datafusion/spark/src/timezone.rs @@ -0,0 +1,146 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// Utils for timezone. This is basically from arrow-array::timezone (private). +use arrow_schema::ArrowError; +use chrono::{ + format::{parse, Parsed, StrftimeItems}, + offset::TimeZone, + FixedOffset, LocalResult, NaiveDate, NaiveDateTime, Offset, +}; +use std::str::FromStr; + +/// Parses a fixed offset of the form "+09:00" +fn parse_fixed_offset(tz: &str) -> Result { + let mut parsed = Parsed::new(); + + if let Ok(fixed_offset) = parse(&mut parsed, tz, StrftimeItems::new("%:z")) + .and_then(|_| parsed.to_fixed_offset()) + { + return Ok(fixed_offset); + } + + if let Ok(fixed_offset) = parse(&mut parsed, tz, StrftimeItems::new("%#z")) + .and_then(|_| parsed.to_fixed_offset()) + { + return Ok(fixed_offset); + } + + Err(ArrowError::ParseError(format!( + "Invalid timezone \"{}\": Expected format [+-]XX:XX, [+-]XX, or [+-]XXXX", + tz + ))) +} + +/// An [`Offset`] for [`Tz`] +#[derive(Debug, Copy, Clone)] +pub struct TzOffset { + tz: Tz, + offset: FixedOffset, +} + +impl std::fmt::Display for TzOffset { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + self.offset.fmt(f) + } +} + +impl Offset for TzOffset { + fn fix(&self) -> FixedOffset { + self.offset + } +} + +/// An Arrow [`TimeZone`] +#[derive(Debug, Copy, Clone)] +pub struct Tz(TzInner); + +#[derive(Debug, Copy, Clone)] +enum TzInner { + Timezone(chrono_tz::Tz), + Offset(FixedOffset), +} + +impl FromStr for Tz { + type Err = ArrowError; + + fn from_str(tz: &str) -> Result { + if tz.starts_with('+') || tz.starts_with('-') { + Ok(Self(TzInner::Offset(parse_fixed_offset(tz)?))) + } else { + Ok(Self(TzInner::Timezone(tz.parse().map_err(|e| { + ArrowError::ParseError(format!("Invalid timezone \"{}\": {}", tz, e)) + })?))) + } + } +} + +macro_rules! tz { + ($s:ident, $tz:ident, $b:block) => { + match $s.0 { + TzInner::Timezone($tz) => $b, + TzInner::Offset($tz) => $b, + } + }; +} + +impl TimeZone for Tz { + type Offset = TzOffset; + + fn from_offset(offset: &Self::Offset) -> Self { + offset.tz + } + + fn offset_from_local_date(&self, local: &NaiveDate) -> LocalResult { + tz!(self, tz, { + tz.offset_from_local_date(local).map(|x| TzOffset { + tz: *self, + offset: x.fix(), + }) + }) + } + + fn offset_from_local_datetime( + &self, + local: &NaiveDateTime, + ) -> LocalResult { + tz!(self, tz, { + tz.offset_from_local_datetime(local).map(|x| TzOffset { + tz: *self, + offset: x.fix(), + }) + }) + } + + fn offset_from_utc_date(&self, utc: &NaiveDate) -> Self::Offset { + tz!(self, tz, { + TzOffset { + tz: *self, + offset: tz.offset_from_utc_date(utc).fix(), + } + }) + } + + fn offset_from_utc_datetime(&self, utc: &NaiveDateTime) -> Self::Offset { + tz!(self, tz, { + TzOffset { + tz: *self, + offset: tz.offset_from_utc_datetime(utc).fix(), + } + }) + } +} diff --git a/datafusion/spark/src/unbound.rs b/datafusion/spark/src/unbound.rs new file mode 100644 index 000000000000..14f68c9cd6fb --- /dev/null +++ b/datafusion/spark/src/unbound.rs @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::RecordBatch; +use arrow_schema::{DataType, Schema}; +use datafusion::physical_plan::ColumnarValue; +use datafusion_common::{internal_err, Result}; +use datafusion_physical_expr::PhysicalExpr; +use std::{hash::Hash, sync::Arc}; + +/// This is similar to `UnKnownColumn` in DataFusion, but it has data type. +/// This is only used when the column is not bound to a schema, for example, the +/// inputs to aggregation functions in final aggregation. In the case, we cannot +/// bind the aggregation functions to the input schema which is grouping columns +/// and aggregate buffer attributes in Spark (DataFusion has different design). +/// But when creating certain aggregation functions, we need to know its input +/// data types. As `UnKnownColumn` doesn't have data type, we implement this +/// `UnboundColumn` to carry the data type. +#[derive(Debug, Hash, PartialEq, Eq, Clone)] +pub struct UnboundColumn { + name: String, + datatype: DataType, +} + +impl UnboundColumn { + /// Create a new unbound column expression + pub fn new(name: &str, datatype: DataType) -> Self { + Self { + name: name.to_owned(), + datatype, + } + } + + /// Get the column name + pub fn name(&self) -> &str { + &self.name + } +} + +impl std::fmt::Display for UnboundColumn { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}, datatype: {}", self.name, self.datatype) + } +} + +impl PhysicalExpr for UnboundColumn { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn std::any::Any { + self + } + + /// Get the data type of this expression, given the schema of the input + fn data_type(&self, _input_schema: &Schema) -> Result { + Ok(self.datatype.clone()) + } + + /// Decide whether this expression is nullable, given the schema of the input + fn nullable(&self, _input_schema: &Schema) -> Result { + Ok(true) + } + + /// Evaluate the expression + fn evaluate(&self, _batch: &RecordBatch) -> Result { + internal_err!("UnboundColumn::evaluate() should not be called") + } + + fn children(&self) -> Vec<&Arc> { + vec![] + } + + fn with_new_children( + self: Arc, + _children: Vec>, + ) -> Result> { + Ok(self) + } +} diff --git a/datafusion/spark/src/utils.rs b/datafusion/spark/src/utils.rs new file mode 100644 index 000000000000..37d633e52549 --- /dev/null +++ b/datafusion/spark/src/utils.rs @@ -0,0 +1,260 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::{ + cast::as_primitive_array, + types::{Int32Type, TimestampMicrosecondType}, +}; +use arrow_schema::{ArrowError, DataType, TimeUnit, DECIMAL128_MAX_PRECISION}; +use std::sync::Arc; + +use crate::timezone::Tz; +use arrow::{ + array::{as_dictionary_array, Array, ArrayRef, PrimitiveArray}, + temporal_conversions::as_datetime, +}; +use arrow_array::types::TimestampMillisecondType; +use arrow_data::decimal::{ + MAX_DECIMAL_FOR_EACH_PRECISION, MIN_DECIMAL_FOR_EACH_PRECISION, +}; +use chrono::{DateTime, Offset, TimeZone}; + +/// Preprocesses input arrays to add timezone information from Spark to Arrow array datatype or +/// to apply timezone offset. +// +// We consider the following cases: +// +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Conversion | Input array | Timezone | Output array | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp -> | Array in UTC | Timezone of input | A timestamp with the timezone | +// | Utf8 or Date32 | | | offset applied and timezone | +// | | | | removed | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp -> | Array in UTC | Timezone of input | Same as input array | +// | Timestamp w/Timezone| | | | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp_ntz -> | Array in | Timezone of input | Same as input array | +// | Utf8 or Date32 | timezone | | | +// | | session local| | | +// | | timezone | | | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp_ntz -> | Array in | Timezone of input | Array in UTC and timezone | +// | Timestamp w/Timezone | session local| | specified in input | +// | | timezone | | | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// | Timestamp(_ntz) -> | | +// | Any other type | Not Supported | +// | --------------------- | ------------ | ----------------- | -------------------------------- | +// +pub fn array_with_timezone( + array: ArrayRef, + timezone: String, + to_type: Option<&DataType>, +) -> Result { + match array.data_type() { + DataType::Timestamp(_, None) => { + assert!(!timezone.is_empty()); + match to_type { + Some(DataType::Utf8) | Some(DataType::Date32) => Ok(array), + Some(DataType::Timestamp(_, Some(_))) => timestamp_ntz_to_timestamp( + array, + timezone.as_str(), + Some(timezone.as_str()), + ), + Some(DataType::Timestamp(_, None)) => { + timestamp_ntz_to_timestamp(array, timezone.as_str(), None) + } + _ => { + // Not supported + panic!( + "Cannot convert from {:?} to {:?}", + array.data_type(), + to_type.unwrap() + ) + } + } + } + DataType::Timestamp(TimeUnit::Microsecond, Some(_)) => { + assert!(!timezone.is_empty()); + let array = as_primitive_array::(&array); + let array_with_timezone = array.clone().with_timezone(timezone.clone()); + let array = Arc::new(array_with_timezone) as ArrayRef; + match to_type { + Some(DataType::Utf8) | Some(DataType::Date32) => { + pre_timestamp_cast(array, timezone) + } + _ => Ok(array), + } + } + DataType::Timestamp(TimeUnit::Millisecond, Some(_)) => { + assert!(!timezone.is_empty()); + let array = as_primitive_array::(&array); + let array_with_timezone = array.clone().with_timezone(timezone.clone()); + let array = Arc::new(array_with_timezone) as ArrayRef; + match to_type { + Some(DataType::Utf8) | Some(DataType::Date32) => { + pre_timestamp_cast(array, timezone) + } + _ => Ok(array), + } + } + DataType::Dictionary(_, value_type) + if matches!(value_type.as_ref(), &DataType::Timestamp(_, _)) => + { + let dict = as_dictionary_array::(&array); + let array = as_primitive_array::(dict.values()); + let array_with_timezone = array_with_timezone( + Arc::new(array.clone()) as ArrayRef, + timezone, + to_type, + )?; + let dict = dict.with_values(array_with_timezone); + Ok(Arc::new(dict)) + } + _ => Ok(array), + } +} + +fn datetime_cast_err(value: i64) -> ArrowError { + ArrowError::CastError(format!( + "Cannot convert TimestampMicrosecondType {value} to datetime. Comet only supports dates between Jan 1, 262145 BCE and Dec 31, 262143 CE", + )) +} + +/// Takes in a Timestamp(Microsecond, None) array and a timezone id, and returns +/// a Timestamp(Microsecond, Some<_>) array. +/// The understanding is that the input array has time in the timezone specified in the second +/// argument. +/// Parameters: +/// array - input array of timestamp without timezone +/// tz - timezone of the values in the input array +/// to_timezone - timezone to change the input values to +fn timestamp_ntz_to_timestamp( + array: ArrayRef, + tz: &str, + to_timezone: Option<&str>, +) -> Result { + assert!(!tz.is_empty()); + match array.data_type() { + DataType::Timestamp(TimeUnit::Microsecond, None) => { + let array = as_primitive_array::(&array); + let tz: Tz = tz.parse()?; + let array: PrimitiveArray = + array.try_unary(|value| { + as_datetime::(value) + .ok_or_else(|| datetime_cast_err(value)) + .map(|local_datetime| { + let datetime: DateTime = + tz.from_local_datetime(&local_datetime).unwrap(); + datetime.timestamp_micros() + }) + })?; + let array_with_tz = if let Some(to_tz) = to_timezone { + array.with_timezone(to_tz) + } else { + array + }; + Ok(Arc::new(array_with_tz)) + } + DataType::Timestamp(TimeUnit::Millisecond, None) => { + let array = as_primitive_array::(&array); + let tz: Tz = tz.parse()?; + let array: PrimitiveArray = + array.try_unary(|value| { + as_datetime::(value) + .ok_or_else(|| datetime_cast_err(value)) + .map(|local_datetime| { + let datetime: DateTime = + tz.from_local_datetime(&local_datetime).unwrap(); + datetime.timestamp_millis() + }) + })?; + let array_with_tz = if let Some(to_tz) = to_timezone { + array.with_timezone(to_tz) + } else { + array + }; + Ok(Arc::new(array_with_tz)) + } + _ => Ok(array), + } +} + +/// This takes for special pre-casting cases of Spark. E.g., Timestamp to String. +fn pre_timestamp_cast(array: ArrayRef, timezone: String) -> Result { + assert!(!timezone.is_empty()); + match array.data_type() { + DataType::Timestamp(_, _) => { + // Spark doesn't output timezone while casting timestamp to string, but arrow's cast + // kernel does if timezone exists. So we need to apply offset of timezone to array + // timestamp value and remove timezone from array datatype. + let array = as_primitive_array::(&array); + + let tz: Tz = timezone.parse()?; + let array: PrimitiveArray = + array.try_unary(|value| { + as_datetime::(value) + .ok_or_else(|| datetime_cast_err(value)) + .map(|datetime| { + let offset = tz.offset_from_utc_datetime(&datetime).fix(); + let datetime = datetime + offset; + datetime.and_utc().timestamp_micros() + }) + })?; + + Ok(Arc::new(array)) + } + _ => Ok(array), + } +} + +/// Adapted from arrow-rs `validate_decimal_precision` but returns bool +/// instead of Err to avoid the cost of formatting the error strings and is +/// optimized to remove a memcpy that exists in the original function +/// we can remove this code once we upgrade to a version of arrow-rs that +/// includes +#[inline] +pub fn is_valid_decimal_precision(value: i128, precision: u8) -> bool { + precision <= DECIMAL128_MAX_PRECISION + && value >= MIN_DECIMAL_FOR_EACH_PRECISION[precision as usize - 1] + && value <= MAX_DECIMAL_FOR_EACH_PRECISION[precision as usize - 1] +} + +// These are borrowed from hashbrown crate: +// https://github.com/rust-lang/hashbrown/blob/master/src/raw/mod.rs + +// On stable we can use #[cold] to get a equivalent effect: this attributes +// suggests that the function is unlikely to be called +#[inline] +#[cold] +pub fn cold() {} + +#[inline] +pub fn likely(b: bool) -> bool { + if !b { + cold(); + } + b +} +#[inline] +pub fn unlikely(b: bool) -> bool { + if b { + cold(); + } + b +}