From 39d1a72df99b8b4aee6916e55e79ef483dffc1cc Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 8 Apr 2025 15:22:18 -0400 Subject: [PATCH 01/25] Add in plumbing to pass around metadata for physical expressions --- .../physical_optimizer/projection_pushdown.rs | 5 +++ .../physical-expr-common/src/physical_expr.rs | 14 +++++++ .../src/equivalence/properties/dependency.rs | 3 ++ .../physical-expr/src/expressions/binary.rs | 17 ++++++-- .../physical-expr/src/expressions/case.rs | 17 ++++++-- .../physical-expr/src/expressions/cast.rs | 12 ++++++ .../physical-expr/src/expressions/column.rs | 12 ++++++ .../physical-expr/src/expressions/in_list.rs | 11 +++++ .../src/expressions/is_not_null.rs | 17 ++++++-- .../physical-expr/src/expressions/is_null.rs | 17 ++++++-- .../physical-expr/src/expressions/like.rs | 17 ++++++-- .../physical-expr/src/expressions/literal.rs | 12 ++++++ .../physical-expr/src/expressions/negative.rs | 12 ++++++ .../physical-expr/src/expressions/no_op.rs | 12 ++++++ .../physical-expr/src/expressions/not.rs | 12 ++++++ .../physical-expr/src/expressions/try_cast.rs | 12 ++++++ .../src/expressions/unknown_column.rs | 12 ++++++ .../physical-expr/src/scalar_function.rs | 40 ++++++++++++++++++- .../proto/src/physical_plan/from_proto.rs | 1 + .../tests/cases/roundtrip_physical_plan.rs | 15 +++++++ 20 files changed, 254 insertions(+), 16 deletions(-) diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index 911d2c0cee05..f018a75f657f 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -16,6 +16,7 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::sync::Arc; use arrow::compute::SortOptions; @@ -129,6 +130,7 @@ fn test_update_matching_exprs() -> Result<()> { )), ], DataType::Int32, + HashMap::default(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -194,6 +196,7 @@ fn test_update_matching_exprs() -> Result<()> { )), ], DataType::Int32, + HashMap::default(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 3))), @@ -262,6 +265,7 @@ fn test_update_projected_exprs() -> Result<()> { )), ], DataType::Int32, + HashMap::default(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -327,6 +331,7 @@ fn test_update_projected_exprs() -> Result<()> { )), ], DataType::Int32, + HashMap::default(), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d_new", 3))), diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 3bc41d2652d9..4ad83553fbcc 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -16,6 +16,7 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::fmt; use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; @@ -76,6 +77,17 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { fn nullable(&self, input_schema: &Schema) -> Result; /// Evaluate an expression against a RecordBatch fn evaluate(&self, batch: &RecordBatch) -> Result; + /// Determine if this expression has any associated field metadata in the schema + /// In some circumstances we will get the metadata from the schema, and sometimes + /// we will get it from the physical expression itself. The lifetime of the result + /// must outlive both. + fn metadata<'a, 'b, 'c>( + &'a self, + input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c; /// Evaluate an expression against a RecordBatch after first applying a /// validity array fn evaluate_selection( @@ -453,6 +465,7 @@ where /// ``` /// # // The boiler plate needed to create a `PhysicalExpr` for the example /// # use std::any::Any; +/// use std::collections::HashMap; /// # use std::fmt::Formatter; /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; @@ -466,6 +479,7 @@ where /// # fn data_type(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn nullable(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result { unimplemented!() } +/// # fn metadata<'a>(&self, input_schema: &'a Schema) -> Result<&'a HashMap> { unimplemented!() } /// # fn children(&self) -> Vec<&Arc>{ unimplemented!() } /// # fn with_new_children(self: Arc, children: Vec>) -> Result> { unimplemented!() } /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE a > b THEN 1 ELSE 0 END") } diff --git a/datafusion/physical-expr/src/equivalence/properties/dependency.rs b/datafusion/physical-expr/src/equivalence/properties/dependency.rs index 9eba295e562e..feab0d1f2d3e 100644 --- a/datafusion/physical-expr/src/equivalence/properties/dependency.rs +++ b/datafusion/physical-expr/src/equivalence/properties/dependency.rs @@ -424,6 +424,7 @@ pub fn generate_dependency_orderings( #[cfg(test)] mod tests { + use std::collections::HashMap; use std::ops::Not; use std::sync::Arc; @@ -1225,6 +1226,7 @@ mod tests { concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], DataType::Utf8, + HashMap::default(), )); // Assume existing ordering is [c ASC, a ASC, b ASC] @@ -1316,6 +1318,7 @@ mod tests { concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], DataType::Utf8, + HashMap::default(), )); // Assume existing ordering is [concat(a, b) ASC, a ASC, b ASC] diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 84374f4a2970..80ef2c22b79d 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -17,12 +17,12 @@ mod kernels; -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::expressions::binary::kernels::concat_elements_utf8view; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::PhysicalExpr; +use std::collections::HashMap; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; use arrow::array::*; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; @@ -433,6 +433,17 @@ impl PhysicalExpr for BinaryExpr { .map(ColumnarValue::Array) } + fn metadata<'a, 'b, 'c>( + &'a self, + _input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + Ok(None) + } + fn children(&self) -> Vec<&Arc> { vec![&self.left, &self.right] } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 854c715eb0a2..77df48280e6d 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -15,13 +15,13 @@ // specific language governing permissions and limitations // under the License. +use crate::expressions::try_cast; +use crate::PhysicalExpr; use std::borrow::Cow; +use std::collections::HashMap; use std::hash::Hash; use std::{any::Any, sync::Arc}; -use crate::expressions::try_cast; -use crate::PhysicalExpr; - use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; @@ -514,6 +514,17 @@ impl PhysicalExpr for CaseExpr { } } + fn metadata<'a, 'b, 'c>( + &'a self, + _input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + Ok(None) + } + fn children(&self) -> Vec<&Arc> { let mut children = vec![]; if let Some(expr) = &self.expr { diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index a6766687a881..508318b1b21a 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -16,6 +16,7 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::fmt; use std::hash::Hash; use std::sync::Arc; @@ -144,6 +145,17 @@ impl PhysicalExpr for CastExpr { value.cast_to(&self.cast_type, Some(&self.cast_options)) } + fn metadata<'a, 'b, 'c>( + &'a self, + input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + self.expr.metadata(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.expr] } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index ab5b35984753..8f2e015f734e 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -18,6 +18,7 @@ //! Physical column reference: [`Column`] use std::any::Any; +use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; @@ -127,6 +128,17 @@ impl PhysicalExpr for Column { Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } + fn metadata<'a, 'b, 'c>( + &'a self, + input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + Ok(Some(input_schema.field(self.index).metadata())) + } + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 469f7bbee317..3d5548aa2e83 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -379,6 +379,17 @@ impl PhysicalExpr for InListExpr { Ok(ColumnarValue::Array(Arc::new(r))) } + fn metadata<'a, 'b, 'c>( + &'a self, + _input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + Ok(None) + } + fn children(&self) -> Vec<&Arc> { let mut children = vec![]; children.push(&self.expr); diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 0619e7248858..4df2db47f2ac 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -17,9 +17,6 @@ //! IS NOT NULL expression -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; use arrow::{ datatypes::{DataType, Schema}, @@ -28,6 +25,9 @@ use arrow::{ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; +use std::collections::HashMap; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; /// IS NOT NULL expression #[derive(Debug, Eq)] @@ -94,6 +94,17 @@ impl PhysicalExpr for IsNotNullExpr { } } + fn metadata<'a, 'b, 'c>( + &'a self, + input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + self.arg.metadata(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 4c6081f35cad..bdc14e3286b2 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -17,9 +17,6 @@ //! IS NULL expression -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; use arrow::{ datatypes::{DataType, Schema}, @@ -28,6 +25,9 @@ use arrow::{ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; +use std::collections::HashMap; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; /// IS NULL expression #[derive(Debug, Eq)] @@ -93,6 +93,17 @@ impl PhysicalExpr for IsNullExpr { } } + fn metadata<'a, 'b, 'c>( + &'a self, + input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + self.arg.metadata(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index ebf9882665ba..271366638659 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -15,15 +15,15 @@ // specific language governing permissions and limitations // under the License. -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::apply_cmp; +use std::collections::HashMap; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; // Like expression #[derive(Debug, Eq)] @@ -130,6 +130,17 @@ impl PhysicalExpr for LikeExpr { } } + fn metadata<'a, 'b, 'c>( + &'a self, + _input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + Ok(None) + } + fn children(&self) -> Vec<&Arc> { vec![&self.expr, &self.pattern] } diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 0d0c0ecc62c7..c1bca4985646 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -18,6 +18,7 @@ //! Literal expressions for physical operations use std::any::Any; +use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; @@ -75,6 +76,17 @@ impl PhysicalExpr for Literal { Ok(ColumnarValue::Scalar(self.value.clone())) } + fn metadata<'a, 'b, 'c>( + &'a self, + _input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + Ok(None) + } + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 33a1bae14d42..c5e80aef24d3 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -18,6 +18,7 @@ //! Negation (-) expression use std::any::Any; +use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; @@ -103,6 +104,17 @@ impl PhysicalExpr for NegativeExpr { } } + fn metadata<'a, 'b, 'c>( + &'a self, + input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + self.arg.metadata(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index 24d2f4d9e074..d432c215c8d8 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -18,6 +18,7 @@ //! NoOp placeholder for physical operations use std::any::Any; +use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; @@ -67,6 +68,17 @@ impl PhysicalExpr for NoOp { internal_err!("NoOp::evaluate() should not be called") } + fn metadata<'a, 'b, 'c>( + &'a self, + _input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + Ok(None) + } + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 8a3348b43d20..e136090407dc 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -18,6 +18,7 @@ //! Not expression use std::any::Any; +use std::collections::HashMap; use std::fmt; use std::hash::Hash; use std::sync::Arc; @@ -101,6 +102,17 @@ impl PhysicalExpr for NotExpr { } } + fn metadata<'a, 'b, 'c>( + &'a self, + input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + self.arg.metadata(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index e49815cd8b64..07740ab797dd 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -16,6 +16,7 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::fmt; use std::hash::Hash; use std::sync::Arc; @@ -110,6 +111,17 @@ impl PhysicalExpr for TryCastExpr { } } + fn metadata<'a, 'b, 'c>( + &'a self, + input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + self.expr.metadata(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.expr] } diff --git a/datafusion/physical-expr/src/expressions/unknown_column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs index 2face4eb6bdb..839a5cd6fca8 100644 --- a/datafusion/physical-expr/src/expressions/unknown_column.rs +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -18,6 +18,7 @@ //! UnKnownColumn expression use std::any::Any; +use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::sync::Arc; @@ -76,6 +77,17 @@ impl PhysicalExpr for UnKnownColumn { internal_err!("UnKnownColumn::evaluate() should not be called") } + fn metadata<'a, 'b, 'c>( + &'a self, + _input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + Ok(None) + } + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 44bbcc4928c6..b222ce3df1fc 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -30,6 +30,7 @@ //! to a function that supports f64, it is coerced to f64. use std::any::Any; +use std::collections::HashMap; use std::fmt::{self, Debug, Formatter}; use std::hash::Hash; use std::sync::Arc; @@ -48,13 +49,35 @@ use datafusion_expr::{ }; /// Physical expression of a scalar function -#[derive(Eq, PartialEq, Hash)] +#[derive(Eq, PartialEq)] pub struct ScalarFunctionExpr { fun: Arc, name: String, args: Vec>, return_type: DataType, nullable: bool, + metadata: HashMap, +} + +impl Hash for ScalarFunctionExpr { + fn hash(&self, state: &mut H) { + // Sort keys for deterministic hashing + let mut keys: Vec<&String> = self.metadata.keys().collect(); + keys.sort(); + + for key in keys { + key.hash(state); + if let Some(value) = self.metadata.get(key) { + value.hash(state); + } + } + + self.fun.hash(state); + self.name.hash(state); + self.args.hash(state); + self.return_type.hash(state); + self.nullable.hash(state); + } } impl Debug for ScalarFunctionExpr { @@ -75,6 +98,7 @@ impl ScalarFunctionExpr { fun: Arc, args: Vec>, return_type: DataType, + metadata: HashMap, ) -> Self { Self { fun, @@ -82,6 +106,7 @@ impl ScalarFunctionExpr { args, return_type, nullable: true, + metadata, } } @@ -125,6 +150,7 @@ impl ScalarFunctionExpr { args, return_type, nullable, + metadata: HashMap::new(), }) } @@ -214,6 +240,17 @@ impl PhysicalExpr for ScalarFunctionExpr { Ok(output) } + fn metadata<'a, 'b, 'c>( + &'a self, + _input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + Ok(Some(&self.metadata)) + } + fn children(&self) -> Vec<&Arc> { self.args.iter().collect() } @@ -228,6 +265,7 @@ impl PhysicalExpr for ScalarFunctionExpr { Arc::clone(&self.fun), children, self.return_type().clone(), + self.metadata.clone(), ) .with_nullable(self.nullable), )) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index a886fc242545..ea824f0b035a 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -366,6 +366,7 @@ pub fn parse_physical_expr( scalar_fun_def, args, convert_required!(e.return_type)?, + std::collections::hash_map::HashMap::new(), ) .with_nullable(e.nullable), ) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index be90497a6e21..4b9efb92656b 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -16,6 +16,7 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::fmt::{Display, Formatter}; use std::ops::Deref; use std::sync::Arc; @@ -864,6 +865,17 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { unreachable!() } + fn metadata<'a, 'b, 'c>( + &'a self, + _input_schema: &'b Schema, + ) -> Result>> + where + 'a: 'c, + 'b: 'c, + { + Ok(None) + } + fn children(&self) -> Vec<&Arc> { vec![&self.inner] } @@ -969,6 +981,7 @@ fn roundtrip_scalar_udf() -> Result<()> { fun_def, vec![col("a", &schema)?], DataType::Int64, + HashMap::default(), ); let project = @@ -1097,6 +1110,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], DataType::Int64, + HashMap::default(), )); let filter = Arc::new(FilterExec::try_new( @@ -1199,6 +1213,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], DataType::Int64, + HashMap::default(), )); let udaf = Arc::new(AggregateUDF::from(MyAggregateUDF::new( From 9bcd4b5898c2ea928e671a72ee3458269a8b651c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 8 Apr 2025 16:12:53 -0400 Subject: [PATCH 02/25] Adding argument metadata to scalar argument struct --- datafusion/expr/src/udf.rs | 5 +- datafusion/ffi/src/udf/mod.rs | 53 +++++++++++++++++-- .../functions/benches/character_length.rs | 4 ++ datafusion/functions/benches/date_bin.rs | 1 + .../functions/src/core/union_extract.rs | 3 ++ datafusion/functions/src/core/version.rs | 1 + datafusion/functions/src/datetime/date_bin.rs | 15 ++++++ .../functions/src/datetime/date_trunc.rs | 2 + .../functions/src/datetime/from_unixtime.rs | 2 + .../functions/src/datetime/make_date.rs | 8 +++ datafusion/functions/src/datetime/to_char.rs | 6 +++ datafusion/functions/src/datetime/to_date.rs | 8 +++ .../functions/src/datetime/to_local_time.rs | 2 + .../functions/src/datetime/to_timestamp.rs | 2 + datafusion/functions/src/math/log.rs | 10 ++++ datafusion/functions/src/math/power.rs | 2 + datafusion/functions/src/math/signum.rs | 2 + datafusion/functions/src/regex/regexpcount.rs | 12 +++++ datafusion/functions/src/string/concat.rs | 1 + datafusion/functions/src/string/concat_ws.rs | 2 + datafusion/functions/src/string/contains.rs | 1 + datafusion/functions/src/string/lower.rs | 1 + datafusion/functions/src/string/upper.rs | 1 + .../functions/src/unicode/find_in_set.rs | 2 + datafusion/functions/src/utils.rs | 6 ++- .../physical-expr/src/scalar_function.rs | 7 +++ 26 files changed, 152 insertions(+), 7 deletions(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 9b2400774a3d..70f91f88384f 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -26,6 +26,7 @@ use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; use std::cmp::Ordering; +use std::collections::HashMap; use std::fmt::Debug; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; @@ -293,9 +294,11 @@ where /// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a /// scalar function. -pub struct ScalarFunctionArgs<'a> { +pub struct ScalarFunctionArgs<'a, 'b> { /// The evaluated arguments to the function pub args: Vec, + /// Metadata associated with each arg, if it exists + pub arg_metadata: Vec>>, /// The number of rows in record batch being evaluated pub number_rows: usize, /// The return type of the scalar function returned (from `return_type` or `return_type_from_args`) diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 706b9fabedcb..e0350db1bd1b 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -15,8 +15,7 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, sync::Arc}; - +use abi_stable::std_types::{RHashMap, ROption}; use abi_stable::{ std_types::{RResult, RString, RVec}, StableAbi, @@ -43,6 +42,8 @@ use return_info::FFI_ReturnInfo; use return_type_args::{ FFI_ReturnTypeArgs, ForeignReturnTypeArgs, ForeignReturnTypeArgsOwned, }; +use std::collections::HashMap; +use std::{ffi::c_void, sync::Arc}; use crate::{ arrow_wrappers::{WrappedArray, WrappedSchema}, @@ -88,6 +89,7 @@ pub struct FFI_ScalarUDF { pub invoke_with_args: unsafe extern "C" fn( udf: &Self, args: RVec, + arg_metadata: RVec>>, num_rows: usize, return_type: WrappedSchema, ) -> RResult, @@ -174,6 +176,7 @@ unsafe extern "C" fn coerce_types_fn_wrapper( unsafe extern "C" fn invoke_with_args_fn_wrapper( udf: &FFI_ScalarUDF, args: RVec, + arg_metadata: RVec>>, number_rows: usize, return_type: WrappedSchema, ) -> RResult { @@ -191,8 +194,27 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( let args = rresult_return!(args); let return_type = rresult_return!(DataType::try_from(&return_type.0)); + let arg_metadata_owned: Vec>> = arg_metadata + .into_iter() + .map(|maybe_map| { + maybe_map + .map(|hashmap| { + hashmap + .into_iter() + .map(|kv| (String::from(kv.0), String::from(kv.1))) + .collect::>() + }) + .into() + }) + .collect(); + let arg_metadata = arg_metadata_owned + .iter() + .map(|maybe_map| maybe_map.as_ref()) + .collect::>(); + let args = ScalarFunctionArgs { args, + arg_metadata, number_rows, return_type: &return_type, }; @@ -329,6 +351,7 @@ impl ScalarUDFImpl for ForeignScalarUDF { fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, + arg_metadata, number_rows, return_type, } = invoke_args; @@ -347,10 +370,32 @@ impl ScalarUDFImpl for ForeignScalarUDF { .collect::, ArrowError>>()? .into(); + let arg_metadata = arg_metadata + .into_iter() + .map(|maybe_map| { + maybe_map + .map(|hashmap| { + hashmap + .into_iter() + .map(|(k, v)| { + (RString::from(k.clone()), RString::from(v.clone())) + }) + .collect::>() + }) + .into() + }) + .collect::>(); + let return_type = WrappedSchema(FFI_ArrowSchema::try_from(return_type)?); let result = unsafe { - (self.udf.invoke_with_args)(&self.udf, args, number_rows, return_type) + (self.udf.invoke_with_args)( + &self.udf, + args, + arg_metadata, + number_rows, + return_type, + ) }; let result = df_result!(result)?; @@ -389,7 +434,7 @@ mod tests { let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?; - assert!(original_udf.name() == foreign_udf.name()); + assert_eq!(original_udf.name(), foreign_udf.name()); Ok(()) } diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index bbcfed021064..d4c28e4153e1 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -40,6 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_metadata: vec![None; args_string_ascii.len()], number_rows: n_rows, return_type: &return_type, })) @@ -55,6 +56,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), + arg_metadata: vec![None; args_string_utf8.len()], number_rows: n_rows, return_type: &return_type, })) @@ -70,6 +72,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_metadata: vec![None; args_string_view_ascii.len()], number_rows: n_rows, return_type: &return_type, })) @@ -85,6 +88,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_metadata: vec![None; args_string_view_utf8.len()], number_rows: n_rows, return_type: &return_type, })) diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 7ea5fdcb2be2..a30c29db2f31 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -53,6 +53,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![interval.clone(), timestamps.clone()], + arg_metadata: vec![None; 2], number_rows: batch_len, return_type: &return_type, }) diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index 420eeed42cc3..1b12b3c42518 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -198,6 +198,7 @@ mod tests { )), ColumnarValue::Scalar(ScalarValue::new_utf8("str")), ], + arg_metadata: vec![None; 2], number_rows: 1, return_type: &DataType::Utf8, })?; @@ -213,6 +214,7 @@ mod tests { )), ColumnarValue::Scalar(ScalarValue::new_utf8("str")), ], + arg_metadata: vec![None; 2], number_rows: 1, return_type: &DataType::Utf8, })?; @@ -228,6 +230,7 @@ mod tests { )), ColumnarValue::Scalar(ScalarValue::new_utf8("str")), ], + arg_metadata: vec![None; 2], number_rows: 1, return_type: &DataType::Utf8, })?; diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 34038022f2dc..4b3499eb0d2c 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -105,6 +105,7 @@ mod test { let version = version_udf .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_metadata: vec![], number_rows: 0, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 5ffae46dde48..6000a3a14f10 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -526,6 +526,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -545,6 +546,7 @@ mod tests { ColumnarValue::Array(timestamps), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_metadata: vec![None; 3], number_rows: batch_len, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -561,6 +563,7 @@ mod tests { ))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_metadata: vec![None; 2], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -580,6 +583,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -598,6 +602,7 @@ mod tests { milliseconds: 1, }, )))], + arg_metadata: vec![None; 1], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -614,6 +619,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -636,6 +642,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -655,6 +662,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -671,6 +679,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -687,6 +696,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -708,6 +718,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -728,6 +739,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -751,6 +763,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -774,6 +787,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Array(timestamps), ], + arg_metadata: vec![None; 3], number_rows: batch_len, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -902,6 +916,7 @@ mod tests { tz_opt.clone(), )), ], + arg_metadata: vec![None; 3], number_rows: batch_len, return_type: &DataType::Timestamp( TimeUnit::Nanosecond, diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index ed3eb228bf03..82d89235eb71 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -731,6 +731,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::from("day")), ColumnarValue::Array(Arc::new(input)), ], + arg_metadata: vec![None; 2], number_rows: batch_len, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), }; @@ -893,6 +894,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::from("hour")), ColumnarValue::Array(Arc::new(input)), ], + arg_metadata: vec![None; 2], number_rows: batch_len, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), }; diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index ed8181452dbd..05812ceb43a2 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -172,6 +172,7 @@ mod test { fn test_without_timezone() { let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(Int64(Some(1729900800)))], + arg_metadata: vec![None; 1], number_rows: 1, return_type: &DataType::Timestamp(Second, None), }; @@ -194,6 +195,7 @@ mod test { "America/New_York".to_string(), ))), ], + arg_metadata: vec![None; 2], number_rows: 2, return_type: &DataType::Timestamp( Second, diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 929fa601f107..9cea84a756d1 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -236,6 +236,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }; @@ -255,6 +256,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::UInt64(Some(1))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }; @@ -274,6 +276,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("1".to_string()))), ColumnarValue::Scalar(ScalarValue::Utf8(Some("14".to_string()))), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }; @@ -297,6 +300,7 @@ mod tests { ColumnarValue::Array(months), ColumnarValue::Array(days), ], + arg_metadata: vec![None; 3], number_rows: batch_len, return_type: &DataType::Date32, }; @@ -323,6 +327,7 @@ mod tests { // invalid number of arguments let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], + arg_metadata: vec![None; 1], number_rows: 1, return_type: &DataType::Date32, }; @@ -339,6 +344,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }; @@ -355,6 +361,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::UInt64(Some(u64::MAX))), ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }; @@ -371,6 +378,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(u32::MAX))), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }; diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 8b2e5ad87471..62202fab171b 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -387,6 +387,7 @@ mod tests { for (value, format, expected) in scalar_data { let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)], + arg_metadata: vec![None; 2], number_rows: 1, return_type: &DataType::Utf8, }; @@ -470,6 +471,7 @@ mod tests { ColumnarValue::Scalar(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], + arg_metadata: vec![None; 2], number_rows: batch_len, return_type: &DataType::Utf8, }; @@ -601,6 +603,7 @@ mod tests { ColumnarValue::Array(value as ArrayRef), ColumnarValue::Scalar(format), ], + arg_metadata: vec![None; 2], number_rows: batch_len, return_type: &DataType::Utf8, }; @@ -623,6 +626,7 @@ mod tests { ColumnarValue::Array(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], + arg_metadata: vec![None; 2], number_rows: batch_len, return_type: &DataType::Utf8, }; @@ -645,6 +649,7 @@ mod tests { // invalid number of arguments let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], + arg_metadata: vec![None; 1], number_rows: 1, return_type: &DataType::Utf8, }; @@ -660,6 +665,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_metadata: vec![None; 2], number_rows: 1, return_type: &DataType::Utf8, }; diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 91740b2c31c1..f058217c179f 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -210,6 +210,7 @@ mod tests { fn test_scalar(sv: ScalarValue, tc: &TestCase) { let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(sv)], + arg_metadata: vec![None; 1], number_rows: 1, return_type: &DataType::Date32, }; @@ -236,6 +237,7 @@ mod tests { let batch_len = date_array.len(); let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::new(date_array))], + arg_metadata: vec![None; 1], number_rows: batch_len, return_type: &DataType::Date32, }; @@ -333,6 +335,7 @@ mod tests { ColumnarValue::Scalar(sv), ColumnarValue::Scalar(format_scalar), ], + arg_metadata: vec![None; 2], number_rows: 1, return_type: &DataType::Date32, }; @@ -363,6 +366,7 @@ mod tests { ColumnarValue::Array(Arc::new(date_array)), ColumnarValue::Array(Arc::new(format_array)), ], + arg_metadata: vec![None; 2], number_rows: batch_len, return_type: &DataType::Date32, }; @@ -404,6 +408,7 @@ mod tests { ColumnarValue::Scalar(format1_scalar), ColumnarValue::Scalar(format2_scalar), ], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }; @@ -433,6 +438,7 @@ mod tests { let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(formatted_date_scalar)], + arg_metadata: vec![None; 1], number_rows: 1, return_type: &DataType::Date32, }; @@ -455,6 +461,7 @@ mod tests { let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(date_scalar)], + arg_metadata: vec![None; 1], number_rows: 1, return_type: &DataType::Date32, }; @@ -480,6 +487,7 @@ mod tests { let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(date_scalar)], + arg_metadata: vec![None; 1], number_rows: 1, return_type: &DataType::Date32, }; diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 8dbef90cdc3f..8eab0ad9d293 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -541,6 +541,7 @@ mod tests { let res = ToLocalTimeFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(input)], + arg_metadata: vec![None; 1], number_rows: 1, return_type: &expected.data_type(), }) @@ -604,6 +605,7 @@ mod tests { let batch_size = input.len(); let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::new(input))], + arg_metadata: vec![None; 1], number_rows: batch_size, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 52c86733f332..55d6d0944ce0 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -1015,6 +1015,7 @@ mod tests { assert!(matches!(rt, Timestamp(_, Some(_)))); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], + arg_metadata: vec![None; 1], number_rows: 4, return_type: &rt, }; @@ -1062,6 +1063,7 @@ mod tests { assert!(matches!(rt, Timestamp(_, None))); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], + arg_metadata: vec![None; 1], number_rows: 5, return_type: &rt, }; diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index fd135f4c5ec0..d666a7778037 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -271,6 +271,7 @@ mod tests { ]))), // num ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), ], + arg_metadata: vec![None; 2], number_rows: 4, return_type: &DataType::Float64, }; @@ -283,6 +284,7 @@ mod tests { args: vec![ ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num ], + arg_metadata: vec![None; 1], number_rows: 1, return_type: &DataType::Float64, }; @@ -297,6 +299,7 @@ mod tests { args: vec![ ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num ], + arg_metadata: vec![None; 1], number_rows: 1, return_type: &DataType::Float32, }; @@ -324,6 +327,7 @@ mod tests { args: vec![ ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num ], + arg_metadata: vec![None; 1], number_rows: 1, return_type: &DataType::Float64, }; @@ -352,6 +356,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num ], + arg_metadata: vec![None; 2], number_rows: 1, return_type: &DataType::Float32, }; @@ -380,6 +385,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num ], + arg_metadata: vec![None; 2], number_rows: 1, return_type: &DataType::Float64, }; @@ -409,6 +415,7 @@ mod tests { 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], + arg_metadata: vec![None; 1], number_rows: 4, return_type: &DataType::Float64, }; @@ -441,6 +448,7 @@ mod tests { 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], + arg_metadata: vec![None; 1], number_rows: 4, return_type: &DataType::Float32, }; @@ -476,6 +484,7 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], + arg_metadata: vec![None; 2], number_rows: 4, return_type: &DataType::Float64, }; @@ -511,6 +520,7 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], + arg_metadata: vec![None; 2], number_rows: 4, return_type: &DataType::Float32, }; diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 028ec2fef793..a23097040817 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -202,6 +202,7 @@ mod tests { 3.0, 2.0, 4.0, 4.0, ]))), // exponent ], + arg_metadata: vec![None; 2], number_rows: 4, return_type: &DataType::Float64, }; @@ -232,6 +233,7 @@ mod tests { ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent ], + arg_metadata: vec![None; 2], number_rows: 4, return_type: &DataType::Int64, }; diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index ba5422afa768..a2b69bb9ce65 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -159,6 +159,7 @@ mod test { ])); let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_metadata: vec![None; 1], number_rows: array.len(), return_type: &DataType::Float32, }; @@ -203,6 +204,7 @@ mod test { ])); let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_metadata: vec![None; 1], number_rows: array.len(), return_type: &DataType::Float64, }; diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8cb1a4ff3d60..955cb2b2a748 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -659,6 +659,7 @@ mod tests { let expected = expected.get(pos).cloned(); let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], + arg_metadata: vec![None; 2], number_rows: 2, return_type: &Int64, }); @@ -674,6 +675,7 @@ mod tests { let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], + arg_metadata: vec![None; 2], number_rows: 2, return_type: &Int64, }); @@ -689,6 +691,7 @@ mod tests { let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], + arg_metadata: vec![None; 2], number_rows: 2, return_type: &Int64, }); @@ -719,6 +722,7 @@ mod tests { ColumnarValue::Scalar(regex_sv), ColumnarValue::Scalar(start_sv.clone()), ], + arg_metadata: vec![None; 3], number_rows: 3, return_type: &Int64, }); @@ -738,6 +742,7 @@ mod tests { ColumnarValue::Scalar(regex_sv), ColumnarValue::Scalar(start_sv.clone()), ], + arg_metadata: vec![None; 3], number_rows: 3, return_type: &Int64, }); @@ -757,6 +762,7 @@ mod tests { ColumnarValue::Scalar(regex_sv), ColumnarValue::Scalar(start_sv.clone()), ], + arg_metadata: vec![None; 3], number_rows: 3, return_type: &Int64, }); @@ -790,6 +796,7 @@ mod tests { ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], + arg_metadata: vec![None; 4], number_rows: 4, return_type: &Int64, }); @@ -811,6 +818,7 @@ mod tests { ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], + arg_metadata: vec![None; 4], number_rows: 4, return_type: &Int64, }); @@ -832,6 +840,7 @@ mod tests { ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], + arg_metadata: vec![None; 4], number_rows: 4, return_type: &Int64, }); @@ -914,6 +923,7 @@ mod tests { ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], + arg_metadata: vec![None; 4], number_rows: 4, return_type: &Int64, }); @@ -935,6 +945,7 @@ mod tests { ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], + arg_metadata: vec![None; 4], number_rows: 4, return_type: &Int64, }); @@ -956,6 +967,7 @@ mod tests { ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], + arg_metadata: vec![None; 4], number_rows: 4, return_type: &Int64, }); diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index c47d08d579e4..4f05752b820b 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -471,6 +471,7 @@ mod tests { let args = ScalarFunctionArgs { args: vec![c0, c1, c2, c3, c4], + arg_metadata: vec![None; 5], number_rows: 3, return_type: &Utf8, }; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index c2bad206db15..0c00db841acc 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -483,6 +483,7 @@ mod tests { let args = ScalarFunctionArgs { args: vec![c0, c1, c2], + arg_metadata: vec![None; 3], number_rows: 3, return_type: &Utf8, }; @@ -513,6 +514,7 @@ mod tests { let args = ScalarFunctionArgs { args: vec![c0, c1, c2], + arg_metadata: vec![None; 3], number_rows: 3, return_type: &Utf8, }; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 05a3edf61c5a..4e90c216b9f3 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -167,6 +167,7 @@ mod test { let args = ScalarFunctionArgs { args: vec![array, scalar], + arg_metadata: vec![None; 2], number_rows: 2, return_type: &DataType::Boolean, }; diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 226275b13999..fe7af7bd2d56 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -106,6 +106,7 @@ mod tests { let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], + arg_metadata: vec![None; 1], return_type: &DataType::Utf8, }; diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 2fec7305d183..5a36ad6cbc85 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -105,6 +105,7 @@ mod tests { let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], + arg_metadata: vec![None; 1], return_type: &DataType::Utf8, }; diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index c4a9f067e9f4..3ca81f017e28 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -471,8 +471,10 @@ mod tests { }) .unwrap_or(1); let return_type = fis.return_type(&type_array)?; + let arg_metadata = vec![None; args.len()]; let result = fis.invoke_with_args(ScalarFunctionArgs { args, + arg_metadata, number_rows: cardinality, return_type: &return_type, }); diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 47f3121ba2ce..65d01d218630 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -165,7 +165,8 @@ pub mod test { let (return_type, _nullable) = return_info.unwrap().into_parts(); assert_eq!(return_type, $EXPECTED_DATA_TYPE); - let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}); + let arg_metadata = vec![None; $ARGS.len()]; + let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_metadata, number_rows: cardinality, return_type: &return_type}); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array"); @@ -188,8 +189,9 @@ pub mod test { else { let (return_type, _nullable) = return_info.unwrap().into_parts(); + let arg_metadata = vec![None; $ARGS.len()]; // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}) { + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_metadata, number_rows: cardinality, return_type: &return_type}) { Ok(_) => assert!(false, "expected error"), Err(error) => { assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index b222ce3df1fc..01746e29d1bb 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -211,6 +211,12 @@ impl PhysicalExpr for ScalarFunctionExpr { .map(|e| e.evaluate(batch)) .collect::>>()?; + let arg_metadata = self + .args + .iter() + .map(|e| e.metadata(batch.schema_ref())) + .collect::>>()?; + let input_empty = args.is_empty(); let input_all_scalar = args .iter() @@ -219,6 +225,7 @@ impl PhysicalExpr for ScalarFunctionExpr { // evaluate the function let output = self.fun.invoke_with_args(ScalarFunctionArgs { args, + arg_metadata, number_rows: batch.num_rows(), return_type: &self.return_type, })?; From ea561b05ba4b950712462b03f9c22ca89c6c8db3 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 8 Apr 2025 18:57:58 -0400 Subject: [PATCH 03/25] Since everywhere we use this we immediately clone, go ahead and returned an owned version of the metadata for simplicity --- .../user_defined_scalar_functions.rs | 136 +++++++++++++++++- datafusion/expr/src/udf.rs | 6 +- .../physical-expr-common/src/physical_expr.rs | 18 +-- .../physical-expr-common/src/sort_expr.rs | 2 + .../physical-expr/src/expressions/binary.rs | 11 +- .../physical-expr/src/expressions/case.rs | 11 +- .../physical-expr/src/expressions/cast.rs | 11 +- .../physical-expr/src/expressions/column.rs | 13 +- .../physical-expr/src/expressions/in_list.rs | 11 +- .../src/expressions/is_not_null.rs | 11 +- .../physical-expr/src/expressions/is_null.rs | 11 +- .../physical-expr/src/expressions/like.rs | 11 +- .../physical-expr/src/expressions/literal.rs | 11 +- .../physical-expr/src/expressions/negative.rs | 11 +- .../physical-expr/src/expressions/no_op.rs | 11 +- .../physical-expr/src/expressions/not.rs | 11 +- .../physical-expr/src/expressions/try_cast.rs | 11 +- .../src/expressions/unknown_column.rs | 11 +- .../physical-expr/src/scalar_function.rs | 19 +-- .../physical-plan/src/aggregates/mod.rs | 3 +- datafusion/physical-plan/src/projection.rs | 22 +-- .../tests/cases/roundtrip_physical_plan.rs | 11 +- 22 files changed, 221 insertions(+), 152 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 264bd6b66a60..bd3a0040e9b9 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -15,11 +15,12 @@ // specific language governing permissions and limitations // under the License. +use std::collections::HashMap; use std::any::Any; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; -use arrow::array::as_string_array; +use arrow::array::{as_string_array, record_batch, UInt64Array}; use arrow::array::{ builder::BooleanBuilder, cast::AsArray, Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, @@ -35,7 +36,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::utils::take_function_args; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, not_impl_err, - plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue, + plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ @@ -1367,3 +1368,134 @@ async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> { async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } + +#[derive(Debug)] +struct MetadataBasedUdf { + name: String, + signature: Signature, + output_metadata: HashMap, +} + +impl MetadataBasedUdf { + fn new(metadata: HashMap) -> Self { + // The name we return must be unique. Otherwise we will not call distinct + // instances of this UDF. This is a small hack for the unit tests to get unique + // names, but you could do something more elegant with the metadata. + let name = format!("metadata_based_udf_{}", metadata.len()); + Self { + name, + signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + output_metadata: metadata, + } + } +} + +impl ScalarUDFImpl for MetadataBasedUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::UInt64) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert_eq!(args.arg_metadata.len(), 1); + let should_double = match &args.arg_metadata[0] { + Some(hashmap) => hashmap.get("modify_values").map(|v| v == "double_output").unwrap_or(false), + None => false + }; + let mulitplier = if should_double { 2 } else { 1 }; + + match &args.args[0] { + ColumnarValue::Array(array) => { + let array_values: Vec<_> = array.as_any().downcast_ref::().unwrap().iter().map(|v| v.map(|x| x * mulitplier)).collect(); + let array_ref = Arc::new(UInt64Array::from(array_values)) as ArrayRef; + Ok(ColumnarValue::Array(array_ref)) + } + ColumnarValue::Scalar(value) => { + let ScalarValue::UInt64(value) = value else { + return exec_err!("incorrect data type") + }; + + Ok(ColumnarValue::Scalar(ScalarValue::UInt64(value.map(|v| v * mulitplier)))) + } + } + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name == other.name() + } + + fn metadata(&self, _input_schema: &Schema) -> Option> { + Some(self.output_metadata.clone()) + } +} + + +#[tokio::test] +async fn test_metadata_based_udf() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata([("modify_values".to_string(), "double_output".to_string())].into_iter().collect()), + ])); + let batch = RecordBatch::try_new(schema, vec![Arc::clone(&data_array), Arc::clone(&data_array)])?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let no_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new(HashMap::new())); + let with_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new([("output_metatype".to_string(), "custom_value".to_string())].into_iter().collect())); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .project(vec![ + no_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_no_out"), + no_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_no_out"), + with_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_with_out"), + with_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_with_out"), + ] + )? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the values returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true).with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true).with_metadata(output_meta.clone()) + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [0, 5, 10, 15, 20]), + ("meta_with_in_no_out", UInt64, [0, 10, 20, 30, 40]), + ("meta_no_in_with_out", UInt64, [0, 5, 10, 15, 20]), + ("meta_with_in_with_out", UInt64, [0, 10, 20, 30, 40]) + )?.with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + ctx.deregister_table("t")?; + Ok(()) +} diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 70f91f88384f..7588379c0ff9 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -21,7 +21,7 @@ use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::{ColumnarValue, Documentation, Expr, Signature}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Schema}; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; @@ -722,6 +722,10 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } + + /// This describes the output metadata associated with this UDF. + /// Input field metadata is handled through `ScalarFunctionArgs` + fn metadata(&self, _input_schema: &Schema) -> Option> { None } } /// ScalarUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 4ad83553fbcc..537c8ffa4852 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -78,16 +78,10 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { /// Evaluate an expression against a RecordBatch fn evaluate(&self, batch: &RecordBatch) -> Result; /// Determine if this expression has any associated field metadata in the schema - /// In some circumstances we will get the metadata from the schema, and sometimes - /// we will get it from the physical expression itself. The lifetime of the result - /// must outlive both. - fn metadata<'a, 'b, 'c>( - &'a self, - input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c; + fn metadata( + &self, + input_schema: &Schema, + ) -> Result>>; /// Evaluate an expression against a RecordBatch after first applying a /// validity array fn evaluate_selection( @@ -474,12 +468,12 @@ where /// # use datafusion_expr_common::columnar_value::ColumnarValue; /// # use datafusion_physical_expr_common::physical_expr::{fmt_sql, DynEq, PhysicalExpr}; /// # #[derive(Debug, Hash, PartialOrd, PartialEq)] -/// # struct MyExpr {}; +/// # struct MyExpr {} /// # impl PhysicalExpr for MyExpr {fn as_any(&self) -> &dyn Any { unimplemented!() } /// # fn data_type(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn nullable(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result { unimplemented!() } -/// # fn metadata<'a>(&self, input_schema: &'a Schema) -> Result<&'a HashMap> { unimplemented!() } +/// # fn metadata(&self, input_schema: &Schema) -> Result>> { unimplemented!() } /// # fn children(&self) -> Vec<&Arc>{ unimplemented!() } /// # fn with_new_children(self: Arc, children: Vec>) -> Result> { unimplemented!() } /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE a > b THEN 1 ELSE 0 END") } diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 3a54b5b40399..0fa1f53f31fb 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -37,6 +37,7 @@ use itertools::Itertools; /// Example: /// ``` /// # use std::any::Any; +/// # use std::collections::HashMap; /// # use std::fmt::{Display, Formatter}; /// # use std::hash::Hasher; /// # use std::sync::Arc; @@ -56,6 +57,7 @@ use itertools::Itertools; /// # fn data_type(&self, input_schema: &Schema) -> Result {todo!()} /// # fn nullable(&self, input_schema: &Schema) -> Result {todo!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result {todo!() } +/// # fn metadata(&self, input_schema: &Schema) -> Result>> { unimplemented!() } /// # fn children(&self) -> Vec<&Arc> {todo!()} /// # fn with_new_children(self: Arc, children: Vec>) -> Result> {todo!()} /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { todo!() } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 80ef2c22b79d..bd97cdb24e89 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -433,13 +433,10 @@ impl PhysicalExpr for BinaryExpr { .map(ColumnarValue::Array) } - fn metadata<'a, 'b, 'c>( - &'a self, - _input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + _input_schema: &Schema, + ) -> Result>> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 77df48280e6d..1cac5a5bb32e 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -514,13 +514,10 @@ impl PhysicalExpr for CaseExpr { } } - fn metadata<'a, 'b, 'c>( - &'a self, - _input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + _input_schema: &Schema, + ) -> Result>> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 508318b1b21a..547f580c1522 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -145,13 +145,10 @@ impl PhysicalExpr for CastExpr { value.cast_to(&self.cast_type, Some(&self.cast_options)) } - fn metadata<'a, 'b, 'c>( - &'a self, - input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + input_schema: &Schema, + ) -> Result>> { self.expr.metadata(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 8f2e015f734e..db77a2995754 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -128,15 +128,12 @@ impl PhysicalExpr for Column { Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } - fn metadata<'a, 'b, 'c>( - &'a self, - input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + input_schema: &Schema, + ) -> Result>> { - Ok(Some(input_schema.field(self.index).metadata())) + Ok(Some(input_schema.field(self.index).metadata().clone())) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 3d5548aa2e83..d9e0d120b9b2 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -379,13 +379,10 @@ impl PhysicalExpr for InListExpr { Ok(ColumnarValue::Array(Arc::new(r))) } - fn metadata<'a, 'b, 'c>( - &'a self, - _input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + _input_schema: &Schema, + ) -> Result>> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 4df2db47f2ac..70ea1f938ba4 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -94,13 +94,10 @@ impl PhysicalExpr for IsNotNullExpr { } } - fn metadata<'a, 'b, 'c>( - &'a self, - input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + input_schema: &Schema, + ) -> Result>> { self.arg.metadata(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index bdc14e3286b2..d6dc26990eec 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -93,13 +93,10 @@ impl PhysicalExpr for IsNullExpr { } } - fn metadata<'a, 'b, 'c>( - &'a self, - input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + input_schema: &Schema, + ) -> Result>> { self.arg.metadata(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index 271366638659..2ba4eec221b9 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -130,13 +130,10 @@ impl PhysicalExpr for LikeExpr { } } - fn metadata<'a, 'b, 'c>( - &'a self, - _input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + _input_schema: &Schema, + ) -> Result>> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index c1bca4985646..26ff71d254f1 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -76,13 +76,10 @@ impl PhysicalExpr for Literal { Ok(ColumnarValue::Scalar(self.value.clone())) } - fn metadata<'a, 'b, 'c>( - &'a self, - _input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + _input_schema: &Schema, + ) -> Result>> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index c5e80aef24d3..cf198b7c227f 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -104,13 +104,10 @@ impl PhysicalExpr for NegativeExpr { } } - fn metadata<'a, 'b, 'c>( - &'a self, - input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + input_schema: &Schema, + ) -> Result>> { self.arg.metadata(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index d432c215c8d8..2b70f797da7c 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -68,13 +68,10 @@ impl PhysicalExpr for NoOp { internal_err!("NoOp::evaluate() should not be called") } - fn metadata<'a, 'b, 'c>( - &'a self, - _input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + _input_schema: &Schema, + ) -> Result>> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index e136090407dc..23142966147b 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -102,13 +102,10 @@ impl PhysicalExpr for NotExpr { } } - fn metadata<'a, 'b, 'c>( - &'a self, - input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + input_schema: &Schema, + ) -> Result>> { self.arg.metadata(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index 07740ab797dd..a4ea09d56a7d 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -111,13 +111,10 @@ impl PhysicalExpr for TryCastExpr { } } - fn metadata<'a, 'b, 'c>( - &'a self, - input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + input_schema: &Schema, + ) -> Result>> { self.expr.metadata(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/unknown_column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs index 839a5cd6fca8..6d3905f88430 100644 --- a/datafusion/physical-expr/src/expressions/unknown_column.rs +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -77,13 +77,10 @@ impl PhysicalExpr for UnKnownColumn { internal_err!("UnKnownColumn::evaluate() should not be called") } - fn metadata<'a, 'b, 'c>( - &'a self, - _input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + _input_schema: &Schema, + ) -> Result>> { Ok(None) } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 01746e29d1bb..deab423683d0 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -211,11 +211,15 @@ impl PhysicalExpr for ScalarFunctionExpr { .map(|e| e.evaluate(batch)) .collect::>>()?; - let arg_metadata = self + let arg_metadata_owned = self .args .iter() .map(|e| e.metadata(batch.schema_ref())) .collect::>>()?; + let arg_metadata = arg_metadata_owned + .iter() + .map(|opt_map| opt_map.as_ref()) + .collect::>(); let input_empty = args.is_empty(); let input_all_scalar = args @@ -247,15 +251,12 @@ impl PhysicalExpr for ScalarFunctionExpr { Ok(output) } - fn metadata<'a, 'b, 'c>( - &'a self, - _input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + input_schema: &Schema, + ) -> Result>> { - Ok(Some(&self.metadata)) + Ok(self.fun.as_ref().inner().metadata(input_schema)) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 8906468f68db..dec84a1634f7 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -27,7 +27,6 @@ use crate::aggregates::{ }; use crate::execution_plan::{CardinalityEffect, EmissionType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::projection::get_field_metadata; use crate::windows::get_ordered_partition_by_indices; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, @@ -286,7 +285,7 @@ impl PhysicalGroupBy { group_expr_nullable || expr.nullable(input_schema)?, ) .with_metadata( - get_field_metadata(expr, input_schema).unwrap_or_default(), + expr.metadata(input_schema)?.clone().unwrap_or_default(), ), ); } diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 72934c74446e..358e72025b72 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -26,7 +26,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use super::expressions::{CastExpr, Column, Literal}; +use super::expressions::{Column, Literal}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, @@ -85,7 +85,7 @@ impl ProjectionExec { e.nullable(&input_schema)?, ); field.set_metadata( - get_field_metadata(e, &input_schema).unwrap_or_default(), + e.metadata(&input_schema)?.clone().unwrap_or_default(), ); Ok(field) @@ -273,24 +273,6 @@ impl ExecutionPlan for ProjectionExec { } } -/// If 'e' is a direct column reference, returns the field level -/// metadata for that field, if any. Otherwise returns None -pub(crate) fn get_field_metadata( - e: &Arc, - input_schema: &Schema, -) -> Option> { - if let Some(cast) = e.as_any().downcast_ref::() { - return get_field_metadata(cast.expr(), input_schema); - } - - // Look up field by index in schema (not NAME as there can be more than one - // column with the same name) - e.as_any() - .downcast_ref::() - .map(|column| input_schema.field(column.index()).metadata()) - .cloned() -} - fn stats_projection( mut stats: Statistics, exprs: impl Iterator>, diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 4b9efb92656b..291f76d0689a 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -865,13 +865,10 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { unreachable!() } - fn metadata<'a, 'b, 'c>( - &'a self, - _input_schema: &'b Schema, - ) -> Result>> - where - 'a: 'c, - 'b: 'c, + fn metadata( + &self, + _input_schema: &Schema, + ) -> Result>> { Ok(None) } From 8daa3561652fd0d468c0c45e45055c2e8eb1ae27 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 8 Apr 2025 19:11:05 -0400 Subject: [PATCH 04/25] Cargo fmt --- .../user_defined_scalar_functions.rs | 53 +++++++++++++------ datafusion/expr/src/udf.rs | 4 +- datafusion/functions/benches/date_trunc.rs | 1 + datafusion/functions/benches/find_in_set.rs | 4 ++ datafusion/functions/benches/gcd.rs | 3 ++ .../physical-expr-common/src/physical_expr.rs | 5 +- .../physical-expr/src/expressions/binary.rs | 3 +- .../physical-expr/src/expressions/case.rs | 3 +- .../physical-expr/src/expressions/cast.rs | 6 +-- .../physical-expr/src/expressions/column.rs | 6 +-- .../physical-expr/src/expressions/in_list.rs | 3 +- .../src/expressions/is_not_null.rs | 6 +-- .../physical-expr/src/expressions/is_null.rs | 6 +-- .../physical-expr/src/expressions/like.rs | 3 +- .../physical-expr/src/expressions/literal.rs | 3 +- .../physical-expr/src/expressions/negative.rs | 6 +-- .../physical-expr/src/expressions/no_op.rs | 3 +- .../physical-expr/src/expressions/not.rs | 6 +-- .../physical-expr/src/expressions/try_cast.rs | 6 +-- .../src/expressions/unknown_column.rs | 3 +- .../physical-expr/src/scalar_function.rs | 6 +-- .../physical-plan/src/aggregates/mod.rs | 4 +- datafusion/physical-plan/src/projection.rs | 5 +- .../tests/cases/roundtrip_physical_plan.rs | 3 +- 24 files changed, 69 insertions(+), 82 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index bd3a0040e9b9..26ef91126e1e 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -15,8 +15,8 @@ // specific language governing permissions and limitations // under the License. -use std::collections::HashMap; use std::any::Any; +use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; @@ -1410,23 +1410,34 @@ impl ScalarUDFImpl for MetadataBasedUdf { fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { assert_eq!(args.arg_metadata.len(), 1); let should_double = match &args.arg_metadata[0] { - Some(hashmap) => hashmap.get("modify_values").map(|v| v == "double_output").unwrap_or(false), - None => false + Some(hashmap) => hashmap + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false), + None => false, }; let mulitplier = if should_double { 2 } else { 1 }; match &args.args[0] { ColumnarValue::Array(array) => { - let array_values: Vec<_> = array.as_any().downcast_ref::().unwrap().iter().map(|v| v.map(|x| x * mulitplier)).collect(); + let array_values: Vec<_> = array + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.map(|x| x * mulitplier)) + .collect(); let array_ref = Arc::new(UInt64Array::from(array_values)) as ArrayRef; Ok(ColumnarValue::Array(array_ref)) } ColumnarValue::Scalar(value) => { let ScalarValue::UInt64(value) = value else { - return exec_err!("incorrect data type") + return exec_err!("incorrect data type"); }; - Ok(ColumnarValue::Scalar(ScalarValue::UInt64(value.map(|v| v * mulitplier)))) + Ok(ColumnarValue::Scalar(ScalarValue::UInt64( + value.map(|v| v * mulitplier), + ))) } } } @@ -1440,21 +1451,31 @@ impl ScalarUDFImpl for MetadataBasedUdf { } } - #[tokio::test] async fn test_metadata_based_udf() -> Result<()> { let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; let schema = Arc::new(Schema::new(vec![ Field::new("no_metadata", DataType::UInt64, true), - Field::new("with_metadata", DataType::UInt64, true).with_metadata([("modify_values".to_string(), "double_output".to_string())].into_iter().collect()), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), ])); - let batch = RecordBatch::try_new(schema, vec![Arc::clone(&data_array), Arc::clone(&data_array)])?; + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; let ctx = SessionContext::new(); ctx.register_batch("t", batch)?; let t = ctx.table("t").await?; let no_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new(HashMap::new())); - let with_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new([("output_metatype".to_string(), "custom_value".to_string())].into_iter().collect())); + let with_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + )); let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) .project(vec![ @@ -1470,8 +1491,7 @@ async fn test_metadata_based_udf() -> Result<()> { with_output_meta_udf .call(vec![col("with_metadata")]) .alias("meta_with_in_with_out"), - ] - )? + ])? .build()?; let actual = DataFrame::new(ctx.state(), plan).collect().await?; @@ -1483,8 +1503,10 @@ async fn test_metadata_based_udf() -> Result<()> { let expected_schema = Schema::new(vec![ Field::new("meta_no_in_no_out", DataType::UInt64, true), Field::new("meta_with_in_no_out", DataType::UInt64, true), - Field::new("meta_no_in_with_out", DataType::UInt64, true).with_metadata(output_meta.clone()), - Field::new("meta_with_in_with_out", DataType::UInt64, true).with_metadata(output_meta.clone()) + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), ]); let expected = record_batch!( @@ -1492,7 +1514,8 @@ async fn test_metadata_based_udf() -> Result<()> { ("meta_with_in_no_out", UInt64, [0, 10, 20, 30, 40]), ("meta_no_in_with_out", UInt64, [0, 5, 10, 15, 20]), ("meta_with_in_with_out", UInt64, [0, 10, 20, 30, 40]) - )?.with_schema(Arc::new(expected_schema))?; + )? + .with_schema(Arc::new(expected_schema))?; assert_eq!(expected, actual[0]); diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 7588379c0ff9..0d63bdf9a6c8 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -725,7 +725,9 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// This describes the output metadata associated with this UDF. /// Input field metadata is handled through `ScalarFunctionArgs` - fn metadata(&self, _input_schema: &Schema) -> Option> { None } + fn metadata(&self, _input_schema: &Schema) -> Option> { + None + } } /// ScalarUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index e7e96fb7a9fa..0bcd7235c44c 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -54,6 +54,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: batch_len, return_type, }) diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index 9307525482c2..21ad978a749d 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -157,6 +157,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: n_rows, return_type: &DataType::Int32, })) @@ -168,6 +169,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: n_rows, return_type: &DataType::Int32, })) @@ -183,6 +185,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: n_rows, return_type: &DataType::Int32, })) @@ -194,6 +197,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: n_rows, return_type: &DataType::Int32, })) diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index f8c855c82ad4..c5e460c00e89 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -47,6 +47,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), array_b.clone()], + arg_metadata: vec![None; 2], number_rows: 0, return_type: &DataType::Int64, }) @@ -63,6 +64,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), scalar_b.clone()], + arg_metadata: vec![None; 2], number_rows: 0, return_type: &DataType::Int64, }) @@ -79,6 +81,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![scalar_a.clone(), scalar_b.clone()], + arg_metadata: vec![None; 2], number_rows: 0, return_type: &DataType::Int64, }) diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 537c8ffa4852..8a00b3b7c018 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -78,10 +78,7 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { /// Evaluate an expression against a RecordBatch fn evaluate(&self, batch: &RecordBatch) -> Result; /// Determine if this expression has any associated field metadata in the schema - fn metadata( - &self, - input_schema: &Schema, - ) -> Result>>; + fn metadata(&self, input_schema: &Schema) -> Result>>; /// Evaluate an expression against a RecordBatch after first applying a /// validity array fn evaluate_selection( diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index bd97cdb24e89..e0296a38a08b 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -436,8 +436,7 @@ impl PhysicalExpr for BinaryExpr { fn metadata( &self, _input_schema: &Schema, - ) -> Result>> - { + ) -> Result>> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 1cac5a5bb32e..00f0b4272457 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -517,8 +517,7 @@ impl PhysicalExpr for CaseExpr { fn metadata( &self, _input_schema: &Schema, - ) -> Result>> - { + ) -> Result>> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 547f580c1522..7ca394ef6049 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -145,11 +145,7 @@ impl PhysicalExpr for CastExpr { value.cast_to(&self.cast_type, Some(&self.cast_options)) } - fn metadata( - &self, - input_schema: &Schema, - ) -> Result>> - { + fn metadata(&self, input_schema: &Schema) -> Result>> { self.expr.metadata(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index db77a2995754..73abed8c130b 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -128,11 +128,7 @@ impl PhysicalExpr for Column { Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } - fn metadata( - &self, - input_schema: &Schema, - ) -> Result>> - { + fn metadata(&self, input_schema: &Schema) -> Result>> { Ok(Some(input_schema.field(self.index).metadata().clone())) } diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index d9e0d120b9b2..fd4c692d41f3 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -382,8 +382,7 @@ impl PhysicalExpr for InListExpr { fn metadata( &self, _input_schema: &Schema, - ) -> Result>> - { + ) -> Result>> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 70ea1f938ba4..d2d722d9efe5 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -94,11 +94,7 @@ impl PhysicalExpr for IsNotNullExpr { } } - fn metadata( - &self, - input_schema: &Schema, - ) -> Result>> - { + fn metadata(&self, input_schema: &Schema) -> Result>> { self.arg.metadata(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index d6dc26990eec..cb6d7b63c5d3 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -93,11 +93,7 @@ impl PhysicalExpr for IsNullExpr { } } - fn metadata( - &self, - input_schema: &Schema, - ) -> Result>> - { + fn metadata(&self, input_schema: &Schema) -> Result>> { self.arg.metadata(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index 2ba4eec221b9..3f82834df817 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -133,8 +133,7 @@ impl PhysicalExpr for LikeExpr { fn metadata( &self, _input_schema: &Schema, - ) -> Result>> - { + ) -> Result>> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 26ff71d254f1..5fa1b99bf2af 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -79,8 +79,7 @@ impl PhysicalExpr for Literal { fn metadata( &self, _input_schema: &Schema, - ) -> Result>> - { + ) -> Result>> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index cf198b7c227f..8d2431d1f2a4 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -104,11 +104,7 @@ impl PhysicalExpr for NegativeExpr { } } - fn metadata( - &self, - input_schema: &Schema, - ) -> Result>> - { + fn metadata(&self, input_schema: &Schema) -> Result>> { self.arg.metadata(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index 2b70f797da7c..9213657d3123 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -71,8 +71,7 @@ impl PhysicalExpr for NoOp { fn metadata( &self, _input_schema: &Schema, - ) -> Result>> - { + ) -> Result>> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 23142966147b..317ae98585c0 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -102,11 +102,7 @@ impl PhysicalExpr for NotExpr { } } - fn metadata( - &self, - input_schema: &Schema, - ) -> Result>> - { + fn metadata(&self, input_schema: &Schema) -> Result>> { self.arg.metadata(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index a4ea09d56a7d..327ed4a0e627 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -111,11 +111,7 @@ impl PhysicalExpr for TryCastExpr { } } - fn metadata( - &self, - input_schema: &Schema, - ) -> Result>> - { + fn metadata(&self, input_schema: &Schema) -> Result>> { self.expr.metadata(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/unknown_column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs index 6d3905f88430..c719e7f844c8 100644 --- a/datafusion/physical-expr/src/expressions/unknown_column.rs +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -80,8 +80,7 @@ impl PhysicalExpr for UnKnownColumn { fn metadata( &self, _input_schema: &Schema, - ) -> Result>> - { + ) -> Result>> { Ok(None) } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index deab423683d0..08a3d561b7ed 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -251,11 +251,7 @@ impl PhysicalExpr for ScalarFunctionExpr { Ok(output) } - fn metadata( - &self, - input_schema: &Schema, - ) -> Result>> - { + fn metadata(&self, input_schema: &Schema) -> Result>> { Ok(self.fun.as_ref().inner().metadata(input_schema)) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index dec84a1634f7..1bee720b3680 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -284,9 +284,7 @@ impl PhysicalGroupBy { expr.data_type(input_schema)?, group_expr_nullable || expr.nullable(input_schema)?, ) - .with_metadata( - expr.metadata(input_schema)?.clone().unwrap_or_default(), - ), + .with_metadata(expr.metadata(input_schema)?.clone().unwrap_or_default()), ); } if !self.is_single() { diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 358e72025b72..5eebf307820e 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -84,9 +84,8 @@ impl ProjectionExec { e.data_type(&input_schema)?, e.nullable(&input_schema)?, ); - field.set_metadata( - e.metadata(&input_schema)?.clone().unwrap_or_default(), - ); + field + .set_metadata(e.metadata(&input_schema)?.clone().unwrap_or_default()); Ok(field) }) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 291f76d0689a..5452dc09c48c 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -868,8 +868,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { fn metadata( &self, _input_schema: &Schema, - ) -> Result>> - { + ) -> Result>> { Ok(None) } From a2d5f9e30d1848c921cd1a8a6c5bab2352f3e890 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 8 Apr 2025 19:21:01 -0400 Subject: [PATCH 05/25] Benchmarks required args_metadata in tests --- datafusion/functions-nested/benches/map.rs | 1 + datafusion/functions/benches/chr.rs | 1 + datafusion/functions/benches/concat.rs | 1 + datafusion/functions/benches/cot.rs | 2 ++ datafusion/functions/benches/encoding.rs | 4 ++++ datafusion/functions/benches/initcap.rs | 3 +++ datafusion/functions/benches/isnan.rs | 2 ++ datafusion/functions/benches/iszero.rs | 2 ++ datafusion/functions/benches/lower.rs | 6 ++++++ datafusion/functions/benches/ltrim.rs | 1 + datafusion/functions/benches/make_date.rs | 4 ++++ datafusion/functions/benches/nullif.rs | 1 + datafusion/functions/benches/pad.rs | 6 ++++++ datafusion/functions/benches/random.rs | 2 ++ datafusion/functions/benches/repeat.rs | 7 +++++++ datafusion/functions/benches/reverse.rs | 4 ++++ datafusion/functions/benches/signum.rs | 2 ++ datafusion/functions/benches/strpos.rs | 4 ++++ datafusion/functions/benches/substr.rs | 9 +++++++++ datafusion/functions/benches/substr_index.rs | 1 + datafusion/functions/benches/to_char.rs | 3 +++ datafusion/functions/benches/to_hex.rs | 2 ++ datafusion/functions/benches/to_timestamp.rs | 6 ++++++ datafusion/functions/benches/trunc.rs | 2 ++ datafusion/functions/benches/upper.rs | 1 + datafusion/functions/benches/uuid.rs | 1 + 26 files changed, 78 insertions(+) diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 2774b24b902a..579dd155a3cc 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -103,6 +103,7 @@ fn criterion_benchmark(c: &mut Criterion) { map_udf() .invoke_with_args(ScalarFunctionArgs { args: vec![keys.clone(), values.clone()], + arg_metadata: vec![None; 2], number_rows: 1, return_type, }) diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 8575809c21c8..5f27c1f25a1c 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -56,6 +56,7 @@ fn criterion_benchmark(c: &mut Criterion) { cot_fn .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 45ca076e754f..18425d2146bc 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -45,6 +45,7 @@ fn criterion_benchmark(c: &mut Criterion) { concat() .invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index b2a9ca0b9f47..8801cf3a5995 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -39,6 +39,7 @@ fn criterion_benchmark(c: &mut Criterion) { cot_fn .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_metadata: vec![None; f32_args.len()], number_rows: size, return_type: &DataType::Float32, }) @@ -54,6 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { cot_fn .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_metadata: vec![None; f64_args.len()], number_rows: size, return_type: &DataType::Float64, }) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index cf8f8d2fd62c..a961f17c7eee 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -33,6 +33,7 @@ fn criterion_benchmark(c: &mut Criterion) { let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + arg_metadata: vec![None; 2], number_rows: size, return_type: &DataType::Utf8, }) @@ -44,6 +45,7 @@ fn criterion_benchmark(c: &mut Criterion) { decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) @@ -57,6 +59,7 @@ fn criterion_benchmark(c: &mut Criterion) { let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + arg_metadata: vec![None; 2], number_rows: size, return_type: &DataType::Utf8, }) @@ -68,6 +71,7 @@ fn criterion_benchmark(c: &mut Criterion) { decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index 97c76831b33c..58c7ddbbb43d 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -55,6 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -69,6 +70,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -81,6 +83,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index 42004cc24f69..0323fa138ce6 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -38,6 +38,7 @@ fn criterion_benchmark(c: &mut Criterion) { isnan .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_metadata: vec![None; f32_args.len()], number_rows: size, return_type: &DataType::Boolean, }) @@ -53,6 +54,7 @@ fn criterion_benchmark(c: &mut Criterion) { isnan .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_metadata: vec![None; f64_args.len()], number_rows: size, return_type: &DataType::Boolean, }) diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index 9e5f6a84804b..602567418c4e 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -39,6 +39,7 @@ fn criterion_benchmark(c: &mut Criterion) { iszero .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_metadata: vec![None; f32_args.len()], number_rows: batch_len, return_type: &DataType::Boolean, }) @@ -55,6 +56,7 @@ fn criterion_benchmark(c: &mut Criterion) { iszero .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_metadata: vec![None; f64_args.len()], number_rows: batch_len, return_type: &DataType::Boolean, }) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 534e5739225d..e5c0d2112737 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -129,6 +129,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) @@ -143,6 +144,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) @@ -158,6 +160,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) @@ -183,6 +186,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) @@ -197,6 +201,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) @@ -211,6 +216,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 457fb499f5a1..818d657ecfc4 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -145,6 +145,7 @@ fn run_with_string_type( let args_cloned = args.clone(); black_box(ltrim.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 8dd7a7a59773..2d67ec8551e6 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -69,6 +69,7 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![years.clone(), months.clone(), days.clone()], + arg_metadata: vec![None; 3], number_rows: batch_len, return_type: &DataType::Date32, }) @@ -90,6 +91,7 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), months.clone(), days.clone()], + arg_metadata: vec![None; 3], number_rows: batch_len, return_type: &DataType::Date32, }) @@ -111,6 +113,7 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), days.clone()], + arg_metadata: vec![None; 3], number_rows: batch_len, return_type: &DataType::Date32, }) @@ -129,6 +132,7 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), day.clone()], + arg_metadata: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index 9096c976bf31..e8421ab54a01 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -39,6 +39,7 @@ fn criterion_benchmark(c: &mut Criterion) { nullif .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index f78a53fbee19..94e2178b5e9d 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -106,6 +106,7 @@ fn criterion_benchmark(c: &mut Criterion) { lpad() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) @@ -121,6 +122,7 @@ fn criterion_benchmark(c: &mut Criterion) { lpad() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::LargeUtf8, }) @@ -136,6 +138,7 @@ fn criterion_benchmark(c: &mut Criterion) { lpad() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) @@ -155,6 +158,7 @@ fn criterion_benchmark(c: &mut Criterion) { rpad() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) @@ -170,6 +174,7 @@ fn criterion_benchmark(c: &mut Criterion) { rpad() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::LargeUtf8, }) @@ -186,6 +191,7 @@ fn criterion_benchmark(c: &mut Criterion) { rpad() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 78ebf23e02e0..1ad1dc750ef7 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -34,6 +34,7 @@ fn criterion_benchmark(c: &mut Criterion) { random_func .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_metadata: vec![], number_rows: 8192, return_type: &DataType::Float64, }) @@ -52,6 +53,7 @@ fn criterion_benchmark(c: &mut Criterion) { random_func .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_metadata: vec![], number_rows: 128, return_type: &DataType::Float64, }) diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 5cc6a177d9d9..8db229f9d592 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -77,6 +77,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(repeat.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: repeat_times as usize, return_type: &DataType::Utf8, })) @@ -95,6 +96,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(repeat.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: repeat_times as usize, return_type: &DataType::Utf8, })) @@ -113,6 +115,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(repeat.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: repeat_times as usize, return_type: &DataType::Utf8, })) @@ -140,6 +143,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(repeat.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: repeat_times as usize, return_type: &DataType::Utf8, })) @@ -158,6 +162,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(repeat.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: repeat_times as usize, return_type: &DataType::Utf8, })) @@ -176,6 +181,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(repeat.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: repeat_times as usize, return_type: &DataType::Utf8, })) @@ -203,6 +209,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(repeat.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: repeat_times as usize, return_type: &DataType::Utf8, })) diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index d61f8fb80517..49ef9797947b 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -46,6 +46,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_metadata: vec![None; args_string_ascii.len()], number_rows: N_ROWS, return_type: &DataType::Utf8, })) @@ -65,6 +66,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), + arg_metadata: vec![None; args_string_utf8.len()], number_rows: N_ROWS, return_type: &DataType::Utf8, })) @@ -86,6 +88,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_metadata: vec![None; args_string_view_ascii.len()], number_rows: N_ROWS, return_type: &DataType::Utf8, })) @@ -105,6 +108,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_metadata: vec![None; args_string_view_utf8.len()], number_rows: N_ROWS, return_type: &DataType::Utf8, })) diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 01939fad5f34..609e284edc2c 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -39,6 +39,7 @@ fn criterion_benchmark(c: &mut Criterion) { signum .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_metadata: vec![None; f32_args.len()], number_rows: batch_len, return_type: &DataType::Float32, }) @@ -56,6 +57,7 @@ fn criterion_benchmark(c: &mut Criterion) { signum .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_metadata: vec![None; f64_args.len()], number_rows: batch_len, return_type: &DataType::Float64, }) diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index df57c229e0ad..ed21e82d60b4 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -117,6 +117,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_metadata: vec![None; args_string_ascii.len()], number_rows: n_rows, return_type: &DataType::Int32, })) @@ -132,6 +133,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), + arg_metadata: vec![None; args_string_utf8.len()], number_rows: n_rows, return_type: &DataType::Int32, })) @@ -147,6 +149,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_metadata: vec![None; args_string_view_ascii.len()], number_rows: n_rows, return_type: &DataType::Int32, })) @@ -162,6 +165,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_metadata: vec![None; args_string_view_utf8.len()], number_rows: n_rows, return_type: &DataType::Int32, })) diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index 80ab70ef71b0..ff3af5b0eec2 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -112,6 +112,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -126,6 +127,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -140,6 +142,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -166,6 +169,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -183,6 +187,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -200,6 +205,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -226,6 +232,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -243,6 +250,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -260,6 +268,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index b1c1c3c34a95..204f105212e8 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -96,6 +96,7 @@ fn criterion_benchmark(c: &mut Criterion) { substr_index() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: batch_len, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 6f20a20dc219..b3a7692b9f4b 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -93,6 +93,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], + arg_metadata: vec![None; 2], number_rows: batch_len, return_type: &DataType::Utf8, }) @@ -114,6 +115,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], + arg_metadata: vec![None; 2], number_rows: batch_len, return_type: &DataType::Utf8, }) @@ -141,6 +143,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), pattern.clone()], + arg_metadata: vec![None; 2], number_rows: 1, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index a45d936c0a52..cc18dc21c798 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -36,6 +36,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; i32_args.len()], number_rows: batch_len, return_type: &DataType::Utf8, }) @@ -52,6 +53,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; i64_args.len()], number_rows: batch_len, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index aec56697691f..470e18e373c9 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -120,6 +120,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_metadata: vec![None; 1], number_rows: batch_len, return_type, }) @@ -138,6 +139,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_metadata: vec![None; 1], number_rows: batch_len, return_type, }) @@ -156,6 +158,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_metadata: vec![None; 1], number_rows: batch_len, return_type, }) @@ -179,6 +182,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: batch_len, return_type, }) @@ -210,6 +214,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: batch_len, return_type, }) @@ -242,6 +247,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_metadata: vec![None; args.len()], number_rows: batch_len, return_type, }) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index 7fc93921d2e7..2475f6f78444 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -39,6 +39,7 @@ fn criterion_benchmark(c: &mut Criterion) { trunc .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_metadata: vec![None; f32_args.len()], number_rows: size, return_type: &DataType::Float32, }) @@ -54,6 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { trunc .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_metadata: vec![None; f64_args.len()], number_rows: size, return_type: &DataType::Float64, }) diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index f0bee89c7d37..be01ab540927 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -42,6 +42,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(upper.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_metadata: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index 7b8d156fec21..5b4b3bd3489e 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -28,6 +28,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(uuid.invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_metadata: vec![], number_rows: 1024, return_type: &DataType::Utf8, })) From a3514dee90f0746ab24ebe66fbdd03ec26bdffb9 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 8 Apr 2025 19:21:44 -0400 Subject: [PATCH 06/25] Clippy warnings --- datafusion/ffi/src/udf/mod.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index e0350db1bd1b..7c100bd40c3c 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -86,6 +86,7 @@ pub struct FFI_ScalarUDF { /// Execute the underlying [`ScalarUDF`] and return the result as a `FFI_ArrowArray` /// within an AbiStable wrapper. + #[allow(clippy::type_complexity)] pub invoke_with_args: unsafe extern "C" fn( udf: &Self, args: RVec, @@ -376,7 +377,7 @@ impl ScalarUDFImpl for ForeignScalarUDF { maybe_map .map(|hashmap| { hashmap - .into_iter() + .iter() .map(|(k, v)| { (RString::from(k.clone()), RString::from(v.clone())) }) From 4e3b7bc6613f22b5b6eefbb4d0c6835adf78a308 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 9 Apr 2025 13:50:40 -0400 Subject: [PATCH 07/25] Switching over to passing Field around instead of metadata so we can handle extension types directly --- Cargo.lock | 1 + datafusion/core/Cargo.toml | 2 +- .../user_defined_scalar_functions.rs | 209 +++++++++++++++++- datafusion/expr/src/udf.rs | 21 +- datafusion/ffi/src/udf/mod.rs | 56 ++--- datafusion/functions-nested/benches/map.rs | 2 +- .../functions/benches/character_length.rs | 8 +- datafusion/functions/benches/chr.rs | 2 +- datafusion/functions/benches/concat.rs | 2 +- datafusion/functions/benches/cot.rs | 4 +- datafusion/functions/benches/date_bin.rs | 2 +- datafusion/functions/benches/date_trunc.rs | 2 +- datafusion/functions/benches/encoding.rs | 8 +- datafusion/functions/benches/find_in_set.rs | 8 +- datafusion/functions/benches/gcd.rs | 6 +- datafusion/functions/benches/initcap.rs | 6 +- datafusion/functions/benches/isnan.rs | 4 +- datafusion/functions/benches/iszero.rs | 4 +- datafusion/functions/benches/lower.rs | 12 +- datafusion/functions/benches/ltrim.rs | 2 +- datafusion/functions/benches/make_date.rs | 8 +- datafusion/functions/benches/nullif.rs | 2 +- datafusion/functions/benches/pad.rs | 12 +- datafusion/functions/benches/random.rs | 4 +- datafusion/functions/benches/repeat.rs | 14 +- datafusion/functions/benches/reverse.rs | 8 +- datafusion/functions/benches/signum.rs | 4 +- datafusion/functions/benches/strpos.rs | 8 +- datafusion/functions/benches/substr.rs | 18 +- datafusion/functions/benches/substr_index.rs | 2 +- datafusion/functions/benches/to_char.rs | 6 +- datafusion/functions/benches/to_hex.rs | 4 +- datafusion/functions/benches/to_timestamp.rs | 12 +- datafusion/functions/benches/trunc.rs | 4 +- datafusion/functions/benches/upper.rs | 2 +- datafusion/functions/benches/uuid.rs | 2 +- .../functions/src/core/union_extract.rs | 6 +- datafusion/functions/src/core/version.rs | 2 +- datafusion/functions/src/datetime/date_bin.rs | 30 +-- .../functions/src/datetime/date_trunc.rs | 4 +- .../functions/src/datetime/from_unixtime.rs | 4 +- .../functions/src/datetime/make_date.rs | 16 +- datafusion/functions/src/datetime/to_char.rs | 12 +- datafusion/functions/src/datetime/to_date.rs | 16 +- .../functions/src/datetime/to_local_time.rs | 4 +- .../functions/src/datetime/to_timestamp.rs | 4 +- datafusion/functions/src/math/log.rs | 20 +- datafusion/functions/src/math/power.rs | 4 +- datafusion/functions/src/math/signum.rs | 4 +- datafusion/functions/src/regex/regexpcount.rs | 24 +- datafusion/functions/src/string/concat.rs | 2 +- datafusion/functions/src/string/concat_ws.rs | 4 +- datafusion/functions/src/string/contains.rs | 2 +- datafusion/functions/src/string/lower.rs | 2 +- datafusion/functions/src/string/upper.rs | 2 +- .../functions/src/unicode/find_in_set.rs | 4 +- datafusion/functions/src/utils.rs | 8 +- .../physical-expr-common/src/physical_expr.rs | 11 +- .../physical-expr-common/src/sort_expr.rs | 4 +- .../physical-expr/src/expressions/binary.rs | 6 +- .../physical-expr/src/expressions/case.rs | 10 +- .../physical-expr/src/expressions/cast.rs | 7 +- .../physical-expr/src/expressions/column.rs | 7 +- .../physical-expr/src/expressions/in_list.rs | 5 +- .../src/expressions/is_not_null.rs | 6 +- .../physical-expr/src/expressions/is_null.rs | 7 +- .../physical-expr/src/expressions/like.rs | 8 +- .../physical-expr/src/expressions/literal.rs | 9 +- .../physical-expr/src/expressions/negative.rs | 6 +- .../physical-expr/src/expressions/no_op.rs | 10 +- .../physical-expr/src/expressions/not.rs | 7 +- .../physical-expr/src/expressions/try_cast.rs | 7 +- .../src/expressions/unknown_column.rs | 7 +- .../physical-expr/src/scalar_function.rs | 14 +- .../physical-plan/src/aggregates/mod.rs | 6 +- datafusion/physical-plan/src/projection.rs | 42 ++-- .../tests/cases/roundtrip_physical_plan.rs | 5 +- 77 files changed, 491 insertions(+), 338 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a30d0b4a7bd4..236f4aec83dd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -456,6 +456,7 @@ checksum = "7450c76ab7c5a6805be3440dc2e2096010da58f7cab301fdc996a4ee3ee74e49" dependencies = [ "bitflags 2.8.0", "serde", + "serde_json", ] [[package]] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index edc0d34b539a..3ace3e14ec25 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -95,7 +95,7 @@ extended_tests = [] [dependencies] arrow = { workspace = true } arrow-ipc = { workspace = true } -arrow-schema = { workspace = true } +arrow-schema = { workspace = true, features = ["canonical_extension_types"] } async-trait = { workspace = true } bytes = { workspace = true } bzip2 = { version = "0.5.2", optional = true } diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 26ef91126e1e..c0b5c8d0d23e 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -20,13 +20,15 @@ use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; -use arrow::array::{as_string_array, record_batch, UInt64Array}; +use arrow::array::{as_string_array, record_batch, Int8Array, UInt64Array}; use arrow::array::{ builder::BooleanBuilder, cast::AsArray, Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, }; use arrow::compute::kernels::numeric::add; use arrow::datatypes::{DataType, Field, Schema}; +use arrow_schema::extension::{Bool8, CanonicalExtensionType, ExtensionType}; +use arrow_schema::ArrowError; use datafusion::common::test_util::batches_to_string; use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; @@ -1373,7 +1375,7 @@ async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result, + output_field: Field, } impl MetadataBasedUdf { @@ -1382,10 +1384,12 @@ impl MetadataBasedUdf { // instances of this UDF. This is a small hack for the unit tests to get unique // names, but you could do something more elegant with the metadata. let name = format!("metadata_based_udf_{}", metadata.len()); + let output_field = + Field::new(&name, DataType::UInt64, true).with_metadata(metadata); Self { name, signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), - output_metadata: metadata, + output_field, } } } @@ -1408,9 +1412,10 @@ impl ScalarUDFImpl for MetadataBasedUdf { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - assert_eq!(args.arg_metadata.len(), 1); - let should_double = match &args.arg_metadata[0] { - Some(hashmap) => hashmap + assert_eq!(args.arg_fields.len(), 1); + let should_double = match &args.arg_fields[0] { + Some(field) => field + .metadata() .get("modify_values") .map(|v| v == "double_output") .unwrap_or(false), @@ -1446,8 +1451,8 @@ impl ScalarUDFImpl for MetadataBasedUdf { self.name == other.name() } - fn metadata(&self, _input_schema: &Schema) -> Option> { - Some(self.output_metadata.clone()) + fn output_field(&self, _input_schema: &Schema) -> Option { + Some(self.output_field.clone()) } } @@ -1497,7 +1502,7 @@ async fn test_metadata_based_udf() -> Result<()> { let actual = DataFrame::new(ctx.state(), plan).collect().await?; // To test for output metadata handling, we set the expected values on the result - // To test for input metadata handling, we check the values returned + // To test for input metadata handling, we check the numbers returned let mut output_meta = HashMap::new(); let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); let expected_schema = Schema::new(vec![ @@ -1522,3 +1527,189 @@ async fn test_metadata_based_udf() -> Result<()> { ctx.deregister_table("t")?; Ok(()) } + +/// This UDF is to test extension handling, both on the input and output +/// sides. For the input, we will handle the data differently if there is +/// the canonical extension type Bool8. For the output we will add a +/// user defined extension type. +#[derive(Debug)] +struct ExtensionBasedUdf { + name: String, + signature: Signature, +} + +impl Default for ExtensionBasedUdf { + fn default() -> Self { + Self { + name: "canonical_extension_udf".to_string(), + signature: Signature::exact(vec![DataType::Int8], Volatility::Immutable), + } + } +} +impl ScalarUDFImpl for ExtensionBasedUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert_eq!(args.arg_fields.len(), 1); + let input_field = args.arg_fields[0].unwrap(); + + let output_as_bool = matches!( + CanonicalExtensionType::try_from(input_field), + Ok(CanonicalExtensionType::Bool8(_)) + ); + + // If we have the extension type set, we are outputting a boolean value. + // Otherwise we output a string representation of the numeric value. + fn print_value(v: Option, as_bool: bool) -> Option { + v.map(|x| match as_bool { + true => format!("{}", x != 0), + false => format!("{x}"), + }) + } + + match &args.args[0] { + ColumnarValue::Array(array) => { + let array_values: Vec<_> = array + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| print_value(v, output_as_bool)) + .collect(); + let array_ref = Arc::new(StringArray::from(array_values)) as ArrayRef; + Ok(ColumnarValue::Array(array_ref)) + } + ColumnarValue::Scalar(value) => { + let ScalarValue::Int8(value) = value else { + return exec_err!("incorrect data type"); + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(print_value( + *value, + output_as_bool, + )))) + } + } + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name == other.name() + } + + fn output_field(&self, _input_schema: &Schema) -> Option { + Some( + Field::new("canonical_extension_udf", DataType::Utf8, true) + .with_extension_type(MyUserExtentionType {}), + ) + } +} + +struct MyUserExtentionType {} + +impl ExtensionType for MyUserExtentionType { + const NAME: &'static str = "my_user_extention_type"; + type Metadata = (); + + fn metadata(&self) -> &Self::Metadata { + &() + } + + fn serialize_metadata(&self) -> Option { + None + } + + fn deserialize_metadata( + _metadata: Option<&str>, + ) -> std::result::Result { + Ok(()) + } + + fn supports_data_type( + &self, + data_type: &DataType, + ) -> std::result::Result<(), ArrowError> { + if let DataType::Utf8 = data_type { + Ok(()) + } else { + Err(ArrowError::InvalidArgumentError( + "only utf8 supported".to_string(), + )) + } + } + + fn try_new( + _data_type: &DataType, + _metadata: Self::Metadata, + ) -> std::result::Result { + Ok(Self {}) + } +} + +#[tokio::test] +async fn test_extension_based_udf() -> Result<()> { + let data_array = Arc::new(Int8Array::from(vec![0, 0, 10, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_extension", DataType::Int8, true), + Field::new("with_extension", DataType::Int8, true).with_extension_type(Bool8), + ])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let extension_based_udf = ScalarUDF::from(ExtensionBasedUdf::default()); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .project(vec![ + extension_based_udf + .call(vec![col("no_extension")]) + .alias("without_bool8_extension"), + extension_based_udf + .call(vec![col("with_extension")]) + .alias("with_bool8_extension"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + // To test for output extension handling, we set the expected values on the result + // To test for input extensions handling, we check the strings returned + let expected_schema = Schema::new(vec![ + Field::new("without_bool8_extension", DataType::Utf8, true) + .with_extension_type(MyUserExtentionType {}), + Field::new("with_bool8_extension", DataType::Utf8, true) + .with_extension_type(MyUserExtentionType {}), + ]); + + let expected = record_batch!( + ("without_bool8_extension", Utf8, ["0", "0", "10", "20"]), + ( + "with_bool8_extension", + Utf8, + ["false", "false", "true", "true"] + ) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + ctx.deregister_table("t")?; + Ok(()) +} diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 0d63bdf9a6c8..9d5aa274da9d 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -21,12 +21,11 @@ use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::{ColumnarValue, Documentation, Expr, Signature}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; use std::cmp::Ordering; -use std::collections::HashMap; use std::fmt::Debug; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; @@ -297,8 +296,8 @@ where pub struct ScalarFunctionArgs<'a, 'b> { /// The evaluated arguments to the function pub args: Vec, - /// Metadata associated with each arg, if it exists - pub arg_metadata: Vec>>, + /// Field associated with each arg, if it exists + pub arg_fields: Vec>, /// The number of rows in record batch being evaluated pub number_rows: usize, /// The return type of the scalar function returned (from `return_type` or `return_type_from_args`) @@ -723,9 +722,9 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { None } - /// This describes the output metadata associated with this UDF. - /// Input field metadata is handled through `ScalarFunctionArgs` - fn metadata(&self, _input_schema: &Schema) -> Option> { + /// This describes the output field associated with this UDF. + /// Input field is handled through `ScalarFunctionArgs` + fn output_field(&self, _input_schema: &Schema) -> Option { None } } @@ -774,10 +773,6 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type(arg_types) } - fn aliases(&self) -> &[String] { - &self.aliases - } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { self.inner.return_type_from_args(args) } @@ -786,6 +781,10 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.invoke_with_args(args) } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn simplify( &self, args: Vec, diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 7c100bd40c3c..07faed7c0ec1 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -15,12 +15,12 @@ // specific language governing permissions and limitations // under the License. -use abi_stable::std_types::{RHashMap, ROption}; +use abi_stable::std_types::ROption; use abi_stable::{ std_types::{RResult, RString, RVec}, StableAbi, }; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ array::ArrayRef, error::ArrowError, @@ -42,7 +42,6 @@ use return_info::FFI_ReturnInfo; use return_type_args::{ FFI_ReturnTypeArgs, ForeignReturnTypeArgs, ForeignReturnTypeArgsOwned, }; -use std::collections::HashMap; use std::{ffi::c_void, sync::Arc}; use crate::{ @@ -90,7 +89,7 @@ pub struct FFI_ScalarUDF { pub invoke_with_args: unsafe extern "C" fn( udf: &Self, args: RVec, - arg_metadata: RVec>>, + arg_fields: RVec>, num_rows: usize, return_type: WrappedSchema, ) -> RResult, @@ -177,7 +176,7 @@ unsafe extern "C" fn coerce_types_fn_wrapper( unsafe extern "C" fn invoke_with_args_fn_wrapper( udf: &FFI_ScalarUDF, args: RVec, - arg_metadata: RVec>>, + arg_fields: RVec>, number_rows: usize, return_type: WrappedSchema, ) -> RResult { @@ -195,27 +194,24 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( let args = rresult_return!(args); let return_type = rresult_return!(DataType::try_from(&return_type.0)); - let arg_metadata_owned: Vec>> = arg_metadata + let arg_fields_owned = arg_fields .into_iter() - .map(|maybe_map| { - maybe_map - .map(|hashmap| { - hashmap - .into_iter() - .map(|kv| (String::from(kv.0), String::from(kv.1))) - .collect::>() - }) - .into() + .map(|maybe_field| { + Option::from(maybe_field.as_ref().map(|wrapped_field| { + (&wrapped_field.0).try_into().map_err(DataFusionError::from) + })) + .transpose() }) - .collect(); - let arg_metadata = arg_metadata_owned + .collect::>>>(); + let arg_fields_owned = rresult_return!(arg_fields_owned); + let arg_fields = arg_fields_owned .iter() .map(|maybe_map| maybe_map.as_ref()) .collect::>(); let args = ScalarFunctionArgs { args, - arg_metadata, + arg_fields, number_rows, return_type: &return_type, }; @@ -352,7 +348,7 @@ impl ScalarUDFImpl for ForeignScalarUDF { fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, - arg_metadata, + arg_fields, number_rows, return_type, } = invoke_args; @@ -371,20 +367,14 @@ impl ScalarUDFImpl for ForeignScalarUDF { .collect::, ArrowError>>()? .into(); - let arg_metadata = arg_metadata + let arg_fields_wrapped = arg_fields + .iter() + .map(|maybe_field| maybe_field.map(FFI_ArrowSchema::try_from).transpose()) + .collect::, ArrowError>>()?; + + let arg_fields = arg_fields_wrapped .into_iter() - .map(|maybe_map| { - maybe_map - .map(|hashmap| { - hashmap - .iter() - .map(|(k, v)| { - (RString::from(k.clone()), RString::from(v.clone())) - }) - .collect::>() - }) - .into() - }) + .map(|maybe_field| maybe_field.map(WrappedSchema).into()) .collect::>(); let return_type = WrappedSchema(FFI_ArrowSchema::try_from(return_type)?); @@ -393,7 +383,7 @@ impl ScalarUDFImpl for ForeignScalarUDF { (self.udf.invoke_with_args)( &self.udf, args, - arg_metadata, + arg_fields, number_rows, return_type, ) diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 579dd155a3cc..5813dd4109b3 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -103,7 +103,7 @@ fn criterion_benchmark(c: &mut Criterion) { map_udf() .invoke_with_args(ScalarFunctionArgs { args: vec![keys.clone(), values.clone()], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 1, return_type, }) diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index d4c28e4153e1..26d9aa93d92f 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -40,7 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), - arg_metadata: vec![None; args_string_ascii.len()], + arg_fields: vec![None; args_string_ascii.len()], number_rows: n_rows, return_type: &return_type, })) @@ -56,7 +56,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), - arg_metadata: vec![None; args_string_utf8.len()], + arg_fields: vec![None; args_string_utf8.len()], number_rows: n_rows, return_type: &return_type, })) @@ -72,7 +72,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), - arg_metadata: vec![None; args_string_view_ascii.len()], + arg_fields: vec![None; args_string_view_ascii.len()], number_rows: n_rows, return_type: &return_type, })) @@ -88,7 +88,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), - arg_metadata: vec![None; args_string_view_utf8.len()], + arg_fields: vec![None; args_string_view_utf8.len()], number_rows: n_rows, return_type: &return_type, })) diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 5f27c1f25a1c..568a8a507e81 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -56,7 +56,7 @@ fn criterion_benchmark(c: &mut Criterion) { cot_fn .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 18425d2146bc..47992eb28989 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -45,7 +45,7 @@ fn criterion_benchmark(c: &mut Criterion) { concat() .invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index 8801cf3a5995..44519f6dc595 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -39,7 +39,7 @@ fn criterion_benchmark(c: &mut Criterion) { cot_fn .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), - arg_metadata: vec![None; f32_args.len()], + arg_fields: vec![None; f32_args.len()], number_rows: size, return_type: &DataType::Float32, }) @@ -55,7 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { cot_fn .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), - arg_metadata: vec![None; f64_args.len()], + arg_fields: vec![None; f64_args.len()], number_rows: size, return_type: &DataType::Float64, }) diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index a30c29db2f31..9a741668a9d8 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -53,7 +53,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![interval.clone(), timestamps.clone()], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: batch_len, return_type: &return_type, }) diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index 0bcd7235c44c..c4e9b77f976a 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -54,7 +54,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: batch_len, return_type, }) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index a961f17c7eee..5985d1477d07 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -33,7 +33,7 @@ fn criterion_benchmark(c: &mut Criterion) { let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: size, return_type: &DataType::Utf8, }) @@ -45,7 +45,7 @@ fn criterion_benchmark(c: &mut Criterion) { decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) @@ -59,7 +59,7 @@ fn criterion_benchmark(c: &mut Criterion) { let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: size, return_type: &DataType::Utf8, }) @@ -71,7 +71,7 @@ fn criterion_benchmark(c: &mut Criterion) { decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index 21ad978a749d..5f580d855d8f 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -157,7 +157,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: n_rows, return_type: &DataType::Int32, })) @@ -169,7 +169,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: n_rows, return_type: &DataType::Int32, })) @@ -185,7 +185,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: n_rows, return_type: &DataType::Int32, })) @@ -197,7 +197,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: n_rows, return_type: &DataType::Int32, })) diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index c5e460c00e89..e208c2c092f0 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -47,7 +47,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), array_b.clone()], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 0, return_type: &DataType::Int64, }) @@ -64,7 +64,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), scalar_b.clone()], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 0, return_type: &DataType::Int64, }) @@ -81,7 +81,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![scalar_a.clone(), scalar_b.clone()], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 0, return_type: &DataType::Int64, }) diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index 58c7ddbbb43d..fcebe74fd65a 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -55,7 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -70,7 +70,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -83,7 +83,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index 0323fa138ce6..b77abcea2c78 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -38,7 +38,7 @@ fn criterion_benchmark(c: &mut Criterion) { isnan .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), - arg_metadata: vec![None; f32_args.len()], + arg_fields: vec![None; f32_args.len()], number_rows: size, return_type: &DataType::Boolean, }) @@ -54,7 +54,7 @@ fn criterion_benchmark(c: &mut Criterion) { isnan .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), - arg_metadata: vec![None; f64_args.len()], + arg_fields: vec![None; f64_args.len()], number_rows: size, return_type: &DataType::Boolean, }) diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index 602567418c4e..d0e7148b7364 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -39,7 +39,7 @@ fn criterion_benchmark(c: &mut Criterion) { iszero .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), - arg_metadata: vec![None; f32_args.len()], + arg_fields: vec![None; f32_args.len()], number_rows: batch_len, return_type: &DataType::Boolean, }) @@ -56,7 +56,7 @@ fn criterion_benchmark(c: &mut Criterion) { iszero .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), - arg_metadata: vec![None; f64_args.len()], + arg_fields: vec![None; f64_args.len()], number_rows: batch_len, return_type: &DataType::Boolean, }) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index e5c0d2112737..ae172e9766af 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -129,7 +129,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) @@ -144,7 +144,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) @@ -160,7 +160,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) @@ -186,7 +186,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) @@ -201,7 +201,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) @@ -216,7 +216,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 818d657ecfc4..938febfddeaf 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -145,7 +145,7 @@ fn run_with_string_type( let args_cloned = args.clone(); black_box(ltrim.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 2d67ec8551e6..5037da919730 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -69,7 +69,7 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![years.clone(), months.clone(), days.clone()], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: batch_len, return_type: &DataType::Date32, }) @@ -91,7 +91,7 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), months.clone(), days.clone()], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: batch_len, return_type: &DataType::Date32, }) @@ -113,7 +113,7 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), days.clone()], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: batch_len, return_type: &DataType::Date32, }) @@ -132,7 +132,7 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), day.clone()], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index e8421ab54a01..080c2890cee8 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -39,7 +39,7 @@ fn criterion_benchmark(c: &mut Criterion) { nullif .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index 94e2178b5e9d..2a488b369fe0 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -106,7 +106,7 @@ fn criterion_benchmark(c: &mut Criterion) { lpad() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) @@ -122,7 +122,7 @@ fn criterion_benchmark(c: &mut Criterion) { lpad() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::LargeUtf8, }) @@ -138,7 +138,7 @@ fn criterion_benchmark(c: &mut Criterion) { lpad() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) @@ -158,7 +158,7 @@ fn criterion_benchmark(c: &mut Criterion) { rpad() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) @@ -174,7 +174,7 @@ fn criterion_benchmark(c: &mut Criterion) { rpad() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::LargeUtf8, }) @@ -191,7 +191,7 @@ fn criterion_benchmark(c: &mut Criterion) { rpad() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 1ad1dc750ef7..a09b03affa11 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -34,7 +34,7 @@ fn criterion_benchmark(c: &mut Criterion) { random_func .invoke_with_args(ScalarFunctionArgs { args: vec![], - arg_metadata: vec![], + arg_fields: vec![], number_rows: 8192, return_type: &DataType::Float64, }) @@ -53,7 +53,7 @@ fn criterion_benchmark(c: &mut Criterion) { random_func .invoke_with_args(ScalarFunctionArgs { args: vec![], - arg_metadata: vec![], + arg_fields: vec![], number_rows: 128, return_type: &DataType::Float64, }) diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 8db229f9d592..9b2e52aaac35 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -77,7 +77,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(repeat.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: repeat_times as usize, return_type: &DataType::Utf8, })) @@ -96,7 +96,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(repeat.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: repeat_times as usize, return_type: &DataType::Utf8, })) @@ -115,7 +115,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(repeat.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: repeat_times as usize, return_type: &DataType::Utf8, })) @@ -143,7 +143,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(repeat.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: repeat_times as usize, return_type: &DataType::Utf8, })) @@ -162,7 +162,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(repeat.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: repeat_times as usize, return_type: &DataType::Utf8, })) @@ -181,7 +181,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(repeat.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: repeat_times as usize, return_type: &DataType::Utf8, })) @@ -209,7 +209,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(repeat.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: repeat_times as usize, return_type: &DataType::Utf8, })) diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index 49ef9797947b..6056d313d3dc 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -46,7 +46,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), - arg_metadata: vec![None; args_string_ascii.len()], + arg_fields: vec![None; args_string_ascii.len()], number_rows: N_ROWS, return_type: &DataType::Utf8, })) @@ -66,7 +66,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), - arg_metadata: vec![None; args_string_utf8.len()], + arg_fields: vec![None; args_string_utf8.len()], number_rows: N_ROWS, return_type: &DataType::Utf8, })) @@ -88,7 +88,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), - arg_metadata: vec![None; args_string_view_ascii.len()], + arg_fields: vec![None; args_string_view_ascii.len()], number_rows: N_ROWS, return_type: &DataType::Utf8, })) @@ -108,7 +108,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), - arg_metadata: vec![None; args_string_view_utf8.len()], + arg_fields: vec![None; args_string_view_utf8.len()], number_rows: N_ROWS, return_type: &DataType::Utf8, })) diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 609e284edc2c..46e499542a4f 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -39,7 +39,7 @@ fn criterion_benchmark(c: &mut Criterion) { signum .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), - arg_metadata: vec![None; f32_args.len()], + arg_fields: vec![None; f32_args.len()], number_rows: batch_len, return_type: &DataType::Float32, }) @@ -57,7 +57,7 @@ fn criterion_benchmark(c: &mut Criterion) { signum .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), - arg_metadata: vec![None; f64_args.len()], + arg_fields: vec![None; f64_args.len()], number_rows: batch_len, return_type: &DataType::Float64, }) diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index ed21e82d60b4..72902ea16b92 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -117,7 +117,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), - arg_metadata: vec![None; args_string_ascii.len()], + arg_fields: vec![None; args_string_ascii.len()], number_rows: n_rows, return_type: &DataType::Int32, })) @@ -133,7 +133,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), - arg_metadata: vec![None; args_string_utf8.len()], + arg_fields: vec![None; args_string_utf8.len()], number_rows: n_rows, return_type: &DataType::Int32, })) @@ -149,7 +149,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), - arg_metadata: vec![None; args_string_view_ascii.len()], + arg_fields: vec![None; args_string_view_ascii.len()], number_rows: n_rows, return_type: &DataType::Int32, })) @@ -165,7 +165,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), - arg_metadata: vec![None; args_string_view_utf8.len()], + arg_fields: vec![None; args_string_view_utf8.len()], number_rows: n_rows, return_type: &DataType::Int32, })) diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index ff3af5b0eec2..4b79d958669e 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -112,7 +112,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -127,7 +127,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -142,7 +142,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -169,7 +169,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -187,7 +187,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -205,7 +205,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -232,7 +232,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -250,7 +250,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) @@ -268,7 +268,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(substr.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8View, })) diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index 204f105212e8..c8ef989c866f 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -96,7 +96,7 @@ fn criterion_benchmark(c: &mut Criterion) { substr_index() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: batch_len, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index b3a7692b9f4b..2b6857f6e60b 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -93,7 +93,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: batch_len, return_type: &DataType::Utf8, }) @@ -115,7 +115,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: batch_len, return_type: &DataType::Utf8, }) @@ -143,7 +143,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), pattern.clone()], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 1, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index cc18dc21c798..4f89a710146b 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -36,7 +36,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; i32_args.len()], + arg_fields: vec![None; i32_args.len()], number_rows: batch_len, return_type: &DataType::Utf8, }) @@ -53,7 +53,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; i64_args.len()], + arg_fields: vec![None; i64_args.len()], number_rows: batch_len, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index 470e18e373c9..7a5c1b99fb85 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -120,7 +120,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: batch_len, return_type, }) @@ -139,7 +139,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: batch_len, return_type, }) @@ -158,7 +158,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: batch_len, return_type, }) @@ -182,7 +182,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: batch_len, return_type, }) @@ -214,7 +214,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: batch_len, return_type, }) @@ -247,7 +247,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: batch_len, return_type, }) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index 2475f6f78444..71383f9b12b8 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -39,7 +39,7 @@ fn criterion_benchmark(c: &mut Criterion) { trunc .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), - arg_metadata: vec![None; f32_args.len()], + arg_fields: vec![None; f32_args.len()], number_rows: size, return_type: &DataType::Float32, }) @@ -55,7 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { trunc .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), - arg_metadata: vec![None; f64_args.len()], + arg_fields: vec![None; f64_args.len()], number_rows: size, return_type: &DataType::Float64, }) diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index be01ab540927..9b730b49b4f7 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -42,7 +42,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(upper.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_metadata: vec![None; args.len()], + arg_fields: vec![None; args.len()], number_rows: size, return_type: &DataType::Utf8, })) diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index 5b4b3bd3489e..4b31477e20f9 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -28,7 +28,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(uuid.invoke_with_args(ScalarFunctionArgs { args: vec![], - arg_metadata: vec![], + arg_fields: vec![], number_rows: 1024, return_type: &DataType::Utf8, })) diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index 1b12b3c42518..d993f4536f4f 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -198,7 +198,7 @@ mod tests { )), ColumnarValue::Scalar(ScalarValue::new_utf8("str")), ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 1, return_type: &DataType::Utf8, })?; @@ -214,7 +214,7 @@ mod tests { )), ColumnarValue::Scalar(ScalarValue::new_utf8("str")), ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 1, return_type: &DataType::Utf8, })?; @@ -230,7 +230,7 @@ mod tests { )), ColumnarValue::Scalar(ScalarValue::new_utf8("str")), ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 1, return_type: &DataType::Utf8, })?; diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 4b3499eb0d2c..865489e4517a 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -105,7 +105,7 @@ mod test { let version = version_udf .invoke_with_args(ScalarFunctionArgs { args: vec![], - arg_metadata: vec![], + arg_fields: vec![], number_rows: 0, return_type: &DataType::Utf8, }) diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 6000a3a14f10..eafeee70f4a7 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -526,7 +526,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -546,7 +546,7 @@ mod tests { ColumnarValue::Array(timestamps), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: batch_len, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -563,7 +563,7 @@ mod tests { ))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -583,7 +583,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -602,7 +602,7 @@ mod tests { milliseconds: 1, }, )))], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -619,7 +619,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -642,7 +642,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -662,7 +662,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -679,7 +679,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -696,7 +696,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -718,7 +718,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -739,7 +739,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -763,7 +763,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -787,7 +787,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Array(timestamps), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: batch_len, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; @@ -916,7 +916,7 @@ mod tests { tz_opt.clone(), )), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: batch_len, return_type: &DataType::Timestamp( TimeUnit::Nanosecond, diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 82d89235eb71..63cb239a4285 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -731,7 +731,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::from("day")), ColumnarValue::Array(Arc::new(input)), ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: batch_len, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), }; @@ -894,7 +894,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::from("hour")), ColumnarValue::Array(Arc::new(input)), ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: batch_len, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), }; diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 05812ceb43a2..274ac437dd67 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -172,7 +172,7 @@ mod test { fn test_without_timezone() { let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(Int64(Some(1729900800)))], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 1, return_type: &DataType::Timestamp(Second, None), }; @@ -195,7 +195,7 @@ mod test { "America/New_York".to_string(), ))), ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 2, return_type: &DataType::Timestamp( Second, diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 9cea84a756d1..8f8f65fee0b0 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -236,7 +236,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }; @@ -256,7 +256,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::UInt64(Some(1))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }; @@ -276,7 +276,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("1".to_string()))), ColumnarValue::Scalar(ScalarValue::Utf8(Some("14".to_string()))), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }; @@ -300,7 +300,7 @@ mod tests { ColumnarValue::Array(months), ColumnarValue::Array(days), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: batch_len, return_type: &DataType::Date32, }; @@ -327,7 +327,7 @@ mod tests { // invalid number of arguments let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 1, return_type: &DataType::Date32, }; @@ -344,7 +344,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }; @@ -361,7 +361,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::UInt64(Some(u64::MAX))), ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }; @@ -378,7 +378,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(u32::MAX))), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }; diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 62202fab171b..f56fc53f6af7 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -387,7 +387,7 @@ mod tests { for (value, format, expected) in scalar_data { let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 1, return_type: &DataType::Utf8, }; @@ -471,7 +471,7 @@ mod tests { ColumnarValue::Scalar(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: batch_len, return_type: &DataType::Utf8, }; @@ -603,7 +603,7 @@ mod tests { ColumnarValue::Array(value as ArrayRef), ColumnarValue::Scalar(format), ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: batch_len, return_type: &DataType::Utf8, }; @@ -626,7 +626,7 @@ mod tests { ColumnarValue::Array(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: batch_len, return_type: &DataType::Utf8, }; @@ -649,7 +649,7 @@ mod tests { // invalid number of arguments let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 1, return_type: &DataType::Utf8, }; @@ -665,7 +665,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 1, return_type: &DataType::Utf8, }; diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index f058217c179f..a07ce87f84f0 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -210,7 +210,7 @@ mod tests { fn test_scalar(sv: ScalarValue, tc: &TestCase) { let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(sv)], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 1, return_type: &DataType::Date32, }; @@ -237,7 +237,7 @@ mod tests { let batch_len = date_array.len(); let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::new(date_array))], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: batch_len, return_type: &DataType::Date32, }; @@ -335,7 +335,7 @@ mod tests { ColumnarValue::Scalar(sv), ColumnarValue::Scalar(format_scalar), ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 1, return_type: &DataType::Date32, }; @@ -366,7 +366,7 @@ mod tests { ColumnarValue::Array(Arc::new(date_array)), ColumnarValue::Array(Arc::new(format_array)), ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: batch_len, return_type: &DataType::Date32, }; @@ -408,7 +408,7 @@ mod tests { ColumnarValue::Scalar(format1_scalar), ColumnarValue::Scalar(format2_scalar), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 1, return_type: &DataType::Date32, }; @@ -438,7 +438,7 @@ mod tests { let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(formatted_date_scalar)], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 1, return_type: &DataType::Date32, }; @@ -461,7 +461,7 @@ mod tests { let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(date_scalar)], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 1, return_type: &DataType::Date32, }; @@ -487,7 +487,7 @@ mod tests { let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(date_scalar)], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 1, return_type: &DataType::Date32, }; diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 8eab0ad9d293..b82da1aa1edb 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -541,7 +541,7 @@ mod tests { let res = ToLocalTimeFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(input)], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 1, return_type: &expected.data_type(), }) @@ -605,7 +605,7 @@ mod tests { let batch_size = input.len(); let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::new(input))], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: batch_size, return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), }; diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 55d6d0944ce0..5e0d0ca903a5 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -1015,7 +1015,7 @@ mod tests { assert!(matches!(rt, Timestamp(_, Some(_)))); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 4, return_type: &rt, }; @@ -1063,7 +1063,7 @@ mod tests { assert!(matches!(rt, Timestamp(_, None))); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 5, return_type: &rt, }; diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index d666a7778037..7b57fdd7a798 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -271,7 +271,7 @@ mod tests { ]))), // num ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 4, return_type: &DataType::Float64, }; @@ -284,7 +284,7 @@ mod tests { args: vec![ ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num ], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 1, return_type: &DataType::Float64, }; @@ -299,7 +299,7 @@ mod tests { args: vec![ ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num ], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 1, return_type: &DataType::Float32, }; @@ -327,7 +327,7 @@ mod tests { args: vec![ ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num ], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 1, return_type: &DataType::Float64, }; @@ -356,7 +356,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 1, return_type: &DataType::Float32, }; @@ -385,7 +385,7 @@ mod tests { ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 1, return_type: &DataType::Float64, }; @@ -415,7 +415,7 @@ mod tests { 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 4, return_type: &DataType::Float64, }; @@ -448,7 +448,7 @@ mod tests { 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: 4, return_type: &DataType::Float32, }; @@ -484,7 +484,7 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 4, return_type: &DataType::Float64, }; @@ -520,7 +520,7 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 4, return_type: &DataType::Float32, }; diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index a23097040817..f549f18e777e 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -202,7 +202,7 @@ mod tests { 3.0, 2.0, 4.0, 4.0, ]))), // exponent ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 4, return_type: &DataType::Float64, }; @@ -233,7 +233,7 @@ mod tests { ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent ], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 4, return_type: &DataType::Int64, }; diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index a2b69bb9ce65..f57492ad231e 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -159,7 +159,7 @@ mod test { ])); let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: array.len(), return_type: &DataType::Float32, }; @@ -204,7 +204,7 @@ mod test { ])); let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], number_rows: array.len(), return_type: &DataType::Float64, }; diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 955cb2b2a748..044d2229dac1 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -659,7 +659,7 @@ mod tests { let expected = expected.get(pos).cloned(); let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 2, return_type: &Int64, }); @@ -675,7 +675,7 @@ mod tests { let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 2, return_type: &Int64, }); @@ -691,7 +691,7 @@ mod tests { let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 2, return_type: &Int64, }); @@ -722,7 +722,7 @@ mod tests { ColumnarValue::Scalar(regex_sv), ColumnarValue::Scalar(start_sv.clone()), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 3, return_type: &Int64, }); @@ -742,7 +742,7 @@ mod tests { ColumnarValue::Scalar(regex_sv), ColumnarValue::Scalar(start_sv.clone()), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 3, return_type: &Int64, }); @@ -762,7 +762,7 @@ mod tests { ColumnarValue::Scalar(regex_sv), ColumnarValue::Scalar(start_sv.clone()), ], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 3, return_type: &Int64, }); @@ -796,7 +796,7 @@ mod tests { ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], - arg_metadata: vec![None; 4], + arg_fields: vec![None; 4], number_rows: 4, return_type: &Int64, }); @@ -818,7 +818,7 @@ mod tests { ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], - arg_metadata: vec![None; 4], + arg_fields: vec![None; 4], number_rows: 4, return_type: &Int64, }); @@ -840,7 +840,7 @@ mod tests { ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], - arg_metadata: vec![None; 4], + arg_fields: vec![None; 4], number_rows: 4, return_type: &Int64, }); @@ -923,7 +923,7 @@ mod tests { ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], - arg_metadata: vec![None; 4], + arg_fields: vec![None; 4], number_rows: 4, return_type: &Int64, }); @@ -945,7 +945,7 @@ mod tests { ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], - arg_metadata: vec![None; 4], + arg_fields: vec![None; 4], number_rows: 4, return_type: &Int64, }); @@ -967,7 +967,7 @@ mod tests { ColumnarValue::Scalar(start_sv.clone()), ColumnarValue::Scalar(flags_sv.clone()), ], - arg_metadata: vec![None; 4], + arg_fields: vec![None; 4], number_rows: 4, return_type: &Int64, }); diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 4f05752b820b..220db2d4b655 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -471,7 +471,7 @@ mod tests { let args = ScalarFunctionArgs { args: vec![c0, c1, c2, c3, c4], - arg_metadata: vec![None; 5], + arg_fields: vec![None; 5], number_rows: 3, return_type: &Utf8, }; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 0c00db841acc..8bf25587d9aa 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -483,7 +483,7 @@ mod tests { let args = ScalarFunctionArgs { args: vec![c0, c1, c2], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 3, return_type: &Utf8, }; @@ -514,7 +514,7 @@ mod tests { let args = ScalarFunctionArgs { args: vec![c0, c1, c2], - arg_metadata: vec![None; 3], + arg_fields: vec![None; 3], number_rows: 3, return_type: &Utf8, }; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 4e90c216b9f3..aef5a345fb5c 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -167,7 +167,7 @@ mod test { let args = ScalarFunctionArgs { args: vec![array, scalar], - arg_metadata: vec![None; 2], + arg_fields: vec![None; 2], number_rows: 2, return_type: &DataType::Boolean, }; diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index fe7af7bd2d56..8dd5eef5ea78 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -106,7 +106,7 @@ mod tests { let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], return_type: &DataType::Utf8, }; diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 5a36ad6cbc85..475632f8d4f9 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -105,7 +105,7 @@ mod tests { let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - arg_metadata: vec![None; 1], + arg_fields: vec![None; 1], return_type: &DataType::Utf8, }; diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index 3ca81f017e28..35a0682bdc04 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -471,10 +471,10 @@ mod tests { }) .unwrap_or(1); let return_type = fis.return_type(&type_array)?; - let arg_metadata = vec![None; args.len()]; + let arg_fields = vec![None; args.len()]; let result = fis.invoke_with_args(ScalarFunctionArgs { args, - arg_metadata, + arg_fields, number_rows: cardinality, return_type: &return_type, }); diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 65d01d218630..ebc3be99dbab 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -165,8 +165,8 @@ pub mod test { let (return_type, _nullable) = return_info.unwrap().into_parts(); assert_eq!(return_type, $EXPECTED_DATA_TYPE); - let arg_metadata = vec![None; $ARGS.len()]; - let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_metadata, number_rows: cardinality, return_type: &return_type}); + let arg_fields = vec![None; $ARGS.len()]; + let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_type: &return_type}); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array"); @@ -189,9 +189,9 @@ pub mod test { else { let (return_type, _nullable) = return_info.unwrap().into_parts(); - let arg_metadata = vec![None; $ARGS.len()]; + let arg_fields = vec![None; $ARGS.len()]; // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_metadata, number_rows: cardinality, return_type: &return_type}) { + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_type: &return_type}) { Ok(_) => assert!(false, "expected error"), Err(error) => { assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 8a00b3b7c018..09d372db2425 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -16,7 +16,6 @@ // under the License. use std::any::Any; -use std::collections::HashMap; use std::fmt; use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; @@ -26,7 +25,7 @@ use crate::utils::scatter; use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; @@ -77,8 +76,8 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { fn nullable(&self, input_schema: &Schema) -> Result; /// Evaluate an expression against a RecordBatch fn evaluate(&self, batch: &RecordBatch) -> Result; - /// Determine if this expression has any associated field metadata in the schema - fn metadata(&self, input_schema: &Schema) -> Result>>; + /// The output field associated with this expression + fn output_field(&self, input_schema: &Schema) -> Result>; /// Evaluate an expression against a RecordBatch after first applying a /// validity array fn evaluate_selection( @@ -460,7 +459,7 @@ where /// # use std::fmt::Formatter; /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; -/// # use arrow::datatypes::{DataType, Schema}; +/// # use arrow::datatypes::{DataType, Field, Schema}; /// # use datafusion_common::Result; /// # use datafusion_expr_common::columnar_value::ColumnarValue; /// # use datafusion_physical_expr_common::physical_expr::{fmt_sql, DynEq, PhysicalExpr}; @@ -470,7 +469,7 @@ where /// # fn data_type(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn nullable(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result { unimplemented!() } -/// # fn metadata(&self, input_schema: &Schema) -> Result>> { unimplemented!() } +/// # fn output_field(&self, input_schema: &Schema) -> Result> { unimplemented!() } /// # fn children(&self) -> Vec<&Arc>{ unimplemented!() } /// # fn with_new_children(self: Arc, children: Vec>) -> Result> { unimplemented!() } /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE a > b THEN 1 ELSE 0 END") } diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 0fa1f53f31fb..04c7719ecfa1 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -44,7 +44,7 @@ use itertools::Itertools; /// # use arrow::array::RecordBatch; /// # use datafusion_common::Result; /// # use arrow::compute::SortOptions; -/// # use arrow::datatypes::{DataType, Schema}; +/// # use arrow::datatypes::{DataType, Field, Schema}; /// # use datafusion_expr_common::columnar_value::ColumnarValue; /// # use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// # use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; @@ -57,7 +57,7 @@ use itertools::Itertools; /// # fn data_type(&self, input_schema: &Schema) -> Result {todo!()} /// # fn nullable(&self, input_schema: &Schema) -> Result {todo!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result {todo!() } -/// # fn metadata(&self, input_schema: &Schema) -> Result>> { unimplemented!() } +/// # fn output_field(&self, input_schema: &Schema) -> Result> { unimplemented!() } /// # fn children(&self) -> Vec<&Arc> {todo!()} /// # fn with_new_children(self: Arc, children: Vec>) -> Result> {todo!()} /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { todo!() } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index e0296a38a08b..0b0b25e68ef8 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -20,7 +20,6 @@ mod kernels; use crate::expressions::binary::kernels::concat_elements_utf8view; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::PhysicalExpr; -use std::collections::HashMap; use std::hash::Hash; use std::{any::Any, sync::Arc}; @@ -433,10 +432,7 @@ impl PhysicalExpr for BinaryExpr { .map(ColumnarValue::Array) } - fn metadata( - &self, - _input_schema: &Schema, - ) -> Result>> { + fn output_field(&self, _input_schema: &Schema) -> Result> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 00f0b4272457..291dc7788e3b 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -18,14 +18,13 @@ use crate::expressions::try_cast; use crate::PhysicalExpr; use std::borrow::Cow; -use std::collections::HashMap; use std::hash::Hash; use std::{any::Any, sync::Arc}; use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::cast::as_boolean_array; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, @@ -514,10 +513,7 @@ impl PhysicalExpr for CaseExpr { } } - fn metadata( - &self, - _input_schema: &Schema, - ) -> Result>> { + fn output_field(&self, _input_schema: &Schema) -> Result> { Ok(None) } @@ -610,7 +606,7 @@ mod tests { use crate::expressions::{binary, cast, col, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; - use arrow::datatypes::*; + use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 7ca394ef6049..57f2d61aa968 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -16,7 +16,6 @@ // under the License. use std::any::Any; -use std::collections::HashMap; use std::fmt; use std::hash::Hash; use std::sync::Arc; @@ -24,7 +23,7 @@ use std::sync::Arc; use crate::physical_expr::PhysicalExpr; use arrow::compute::{can_cast_types, CastOptions}; -use arrow::datatypes::{DataType, DataType::*, Schema}; +use arrow::datatypes::{DataType, DataType::*, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result}; @@ -145,8 +144,8 @@ impl PhysicalExpr for CastExpr { value.cast_to(&self.cast_type, Some(&self.cast_options)) } - fn metadata(&self, input_schema: &Schema) -> Result>> { - self.expr.metadata(input_schema) + fn output_field(&self, input_schema: &Schema) -> Result> { + self.expr.output_field(input_schema) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 73abed8c130b..b88880ea7670 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -18,13 +18,12 @@ //! Physical column reference: [`Column`] use std::any::Any; -use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; use arrow::{ - datatypes::{DataType, Schema, SchemaRef}, + datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; use datafusion_common::tree_node::{Transformed, TreeNode}; @@ -128,8 +127,8 @@ impl PhysicalExpr for Column { Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } - fn metadata(&self, input_schema: &Schema) -> Result>> { - Ok(Some(input_schema.field(self.index).metadata().clone())) + fn output_field(&self, input_schema: &Schema) -> Result> { + Ok(Some(input_schema.field(self.index).clone())) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index fd4c692d41f3..704cf47879f9 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -379,10 +379,7 @@ impl PhysicalExpr for InListExpr { Ok(ColumnarValue::Array(Arc::new(r))) } - fn metadata( - &self, - _input_schema: &Schema, - ) -> Result>> { + fn output_field(&self, _input_schema: &Schema) -> Result> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index d2d722d9efe5..f16114259c0f 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -18,6 +18,7 @@ //! IS NOT NULL expression use crate::PhysicalExpr; +use arrow::datatypes::Field; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -25,7 +26,6 @@ use arrow::{ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; -use std::collections::HashMap; use std::hash::Hash; use std::{any::Any, sync::Arc}; @@ -94,8 +94,8 @@ impl PhysicalExpr for IsNotNullExpr { } } - fn metadata(&self, input_schema: &Schema) -> Result>> { - self.arg.metadata(input_schema) + fn output_field(&self, input_schema: &Schema) -> Result> { + self.arg.output_field(input_schema) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index cb6d7b63c5d3..32ef64fe0230 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -19,13 +19,12 @@ use crate::PhysicalExpr; use arrow::{ - datatypes::{DataType, Schema}, + datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; -use std::collections::HashMap; use std::hash::Hash; use std::{any::Any, sync::Arc}; @@ -93,8 +92,8 @@ impl PhysicalExpr for IsNullExpr { } } - fn metadata(&self, input_schema: &Schema) -> Result>> { - self.arg.metadata(input_schema) + fn output_field(&self, input_schema: &Schema) -> Result> { + self.arg.output_field(input_schema) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index 3f82834df817..fa51def9eb09 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -16,12 +16,11 @@ // under the License. use crate::PhysicalExpr; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::apply_cmp; -use std::collections::HashMap; use std::hash::Hash; use std::{any::Any, sync::Arc}; @@ -130,10 +129,7 @@ impl PhysicalExpr for LikeExpr { } } - fn metadata( - &self, - _input_schema: &Schema, - ) -> Result>> { + fn output_field(&self, _input_schema: &Schema) -> Result> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 5fa1b99bf2af..08c4bd38b383 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -18,12 +18,12 @@ //! Literal expressions for physical operations use std::any::Any; -use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; use crate::physical_expr::PhysicalExpr; +use arrow::datatypes::Field; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -76,10 +76,7 @@ impl PhysicalExpr for Literal { Ok(ColumnarValue::Scalar(self.value.clone())) } - fn metadata( - &self, - _input_schema: &Schema, - ) -> Result>> { + fn output_field(&self, _input_schema: &Schema) -> Result> { Ok(None) } @@ -120,7 +117,7 @@ mod tests { use super::*; use arrow::array::Int32Array; - use arrow::datatypes::*; + use datafusion_common::cast::as_int32_array; use datafusion_physical_expr_common::physical_expr::fmt_sql; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 8d2431d1f2a4..fe7b7cab429f 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -18,12 +18,12 @@ //! Negation (-) expression use std::any::Any; -use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; use crate::PhysicalExpr; +use arrow::datatypes::Field; use arrow::{ compute::kernels::numeric::neg_wrapping, datatypes::{DataType, Schema}, @@ -104,8 +104,8 @@ impl PhysicalExpr for NegativeExpr { } } - fn metadata(&self, input_schema: &Schema) -> Result>> { - self.arg.metadata(input_schema) + fn output_field(&self, input_schema: &Schema) -> Result> { + self.arg.output_field(input_schema) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index 9213657d3123..c918aa5f23f2 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -18,16 +18,15 @@ //! NoOp placeholder for physical operations use std::any::Any; -use std::collections::HashMap; use std::hash::Hash; use std::sync::Arc; +use crate::PhysicalExpr; +use arrow::datatypes::Field; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; - -use crate::PhysicalExpr; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; @@ -68,10 +67,7 @@ impl PhysicalExpr for NoOp { internal_err!("NoOp::evaluate() should not be called") } - fn metadata( - &self, - _input_schema: &Schema, - ) -> Result>> { + fn output_field(&self, _input_schema: &Schema) -> Result> { Ok(None) } diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 317ae98585c0..d23b708efd36 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -18,14 +18,13 @@ //! Not expression use std::any::Any; -use std::collections::HashMap; use std::fmt; use std::hash::Hash; use std::sync::Arc; use crate::PhysicalExpr; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{cast::as_boolean_array, internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; @@ -102,8 +101,8 @@ impl PhysicalExpr for NotExpr { } } - fn metadata(&self, input_schema: &Schema) -> Result>> { - self.arg.metadata(input_schema) + fn output_field(&self, input_schema: &Schema) -> Result> { + self.arg.output_field(input_schema) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index 327ed4a0e627..e12f5af94360 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -16,7 +16,6 @@ // under the License. use std::any::Any; -use std::collections::HashMap; use std::fmt; use std::hash::Hash; use std::sync::Arc; @@ -24,7 +23,7 @@ use std::sync::Arc; use crate::PhysicalExpr; use arrow::compute; use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use compute::can_cast_types; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; @@ -111,8 +110,8 @@ impl PhysicalExpr for TryCastExpr { } } - fn metadata(&self, input_schema: &Schema) -> Result>> { - self.expr.metadata(input_schema) + fn output_field(&self, input_schema: &Schema) -> Result> { + self.expr.output_field(input_schema) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/unknown_column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs index c719e7f844c8..f0c18e785f28 100644 --- a/datafusion/physical-expr/src/expressions/unknown_column.rs +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -18,12 +18,12 @@ //! UnKnownColumn expression use std::any::Any; -use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::sync::Arc; use crate::PhysicalExpr; +use arrow::datatypes::Field; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -77,10 +77,7 @@ impl PhysicalExpr for UnKnownColumn { internal_err!("UnKnownColumn::evaluate() should not be called") } - fn metadata( - &self, - _input_schema: &Schema, - ) -> Result>> { + fn output_field(&self, _input_schema: &Schema) -> Result> { Ok(None) } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 08a3d561b7ed..12758a473fe9 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -39,7 +39,7 @@ use crate::expressions::Literal; use crate::PhysicalExpr; use arrow::array::{Array, RecordBatch}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; @@ -211,12 +211,12 @@ impl PhysicalExpr for ScalarFunctionExpr { .map(|e| e.evaluate(batch)) .collect::>>()?; - let arg_metadata_owned = self + let arg_fields_owned = self .args .iter() - .map(|e| e.metadata(batch.schema_ref())) + .map(|e| e.output_field(batch.schema_ref())) .collect::>>()?; - let arg_metadata = arg_metadata_owned + let arg_fields = arg_fields_owned .iter() .map(|opt_map| opt_map.as_ref()) .collect::>(); @@ -229,7 +229,7 @@ impl PhysicalExpr for ScalarFunctionExpr { // evaluate the function let output = self.fun.invoke_with_args(ScalarFunctionArgs { args, - arg_metadata, + arg_fields, number_rows: batch.num_rows(), return_type: &self.return_type, })?; @@ -251,8 +251,8 @@ impl PhysicalExpr for ScalarFunctionExpr { Ok(output) } - fn metadata(&self, input_schema: &Schema) -> Result>> { - Ok(self.fun.as_ref().inner().metadata(input_schema)) + fn output_field(&self, input_schema: &Schema) -> Result> { + Ok(self.fun.as_ref().inner().output_field(input_schema)) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 1bee720b3680..785752c4cd08 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -284,7 +284,11 @@ impl PhysicalGroupBy { expr.data_type(input_schema)?, group_expr_nullable || expr.nullable(input_schema)?, ) - .with_metadata(expr.metadata(input_schema)?.clone().unwrap_or_default()), + .with_metadata( + expr.output_field(input_schema)? + .map(|f| f.metadata().clone()) + .unwrap_or_default(), + ), ); } if !self.is_single() { diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 5eebf307820e..b5be14427bd0 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -79,13 +79,17 @@ impl ProjectionExec { let fields: Result> = expr .iter() .map(|(e, name)| { - let mut field = Field::new( + let metadata = e + .output_field(&input_schema)? + .map(|field| field.metadata().clone()) + .unwrap_or_default(); + + let field = Field::new( name, e.data_type(&input_schema)?, e.nullable(&input_schema)?, - ); - field - .set_metadata(e.metadata(&input_schema)?.clone().unwrap_or_default()); + ) + .with_metadata(metadata); Ok(field) }) @@ -197,23 +201,11 @@ impl ExecutionPlan for ProjectionExec { &self.cache } - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - fn maintains_input_order(&self) -> Vec { // Tell optimizer this operator doesn't reorder its input vec![true] } - fn with_new_children( - self: Arc, - mut children: Vec>, - ) -> Result> { - ProjectionExec::try_new(self.expr.clone(), children.swap_remove(0)) - .map(|p| Arc::new(p) as _) - } - fn benefits_from_input_partitioning(&self) -> Vec { let all_simple_exprs = self .expr @@ -224,6 +216,18 @@ impl ExecutionPlan for ProjectionExec { vec![!all_simple_exprs] } + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + ProjectionExec::try_new(self.expr.clone(), children.swap_remove(0)) + .map(|p| Arc::new(p) as _) + } + fn execute( &self, partition: usize, @@ -1074,13 +1078,11 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let exec = test::scan_partitioned(1); - let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?) - .await - .unwrap(); + let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?).await?; let projection = ProjectionExec::try_new(vec![], exec)?; let stream = projection.execute(0, Arc::clone(&task_ctx))?; - let output = collect(stream).await.unwrap(); + let output = collect(stream).await?; assert_eq!(output.len(), expected.len()); Ok(()) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 5452dc09c48c..6387578be1fa 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -865,10 +865,7 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { unreachable!() } - fn metadata( - &self, - _input_schema: &Schema, - ) -> Result>> { + fn output_field(&self, _input_schema: &Schema) -> Result> { Ok(None) } From 281a83e717cda9b1e80635b50b32126b9c32eaad Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 15 Apr 2025 15:35:42 -0400 Subject: [PATCH 08/25] Switching return_type_from_args to return_field_from_args --- .../user_defined_scalar_functions.rs | 18 +++-- datafusion/expr/src/expr_schema.rs | 23 +++--- datafusion/expr/src/lib.rs | 2 +- datafusion/expr/src/udf.rs | 77 ++++--------------- datafusion/ffi/src/udf/mod.rs | 39 +++++----- datafusion/ffi/src/udf/return_info.rs | 53 ------------- datafusion/ffi/src/udf/return_type_args.rs | 66 +++++++--------- datafusion/ffi/src/util.rs | 27 ++++++- datafusion/functions/src/core/arrow_cast.rs | 15 ++-- datafusion/functions/src/core/coalesce.rs | 17 ++-- datafusion/functions/src/core/getfield.rs | 18 ++--- datafusion/functions/src/core/named_struct.rs | 14 ++-- .../functions/src/core/union_extract.rs | 20 ++--- .../functions/src/datetime/date_part.rs | 12 +-- .../functions/src/datetime/from_unixtime.rs | 17 ++-- datafusion/functions/src/datetime/now.rs | 15 ++-- datafusion/functions/src/unicode/strpos.rs | 23 +++--- datafusion/functions/src/utils.rs | 7 +- .../physical-expr/src/scalar_function.rs | 33 ++++---- 19 files changed, 203 insertions(+), 293 deletions(-) delete mode 100644 datafusion/ffi/src/udf/return_info.rs diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index c0b5c8d0d23e..333b3339d652 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -43,8 +43,8 @@ use datafusion_common::{ use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, - OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, + OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, }; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; @@ -806,7 +806,7 @@ impl ScalarUDFImpl for TakeUDF { &self.signature } fn return_type(&self, _args: &[DataType]) -> Result { - not_impl_err!("Not called because the return_type_from_args is implemented") + not_impl_err!("Not called because the return_field_from_args is implemented") } /// This function returns the type of the first or second argument based on @@ -814,9 +814,9 @@ impl ScalarUDFImpl for TakeUDF { /// /// 1. If the third argument is '0', return the type of the first argument /// 2. If the third argument is '1', return the type of the second argument - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - if args.arg_types.len() != 3 { - return plan_err!("Expected 3 arguments, got {}.", args.arg_types.len()); + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + if args.arg_fields.len() != 3 { + return plan_err!("Expected 3 arguments, got {}.", args.arg_fields.len()); } let take_idx = if let Some(take_idx) = args.scalar_arguments.get(2) { @@ -841,8 +841,10 @@ impl ScalarUDFImpl for TakeUDF { ); }; - Ok(ReturnInfo::new_nullable( - args.arg_types[take_idx].to_owned(), + Ok(Field::new( + self.name(), + args.arg_fields[take_idx].data_type().to_owned(), + true, )) } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index a349c83a4934..b93389679e60 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -24,7 +24,7 @@ use crate::expr::{ use crate::type_coercion::functions::{ data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, }; -use crate::udf::ReturnTypeArgs; +use crate::udf::ReturnFieldArgs; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; @@ -418,11 +418,12 @@ impl ExprSchemable for Expr { self.data_type_and_nullable_with_window_function(schema, window_function) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let (arg_types, nullables): (Vec, Vec) = args + let (arg_types, fields): (Vec, Vec>) = args .iter() - .map(|e| e.data_type_and_nullable(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()? .into_iter() + .map(|f| (f.data_type().clone(), f)) .unzip(); // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_types, func) @@ -440,6 +441,12 @@ impl ExprSchemable for Expr { ) ) })?; + let new_fields = fields.into_iter() + .zip(new_data_types) + .map(|(f, d)| { + f.as_ref().clone().with_data_type(d) + }) + .collect::>(); let arguments = args .iter() @@ -448,15 +455,13 @@ impl ExprSchemable for Expr { _ => None, }) .collect::>(); - let args = ReturnTypeArgs { - arg_types: &new_data_types, + let args = ReturnFieldArgs { + arg_fields: &new_fields, scalar_arguments: &arguments, - nullables: &nullables, }; - let (return_type, nullable) = - func.return_type_from_args(args)?.into_parts(); - Ok((return_type, nullable)) + let return_field = func.return_field_from_args(args)?; + Ok((return_field.data_type().clone(), return_field.is_nullable())) } _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index d3cc881af361..258c28dc89c5 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -104,7 +104,7 @@ pub use udaf::{ SetMonotonicity, StatisticsArgs, }; pub use udf::{ - scalar_doc_sections, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, + scalar_doc_sections, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, }; pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 9d5aa274da9d..b06a1b273216 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -170,7 +170,7 @@ impl ScalarUDF { /// /// # Notes /// - /// If a function implement [`ScalarUDFImpl::return_type_from_args`], + /// If a function implement [`ScalarUDFImpl::return_field_from_args`], /// its [`ScalarUDFImpl::return_type`] should raise an error. /// /// See [`ScalarUDFImpl::return_type`] for more details. @@ -180,9 +180,9 @@ impl ScalarUDF { /// Return the datatype this function returns given the input argument types. /// - /// See [`ScalarUDFImpl::return_type_from_args`] for more details. - pub fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - self.inner.return_type_from_args(args) + /// See [`ScalarUDFImpl::return_field_from_args`] for more details. + pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.inner.return_field_from_args(args) } /// Do the function rewrite @@ -300,7 +300,7 @@ pub struct ScalarFunctionArgs<'a, 'b> { pub arg_fields: Vec>, /// The number of rows in record batch being evaluated pub number_rows: usize, - /// The return type of the scalar function returned (from `return_type` or `return_type_from_args`) + /// The return type of the scalar function returned (from `return_type` or `return_field_from_args`) /// when creating the physical expression from the logical expression pub return_type: &'a DataType, } @@ -311,11 +311,11 @@ pub struct ScalarFunctionArgs<'a, 'b> { /// such as the type of the arguments, any scalar arguments and if the /// arguments can (ever) be null /// -/// See [`ScalarUDFImpl::return_type_from_args`] for more information +/// See [`ScalarUDFImpl::return_field_from_args`] for more information #[derive(Debug)] -pub struct ReturnTypeArgs<'a> { +pub struct ReturnFieldArgs<'a> { /// The data types of the arguments to the function - pub arg_types: &'a [DataType], + pub arg_fields: &'a [Field], /// Is argument `i` to the function a scalar (constant) /// /// If argument `i` is not a scalar, it will be None @@ -323,52 +323,6 @@ pub struct ReturnTypeArgs<'a> { /// For example, if a function is called like `my_function(column_a, 5)` /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` pub scalar_arguments: &'a [Option<&'a ScalarValue>], - /// Can argument `i` (ever) null? - pub nullables: &'a [bool], -} - -/// Return metadata for this function. -/// -/// See [`ScalarUDFImpl::return_type_from_args`] for more information -#[derive(Debug)] -pub struct ReturnInfo { - return_type: DataType, - nullable: bool, -} - -impl ReturnInfo { - pub fn new(return_type: DataType, nullable: bool) -> Self { - Self { - return_type, - nullable, - } - } - - pub fn new_nullable(return_type: DataType) -> Self { - Self { - return_type, - nullable: true, - } - } - - pub fn new_non_nullable(return_type: DataType) -> Self { - Self { - return_type, - nullable: false, - } - } - - pub fn return_type(&self) -> &DataType { - &self.return_type - } - - pub fn nullable(&self) -> bool { - self.nullable - } - - pub fn into_parts(self) -> (DataType, bool) { - (self.return_type, self.nullable) - } } /// Trait for implementing user defined scalar functions. @@ -482,7 +436,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// /// # Notes /// - /// If you provide an implementation for [`Self::return_type_from_args`], + /// If you provide an implementation for [`Self::return_field_from_args`], /// DataFusion will not call `return_type` (this function). In such cases /// is recommended to return [`DataFusionError::Internal`]. /// @@ -520,14 +474,15 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// This function **must** consistently return the same type for the same /// logical input even if the input is simplified (e.g. it must return the same /// value for `('foo' | 'bar')` as it does for ('foobar'). - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let return_type = self.return_type(args.arg_types)?; - Ok(ReturnInfo::new_nullable(return_type)) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let data_types = args.arg_fields.iter().map(|f| f.data_type()).cloned().collect::>(); + let return_type = self.return_type(&data_types)?; + Ok(Field::new(self.name(), return_type, true)) } #[deprecated( since = "45.0.0", - note = "Use `return_type_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_type_from_args`, you might have error" + note = "Use `return_field_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_field_from_args`, you might have error" )] fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { true @@ -773,8 +728,8 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type(arg_types) } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - self.inner.return_type_from_args(args) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.inner.return_field_from_args(args) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 07faed7c0ec1..b233a52410ba 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -29,7 +29,7 @@ use arrow::{ use datafusion::{ error::DataFusionError, logical_expr::{ - type_coercion::functions::data_types_with_scalar_udf, ReturnInfo, ReturnTypeArgs, + type_coercion::functions::data_types_with_scalar_udf, }, }; use datafusion::{ @@ -38,12 +38,11 @@ use datafusion::{ ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, }, }; -use return_info::FFI_ReturnInfo; use return_type_args::{ - FFI_ReturnTypeArgs, ForeignReturnTypeArgs, ForeignReturnTypeArgsOwned, + FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned, }; use std::{ffi::c_void, sync::Arc}; - +use datafusion::logical_expr::ReturnFieldArgs; use crate::{ arrow_wrappers::{WrappedArray, WrappedSchema}, df_result, rresult, rresult_return, @@ -51,7 +50,6 @@ use crate::{ volatility::FFI_Volatility, }; -pub mod return_info; pub mod return_type_args; /// A stable struct for sharing a [`ScalarUDF`] across FFI boundaries. @@ -77,11 +75,11 @@ pub struct FFI_ScalarUDF { /// Determines the return info of the underlying [`ScalarUDF`]. Either this /// or return_type may be implemented on a UDF. - pub return_type_from_args: unsafe extern "C" fn( + pub return_field_from_args: unsafe extern "C" fn( udf: &Self, - args: FFI_ReturnTypeArgs, + args: FFI_ReturnFieldArgs, ) - -> RResult, + -> RResult, /// Execute the underlying [`ScalarUDF`] and return the result as a `FFI_ArrowArray` /// within an AbiStable wrapper. @@ -142,19 +140,20 @@ unsafe extern "C" fn return_type_fn_wrapper( rresult!(return_type) } -unsafe extern "C" fn return_type_from_args_fn_wrapper( +unsafe extern "C" fn return_field_from_args_fn_wrapper( udf: &FFI_ScalarUDF, - args: FFI_ReturnTypeArgs, -) -> RResult { + args: FFI_ReturnFieldArgs, +) -> RResult { let private_data = udf.private_data as *const ScalarUDFPrivateData; let udf = &(*private_data).udf; - let args: ForeignReturnTypeArgsOwned = rresult_return!((&args).try_into()); - let args_ref: ForeignReturnTypeArgs = (&args).into(); + let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into()); + let args_ref: ForeignReturnFieldArgs = (&args).into(); let return_type = udf - .return_type_from_args((&args_ref).into()) - .and_then(FFI_ReturnInfo::try_from); + .return_field_from_args((&args_ref).into()) + .and_then(|f| FFI_ArrowSchema::try_from(f).map_err(DataFusionError::from)) + .map(WrappedSchema); rresult!(return_type) } @@ -262,7 +261,7 @@ impl From> for FFI_ScalarUDF { short_circuits, invoke_with_args: invoke_with_args_fn_wrapper, return_type: return_type_fn_wrapper, - return_type_from_args: return_type_from_args_fn_wrapper, + return_field_from_args: return_field_from_args_fn_wrapper, coerce_types: coerce_types_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, @@ -335,14 +334,14 @@ impl ScalarUDFImpl for ForeignScalarUDF { result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let args: FFI_ReturnTypeArgs = args.try_into()?; + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let args: FFI_ReturnFieldArgs = args.try_into()?; - let result = unsafe { (self.udf.return_type_from_args)(&self.udf, args) }; + let result = unsafe { (self.udf.return_field_from_args)(&self.udf, args) }; let result = df_result!(result); - result.and_then(|r| r.try_into()) + result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) } fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result { diff --git a/datafusion/ffi/src/udf/return_info.rs b/datafusion/ffi/src/udf/return_info.rs deleted file mode 100644 index cf76ddd1db76..000000000000 --- a/datafusion/ffi/src/udf/return_info.rs +++ /dev/null @@ -1,53 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use abi_stable::StableAbi; -use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; -use datafusion::{error::DataFusionError, logical_expr::ReturnInfo}; - -use crate::arrow_wrappers::WrappedSchema; - -/// A stable struct for sharing a [`ReturnInfo`] across FFI boundaries. -#[repr(C)] -#[derive(Debug, StableAbi)] -#[allow(non_camel_case_types)] -pub struct FFI_ReturnInfo { - return_type: WrappedSchema, - nullable: bool, -} - -impl TryFrom for FFI_ReturnInfo { - type Error = DataFusionError; - - fn try_from(value: ReturnInfo) -> Result { - let return_type = WrappedSchema(FFI_ArrowSchema::try_from(value.return_type())?); - Ok(Self { - return_type, - nullable: value.nullable(), - }) - } -} - -impl TryFrom for ReturnInfo { - type Error = DataFusionError; - - fn try_from(value: FFI_ReturnInfo) -> Result { - let return_type = DataType::try_from(&value.return_type.0)?; - - Ok(ReturnInfo::new(return_type, value.nullable)) - } -} diff --git a/datafusion/ffi/src/udf/return_type_args.rs b/datafusion/ffi/src/udf/return_type_args.rs index a0897630e2ea..40e577591c34 100644 --- a/datafusion/ffi/src/udf/return_type_args.rs +++ b/datafusion/ffi/src/udf/return_type_args.rs @@ -19,33 +19,30 @@ use abi_stable::{ std_types::{ROption, RVec}, StableAbi, }; -use arrow::datatypes::DataType; +use arrow::datatypes::Field; use datafusion::{ - common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnTypeArgs, + common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnFieldArgs, scalar::ScalarValue, }; -use crate::{ - arrow_wrappers::WrappedSchema, - util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, -}; +use crate::arrow_wrappers::WrappedSchema; +use crate::util::{rvec_wrapped_to_vec_field, vec_field_to_rvec_wrapped}; use prost::Message; -/// A stable struct for sharing a [`ReturnTypeArgs`] across FFI boundaries. +/// A stable struct for sharing a [`ReturnFieldArgs`] across FFI boundaries. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] -pub struct FFI_ReturnTypeArgs { - arg_types: RVec, +pub struct FFI_ReturnFieldArgs { + arg_fields: RVec, scalar_arguments: RVec>>, - nullables: RVec, } -impl TryFrom> for FFI_ReturnTypeArgs { +impl TryFrom> for FFI_ReturnFieldArgs { type Error = DataFusionError; - fn try_from(value: ReturnTypeArgs) -> Result { - let arg_types = vec_datatype_to_rvec_wrapped(value.arg_types)?; + fn try_from(value: ReturnFieldArgs) -> Result { + let arg_fields = vec_field_to_rvec_wrapped(value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments .iter() @@ -62,35 +59,31 @@ impl TryFrom> for FFI_ReturnTypeArgs { .collect(); let scalar_arguments = scalar_arguments?.into_iter().map(ROption::from).collect(); - let nullables = value.nullables.into(); Ok(Self { - arg_types, + arg_fields, scalar_arguments, - nullables, }) } } // TODO(tsaucer) It would be good to find a better way around this, but it // appears a restriction based on the need to have a borrowed ScalarValue -// in the arguments when converted to ReturnTypeArgs -pub struct ForeignReturnTypeArgsOwned { - arg_types: Vec, +// in the arguments when converted to ReturnFieldArgs +pub struct ForeignReturnFieldArgsOwned { + arg_fields: Vec, scalar_arguments: Vec>, - nullables: Vec, } -pub struct ForeignReturnTypeArgs<'a> { - arg_types: &'a [DataType], +pub struct ForeignReturnFieldArgs<'a> { + arg_fields: &'a [Field], scalar_arguments: Vec>, - nullables: &'a [bool], } -impl TryFrom<&FFI_ReturnTypeArgs> for ForeignReturnTypeArgsOwned { +impl TryFrom<&FFI_ReturnFieldArgs> for ForeignReturnFieldArgsOwned { type Error = DataFusionError; - fn try_from(value: &FFI_ReturnTypeArgs) -> Result { - let arg_types = rvec_wrapped_to_vec_datatype(&value.arg_types)?; + fn try_from(value: &FFI_ReturnFieldArgs) -> Result { + let arg_fields = rvec_wrapped_to_vec_field(&value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments .iter() @@ -107,36 +100,31 @@ impl TryFrom<&FFI_ReturnTypeArgs> for ForeignReturnTypeArgsOwned { .collect(); let scalar_arguments = scalar_arguments?.into_iter().collect(); - let nullables = value.nullables.iter().cloned().collect(); - Ok(Self { - arg_types, + arg_fields, scalar_arguments, - nullables, }) } } -impl<'a> From<&'a ForeignReturnTypeArgsOwned> for ForeignReturnTypeArgs<'a> { - fn from(value: &'a ForeignReturnTypeArgsOwned) -> Self { +impl<'a> From<&'a ForeignReturnFieldArgsOwned> for ForeignReturnFieldArgs<'a> { + fn from(value: &'a ForeignReturnFieldArgsOwned) -> Self { Self { - arg_types: &value.arg_types, + arg_fields: &value.arg_fields, scalar_arguments: value .scalar_arguments .iter() .map(|opt| opt.as_ref()) .collect(), - nullables: &value.nullables, } } } -impl<'a> From<&'a ForeignReturnTypeArgs<'a>> for ReturnTypeArgs<'a> { - fn from(value: &'a ForeignReturnTypeArgs) -> Self { - ReturnTypeArgs { - arg_types: value.arg_types, +impl<'a> From<&'a ForeignReturnFieldArgs<'a>> for ReturnFieldArgs<'a> { + fn from(value: &'a ForeignReturnFieldArgs) -> Self { + ReturnFieldArgs { + arg_fields: value.arg_fields, scalar_arguments: &value.scalar_arguments, - nullables: value.nullables, } } } diff --git a/datafusion/ffi/src/util.rs b/datafusion/ffi/src/util.rs index 9d5f2aefe324..6992b2fd545f 100644 --- a/datafusion/ffi/src/util.rs +++ b/datafusion/ffi/src/util.rs @@ -17,7 +17,7 @@ use abi_stable::std_types::RVec; use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; - +use arrow::datatypes::Field; use crate::arrow_wrappers::WrappedSchema; /// This macro is a helpful conversion utility to conver from an abi_stable::RResult to a @@ -64,6 +64,31 @@ macro_rules! rresult_return { }; } +/// This is a utility function to convert a slice of [`Field`] to its equivalent +/// FFI friendly counterpart, [`WrappedSchema`] +pub fn vec_field_to_rvec_wrapped( + fields: &[Field], +) -> Result, arrow::error::ArrowError> { + Ok(fields + .iter() + .map(FFI_ArrowSchema::try_from) + .collect::, arrow::error::ArrowError>>()? + .into_iter() + .map(WrappedSchema) + .collect()) +} + +/// This is a utility function to convert an FFI friendly vector of [`WrappedSchema`] +/// to their equivalent [`Field`]. +pub fn rvec_wrapped_to_vec_field( + fields: &RVec, +) -> Result, arrow::error::ArrowError> { + fields + .iter() + .map(|d| Field::try_from(&d.0)) + .collect() +} + /// This is a utility function to convert a slice of [`DataType`] to its equivalent /// FFI friendly counterpart, [`WrappedSchema`] pub fn vec_datatype_to_rvec_wrapped( diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 2686dbf8be3c..71f45dece741 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -17,7 +17,7 @@ //! [`ArrowCastFunc`]: Implementation of the `arrow_cast` -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::error::ArrowError; use datafusion_common::{ arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue, @@ -28,10 +28,7 @@ use datafusion_common::{ use std::any::Any; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, - ScalarUDFImpl, Signature, Volatility, -}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; /// Implements casting to arbitrary arrow types (rather than SQL types) @@ -113,11 +110,11 @@ impl ScalarUDFImpl for ArrowCastFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let nullable = args.nullables.iter().any(|&nullable| nullable); + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?; @@ -131,7 +128,7 @@ impl ScalarUDFImpl for ArrowCastFunc { ) }, |casted_type| match casted_type.parse::() { - Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)), + Ok(data_type) => Ok(Field::new(self.name(), data_type, nullable)), Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), Err(e) => Err(arrow_datafusion_err!(e)), }, diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index ba20c23828eb..12b87b67f1b1 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -18,12 +18,10 @@ use arrow::array::{new_null_array, BooleanArray}; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::binary::try_type_union_resolution; -use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, -}; +use datafusion_expr::{ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use itertools::Itertools; @@ -79,19 +77,20 @@ impl ScalarUDFImpl for CoalesceFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // If any the arguments in coalesce is non-null, the result is non-null - let nullable = args.nullables.iter().all(|&nullable| nullable); + let nullable = args.arg_fields.iter().all(|f| f.is_nullable()); let return_type = args - .arg_types + .arg_fields .iter() + .map(|f| f.data_type()) .find_or_first(|d| !d.is_null()) .unwrap() .clone(); - Ok(ReturnInfo::new(return_type, nullable)) + Ok(Field::new(self.name(), return_type, nullable)) } /// coalesce evaluates to the first value which is not NULL diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 3ac26b98359b..b016249ff723 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -20,16 +20,14 @@ use arrow::array::{ Scalar, }; use arrow::compute::SortOptions; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow_buffer::NullBuffer; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result, ScalarValue, }; -use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, -}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; @@ -130,14 +128,14 @@ impl ScalarUDFImpl for GetFieldFunc { } fn return_type(&self, _: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // Length check handled in the signature debug_assert_eq!(args.scalar_arguments.len(), 2); - match (&args.arg_types[0], args.scalar_arguments[1].as_ref()) { + match (&args.arg_fields[0].data_type(), args.scalar_arguments[1].as_ref()) { (DataType::Map(fields, _), _) => { match fields.data_type() { DataType::Struct(fields) if fields.len() == 2 => { @@ -146,7 +144,7 @@ impl ScalarUDFImpl for GetFieldFunc { // instead, we assume that the second column is the "value" column both here and in // execution. let value_field = fields.get(1).expect("fields should have exactly two members"); - Ok(ReturnInfo::new_nullable(value_field.data_type().clone())) + Ok(Field::new(self.name(), value_field.data_type().clone(), true)) }, _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"), } @@ -158,10 +156,10 @@ impl ScalarUDFImpl for GetFieldFunc { |field_name| { fields.iter().find(|f| f.name() == field_name) .ok_or(plan_datafusion_err!("Field {field_name} not found in struct")) - .map(|f| ReturnInfo::new_nullable(f.data_type().to_owned())) + .map(|f| Field::new(self.name(), f.data_type().to_owned(), true)) }) }, - (DataType::Null, _) => Ok(ReturnInfo::new_nullable(DataType::Null)), + (DataType::Null, _) => Ok(Field::new(self.name(), DataType::Null, true)), (other, _) => exec_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), } } diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index bba884d96483..2aea9108a1eb 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -18,9 +18,7 @@ use arrow::array::StructArray; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{exec_err, internal_err, Result}; -use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, -}; +use datafusion_expr::{ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; @@ -91,10 +89,10 @@ impl ScalarUDFImpl for NamedStructFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("named_struct: return_type called instead of return_type_from_args") + internal_err!("named_struct: return_type called instead of return_field_from_args") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // do not accept 0 arguments. if args.scalar_arguments.is_empty() { return exec_err!( @@ -126,7 +124,7 @@ impl ScalarUDFImpl for NamedStructFunc { ) ) .collect::>>()?; - let types = args.arg_types.iter().skip(1).step_by(2).collect::>(); + let types = args.arg_fields.iter().skip(1).step_by(2).map(|f| f.data_type()).collect::>(); let return_fields = names .into_iter() @@ -134,9 +132,9 @@ impl ScalarUDFImpl for NamedStructFunc { .map(|(name, data_type)| Ok(Field::new(name, data_type.to_owned(), true))) .collect::>>()?; - Ok(ReturnInfo::new_nullable(DataType::Struct(Fields::from( + Ok(Field::new(self.name(), DataType::Struct(Fields::from( return_fields, - )))) + )), true)) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index d993f4536f4f..968d443a0a42 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -16,14 +16,14 @@ // under the License. use arrow::array::Array; -use arrow::datatypes::{DataType, FieldRef, UnionFields}; +use arrow::datatypes::{DataType, Field, FieldRef, UnionFields}; use datafusion_common::cast::as_union_array; use datafusion_common::utils::take_function_args; use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, Result, ScalarValue, }; use datafusion_doc::Documentation; -use datafusion_expr::{ColumnarValue, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs}; +use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -82,35 +82,35 @@ impl ScalarUDFImpl for UnionExtractFun { } fn return_type(&self, _: &[DataType]) -> Result { - // should be using return_type_from_args and not calling the default implementation + // should be using return_field_from_args and not calling the default implementation internal_err!("union_extract should return type from args") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - if args.arg_types.len() != 2 { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + if args.arg_fields.len() != 2 { return exec_err!( "union_extract expects 2 arguments, got {} instead", - args.arg_types.len() + args.arg_fields.len() ); } - let DataType::Union(fields, _) = &args.arg_types[0] else { + let DataType::Union(fields, _) = &args.arg_fields[0].data_type() else { return exec_err!( "union_extract first argument must be a union, got {} instead", - args.arg_types[0] + args.arg_fields[0].data_type() ); }; let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else { return exec_err!( "union_extract second argument must be a non-null string literal, got {} instead", - args.arg_types[1] + args.arg_fields[1].data_type() ); }; let field = find_field(fields, field_name)?.1; - Ok(ReturnInfo::new_nullable(field.data_type().clone())) + Ok(Field::new(self.name(), field.data_type().clone(), true)) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index bfd06b39d206..91f983e0acc3 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -26,7 +26,7 @@ use arrow::datatypes::DataType::{ Date32, Date64, Duration, Interval, Time32, Time64, Timestamp, }; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion_common::types::{logical_date, NativeType}; use datafusion_common::{ @@ -42,7 +42,7 @@ use datafusion_common::{ Result, ScalarValue, }; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; @@ -142,10 +142,10 @@ impl ScalarUDFImpl for DatePartFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { let [field, _] = take_function_args(self.name(), args.scalar_arguments)?; field @@ -155,9 +155,9 @@ impl ScalarUDFImpl for DatePartFunc { .filter(|s| !s.is_empty()) .map(|part| { if is_epoch(part) { - ReturnInfo::new_nullable(DataType::Float64) + Field::new(self.name(), DataType::Float64, true) } else { - ReturnInfo::new_nullable(DataType::Int32) + Field::new(self.name(), DataType::Int32, true) } }) }) diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 274ac437dd67..6ac5cee42450 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -18,15 +18,12 @@ use std::any::Any; use std::sync::Arc; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::datatypes::DataType::{Int64, Timestamp, Utf8}; use arrow::datatypes::TimeUnit::Second; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, - Volatility, -}; +use datafusion_expr::{ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; #[user_doc( @@ -82,12 +79,12 @@ impl ScalarUDFImpl for FromUnixtimeFunc { &self.signature } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // Length check handled in the signature debug_assert!(matches!(args.scalar_arguments.len(), 1 | 2)); if args.scalar_arguments.len() == 1 { - Ok(ReturnInfo::new_nullable(Timestamp(Second, None))) + Ok(Field::new(self.name(), Timestamp(Second, None), true)) } else { args.scalar_arguments[1] .and_then(|sv| { @@ -95,10 +92,10 @@ impl ScalarUDFImpl for FromUnixtimeFunc { .flatten() .filter(|s| !s.is_empty()) .map(|tz| { - ReturnInfo::new_nullable(Timestamp( + Field::new(self.name(), Timestamp( Second, Some(Arc::from(tz.to_string())), - )) + ), true) }) }) .map_or_else( @@ -114,7 +111,7 @@ impl ScalarUDFImpl for FromUnixtimeFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("call return_type_from_args instead") + internal_err!("call return_field_from_args instead") } fn invoke_with_args( diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index b26dc52cee4d..ec71df7c8d0f 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -15,17 +15,14 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; use std::any::Any; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, - Signature, Volatility, -}; +use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; #[user_doc( @@ -77,15 +74,15 @@ impl ScalarUDFImpl for NowFunc { &self.signature } - fn return_type_from_args(&self, _args: ReturnTypeArgs) -> Result { - Ok(ReturnInfo::new_non_nullable(Timestamp( + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new(self.name(), Timestamp( Nanosecond, Some("+00:00".into()), - ))) + ), false)) } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } fn invoke_with_args( diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index b3bc73a29585..35163b81bc93 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -22,7 +22,7 @@ use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType, }; -use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use arrow::datatypes::{ArrowNativeType, DataType, Field, Int32Type, Int64Type}; use datafusion_common::types::logical_string; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::{ @@ -88,15 +88,15 @@ impl ScalarUDFImpl for StrposFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be used instead") + internal_err!("return_field_from_args should be used instead") } - fn return_type_from_args( + fn return_field_from_args( &self, - args: datafusion_expr::ReturnTypeArgs, - ) -> Result { - utf8_to_int_type(&args.arg_types[0], "strpos/instr/position").map(|data_type| { - datafusion_expr::ReturnInfo::new(data_type, args.nullables.iter().any(|x| *x)) + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + utf8_to_int_type(args.arg_fields[0].data_type(), "strpos/instr/position").map(|data_type| { + Field::new(self.name(), data_type, args.arg_fields.iter().any(|x| x.is_nullable())) }) } @@ -228,7 +228,7 @@ mod tests { use arrow::array::{Array, Int32Array, Int64Array}; use arrow::datatypes::DataType::{Int32, Int64}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -321,13 +321,12 @@ mod tests { fn nullable_return_type() { fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool { let strpos = StrposFunc::new(); - let args = datafusion_expr::ReturnTypeArgs { - arg_types: &[DataType::Utf8, DataType::Utf8], - nullables: &[string_array_nullable, substring_nullable], + let args = datafusion_expr::ReturnFieldArgs { + arg_fields: &[Field::new("f1", DataType::Utf8, string_array_nullable), Field::new("f2", DataType::Utf8, substring_nullable)], scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>], }; - let (_, nullable) = strpos.return_type_from_args(args).unwrap().into_parts(); + let (_, nullable) = strpos.return_field_from_args(args).unwrap().into_parts(); nullable } diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index ebc3be99dbab..e505150d82a6 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -133,7 +133,7 @@ pub mod test { let expected: Result> = $EXPECTED; let func = $FUNC; - let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let type_array = $ARGS.iter().map(|arg| arg.field()).collect::>(); let cardinality = $ARGS .iter() .fold(Option::::None, |acc, arg| match arg { @@ -153,10 +153,9 @@ pub mod test { ColumnarValue::Array(a) => a.null_count() > 0, }).collect::>(); - let return_info = func.return_type_from_args(datafusion_expr::ReturnTypeArgs { - arg_types: &type_array, + let return_info = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { + arg_fields: &field_array, scalar_arguments: &scalar_arguments_refs, - nullables: &nullables }); match expected { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 12758a473fe9..a86442e55761 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -45,7 +45,7 @@ use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, + expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, }; /// Physical expression of a scalar function @@ -117,18 +117,24 @@ impl ScalarFunctionExpr { schema: &Schema, ) -> Result { let name = fun.name().to_string(); - let arg_types = args + let arg_fields = args .iter() - .map(|e| e.data_type(schema)) + .enumerate() + .map(|(idx, e)| { + e.output_field(schema).and_then(|maybe_field| { + Ok(maybe_field.unwrap_or({ + Field::new(format!("field_{idx}"), e.data_type(schema)?, true) + })) + }) + }) .collect::>>()?; // verify that input data types is consistent with function's `TypeSignature` - data_types_with_scalar_udf(&arg_types, &fun)?; - - let nullables = args + let arg_types = arg_fields .iter() - .map(|e| e.nullable(schema)) - .collect::>>()?; + .map(|f| f.data_type().clone()) + .collect::>(); + data_types_with_scalar_udf(&arg_types, &fun)?; let arguments = args .iter() @@ -138,18 +144,17 @@ impl ScalarFunctionExpr { .map(|literal| literal.value()) }) .collect::>(); - let ret_args = ReturnTypeArgs { - arg_types: &arg_types, + let ret_args = ReturnFieldArgs { + arg_fields: &arg_fields, scalar_arguments: &arguments, - nullables: &nullables, }; - let (return_type, nullable) = fun.return_type_from_args(ret_args)?.into_parts(); + let return_field = fun.return_field_from_args(ret_args)?; Ok(Self { fun, name, args, - return_type, - nullable, + return_type: return_field.data_type().clone(), + nullable: return_field.is_nullable(), metadata: HashMap::new(), }) } From 03ddfe78638b7f32046c0aa60ab23117677853df Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 15 Apr 2025 15:45:06 -0400 Subject: [PATCH 09/25] Updates to unit tests for switching to field instead of data_type --- datafusion/expr/src/expr_schema.rs | 7 +++--- datafusion/expr/src/lib.rs | 3 +-- datafusion/expr/src/udf.rs | 7 +++++- datafusion/ffi/src/udf/mod.rs | 18 +++++++------- datafusion/ffi/src/util.rs | 9 +++---- datafusion/functions/src/core/arrow_cast.rs | 5 +++- datafusion/functions/src/core/coalesce.rs | 4 +++- datafusion/functions/src/core/getfield.rs | 4 +++- datafusion/functions/src/core/named_struct.rs | 24 ++++++++++++++----- .../functions/src/datetime/from_unixtime.rs | 15 +++++++----- datafusion/functions/src/datetime/now.rs | 16 ++++++++----- datafusion/functions/src/unicode/strpos.rs | 21 ++++++++++------ datafusion/functions/src/utils.rs | 14 +++++++---- 13 files changed, 92 insertions(+), 55 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index b93389679e60..bd8dd7bd7b2c 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -441,11 +441,10 @@ impl ExprSchemable for Expr { ) ) })?; - let new_fields = fields.into_iter() + let new_fields = fields + .into_iter() .zip(new_data_types) - .map(|(f, d)| { - f.as_ref().clone().with_data_type(d) - }) + .map(|(f, d)| f.as_ref().clone().with_data_type(d)) .collect::>(); let arguments = args diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 258c28dc89c5..48931d6525af 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -104,8 +104,7 @@ pub use udaf::{ SetMonotonicity, StatisticsArgs, }; pub use udf::{ - scalar_doc_sections, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, + scalar_doc_sections, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, }; pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index b06a1b273216..8e225611d476 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -475,7 +475,12 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// logical input even if the input is simplified (e.g. it must return the same /// value for `('foo' | 'bar')` as it does for ('foobar'). fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { - let data_types = args.arg_fields.iter().map(|f| f.data_type()).cloned().collect::>(); + let data_types = args + .arg_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); let return_type = self.return_type(&data_types)?; Ok(Field::new(self.name(), return_type, true)) } diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index b233a52410ba..05e4ef8c7e94 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -15,6 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; use abi_stable::std_types::ROption; use abi_stable::{ std_types::{RResult, RString, RVec}, @@ -26,11 +32,10 @@ use arrow::{ error::ArrowError, ffi::{from_ffi, to_ffi, FFI_ArrowSchema}, }; +use datafusion::logical_expr::ReturnFieldArgs; use datafusion::{ error::DataFusionError, - logical_expr::{ - type_coercion::functions::data_types_with_scalar_udf, - }, + logical_expr::type_coercion::functions::data_types_with_scalar_udf, }; use datafusion::{ error::Result, @@ -42,13 +47,6 @@ use return_type_args::{ FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned, }; use std::{ffi::c_void, sync::Arc}; -use datafusion::logical_expr::ReturnFieldArgs; -use crate::{ - arrow_wrappers::{WrappedArray, WrappedSchema}, - df_result, rresult, rresult_return, - util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, - volatility::FFI_Volatility, -}; pub mod return_type_args; diff --git a/datafusion/ffi/src/util.rs b/datafusion/ffi/src/util.rs index 6992b2fd545f..6b9f373939ea 100644 --- a/datafusion/ffi/src/util.rs +++ b/datafusion/ffi/src/util.rs @@ -15,10 +15,10 @@ // specific language governing permissions and limitations // under the License. +use crate::arrow_wrappers::WrappedSchema; use abi_stable::std_types::RVec; -use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; use arrow::datatypes::Field; -use crate::arrow_wrappers::WrappedSchema; +use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; /// This macro is a helpful conversion utility to conver from an abi_stable::RResult to a /// DataFusion result. @@ -83,10 +83,7 @@ pub fn vec_field_to_rvec_wrapped( pub fn rvec_wrapped_to_vec_field( fields: &RVec, ) -> Result, arrow::error::ArrowError> { - fields - .iter() - .map(|d| Field::try_from(&d.0)) - .collect() + fields.iter().map(|d| Field::try_from(&d.0)).collect() } /// This is a utility function to convert a slice of [`DataType`] to its equivalent diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 71f45dece741..0e18ec180cef 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -28,7 +28,10 @@ use datafusion_common::{ use std::any::Any; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, + ScalarUDFImpl, Signature, Volatility, +}; use datafusion_macros::user_doc; /// Implements casting to arbitrary arrow types (rather than SQL types) diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index 12b87b67f1b1..b2ca3692c1d3 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -21,7 +21,9 @@ use arrow::compute::{and, is_not_null, is_null}; use arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::binary::try_type_union_resolution; -use datafusion_expr::{ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs}; +use datafusion_expr::{ + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, +}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use itertools::Itertools; diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index b016249ff723..7989317efb2b 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -27,7 +27,9 @@ use datafusion_common::{ exec_err, internal_err, plan_datafusion_err, utils::take_function_args, Result, ScalarValue, }; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, +}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 2aea9108a1eb..97af6df7bc6c 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -18,7 +18,9 @@ use arrow::array::StructArray; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{exec_err, internal_err, Result}; -use datafusion_expr::{ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs}; +use datafusion_expr::{ + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, +}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; use std::any::Any; @@ -89,7 +91,9 @@ impl ScalarUDFImpl for NamedStructFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("named_struct: return_type called instead of return_field_from_args") + internal_err!( + "named_struct: return_type called instead of return_field_from_args" + ) } fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { @@ -124,7 +128,13 @@ impl ScalarUDFImpl for NamedStructFunc { ) ) .collect::>>()?; - let types = args.arg_fields.iter().skip(1).step_by(2).map(|f| f.data_type()).collect::>(); + let types = args + .arg_fields + .iter() + .skip(1) + .step_by(2) + .map(|f| f.data_type()) + .collect::>(); let return_fields = names .into_iter() @@ -132,9 +142,11 @@ impl ScalarUDFImpl for NamedStructFunc { .map(|(name, data_type)| Ok(Field::new(name, data_type.to_owned(), true))) .collect::>>()?; - Ok(Field::new(self.name(), DataType::Struct(Fields::from( - return_fields, - )), true)) + Ok(Field::new( + self.name(), + DataType::Struct(Fields::from(return_fields)), + true, + )) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 6ac5cee42450..3db0500bfeb2 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -18,12 +18,14 @@ use std::any::Any; use std::sync::Arc; -use arrow::datatypes::{DataType, Field}; use arrow::datatypes::DataType::{Int64, Timestamp, Utf8}; use arrow::datatypes::TimeUnit::Second; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; -use datafusion_expr::{ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility, +}; use datafusion_macros::user_doc; #[user_doc( @@ -92,10 +94,11 @@ impl ScalarUDFImpl for FromUnixtimeFunc { .flatten() .filter(|s| !s.is_empty()) .map(|tz| { - Field::new(self.name(), Timestamp( - Second, - Some(Arc::from(tz.to_string())), - ), true) + Field::new( + self.name(), + Timestamp(Second, Some(Arc::from(tz.to_string()))), + true, + ) }) }) .map_or_else( diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index ec71df7c8d0f..867442df45ad 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -15,14 +15,17 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field}; use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; +use arrow::datatypes::{DataType, Field}; use std::any::Any; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; -use datafusion_expr::{ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility}; +use datafusion_expr::{ + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDFImpl, Signature, + Volatility, +}; use datafusion_macros::user_doc; #[user_doc( @@ -75,10 +78,11 @@ impl ScalarUDFImpl for NowFunc { } fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { - Ok(Field::new(self.name(), Timestamp( - Nanosecond, - Some("+00:00".into()), - ), false)) + Ok(Field::new( + self.name(), + Timestamp(Nanosecond, Some("+00:00".into())), + false, + )) } fn return_type(&self, _arg_types: &[DataType]) -> Result { diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index 35163b81bc93..b33a1ca7713a 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -95,9 +95,15 @@ impl ScalarUDFImpl for StrposFunc { &self, args: datafusion_expr::ReturnFieldArgs, ) -> Result { - utf8_to_int_type(args.arg_fields[0].data_type(), "strpos/instr/position").map(|data_type| { - Field::new(self.name(), data_type, args.arg_fields.iter().any(|x| x.is_nullable())) - }) + utf8_to_int_type(args.arg_fields[0].data_type(), "strpos/instr/position").map( + |data_type| { + Field::new( + self.name(), + data_type, + args.arg_fields.iter().any(|x| x.is_nullable()), + ) + }, + ) } fn invoke_with_args( @@ -322,13 +328,14 @@ mod tests { fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool { let strpos = StrposFunc::new(); let args = datafusion_expr::ReturnFieldArgs { - arg_fields: &[Field::new("f1", DataType::Utf8, string_array_nullable), Field::new("f2", DataType::Utf8, substring_nullable)], + arg_fields: &[ + Field::new("f1", DataType::Utf8, string_array_nullable), + Field::new("f2", DataType::Utf8, substring_nullable), + ], scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>], }; - let (_, nullable) = strpos.return_field_from_args(args).unwrap().into_parts(); - - nullable + strpos.return_field_from_args(args).unwrap().is_nullable() } assert!(!get_nullable(false, false)); diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index e505150d82a6..a967fdf7c2ce 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -133,7 +133,7 @@ pub mod test { let expected: Result> = $EXPECTED; let func = $FUNC; - let type_array = $ARGS.iter().map(|arg| arg.field()).collect::>(); + let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); let cardinality = $ARGS .iter() .fold(Option::::None, |acc, arg| match arg { @@ -153,6 +153,10 @@ pub mod test { ColumnarValue::Array(a) => a.null_count() > 0, }).collect::>(); + let field_array = data_array.into_iter().zip(nullables).enumerate() + .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable)) + .collect::>(); + let return_info = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &field_array, scalar_arguments: &scalar_arguments_refs, @@ -161,8 +165,9 @@ pub mod test { match expected { Ok(expected) => { assert_eq!(return_info.is_ok(), true); - let (return_type, _nullable) = return_info.unwrap().into_parts(); - assert_eq!(return_type, $EXPECTED_DATA_TYPE); + let return_info = return_info.unwrap(); + let return_type = return_info.data_type(); + assert_eq!(return_type, &$EXPECTED_DATA_TYPE); let arg_fields = vec![None; $ARGS.len()]; let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_type: &return_type}); @@ -186,7 +191,8 @@ pub mod test { } } else { - let (return_type, _nullable) = return_info.unwrap().into_parts(); + let return_info = return_info.unwrap(); + let return_type = return_info.data_type(); let arg_fields = vec![None; $ARGS.len()]; // invoke is expected error - cannot use .expect_err() due to Debug not being implemented From 58933dfa48ff21fa90d4d3046f92810e741fbd11 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Tue, 15 Apr 2025 17:08:04 -0400 Subject: [PATCH 10/25] Resolve unit test issues --- datafusion/core/tests/tpc-ds/49.sql | 2 +- datafusion/physical-expr/src/expressions/cast.rs | 5 ++++- datafusion/physical-expr/src/scalar_function.rs | 6 +++++- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/datafusion/core/tests/tpc-ds/49.sql b/datafusion/core/tests/tpc-ds/49.sql index 090e9746c0d8..219877719f22 100644 --- a/datafusion/core/tests/tpc-ds/49.sql +++ b/datafusion/core/tests/tpc-ds/49.sql @@ -110,7 +110,7 @@ select channel, item, return_ratio, return_rank, currency_rank from where sr.sr_return_amt > 10000 and sts.ss_net_profit > 1 - and sts.ss_net_paid > 0 + and sts.ss_net_paid > 0 and sts.ss_quantity > 0 and ss_sold_date_sk = d_date_sk and d_year = 2000 diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 57f2d61aa968..8183a36ca52a 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -145,7 +145,10 @@ impl PhysicalExpr for CastExpr { } fn output_field(&self, input_schema: &Schema) -> Result> { - self.expr.output_field(input_schema) + Ok(self + .expr + .output_field(input_schema)? + .map(|f| f.with_data_type(self.cast_type.clone()))) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index a86442e55761..bbaa9568120c 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -123,7 +123,11 @@ impl ScalarFunctionExpr { .map(|(idx, e)| { e.output_field(schema).and_then(|maybe_field| { Ok(maybe_field.unwrap_or({ - Field::new(format!("field_{idx}"), e.data_type(schema)?, true) + Field::new( + format!("field_{idx}"), + e.data_type(schema)?, + e.nullable(schema)?, + ) })) }) }) From 6924e4e111b3e744118bd84ced76cfd699a2ba63 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 16 Apr 2025 07:57:11 -0400 Subject: [PATCH 11/25] Update after rebase on main --- datafusion/physical-expr/src/expressions/dynamic_filters.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters.rs b/datafusion/physical-expr/src/expressions/dynamic_filters.rs index c0a3285f0e78..2cf17eb72a86 100644 --- a/datafusion/physical-expr/src/expressions/dynamic_filters.rs +++ b/datafusion/physical-expr/src/expressions/dynamic_filters.rs @@ -23,7 +23,7 @@ use std::{ }; use crate::PhysicalExpr; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode}, Result, @@ -291,6 +291,10 @@ impl PhysicalExpr for DynamicFilterPhysicalExpr { // Return the current expression as a snapshot. Ok(Some(self.current()?)) } + + fn output_field(&self, _input_schema: &Schema) -> Result> { + Ok(None) + } } #[cfg(test)] From 68f435668c571dfab70e70d269a557dbbb9c6c67 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Wed, 16 Apr 2025 08:03:44 -0400 Subject: [PATCH 12/25] GetFieldFunc should return the field it finds instead of creating a new one --- datafusion/functions/src/core/getfield.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 7989317efb2b..a673a07df8d9 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -146,7 +146,7 @@ impl ScalarUDFImpl for GetFieldFunc { // instead, we assume that the second column is the "value" column both here and in // execution. let value_field = fields.get(1).expect("fields should have exactly two members"); - Ok(Field::new(self.name(), value_field.data_type().clone(), true)) + Ok(value_field.as_ref().clone()) }, _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"), } @@ -158,7 +158,7 @@ impl ScalarUDFImpl for GetFieldFunc { |field_name| { fields.iter().find(|f| f.name() == field_name) .ok_or(plan_datafusion_err!("Field {field_name} not found in struct")) - .map(|f| Field::new(self.name(), f.data_type().to_owned(), true)) + .map(|f| f.as_ref().clone()) }) }, (DataType::Null, _) => Ok(Field::new(self.name(), DataType::Null, true)), From d6af7e3bed4ca47b10abc613cf4b9219937797f0 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 17 Apr 2025 13:34:29 -0400 Subject: [PATCH 13/25] Get metadata from scalar functions --- datafusion/expr/src/expr_schema.rs | 26 +++++++++++++++++++++++ datafusion/functions/src/core/getfield.rs | 3 ++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index bd8dd7bd7b2c..90ef5a3d01f6 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -354,6 +354,32 @@ impl ExprSchemable for Expr { Ok(ret) } Expr::Cast(Cast { expr, .. }) => expr.metadata(schema), + Expr::ScalarFunction(func) => { + let arg_fields = func + .args + .iter() + .map(|e| { + e.to_field(schema).map(|(_, f)| f.as_ref().clone()).or({ + let (data_type, nullable) = + e.data_type_and_nullable(schema)?; + Ok(Field::new("arg", data_type, nullable)) + }) + }) + .collect::>>()?; + let scalar_arguments = func + .args + .iter() + .map(|e| match e { + Expr::Literal(sv) => Some(sv), + _ => None, + }) + .collect::>(); + let output_field = func.func.return_field_from_args(ReturnFieldArgs { + arg_fields: &arg_fields, + scalar_arguments: &scalar_arguments, + })?; + Ok(output_field.metadata().clone()) + } _ => Ok(HashMap::new()), } } diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index a673a07df8d9..913767f3b835 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -158,7 +158,8 @@ impl ScalarUDFImpl for GetFieldFunc { |field_name| { fields.iter().find(|f| f.name() == field_name) .ok_or(plan_datafusion_err!("Field {field_name} not found in struct")) - .map(|f| f.as_ref().clone()) + .map(|f| { + f.as_ref().clone()}) }) }, (DataType::Null, _) => Ok(Field::new(self.name(), DataType::Null, true)), From 07b7ec862dc9c57bfb38c8874cd131be6d82de86 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Thu, 17 Apr 2025 16:20:30 -0400 Subject: [PATCH 14/25] Change expr_schema to use to_field primarily instead of individual calls for getting data type, nullability, and schema --- datafusion/common/src/dfschema.rs | 43 ++--- datafusion/expr/src/expr_schema.rs | 206 +++++++++++----------- datafusion/functions/src/core/getfield.rs | 14 +- 3 files changed, 140 insertions(+), 123 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 66a26a18c0dc..da670eb93167 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -472,7 +472,7 @@ impl DFSchema { let matches = self.qualified_fields_with_unqualified_name(name); match matches.len() { 0 => Err(unqualified_field_not_found(name, self)), - 1 => Ok((matches[0].0, (matches[0].1))), + 1 => Ok((matches[0].0, matches[0].1)), _ => { // When `matches` size > 1, it doesn't necessarily mean an `ambiguous name` problem. // Because name may generate from Alias/... . It means that it don't own qualifier. @@ -969,16 +969,28 @@ impl Display for DFSchema { /// widely used in the DataFusion codebase. pub trait ExprSchema: std::fmt::Debug { /// Is this column reference nullable? - fn nullable(&self, col: &Column) -> Result; + fn nullable(&self, col: &Column) -> Result { + Ok(self.to_field(col)?.is_nullable()) + } /// What is the datatype of this column? - fn data_type(&self, col: &Column) -> Result<&DataType>; + fn data_type(&self, col: &Column) -> Result<&DataType> { + Ok(self.to_field(col)?.data_type()) + } /// Returns the column's optional metadata. - fn metadata(&self, col: &Column) -> Result<&HashMap>; + fn metadata(&self, col: &Column) -> Result<&HashMap> { + Ok(self.to_field(col)?.metadata()) + } /// Return the column's datatype and nullability - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)>; + fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { + let field = self.to_field(col)?; + Ok((field.data_type(), field.is_nullable())) + } + + // Return the column's field + fn to_field(&self, col: &Column) -> Result<&Field>; } // Implement `ExprSchema` for `Arc` @@ -998,24 +1010,15 @@ impl + std::fmt::Debug> ExprSchema for P { fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { self.as_ref().data_type_and_nullable(col) } -} -impl ExprSchema for DFSchema { - fn nullable(&self, col: &Column) -> Result { - Ok(self.field_from_column(col)?.is_nullable()) - } - - fn data_type(&self, col: &Column) -> Result<&DataType> { - Ok(self.field_from_column(col)?.data_type()) - } - - fn metadata(&self, col: &Column) -> Result<&HashMap> { - Ok(self.field_from_column(col)?.metadata()) + fn to_field(&self, col: &Column) -> Result<&Field> { + self.as_ref().to_field(col) } +} - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { - let field = self.field_from_column(col)?; - Ok((field.data_type(), field.is_nullable())) +impl ExprSchema for DFSchema { + fn to_field(&self, col: &Column) -> Result<&Field> { + self.field_from_column(col) } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 90ef5a3d01f6..80eb91323c01 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -341,47 +341,8 @@ impl ExprSchemable for Expr { } fn metadata(&self, schema: &dyn ExprSchema) -> Result> { - match self { - Expr::Column(c) => Ok(schema.metadata(c)?.clone()), - Expr::Alias(Alias { expr, metadata, .. }) => { - let mut ret = expr.metadata(schema)?; - if let Some(metadata) = metadata { - if !metadata.is_empty() { - ret.extend(metadata.clone()); - return Ok(ret); - } - } - Ok(ret) - } - Expr::Cast(Cast { expr, .. }) => expr.metadata(schema), - Expr::ScalarFunction(func) => { - let arg_fields = func - .args - .iter() - .map(|e| { - e.to_field(schema).map(|(_, f)| f.as_ref().clone()).or({ - let (data_type, nullable) = - e.data_type_and_nullable(schema)?; - Ok(Field::new("arg", data_type, nullable)) - }) - }) - .collect::>>()?; - let scalar_arguments = func - .args - .iter() - .map(|e| match e { - Expr::Literal(sv) => Some(sv), - _ => None, - }) - .collect::>(); - let output_field = func.func.return_field_from_args(ReturnFieldArgs { - arg_fields: &arg_fields, - scalar_arguments: &scalar_arguments, - })?; - Ok(output_field.metadata().clone()) - } - _ => Ok(HashMap::new()), - } + self.to_field(schema) + .map(|(_, field)| field.metadata().clone()) } /// Returns the datatype and nullability of the expression based on [ExprSchema]. @@ -398,23 +359,62 @@ impl ExprSchemable for Expr { &self, schema: &dyn ExprSchema, ) -> Result<(DataType, bool)> { - match self { - Expr::Alias(Alias { expr, name, .. }) => match &**expr { - Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { - None => schema - .data_type_and_nullable(&Column::from_name(name)) - .map(|(d, n)| (d.clone(), n)), - Some(dt) => Ok((dt.clone(), expr.nullable(schema)?)), - }, - _ => expr.data_type_and_nullable(schema), - }, - Expr::Negative(expr) => expr.data_type_and_nullable(schema), - Expr::Column(c) => schema - .data_type_and_nullable(c) - .map(|(d, n)| (d.clone(), n)), - Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)), - Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)), - Expr::Literal(l) => Ok((l.data_type(), l.is_null())), + let field = self.to_field(schema)?.1; + + Ok((field.data_type().clone(), field.is_nullable())) + } + + /// Returns a [arrow::datatypes::Field] compatible with this expression. + /// + /// So for example, a projected expression `col(c1) + col(c2)` is + /// placed in an output field **named** col("c1 + c2") + fn to_field( + &self, + schema: &dyn ExprSchema, + ) -> Result<(Option, Arc)> { + let (relation, schema_name) = self.qualified_name(); + #[allow(deprecated)] + let field = match self { + Expr::Alias(Alias { + expr, + name, + metadata, + .. + }) => { + let field = match &**expr { + Expr::Placeholder(Placeholder { data_type, .. }) => { + match &data_type { + None => schema + .data_type_and_nullable(&Column::from_name(name)) + .map(|(d, n)| Field::new(&schema_name, d.clone(), n)), + Some(dt) => Ok(Field::new( + &schema_name, + dt.clone(), + expr.nullable(schema)?, + )), + } + } + _ => expr.to_field(schema).map(|(_, f)| f.as_ref().clone()), + }?; + + let mut combined_metadata = expr.metadata(schema)?; + if let Some(metadata) = metadata { + if !metadata.is_empty() { + combined_metadata.extend(metadata.clone()); + } + } + + Ok(field.with_metadata(combined_metadata)) + } + Expr::Negative(expr) => { + expr.to_field(schema).map(|(_, f)| f.as_ref().clone()) + } + Expr::Column(c) => schema.to_field(c).cloned(), + Expr::OuterReferenceColumn(ty, _) => { + Ok(Field::new(&schema_name, ty.clone(), true)) + } + Expr::ScalarVariable(ty, _) => Ok(Field::new(&schema_name, ty.clone(), true)), + Expr::Literal(l) => Ok(Field::new(&schema_name, l.data_type(), l.is_null())), Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) @@ -423,8 +423,11 @@ impl ExprSchemable for Expr { | Expr::IsNotTrue(_) | Expr::IsNotFalse(_) | Expr::IsNotUnknown(_) - | Expr::Exists { .. } => Ok((DataType::Boolean, false)), - Expr::ScalarSubquery(subquery) => Ok(( + | Expr::Exists { .. } => { + Ok(Field::new(&schema_name, DataType::Boolean, false)) + } + Expr::ScalarSubquery(subquery) => Ok(Field::new( + &schema_name, subquery.subquery.schema().field(0).data_type().clone(), subquery.subquery.schema().field(0).is_nullable(), )), @@ -438,10 +441,18 @@ impl ExprSchemable for Expr { let mut coercer = BinaryTypeCoercer::new(&lhs_type, op, &rhs_type); coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default()); coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default()); - Ok((coercer.get_result_type()?, lhs_nullable || rhs_nullable)) + Ok(Field::new( + &schema_name, + coercer.get_result_type()?, + lhs_nullable || rhs_nullable, + )) } Expr::WindowFunction(window_function) => { - self.data_type_and_nullable_with_window_function(schema, window_function) + let (dt, nullable) = self.data_type_and_nullable_with_window_function( + schema, + window_function, + )?; + Ok(Field::new(&schema_name, dt, nullable)) } Expr::ScalarFunction(ScalarFunction { func, args }) => { let (arg_types, fields): (Vec, Vec>) = args @@ -485,27 +496,32 @@ impl ExprSchemable for Expr { scalar_arguments: &arguments, }; - let return_field = func.return_field_from_args(args)?; - Ok((return_field.data_type().clone(), return_field.is_nullable())) + func.return_field_from_args(args) } - _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), - } - } + // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), + Expr::Cast(Cast { expr, data_type }) => expr + .to_field(schema) + .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone())), + Expr::Like(_) + | Expr::SimilarTo(_) + | Expr::Not(_) + | Expr::Between(_) + | Expr::Case(_) + | Expr::TryCast(_) + | Expr::AggregateFunction(_) + | Expr::InList(_) + | Expr::InSubquery(_) + | Expr::Wildcard { .. } + | Expr::GroupingSet(_) + | Expr::Placeholder(_) + | Expr::Unnest(_) => Ok(Field::new( + &schema_name, + self.get_type(schema)?, + self.nullable(schema)?, + )), + }?; - /// Returns a [arrow::datatypes::Field] compatible with this expression. - /// - /// So for example, a projected expression `col(c1) + col(c2)` is - /// placed in an output field **named** col("c1 + c2") - fn to_field( - &self, - input_schema: &dyn ExprSchema, - ) -> Result<(Option, Arc)> { - let (relation, schema_name) = self.qualified_name(); - let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; - let field = Field::new(schema_name, data_type, nullable) - .with_metadata(self.metadata(input_schema)?) - .into(); - Ok((relation, field)) + Ok((relation, Arc::new(field.with_name(schema_name)))) } /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. @@ -792,29 +808,25 @@ mod tests { #[derive(Debug)] struct MockExprSchema { - nullable: bool, - data_type: DataType, + field: Field, error_on_nullable: bool, - metadata: HashMap, } impl MockExprSchema { fn new() -> Self { Self { - nullable: false, - data_type: DataType::Null, + field: Field::new("mock_field", DataType::Null, false), error_on_nullable: false, - metadata: HashMap::new(), } } fn with_nullable(mut self, nullable: bool) -> Self { - self.nullable = nullable; + self.field = self.field.with_nullable(nullable); self } fn with_data_type(mut self, data_type: DataType) -> Self { - self.data_type = data_type; + self.field = self.field.with_data_type(data_type); self } @@ -824,7 +836,7 @@ mod tests { } fn with_metadata(mut self, metadata: HashMap) -> Self { - self.metadata = metadata; + self.field = self.field.with_metadata(metadata); self } } @@ -834,20 +846,12 @@ mod tests { if self.error_on_nullable { internal_err!("nullable error") } else { - Ok(self.nullable) + Ok(self.field.is_nullable()) } } - fn data_type(&self, _col: &Column) -> Result<&DataType> { - Ok(&self.data_type) - } - - fn metadata(&self, _col: &Column) -> Result<&HashMap> { - Ok(&self.metadata) - } - - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { - Ok((self.data_type(col)?, self.nullable(col)?)) + fn to_field(&self, _col: &Column) -> Result<&Field> { + Ok(&self.field) } } } diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 913767f3b835..97df76eaac58 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -146,7 +146,8 @@ impl ScalarUDFImpl for GetFieldFunc { // instead, we assume that the second column is the "value" column both here and in // execution. let value_field = fields.get(1).expect("fields should have exactly two members"); - Ok(value_field.as_ref().clone()) + + Ok(value_field.as_ref().clone().with_nullable(true)) }, _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"), } @@ -159,7 +160,16 @@ impl ScalarUDFImpl for GetFieldFunc { fields.iter().find(|f| f.name() == field_name) .ok_or(plan_datafusion_err!("Field {field_name} not found in struct")) .map(|f| { - f.as_ref().clone()}) + let mut child_field = f.as_ref().clone(); + + // If the parent is nullable, then getting the child must be nullable, + // so potentially override the return value + + if args.arg_fields[0].is_nullable() { + child_field = child_field.with_nullable(true); + } + child_field + }) }) }, (DataType::Null, _) => Ok(Field::new(self.name(), DataType::Null, true)), From caad0217122633daa38915db568f7288ac88b900 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 18 Apr 2025 08:42:58 -0400 Subject: [PATCH 15/25] Scalar function arguments should take return field instead of return data type now --- .../physical_optimizer/projection_pushdown.rs | 13 ++-- datafusion/expr/src/udf.rs | 2 +- datafusion/ffi/src/udf/mod.rs | 14 ++-- datafusion/functions-nested/benches/map.rs | 5 +- .../functions/benches/character_length.rs | 12 ++-- datafusion/functions/benches/chr.rs | 4 +- datafusion/functions/benches/concat.rs | 4 +- datafusion/functions/benches/cot.rs | 6 +- datafusion/functions/benches/date_bin.rs | 4 +- datafusion/functions/benches/date_trunc.rs | 6 +- datafusion/functions/benches/encoding.rs | 10 +-- datafusion/functions/benches/find_in_set.rs | 10 +-- datafusion/functions/benches/gcd.rs | 7 +- datafusion/functions/benches/initcap.rs | 8 +-- datafusion/functions/benches/isnan.rs | 6 +- datafusion/functions/benches/iszero.rs | 6 +- datafusion/functions/benches/lower.rs | 14 ++-- datafusion/functions/benches/ltrim.rs | 4 +- datafusion/functions/benches/make_date.rs | 10 +-- datafusion/functions/benches/nullif.rs | 4 +- datafusion/functions/benches/pad.rs | 14 ++-- datafusion/functions/benches/random.rs | 6 +- datafusion/functions/benches/repeat.rs | 16 ++--- datafusion/functions/benches/reverse.rs | 10 +-- datafusion/functions/benches/signum.rs | 6 +- datafusion/functions/benches/strpos.rs | 10 +-- datafusion/functions/benches/substr.rs | 20 +++--- datafusion/functions/benches/substr_index.rs | 4 +- datafusion/functions/benches/to_char.rs | 8 +-- datafusion/functions/benches/to_hex.rs | 6 +- datafusion/functions/benches/to_timestamp.rs | 17 ++--- datafusion/functions/benches/trunc.rs | 6 +- datafusion/functions/benches/upper.rs | 4 +- datafusion/functions/benches/uuid.rs | 4 +- datafusion/functions/src/core/named_struct.rs | 2 +- datafusion/functions/src/core/struct.rs | 2 +- .../functions/src/core/union_extract.rs | 6 +- datafusion/functions/src/core/version.rs | 3 +- datafusion/functions/src/datetime/date_bin.rs | 40 ++++++----- .../functions/src/datetime/date_trunc.rs | 14 +++- .../functions/src/datetime/from_unixtime.rs | 11 +-- .../functions/src/datetime/make_date.rs | 18 ++--- datafusion/functions/src/datetime/to_char.rs | 14 ++-- datafusion/functions/src/datetime/to_date.rs | 18 ++--- .../functions/src/datetime/to_local_time.rs | 10 ++- .../functions/src/datetime/to_timestamp.rs | 6 +- datafusion/functions/src/math/log.rs | 21 +++--- datafusion/functions/src/math/power.rs | 5 +- datafusion/functions/src/math/signum.rs | 6 +- datafusion/functions/src/regex/regexpcount.rs | 25 +++---- datafusion/functions/src/string/concat.rs | 3 +- datafusion/functions/src/string/concat_ws.rs | 8 +-- datafusion/functions/src/string/contains.rs | 4 +- datafusion/functions/src/string/lower.rs | 4 +- datafusion/functions/src/string/upper.rs | 4 +- .../functions/src/unicode/find_in_set.rs | 4 +- datafusion/functions/src/utils.rs | 19 +++-- .../src/equivalence/properties/dependency.rs | 7 +- .../physical-expr/src/scalar_function.rs | 69 +++++-------------- .../proto/src/physical_plan/from_proto.rs | 4 +- .../tests/cases/roundtrip_physical_plan.rs | 10 +-- 61 files changed, 302 insertions(+), 315 deletions(-) diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index f018a75f657f..e00a44188e57 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -16,7 +16,6 @@ // under the License. use std::any::Any; -use std::collections::HashMap; use std::sync::Arc; use arrow::compute::SortOptions; @@ -129,8 +128,7 @@ fn test_update_matching_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, - HashMap::default(), + Field::new("f", DataType::Int32, true), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -195,8 +193,7 @@ fn test_update_matching_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, - HashMap::default(), + Field::new("f", DataType::Int32, true), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 3))), @@ -264,8 +261,7 @@ fn test_update_projected_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, - HashMap::default(), + Field::new("f", DataType::Int32, true), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -330,8 +326,7 @@ fn test_update_projected_exprs() -> Result<()> { Arc::new(Column::new("b_new", 1)), )), ], - DataType::Int32, - HashMap::default(), + Field::new("f", DataType::Int32, true), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d_new", 3))), diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 8e225611d476..cb22cd2a07cc 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -302,7 +302,7 @@ pub struct ScalarFunctionArgs<'a, 'b> { pub number_rows: usize, /// The return type of the scalar function returned (from `return_type` or `return_field_from_args`) /// when creating the physical expression from the logical expression - pub return_type: &'a DataType, + pub return_field: &'a Field, } /// Information about arguments passed to the function diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 05e4ef8c7e94..95441df77c13 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -87,7 +87,7 @@ pub struct FFI_ScalarUDF { args: RVec, arg_fields: RVec>, num_rows: usize, - return_type: WrappedSchema, + return_field: WrappedSchema, ) -> RResult, /// See [`ScalarUDFImpl`] for details on short_circuits @@ -175,7 +175,7 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( args: RVec, arg_fields: RVec>, number_rows: usize, - return_type: WrappedSchema, + return_field: WrappedSchema, ) -> RResult { let private_data = udf.private_data as *const ScalarUDFPrivateData; let udf = &(*private_data).udf; @@ -189,7 +189,7 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( .collect::>(); let args = rresult_return!(args); - let return_type = rresult_return!(DataType::try_from(&return_type.0)); + let return_field = rresult_return!(Field::try_from(&return_field.0)); let arg_fields_owned = arg_fields .into_iter() @@ -210,7 +210,7 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( args, arg_fields, number_rows, - return_type: &return_type, + return_field: &return_field, }; let result = rresult_return!(udf @@ -347,7 +347,7 @@ impl ScalarUDFImpl for ForeignScalarUDF { args, arg_fields, number_rows, - return_type, + return_field, } = invoke_args; let args = args @@ -374,7 +374,7 @@ impl ScalarUDFImpl for ForeignScalarUDF { .map(|maybe_field| maybe_field.map(WrappedSchema).into()) .collect::>(); - let return_type = WrappedSchema(FFI_ArrowSchema::try_from(return_type)?); + let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?); let result = unsafe { (self.udf.invoke_with_args)( @@ -382,7 +382,7 @@ impl ScalarUDFImpl for ForeignScalarUDF { args, arg_fields, number_rows, - return_type, + return_field, ) }; diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 5813dd4109b3..e24965941b81 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -94,9 +94,10 @@ fn criterion_benchmark(c: &mut Criterion) { let keys = ColumnarValue::Scalar(ScalarValue::List(Arc::new(key_list))); let values = ColumnarValue::Scalar(ScalarValue::List(Arc::new(value_list))); - let return_type = &map_udf() + let return_type = map_udf() .return_type(&[DataType::Utf8, DataType::Int32]) .expect("should get return type"); + let return_field = &Field::new("f", return_type, true); b.iter(|| { black_box( @@ -105,7 +106,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![keys.clone(), values.clone()], arg_fields: vec![None; 2], number_rows: 1, - return_type, + return_field, }) .expect("map should work on valid values"), ); diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index 26d9aa93d92f..4ab741874a08 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ScalarFunctionArgs; use helper::gen_string_array; @@ -28,7 +28,7 @@ fn criterion_benchmark(c: &mut Criterion) { // All benches are single batch run with 8192 rows let character_length = datafusion_functions::unicode::character_length(); - let return_type = DataType::Utf8; + let return_field = Field::new("f", DataType::Utf8, true); let n_rows = 8192; for str_len in [8, 32, 128, 4096] { @@ -42,7 +42,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_ascii.clone(), arg_fields: vec![None; args_string_ascii.len()], number_rows: n_rows, - return_type: &return_type, + return_field: &return_field, })) }) }, @@ -58,7 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_utf8.clone(), arg_fields: vec![None; args_string_utf8.len()], number_rows: n_rows, - return_type: &return_type, + return_field: &return_field, })) }) }, @@ -74,7 +74,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_view_ascii.clone(), arg_fields: vec![None; args_string_view_ascii.len()], number_rows: n_rows, - return_type: &return_type, + return_field: &return_field, })) }) }, @@ -90,7 +90,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_view_utf8.clone(), arg_fields: vec![None; args_string_view_utf8.len()], number_rows: n_rows, - return_type: &return_type, + return_field: &return_field, })) }) }, diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 568a8a507e81..11a6950ff5e6 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -23,7 +23,7 @@ use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string::chr; use rand::{Rng, SeedableRng}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use rand::rngs::StdRng; use std::sync::Arc; @@ -58,7 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 47992eb28989..05c4ece9b0c5 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -16,7 +16,7 @@ // under the License. use arrow::array::ArrayRef; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_common::ScalarValue; @@ -47,7 +47,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index 44519f6dc595..508f53027c1c 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -25,7 +25,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::cot; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { @@ -41,7 +41,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f32_args.clone(), arg_fields: vec![None; f32_args.len()], number_rows: size, - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }) .unwrap(), ) @@ -57,7 +57,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f64_args.clone(), arg_fields: vec![None; f64_args.len()], number_rows: size, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 9a741668a9d8..36356654b0d1 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -20,6 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, TimestampSecondArray}; +use arrow::datatypes::Field; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use rand::rngs::ThreadRng; @@ -48,6 +49,7 @@ fn criterion_benchmark(c: &mut Criterion) { let return_type = udf .return_type(&[interval.data_type(), timestamps.data_type()]) .unwrap(); + let return_field = Field::new("f", return_type, true); b.iter(|| { black_box( @@ -55,7 +57,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![interval.clone(), timestamps.clone()], arg_fields: vec![None; 2], number_rows: batch_len, - return_type: &return_type, + return_field: &return_field, }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index c4e9b77f976a..7820c268b93b 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -20,6 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, TimestampSecondArray}; +use arrow::datatypes::Field; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use rand::rngs::ThreadRng; @@ -47,16 +48,17 @@ fn criterion_benchmark(c: &mut Criterion) { let timestamps = ColumnarValue::Array(timestamps_array); let udf = date_trunc(); let args = vec![precision, timestamps]; - let return_type = &udf + let return_type = udf .return_type(&args.iter().map(|arg| arg.data_type()).collect::>()) .unwrap(); + let return_field = Field::new("f", return_type, true); b.iter(|| { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: batch_len, - return_type, + return_field: &return_field, }) .expect("date_trunc should work on valid values"), ) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index 5985d1477d07..56f235687451 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -35,7 +35,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], arg_fields: vec![None; 2], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(); @@ -47,7 +47,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) @@ -61,7 +61,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], arg_fields: vec![None; 2], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(); @@ -73,7 +73,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index 5f580d855d8f..f82ff775cd1c 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::{StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; @@ -159,7 +159,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }); @@ -171,7 +171,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }); @@ -187,7 +187,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }); @@ -199,7 +199,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }); diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index e208c2c092f0..ab8be59f74fb 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -17,6 +17,7 @@ extern crate criterion; +use arrow::datatypes::Field; use arrow::{ array::{ArrayRef, Int64Array}, datatypes::DataType, @@ -49,7 +50,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![array_a.clone(), array_b.clone()], arg_fields: vec![None; 2], number_rows: 0, - return_type: &DataType::Int64, + return_field: &Field::new("f", DataType::Int64, true), }) .expect("date_bin should work on valid values"), ) @@ -66,7 +67,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![array_a.clone(), scalar_b.clone()], arg_fields: vec![None; 2], number_rows: 0, - return_type: &DataType::Int64, + return_field: &Field::new("f", DataType::Int64, true), }) .expect("date_bin should work on valid values"), ) @@ -83,7 +84,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![scalar_a.clone(), scalar_b.clone()], arg_fields: vec![None; 2], number_rows: 0, - return_type: &DataType::Int64, + return_field: &Field::new("f", DataType::Int64, true), }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index fcebe74fd65a..45ee5c435f5c 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::OffsetSizeTrait; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; @@ -57,7 +57,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8View, + return_field: &Field::new("f", DataType::Utf8View, true), })) }) }, @@ -72,7 +72,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8View, + return_field: &Field::new("f", DataType::Utf8View, true), })) }) }, @@ -85,7 +85,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }); diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index b77abcea2c78..0bab5d00cbc8 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, @@ -40,7 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f32_args.clone(), arg_fields: vec![None; f32_args.len()], number_rows: size, - return_type: &DataType::Boolean, + return_field: &Field::new("f", DataType::Boolean, true), }) .unwrap(), ) @@ -56,7 +56,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f64_args.clone(), arg_fields: vec![None; f64_args.len()], number_rows: size, - return_type: &DataType::Boolean, + return_field: &Field::new("f", DataType::Boolean, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index d0e7148b7364..f9f85d6d8b7c 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, @@ -41,7 +41,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f32_args.clone(), arg_fields: vec![None; f32_args.len()], number_rows: batch_len, - return_type: &DataType::Boolean, + return_field: &Field::new("f", DataType::Boolean, true), }) .unwrap(), ) @@ -58,7 +58,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f64_args.clone(), arg_fields: vec![None; f64_args.len()], number_rows: batch_len, - return_type: &DataType::Boolean, + return_field: &Field::new("f", DataType::Boolean, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index ae172e9766af..348c6bf2b6d9 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::{ArrayRef, StringArray, StringViewBuilder}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; @@ -131,7 +131,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }); @@ -146,7 +146,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -162,7 +162,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -188,7 +188,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }), ); @@ -203,7 +203,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }), ); @@ -218,7 +218,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }), ); diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 938febfddeaf..ebe576481bcf 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{ black_box, criterion_group, criterion_main, measurement::Measurement, BenchmarkGroup, Criterion, SamplingMode, @@ -147,7 +147,7 @@ fn run_with_string_type( args: args_cloned, arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 5037da919730..c24fc3bb9ad6 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -20,7 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int32Array}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::rngs::ThreadRng; use rand::Rng; @@ -71,7 +71,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![years.clone(), months.clone(), days.clone()], arg_fields: vec![None; 3], number_rows: batch_len, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }) .expect("make_date should work on valid values"), ) @@ -93,7 +93,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![year.clone(), months.clone(), days.clone()], arg_fields: vec![None; 3], number_rows: batch_len, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }) .expect("make_date should work on valid values"), ) @@ -115,7 +115,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![year.clone(), month.clone(), days.clone()], arg_fields: vec![None; 3], number_rows: batch_len, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }) .expect("make_date should work on valid values"), ) @@ -134,7 +134,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![year.clone(), month.clone(), day.clone()], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }) .expect("make_date should work on valid values"), ) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index 080c2890cee8..53f8a3684d87 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; @@ -41,7 +41,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index 2a488b369fe0..d494320a1462 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -16,7 +16,7 @@ // under the License. use arrow::array::{ArrayRef, ArrowPrimitiveType, OffsetSizeTrait, PrimitiveArray}; -use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{DataType, Field, Int64Type}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; @@ -108,7 +108,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) @@ -124,7 +124,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::LargeUtf8, + return_field: &Field::new("f", DataType::LargeUtf8, true), }) .unwrap(), ) @@ -140,7 +140,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) @@ -160,7 +160,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) @@ -176,7 +176,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::LargeUtf8, + return_field: &Field::new("f", DataType::LargeUtf8, true), }) .unwrap(), ) @@ -193,7 +193,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index a09b03affa11..5a80bbd1f11b 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use datafusion_functions::math::random::RandomFunc; @@ -36,7 +36,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![], arg_fields: vec![], number_rows: 8192, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }) .unwrap(), ); @@ -55,7 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![], arg_fields: vec![], number_rows: 128, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }) .unwrap(), ); diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 9b2e52aaac35..94514f24af9f 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; @@ -79,7 +79,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: repeat_times as usize, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -98,7 +98,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: repeat_times as usize, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -117,7 +117,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: repeat_times as usize, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -145,7 +145,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: repeat_times as usize, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -164,7 +164,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: repeat_times as usize, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -183,7 +183,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: repeat_times as usize, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -211,7 +211,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: repeat_times as usize, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index 6056d313d3dc..6fcefbf33714 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -18,7 +18,7 @@ extern crate criterion; mod helper; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ScalarFunctionArgs; use helper::gen_string_array; @@ -48,7 +48,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_ascii.clone(), arg_fields: vec![None; args_string_ascii.len()], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -68,7 +68,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_utf8.clone(), arg_fields: vec![None; args_string_utf8.len()], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -90,7 +90,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_view_ascii.clone(), arg_fields: vec![None; args_string_view_ascii.len()], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -110,7 +110,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_view_utf8.clone(), arg_fields: vec![None; args_string_view_utf8.len()], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 46e499542a4f..d21acbe36414 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -19,7 +19,7 @@ extern crate criterion; use arrow::datatypes::DataType; use arrow::{ - datatypes::{Float32Type, Float64Type}, + datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -41,7 +41,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f32_args.clone(), arg_fields: vec![None; f32_args.len()], number_rows: batch_len, - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }) .unwrap(), ) @@ -59,7 +59,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f64_args.clone(), arg_fields: vec![None; f64_args.len()], number_rows: batch_len, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index 72902ea16b92..cd7e763bcba0 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::{StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use rand::distributions::Alphanumeric; @@ -119,7 +119,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_ascii.clone(), arg_fields: vec![None; args_string_ascii.len()], number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }, @@ -135,7 +135,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_utf8.clone(), arg_fields: vec![None; args_string_utf8.len()], number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }, @@ -151,7 +151,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_view_ascii.clone(), arg_fields: vec![None; args_string_view_ascii.len()], number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }, @@ -167,7 +167,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_string_view_utf8.clone(), arg_fields: vec![None; args_string_view_utf8.len()], number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }, diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index 4b79d958669e..d4240ef2b7af 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; @@ -114,7 +114,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8View, + return_field: &Field::new("f", DataType::Utf8View, true), })) }) }, @@ -129,7 +129,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8View, + return_field: &Field::new("f", DataType::Utf8View, true), })) }) }, @@ -144,7 +144,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8View, + return_field: &Field::new("f", DataType::Utf8View, true), })) }) }, @@ -171,7 +171,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8View, + return_field: &Field::new("f", DataType::Utf8View, true), })) }) }, @@ -189,7 +189,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8View, + return_field: &Field::new("f", DataType::Utf8View, true), })) }) }, @@ -207,7 +207,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8View, + return_field: &Field::new("f", DataType::Utf8View, true), })) }) }, @@ -234,7 +234,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8View, + return_field: &Field::new("f", DataType::Utf8View, true), })) }) }, @@ -252,7 +252,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8View, + return_field: &Field::new("f", DataType::Utf8View, true), })) }) }, @@ -270,7 +270,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8View, + return_field: &Field::new("f", DataType::Utf8View, true), })) }) }, diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index c8ef989c866f..106b62b1c998 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -20,7 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{ArrayRef, Int64Array, StringArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::distributions::{Alphanumeric, Uniform}; use rand::prelude::Distribution; @@ -98,7 +98,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .expect("substr_index should work on valid values"), ) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 2b6857f6e60b..dc7df4d6698f 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -20,7 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{ArrayRef, Date32Array, StringArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use chrono::prelude::*; use chrono::TimeDelta; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -95,7 +95,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![data.clone(), patterns.clone()], arg_fields: vec![None; 2], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .expect("to_char should work on valid values"), ) @@ -117,7 +117,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![data.clone(), patterns.clone()], arg_fields: vec![None; 2], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .expect("to_char should work on valid values"), ) @@ -145,7 +145,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![data.clone(), pattern.clone()], arg_fields: vec![None; 2], number_rows: 1, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .expect("to_char should work on valid values"), ) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index 4f89a710146b..3514da83d9de 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use arrow::datatypes::{DataType, Field, Int32Type, Int64Type}; use arrow::util::bench_util::create_primitive_array; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -38,7 +38,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; i32_args.len()], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) @@ -55,7 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; i64_args.len()], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index 7a5c1b99fb85..47f125585b4b 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::array::builder::StringBuilder; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::compute::cast; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::{DataType, Field, TimeUnit}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -109,7 +109,8 @@ fn data_with_formats() -> (StringArray, StringArray, StringArray, StringArray) { ) } fn criterion_benchmark(c: &mut Criterion) { - let return_type = &DataType::Timestamp(TimeUnit::Nanosecond, None); + let return_field = + &Field::new("f", DataType::Timestamp(TimeUnit::Nanosecond, None), true); c.bench_function("to_timestamp_no_formats_utf8", |b| { let arr_data = data(); let batch_len = arr_data.len(); @@ -122,7 +123,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![string_array.clone()], arg_fields: vec![None; 1], number_rows: batch_len, - return_type, + return_field, }) .expect("to_timestamp should work on valid values"), ) @@ -141,7 +142,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![string_array.clone()], arg_fields: vec![None; 1], number_rows: batch_len, - return_type, + return_field, }) .expect("to_timestamp should work on valid values"), ) @@ -160,7 +161,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![string_array.clone()], arg_fields: vec![None; 1], number_rows: batch_len, - return_type, + return_field, }) .expect("to_timestamp should work on valid values"), ) @@ -184,7 +185,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: batch_len, - return_type, + return_field, }) .expect("to_timestamp should work on valid values"), ) @@ -216,7 +217,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: batch_len, - return_type, + return_field, }) .expect("to_timestamp should work on valid values"), ) @@ -249,7 +250,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args.clone(), arg_fields: vec![None; args.len()], number_rows: batch_len, - return_type, + return_field, }) .expect("to_timestamp should work on valid values"), ) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index 71383f9b12b8..2102d5985637 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::{ - datatypes::{Float32Type, Float64Type}, + datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -41,7 +41,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f32_args.clone(), arg_fields: vec![None; f32_args.len()], number_rows: size, - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }) .unwrap(), ) @@ -57,7 +57,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: f64_args.clone(), arg_fields: vec![None; f64_args.len()], number_rows: size, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index 9b730b49b4f7..d58fbe0e16f8 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -44,7 +44,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: args_cloned, arg_fields: vec![None; args.len()], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }); diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index 4b31477e20f9..dfed6871d6c2 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ScalarFunctionArgs; use datafusion_functions::string; @@ -30,7 +30,7 @@ fn criterion_benchmark(c: &mut Criterion) { args: vec![], arg_fields: vec![], number_rows: 1024, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }); diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index 97af6df7bc6c..5fb118a8a2fa 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -150,7 +150,7 @@ impl ScalarUDFImpl for NamedStructFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let DataType::Struct(fields) = args.return_type else { + let DataType::Struct(fields) = args.return_field.data_type() else { return internal_err!("incorrect named_struct return type"); }; diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index 8792bf1bd1b9..416ad288fc38 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -117,7 +117,7 @@ impl ScalarUDFImpl for StructFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let DataType::Struct(fields) = args.return_type else { + let DataType::Struct(fields) = args.return_field.data_type() else { return internal_err!("incorrect struct return type"); }; diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index 968d443a0a42..5cb22d9d1b9e 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -200,7 +200,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: 1, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })?; assert_scalar(result, ScalarValue::Utf8(None)); @@ -216,7 +216,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: 1, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })?; assert_scalar(result, ScalarValue::Utf8(None)); @@ -232,7 +232,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: 1, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })?; assert_scalar(result, ScalarValue::new_utf8("42")); diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 865489e4517a..9e243dd0adb8 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -97,6 +97,7 @@ impl ScalarUDFImpl for VersionFunc { #[cfg(test)] mod test { use super::*; + use arrow::datatypes::Field; use datafusion_expr::ScalarUDF; #[tokio::test] @@ -107,7 +108,7 @@ mod test { args: vec![], arg_fields: vec![], number_rows: 0, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(); diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index eafeee70f4a7..b58262103d40 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -505,7 +505,7 @@ mod tests { use arrow::array::types::TimestampNanosecondType; use arrow::array::{Array, IntervalDayTimeArray, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, TimeUnit}; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; use datafusion_common::ScalarValue; @@ -515,6 +515,9 @@ mod tests { #[test] fn test_date_bin() { + let return_field = + Field::new("f", DataType::Timestamp(TimeUnit::Nanosecond, None), true); + let mut args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( @@ -528,7 +531,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &return_field, }; let res = DateBinFunc::new().invoke_with_args(args); assert!(res.is_ok()); @@ -548,7 +551,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &return_field, }; let res = DateBinFunc::new().invoke_with_args(args); assert!(res.is_ok()); @@ -565,7 +568,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &return_field, }; let res = DateBinFunc::new().invoke_with_args(args); assert!(res.is_ok()); @@ -585,7 +588,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &return_field, }; let res = DateBinFunc::new().invoke_with_args(args); assert!(res.is_ok()); @@ -604,7 +607,7 @@ mod tests { )))], arg_fields: vec![None; 1], number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &return_field, }; let res = DateBinFunc::new().invoke_with_args(args); assert_eq!( @@ -621,7 +624,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &return_field, }; let res = DateBinFunc::new().invoke_with_args(args); assert_eq!( @@ -644,7 +647,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &return_field, }; let res = DateBinFunc::new().invoke_with_args(args); @@ -664,7 +667,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &return_field, }; let res = DateBinFunc::new().invoke_with_args(args); assert_eq!( @@ -681,7 +684,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &return_field, }; let res = DateBinFunc::new().invoke_with_args(args); assert_eq!( @@ -698,7 +701,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &return_field, }; let res = DateBinFunc::new().invoke_with_args(args); assert_eq!( @@ -720,7 +723,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &return_field, }; let res = DateBinFunc::new().invoke_with_args(args); assert_eq!( @@ -741,7 +744,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &return_field, }; let res = DateBinFunc::new().invoke_with_args(args); assert!(res.is_ok()); @@ -765,7 +768,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &return_field, }; let res = DateBinFunc::new().invoke_with_args(args); assert_eq!( @@ -789,7 +792,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &return_field, }; let res = DateBinFunc::new().invoke_with_args(args); assert_eq!( @@ -918,9 +921,10 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: batch_len, - return_type: &DataType::Timestamp( - TimeUnit::Nanosecond, - tz_opt.clone(), + return_field: &Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, ), }; let result = DateBinFunc::new().invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 63cb239a4285..1b5bac7ea5f7 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -487,7 +487,7 @@ mod tests { use arrow::array::types::TimestampNanosecondType; use arrow::array::{Array, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -733,7 +733,11 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + return_field: &Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + ), }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { @@ -896,7 +900,11 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + return_field: &Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + ), }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 3db0500bfeb2..885285f2473c 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -161,8 +161,8 @@ impl ScalarUDFImpl for FromUnixtimeFunc { #[cfg(test)] mod test { use crate::datetime::from_unixtime::FromUnixtimeFunc; - use arrow::datatypes::DataType; use arrow::datatypes::TimeUnit::Second; + use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_common::ScalarValue::Int64; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -174,7 +174,7 @@ mod test { args: vec![ColumnarValue::Scalar(Int64(Some(1729900800)))], arg_fields: vec![None; 1], number_rows: 1, - return_type: &DataType::Timestamp(Second, None), + return_field: &Field::new("f", DataType::Timestamp(Second, None), true), }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); @@ -197,9 +197,10 @@ mod test { ], arg_fields: vec![None; 2], number_rows: 2, - return_type: &DataType::Timestamp( - Second, - Some(Arc::from("America/New_York")), + return_field: &Field::new( + "f", + DataType::Timestamp(Second, Some(Arc::from("America/New_York"))), + true, ), }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 8f8f65fee0b0..e6930a4bb708 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -223,7 +223,7 @@ fn make_date_inner( mod tests { use crate::datetime::make_date::MakeDateFunc; use arrow::array::{Array, Date32Array, Int32Array, Int64Array, UInt32Array}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use std::sync::Arc; @@ -238,7 +238,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let res = MakeDateFunc::new() .invoke_with_args(args) @@ -258,7 +258,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let res = MakeDateFunc::new() .invoke_with_args(args) @@ -278,7 +278,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let res = MakeDateFunc::new() .invoke_with_args(args) @@ -302,7 +302,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: batch_len, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let res = MakeDateFunc::new() .invoke_with_args(args) @@ -329,7 +329,7 @@ mod tests { args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], arg_fields: vec![None; 1], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let res = MakeDateFunc::new().invoke_with_args(args); assert_eq!( @@ -346,7 +346,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let res = MakeDateFunc::new().invoke_with_args(args); assert_eq!( @@ -363,7 +363,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let res = MakeDateFunc::new().invoke_with_args(args); assert_eq!( @@ -380,7 +380,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let res = MakeDateFunc::new().invoke_with_args(args); assert_eq!( diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index f56fc53f6af7..6d51a06d2bf7 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -303,7 +303,7 @@ mod tests { TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use chrono::{NaiveDateTime, Timelike}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -389,7 +389,7 @@ mod tests { args: vec![ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)], arg_fields: vec![None; 2], number_rows: 1, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -473,7 +473,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -605,7 +605,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -628,7 +628,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -651,7 +651,7 @@ mod tests { args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], arg_fields: vec![None; 1], number_rows: 1, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( @@ -667,7 +667,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: 1, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index a07ce87f84f0..a6af727291d3 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -163,7 +163,7 @@ impl ScalarUDFImpl for ToDateFunc { #[cfg(test)] mod tests { use arrow::array::{Array, Date32Array, GenericStringArray, StringViewArray}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use arrow::{compute::kernels::cast_utils::Parser, datatypes::Date32Type}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -212,7 +212,7 @@ mod tests { args: vec![ColumnarValue::Scalar(sv)], arg_fields: vec![None; 1], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let to_date_result = ToDateFunc::new().invoke_with_args(args); @@ -239,7 +239,7 @@ mod tests { args: vec![ColumnarValue::Array(Arc::new(date_array))], arg_fields: vec![None; 1], number_rows: batch_len, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let to_date_result = ToDateFunc::new().invoke_with_args(args); @@ -337,7 +337,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let to_date_result = ToDateFunc::new().invoke_with_args(args); @@ -368,7 +368,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: batch_len, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let to_date_result = ToDateFunc::new().invoke_with_args(args); @@ -410,7 +410,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let to_date_result = ToDateFunc::new().invoke_with_args(args); @@ -440,7 +440,7 @@ mod tests { args: vec![ColumnarValue::Scalar(formatted_date_scalar)], arg_fields: vec![None; 1], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let to_date_result = ToDateFunc::new().invoke_with_args(args); @@ -463,7 +463,7 @@ mod tests { args: vec![ColumnarValue::Scalar(date_scalar)], arg_fields: vec![None; 1], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let to_date_result = ToDateFunc::new().invoke_with_args(args); @@ -489,7 +489,7 @@ mod tests { args: vec![ColumnarValue::Scalar(date_scalar)], arg_fields: vec![None; 1], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }; let to_date_result = ToDateFunc::new().invoke_with_args(args); diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index b82da1aa1edb..3e94008cead6 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -409,7 +409,7 @@ mod tests { use arrow::array::{types::TimestampNanosecondType, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, TimeUnit}; use chrono::NaiveDateTime; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; @@ -543,7 +543,7 @@ mod tests { args: vec![ColumnarValue::Scalar(input)], arg_fields: vec![None; 1], number_rows: 1, - return_type: &expected.data_type(), + return_field: &Field::new("f", expected.data_type(), true), }) .unwrap(); match res { @@ -607,7 +607,11 @@ mod tests { args: vec![ColumnarValue::Array(Arc::new(input))], arg_fields: vec![None; 1], number_rows: batch_size, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), }; let result = ToLocalTimeFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 5e0d0ca903a5..d42467fcde4c 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -639,7 +639,7 @@ mod tests { TimestampNanosecondArray, TimestampSecondArray, }; use arrow::array::{ArrayRef, Int64Array, StringBuilder}; - use arrow::datatypes::TimeUnit; + use arrow::datatypes::{Field, TimeUnit}; use chrono::Utc; use datafusion_common::{assert_contains, DataFusionError, ScalarValue}; use datafusion_expr::ScalarFunctionImplementation; @@ -1017,7 +1017,7 @@ mod tests { args: vec![array.clone()], arg_fields: vec![None; 1], number_rows: 4, - return_type: &rt, + return_field: &Field::new("f", rt, true), }; let res = udf .invoke_with_args(args) @@ -1065,7 +1065,7 @@ mod tests { args: vec![array.clone()], arg_fields: vec![None; 1], number_rows: 5, - return_type: &rt, + return_field: &Field::new("f", rt, true), }; let res = udf .invoke_with_args(args) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 7b57fdd7a798..4e7d48f19d31 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -256,6 +256,7 @@ mod tests { use arrow::array::{Float32Array, Float64Array, Int64Array}; use arrow::compute::SortOptions; + use arrow::datatypes::Field; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::DFSchema; use datafusion_expr::execution_props::ExecutionProps; @@ -273,7 +274,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: 4, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let _ = LogFunc::new().invoke_with_args(args); } @@ -286,7 +287,7 @@ mod tests { ], arg_fields: vec![None; 1], number_rows: 1, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let result = LogFunc::new().invoke_with_args(args); @@ -301,7 +302,7 @@ mod tests { ], arg_fields: vec![None; 1], number_rows: 1, - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }; let result = LogFunc::new() .invoke_with_args(args) @@ -329,7 +330,7 @@ mod tests { ], arg_fields: vec![None; 1], number_rows: 1, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let result = LogFunc::new() .invoke_with_args(args) @@ -358,7 +359,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: 1, - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }; let result = LogFunc::new() .invoke_with_args(args) @@ -387,7 +388,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: 1, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let result = LogFunc::new() .invoke_with_args(args) @@ -417,7 +418,7 @@ mod tests { ], arg_fields: vec![None; 1], number_rows: 4, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let result = LogFunc::new() .invoke_with_args(args) @@ -450,7 +451,7 @@ mod tests { ], arg_fields: vec![None; 1], number_rows: 4, - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }; let result = LogFunc::new() .invoke_with_args(args) @@ -486,7 +487,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: 4, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let result = LogFunc::new() .invoke_with_args(args) @@ -522,7 +523,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: 4, - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }; let result = LogFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index f549f18e777e..e3bc0972e592 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -187,6 +187,7 @@ fn is_log(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { use arrow::array::Float64Array; + use arrow::datatypes::Field; use datafusion_common::cast::{as_float64_array, as_int64_array}; use super::*; @@ -204,7 +205,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: 4, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let result = PowerFunc::new() .invoke_with_args(args) @@ -235,7 +236,7 @@ mod tests { ], arg_fields: vec![None; 2], number_rows: 4, - return_type: &DataType::Int64, + return_field: &Field::new("f", DataType::Int64, true), }; let result = PowerFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index f57492ad231e..6d75af27a3e0 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -138,7 +138,7 @@ mod test { use std::sync::Arc; use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; @@ -161,7 +161,7 @@ mod test { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], arg_fields: vec![None; 1], number_rows: array.len(), - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }; let result = SignumFunc::new() .invoke_with_args(args) @@ -206,7 +206,7 @@ mod test { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], arg_fields: vec![None; 1], number_rows: array.len(), - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let result = SignumFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 044d2229dac1..bfb203e2f665 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -619,6 +619,7 @@ fn count_matches( mod tests { use super::*; use arrow::array::{GenericStringArray, StringViewArray}; + use arrow::datatypes::Field; use datafusion_expr::ScalarFunctionArgs; #[test] @@ -661,7 +662,7 @@ mod tests { args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], arg_fields: vec![None; 2], number_rows: 2, - return_type: &Int64, + return_field: &Field::new("f", Int64, true), }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { @@ -677,7 +678,7 @@ mod tests { args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], arg_fields: vec![None; 2], number_rows: 2, - return_type: &Int64, + return_field: &Field::new("f", Int64, true), }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { @@ -693,7 +694,7 @@ mod tests { args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], arg_fields: vec![None; 2], number_rows: 2, - return_type: &Int64, + return_field: &Field::new("f", Int64, true), }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { @@ -724,7 +725,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 3, - return_type: &Int64, + return_field: &Field::new("f", Int64, true), }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { @@ -744,7 +745,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 3, - return_type: &Int64, + return_field: &Field::new("f", Int64, true), }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { @@ -764,7 +765,7 @@ mod tests { ], arg_fields: vec![None; 3], number_rows: 3, - return_type: &Int64, + return_field: &Field::new("f", Int64, true), }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { @@ -798,7 +799,7 @@ mod tests { ], arg_fields: vec![None; 4], number_rows: 4, - return_type: &Int64, + return_field: &Field::new("f", Int64, true), }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { @@ -820,7 +821,7 @@ mod tests { ], arg_fields: vec![None; 4], number_rows: 4, - return_type: &Int64, + return_field: &Field::new("f", Int64, true), }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { @@ -842,7 +843,7 @@ mod tests { ], arg_fields: vec![None; 4], number_rows: 4, - return_type: &Int64, + return_field: &Field::new("f", Int64, true), }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { @@ -925,7 +926,7 @@ mod tests { ], arg_fields: vec![None; 4], number_rows: 4, - return_type: &Int64, + return_field: &Field::new("f", Int64, true), }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { @@ -947,7 +948,7 @@ mod tests { ], arg_fields: vec![None; 4], number_rows: 4, - return_type: &Int64, + return_field: &Field::new("f", Int64, true), }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { @@ -969,7 +970,7 @@ mod tests { ], arg_fields: vec![None; 4], number_rows: 4, - return_type: &Int64, + return_field: &Field::new("f", Int64, true), }); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index 220db2d4b655..a46b79be74dc 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -376,6 +376,7 @@ mod tests { use crate::utils::test::test_function; use arrow::array::{Array, LargeStringArray, StringViewArray}; use arrow::array::{ArrayRef, StringArray}; + use arrow::datatypes::Field; use DataType::*; #[test] @@ -473,7 +474,7 @@ mod tests { args: vec![c0, c1, c2, c3, c4], arg_fields: vec![None; 5], number_rows: 3, - return_type: &Utf8, + return_field: &Field::new("f", Utf8, true), }; let result = ConcatFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 8bf25587d9aa..0de16ad4fec8 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -403,10 +403,10 @@ fn is_null(expr: &Expr) -> bool { mod tests { use std::sync::Arc; + use crate::string::concat_ws::ConcatWsFunc; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::datatypes::DataType::Utf8; - - use crate::string::concat_ws::ConcatWsFunc; + use arrow::datatypes::Field; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; @@ -485,7 +485,7 @@ mod tests { args: vec![c0, c1, c2], arg_fields: vec![None; 3], number_rows: 3, - return_type: &Utf8, + return_field: &Field::new("f", Utf8, true), }; let result = ConcatWsFunc::new().invoke_with_args(args)?; @@ -516,7 +516,7 @@ mod tests { args: vec![c0, c1, c2], arg_fields: vec![None; 3], number_rows: 3, - return_type: &Utf8, + return_field: &Field::new("f", Utf8, true), }; let result = ConcatWsFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index aef5a345fb5c..f60514427431 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -151,7 +151,7 @@ fn contains(args: &[ArrayRef]) -> Result { mod test { use super::ContainsFunc; use arrow::array::{BooleanArray, StringArray}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; @@ -169,7 +169,7 @@ mod test { args: vec![array, scalar], arg_fields: vec![None; 2], number_rows: 2, - return_type: &DataType::Boolean, + return_field: &Field::new("f", DataType::Boolean, true), }; let actual = udf.invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 8dd5eef5ea78..df1d378d0c16 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -98,6 +98,8 @@ impl ScalarUDFImpl for LowerFunc { mod tests { use super::*; use arrow::array::{Array, ArrayRef, StringArray}; + use arrow::datatypes::DataType::Utf8; + use arrow::datatypes::Field; use std::sync::Arc; fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { @@ -107,7 +109,7 @@ mod tests { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], arg_fields: vec![None; 1], - return_type: &DataType::Utf8, + return_field: &Field::new("f", Utf8, true), }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 475632f8d4f9..acff5049ab14 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -97,6 +97,8 @@ impl ScalarUDFImpl for UpperFunc { mod tests { use super::*; use arrow::array::{Array, ArrayRef, StringArray}; + use arrow::datatypes::DataType::Utf8; + use arrow::datatypes::Field; use std::sync::Arc; fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { @@ -106,7 +108,7 @@ mod tests { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], arg_fields: vec![None; 1], - return_type: &DataType::Utf8, + return_field: &Field::new("f", Utf8, true), }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index 35a0682bdc04..8e80dd843511 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -348,7 +348,7 @@ mod tests { use crate::unicode::find_in_set::FindInSetFunc; use crate::utils::test::test_function; use arrow::array::{Array, Int32Array, StringArray}; - use arrow::datatypes::DataType::Int32; + use arrow::datatypes::{DataType::Int32, Field}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; @@ -476,7 +476,7 @@ mod tests { args, arg_fields, number_rows: cardinality, - return_type: &return_type, + return_field: &Field::new("f", return_type, true), }); assert!(result.is_ok()); diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index a967fdf7c2ce..a81ea96951d8 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -157,20 +157,20 @@ pub mod test { .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable)) .collect::>(); - let return_info = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { + let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { arg_fields: &field_array, scalar_arguments: &scalar_arguments_refs, }); match expected { Ok(expected) => { - assert_eq!(return_info.is_ok(), true); - let return_info = return_info.unwrap(); - let return_type = return_info.data_type(); + assert_eq!(return_field.is_ok(), true); + let return_field = return_field.unwrap(); + let return_type = return_field.data_type(); assert_eq!(return_type, &$EXPECTED_DATA_TYPE); let arg_fields = vec![None; $ARGS.len()]; - let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_type: &return_type}); + let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field: &return_field}); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array"); @@ -184,19 +184,18 @@ pub mod test { }; } Err(expected_error) => { - if return_info.is_err() { - match return_info { + if return_field.is_err() { + match return_field { Ok(_) => assert!(false, "expected error"), Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } } } else { - let return_info = return_info.unwrap(); - let return_type = return_info.data_type(); + let return_field = return_field.unwrap(); let arg_fields = vec![None; $ARGS.len()]; // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_type: &return_type}) { + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field: &return_field}) { Ok(_) => assert!(false, "expected error"), Err(error) => { assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); diff --git a/datafusion/physical-expr/src/equivalence/properties/dependency.rs b/datafusion/physical-expr/src/equivalence/properties/dependency.rs index feab0d1f2d3e..8b17db04844d 100644 --- a/datafusion/physical-expr/src/equivalence/properties/dependency.rs +++ b/datafusion/physical-expr/src/equivalence/properties/dependency.rs @@ -424,7 +424,6 @@ pub fn generate_dependency_orderings( #[cfg(test)] mod tests { - use std::collections::HashMap; use std::ops::Not; use std::sync::Arc; @@ -1225,8 +1224,7 @@ mod tests { "concat", concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], - DataType::Utf8, - HashMap::default(), + Field::new("f", DataType::Utf8, true), )); // Assume existing ordering is [c ASC, a ASC, b ASC] @@ -1317,8 +1315,7 @@ mod tests { "concat", concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], - DataType::Utf8, - HashMap::default(), + Field::new("f", DataType::Utf8, true), )); // Assume existing ordering is [concat(a, b) ASC, a ASC, b ASC] diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index bbaa9568120c..941b799e3325 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -30,7 +30,6 @@ //! to a function that supports f64, it is coerced to f64. use std::any::Any; -use std::collections::HashMap; use std::fmt::{self, Debug, Formatter}; use std::hash::Hash; use std::sync::Arc; @@ -49,35 +48,12 @@ use datafusion_expr::{ }; /// Physical expression of a scalar function -#[derive(Eq, PartialEq)] +#[derive(Eq, PartialEq, Hash)] pub struct ScalarFunctionExpr { fun: Arc, name: String, args: Vec>, - return_type: DataType, - nullable: bool, - metadata: HashMap, -} - -impl Hash for ScalarFunctionExpr { - fn hash(&self, state: &mut H) { - // Sort keys for deterministic hashing - let mut keys: Vec<&String> = self.metadata.keys().collect(); - keys.sort(); - - for key in keys { - key.hash(state); - if let Some(value) = self.metadata.get(key) { - value.hash(state); - } - } - - self.fun.hash(state); - self.name.hash(state); - self.args.hash(state); - self.return_type.hash(state); - self.nullable.hash(state); - } + return_field: Field, } impl Debug for ScalarFunctionExpr { @@ -86,7 +62,7 @@ impl Debug for ScalarFunctionExpr { .field("fun", &"") .field("name", &self.name) .field("args", &self.args) - .field("return_type", &self.return_type) + .field("return_field", &self.return_field) .finish() } } @@ -97,16 +73,13 @@ impl ScalarFunctionExpr { name: &str, fun: Arc, args: Vec>, - return_type: DataType, - metadata: HashMap, + return_field: Field, ) -> Self { Self { fun, name: name.to_owned(), args, - return_type, - nullable: true, - metadata, + return_field, } } @@ -157,9 +130,7 @@ impl ScalarFunctionExpr { fun, name, args, - return_type: return_field.data_type().clone(), - nullable: return_field.is_nullable(), - metadata: HashMap::new(), + return_field, }) } @@ -180,16 +151,16 @@ impl ScalarFunctionExpr { /// Data type produced by this expression pub fn return_type(&self) -> &DataType { - &self.return_type + self.return_field.data_type() } pub fn with_nullable(mut self, nullable: bool) -> Self { - self.nullable = nullable; + self.return_field = self.return_field.with_nullable(nullable); self } pub fn nullable(&self) -> bool { - self.nullable + self.return_field.is_nullable() } } @@ -206,11 +177,11 @@ impl PhysicalExpr for ScalarFunctionExpr { } fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(self.return_type.clone()) + Ok(self.return_field.data_type().clone()) } fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(self.nullable) + Ok(self.return_field.is_nullable()) } fn evaluate(&self, batch: &RecordBatch) -> Result { @@ -240,7 +211,7 @@ impl PhysicalExpr for ScalarFunctionExpr { args, arg_fields, number_rows: batch.num_rows(), - return_type: &self.return_type, + return_field: &self.return_field, })?; if let ColumnarValue::Array(array) = &output { @@ -272,16 +243,12 @@ impl PhysicalExpr for ScalarFunctionExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new( - ScalarFunctionExpr::new( - &self.name, - Arc::clone(&self.fun), - children, - self.return_type().clone(), - self.metadata.clone(), - ) - .with_nullable(self.nullable), - )) + Ok(Arc::new(ScalarFunctionExpr::new( + &self.name, + Arc::clone(&self.fun), + children, + self.return_field.clone(), + ))) } fn evaluate_bounds(&self, children: &[&Interval]) -> Result { diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index ea824f0b035a..c412dfed5d03 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use arrow::compute::SortOptions; +use arrow::datatypes::Field; use chrono::{TimeZone, Utc}; use datafusion_expr::dml::InsertOp; use object_store::path::Path; @@ -365,8 +366,7 @@ pub fn parse_physical_expr( e.name.as_str(), scalar_fun_def, args, - convert_required!(e.return_type)?, - std::collections::hash_map::HashMap::new(), + Field::new("f", convert_required!(e.return_type)?, true), ) .with_nullable(e.nullable), ) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 6387578be1fa..7b03939d754c 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -16,7 +16,6 @@ // under the License. use std::any::Any; -use std::collections::HashMap; use std::fmt::{Display, Formatter}; use std::ops::Deref; use std::sync::Arc; @@ -973,8 +972,7 @@ fn roundtrip_scalar_udf() -> Result<()> { "dummy", fun_def, vec![col("a", &schema)?], - DataType::Int64, - HashMap::default(), + Field::new("f", DataType::Int64, true), ); let project = @@ -1102,8 +1100,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { "regex_udf", Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], - DataType::Int64, - HashMap::default(), + Field::new("f", DataType::Int64, true), )); let filter = Arc::new(FilterExec::try_new( @@ -1205,8 +1202,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { "regex_udf", Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], - DataType::Int64, - HashMap::default(), + Field::new("f", DataType::Int64, true), )); let udaf = Arc::new(AggregateUDF::from(MyAggregateUDF::new( From 9aa5227b330ab187bf82eee7a3f8c93366bb38bf Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 18 Apr 2025 09:00:43 -0400 Subject: [PATCH 16/25] subquery should just get the field from below and not lose potential metadata --- datafusion/expr/src/expr_schema.rs | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 80eb91323c01..9572e03d14c4 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -426,11 +426,9 @@ impl ExprSchemable for Expr { | Expr::Exists { .. } => { Ok(Field::new(&schema_name, DataType::Boolean, false)) } - Expr::ScalarSubquery(subquery) => Ok(Field::new( - &schema_name, - subquery.subquery.schema().field(0).data_type().clone(), - subquery.subquery.schema().field(0).is_nullable(), - )), + Expr::ScalarSubquery(subquery) => { + Ok(subquery.subquery.schema().field(0).clone()) + } Expr::BinaryExpr(BinaryExpr { ref left, ref right, From 2fab67baa7ee10e458b0dbbb3aea59992beac103 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Fri, 18 Apr 2025 09:18:09 -0400 Subject: [PATCH 17/25] Update comment --- datafusion/expr/src/udf.rs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index cb22cd2a07cc..137ef746090e 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -300,8 +300,9 @@ pub struct ScalarFunctionArgs<'a, 'b> { pub arg_fields: Vec>, /// The number of rows in record batch being evaluated pub number_rows: usize, - /// The return type of the scalar function returned (from `return_type` or `return_field_from_args`) - /// when creating the physical expression from the logical expression + /// The return field of the scalar function returned (from `return_type` + /// or `return_field_from_args`) when creating the physical expression + /// from the logical expression pub return_field: &'a Field, } From 02d6f4756fa6e2ff92b46285ae798be96763aaa2 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 21 Apr 2025 08:36:03 -0400 Subject: [PATCH 18/25] Remove output_field now that we've determined it using return_field_from_args --- .../user_defined/user_defined_scalar_functions.rs | 15 --------------- datafusion/expr/src/udf.rs | 8 +------- datafusion/physical-expr/src/scalar_function.rs | 4 ++-- 3 files changed, 3 insertions(+), 24 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 333b3339d652..ced5dce36804 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -1377,7 +1377,6 @@ async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result bool { self.name == other.name() } - - fn output_field(&self, _input_schema: &Schema) -> Option { - Some(self.output_field.clone()) - } } #[tokio::test] @@ -1611,13 +1603,6 @@ impl ScalarUDFImpl for ExtensionBasedUdf { fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { self.name == other.name() } - - fn output_field(&self, _input_schema: &Schema) -> Option { - Some( - Field::new("canonical_extension_udf", DataType::Utf8, true) - .with_extension_type(MyUserExtentionType {}), - ) - } } struct MyUserExtentionType {} diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 137ef746090e..9660866406db 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -21,7 +21,7 @@ use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::{ColumnarValue, Documentation, Expr, Signature}; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; @@ -682,12 +682,6 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { fn documentation(&self) -> Option<&Documentation> { None } - - /// This describes the output field associated with this UDF. - /// Input field is handled through `ScalarFunctionArgs` - fn output_field(&self, _input_schema: &Schema) -> Option { - None - } } /// ScalarUDF that adds an alias to the underlying function. It is better to diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 941b799e3325..b3a66dafb8bc 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -231,8 +231,8 @@ impl PhysicalExpr for ScalarFunctionExpr { Ok(output) } - fn output_field(&self, input_schema: &Schema) -> Result> { - Ok(self.fun.as_ref().inner().output_field(input_schema)) + fn output_field(&self, _input_schema: &Schema) -> Result> { + Ok(Some(self.return_field.clone())) } fn children(&self) -> Vec<&Arc> { From 871c3829e708127346426b9ec6a5e34fd7e44194 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 21 Apr 2025 09:00:04 -0400 Subject: [PATCH 19/25] Change name to_field to field_from_column to be more consistent with the usage and prevent misconception about if we are doing some conversion --- datafusion/common/src/dfschema.rs | 29 +++++++++------------ datafusion/expr/src/expr_schema.rs | 4 +-- datafusion/expr/src/logical_plan/builder.rs | 16 ++++++++++-- 3 files changed, 28 insertions(+), 21 deletions(-) diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index da670eb93167..a4fa502189f8 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -515,14 +515,6 @@ impl DFSchema { Ok(self.field(idx)) } - /// Find the field with the given qualified column - pub fn field_from_column(&self, column: &Column) -> Result<&Field> { - match &column.relation { - Some(r) => self.field_with_qualified_name(r, &column.name), - None => self.field_with_unqualified_name(&column.name), - } - } - /// Find the field with the given qualified column pub fn qualified_field_from_column( &self, @@ -970,27 +962,27 @@ impl Display for DFSchema { pub trait ExprSchema: std::fmt::Debug { /// Is this column reference nullable? fn nullable(&self, col: &Column) -> Result { - Ok(self.to_field(col)?.is_nullable()) + Ok(self.field_from_column(col)?.is_nullable()) } /// What is the datatype of this column? fn data_type(&self, col: &Column) -> Result<&DataType> { - Ok(self.to_field(col)?.data_type()) + Ok(self.field_from_column(col)?.data_type()) } /// Returns the column's optional metadata. fn metadata(&self, col: &Column) -> Result<&HashMap> { - Ok(self.to_field(col)?.metadata()) + Ok(self.field_from_column(col)?.metadata()) } /// Return the column's datatype and nullability fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { - let field = self.to_field(col)?; + let field = self.field_from_column(col)?; Ok((field.data_type(), field.is_nullable())) } // Return the column's field - fn to_field(&self, col: &Column) -> Result<&Field>; + fn field_from_column(&self, col: &Column) -> Result<&Field>; } // Implement `ExprSchema` for `Arc` @@ -1011,14 +1003,17 @@ impl + std::fmt::Debug> ExprSchema for P { self.as_ref().data_type_and_nullable(col) } - fn to_field(&self, col: &Column) -> Result<&Field> { - self.as_ref().to_field(col) + fn field_from_column(&self, col: &Column) -> Result<&Field> { + self.as_ref().field_from_column(col) } } impl ExprSchema for DFSchema { - fn to_field(&self, col: &Column) -> Result<&Field> { - self.field_from_column(col) + fn field_from_column(&self, col: &Column) -> Result<&Field> { + match &col.relation { + Some(r) => self.field_with_qualified_name(r, &col.name), + None => self.field_with_unqualified_name(&col.name), + } } } diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index 9572e03d14c4..3786180e2cfa 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -409,7 +409,7 @@ impl ExprSchemable for Expr { Expr::Negative(expr) => { expr.to_field(schema).map(|(_, f)| f.as_ref().clone()) } - Expr::Column(c) => schema.to_field(c).cloned(), + Expr::Column(c) => schema.field_from_column(c).cloned(), Expr::OuterReferenceColumn(ty, _) => { Ok(Field::new(&schema_name, ty.clone(), true)) } @@ -848,7 +848,7 @@ mod tests { } } - fn to_field(&self, _col: &Column) -> Result<&Field> { + fn field_from_column(&self, _col: &Column) -> Result<&Field> { Ok(&self.field) } } diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 64931df5a83f..f371c4a4f99a 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1124,12 +1124,24 @@ impl LogicalPlanBuilder { for (l, r) in &on { if self.plan.schema().has_column(l) && right.schema().has_column(r) - && can_hash(self.plan.schema().field_from_column(l)?.data_type()) + && can_hash( + datafusion_common::ExprSchema::field_from_column( + self.plan.schema(), + l, + )? + .data_type(), + ) { join_on.push((Expr::Column(l.clone()), Expr::Column(r.clone()))); } else if self.plan.schema().has_column(l) && right.schema().has_column(r) - && can_hash(self.plan.schema().field_from_column(r)?.data_type()) + && can_hash( + datafusion_common::ExprSchema::field_from_column( + self.plan.schema(), + r, + )? + .data_type(), + ) { join_on.push((Expr::Column(r.clone()), Expr::Column(l.clone()))); } else { From c4392243956f9ceff986560ea9bb2db2eaff551c Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 21 Apr 2025 09:11:00 -0400 Subject: [PATCH 20/25] Minor moving around of the explicit lifetimes in the struct definition --- datafusion/expr/src/udf.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 9660866406db..07bcfc4e5d68 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -297,13 +297,13 @@ pub struct ScalarFunctionArgs<'a, 'b> { /// The evaluated arguments to the function pub args: Vec, /// Field associated with each arg, if it exists - pub arg_fields: Vec>, + pub arg_fields: Vec>, /// The number of rows in record batch being evaluated pub number_rows: usize, /// The return field of the scalar function returned (from `return_type` /// or `return_field_from_args`) when creating the physical expression /// from the logical expression - pub return_field: &'a Field, + pub return_field: &'b Field, } /// Information about arguments passed to the function From d69deee575f6ad268e5d34a295d58cdde5f45b09 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 21 Apr 2025 13:54:30 -0400 Subject: [PATCH 21/25] Change physical expression to require to output a field which requires a lot of unit test updates, especially because the scalar arguments pass around borrowed values --- .../user_defined_scalar_functions.rs | 41 +- datafusion/expr/src/udf.rs | 2 +- datafusion/ffi/src/udf/mod.rs | 23 +- datafusion/functions-nested/benches/map.rs | 5 +- .../functions/benches/character_length.rs | 32 +- datafusion/functions/benches/chr.rs | 9 +- datafusion/functions/benches/concat.rs | 9 +- datafusion/functions/benches/cot.rs | 18 +- datafusion/functions/benches/date_bin.rs | 5 +- datafusion/functions/benches/date_trunc.rs | 9 +- datafusion/functions/benches/encoding.rs | 26 +- datafusion/functions/benches/find_in_set.rs | 28 +- datafusion/functions/benches/gcd.rs | 15 +- datafusion/functions/benches/initcap.rs | 13 +- datafusion/functions/benches/isnan.rs | 17 +- datafusion/functions/benches/iszero.rs | 18 +- datafusion/functions/benches/lower.rs | 42 +- datafusion/functions/benches/ltrim.rs | 8 +- datafusion/functions/benches/make_date.rs | 24 +- datafusion/functions/benches/nullif.rs | 9 +- datafusion/functions/benches/pad.rs | 82 ++-- datafusion/functions/benches/repeat.rs | 70 ++-- datafusion/functions/benches/reverse.rs | 24 +- datafusion/functions/benches/signum.rs | 18 +- datafusion/functions/benches/strpos.rs | 24 +- datafusion/functions/benches/substr.rs | 120 ++---- datafusion/functions/benches/substr_index.rs | 9 +- datafusion/functions/benches/to_char.rs | 15 +- datafusion/functions/benches/to_hex.rs | 4 +- datafusion/functions/benches/to_timestamp.rs | 35 +- datafusion/functions/benches/trunc.rs | 4 +- datafusion/functions/benches/upper.rs | 2 +- .../functions/src/core/union_extract.rs | 71 ++-- datafusion/functions/src/datetime/date_bin.rs | 369 +++++++----------- .../functions/src/datetime/date_trunc.rs | 12 +- .../functions/src/datetime/from_unixtime.rs | 9 +- .../functions/src/datetime/make_date.rs | 119 +++--- datafusion/functions/src/datetime/to_char.rs | 35 +- datafusion/functions/src/datetime/to_date.rs | 102 +++-- .../functions/src/datetime/to_local_time.rs | 8 +- .../functions/src/datetime/to_timestamp.rs | 6 +- datafusion/functions/src/math/log.rs | 45 ++- datafusion/functions/src/math/power.rs | 12 +- datafusion/functions/src/math/signum.rs | 6 +- datafusion/functions/src/regex/regexpcount.rs | 180 ++++----- datafusion/functions/src/string/concat.rs | 9 +- datafusion/functions/src/string/concat_ws.rs | 14 +- datafusion/functions/src/string/contains.rs | 6 +- datafusion/functions/src/string/lower.rs | 3 +- datafusion/functions/src/string/upper.rs | 3 +- .../functions/src/unicode/find_in_set.rs | 7 +- datafusion/functions/src/utils.rs | 8 +- .../physical-expr-common/src/physical_expr.rs | 18 +- .../physical-expr-common/src/sort_expr.rs | 2 +- .../physical-expr/src/expressions/binary.rs | 4 - .../physical-expr/src/expressions/case.rs | 8 +- .../physical-expr/src/expressions/cast.rs | 4 +- .../physical-expr/src/expressions/column.rs | 4 +- .../src/expressions/dynamic_filters.rs | 6 +- .../physical-expr/src/expressions/in_list.rs | 4 - .../src/expressions/is_not_null.rs | 2 +- .../physical-expr/src/expressions/is_null.rs | 2 +- .../physical-expr/src/expressions/like.rs | 6 +- .../physical-expr/src/expressions/literal.rs | 7 +- .../physical-expr/src/expressions/negative.rs | 2 +- .../physical-expr/src/expressions/no_op.rs | 5 - .../physical-expr/src/expressions/not.rs | 2 +- .../physical-expr/src/expressions/try_cast.rs | 2 +- .../src/expressions/unknown_column.rs | 5 - .../physical-expr/src/scalar_function.rs | 22 +- .../physical-plan/src/aggregates/mod.rs | 6 +- datafusion/physical-plan/src/projection.rs | 5 +- .../tests/cases/roundtrip_physical_plan.rs | 4 - 73 files changed, 1019 insertions(+), 885 deletions(-) diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index ced5dce36804..fbef524440e4 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -60,7 +60,7 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> { let ctx = create_udf_context(); register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; - let actual = plan_and_collect(&ctx, sql).await.unwrap(); + let actual = plan_and_collect(&ctx, sql).await?; insta::assert_snapshot!(batches_to_string(&actual), @r###" +------------------------------------------+ @@ -79,7 +79,7 @@ async fn csv_query_avg_sqrt() -> Result<()> { register_aggregate_csv(&ctx).await?; // Note it is a different column (c12) than above (c11) let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; - let actual = plan_and_collect(&ctx, sql).await.unwrap(); + let actual = plan_and_collect(&ctx, sql).await?; insta::assert_snapshot!(batches_to_string(&actual), @r###" +------------------------------------------+ @@ -392,7 +392,7 @@ async fn udaf_as_window_func() -> Result<()> { WindowAggr: windowExpr=[[my_acc(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] TableScan: my_table"#; - let dataframe = context.sql(sql).await.unwrap(); + let dataframe = context.sql(sql).await?; assert_eq!(format!("{}", dataframe.logical_plan()), expected); Ok(()) } @@ -402,7 +402,7 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch)?; let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { @@ -446,7 +446,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch)?; let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { @@ -1377,6 +1377,7 @@ async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result, } impl MetadataBasedUdf { @@ -1388,6 +1389,7 @@ impl MetadataBasedUdf { Self { name, signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + metadata, } } } @@ -1406,19 +1408,23 @@ impl ScalarUDFImpl for MetadataBasedUdf { } fn return_type(&self, _args: &[DataType]) -> Result { - Ok(DataType::UInt64) + unimplemented!( + "this should never be called since return_field_from_args is implemented" + ); + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new(self.name(), DataType::UInt64, true) + .with_metadata(self.metadata.clone())) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { assert_eq!(args.arg_fields.len(), 1); - let should_double = match &args.arg_fields[0] { - Some(field) => field - .metadata() - .get("modify_values") - .map(|v| v == "double_output") - .unwrap_or(false), - None => false, - }; + let should_double = args.arg_fields[0] + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false); let mulitplier = if should_double { 2 } else { 1 }; match &args.args[0] { @@ -1557,9 +1563,14 @@ impl ScalarUDFImpl for ExtensionBasedUdf { Ok(DataType::Utf8) } + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new("canonical_extension_udf", DataType::Utf8, true) + .with_extension_type(MyUserExtentionType {})) + } + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { assert_eq!(args.arg_fields.len(), 1); - let input_field = args.arg_fields[0].unwrap(); + let input_field = args.arg_fields[0]; let output_as_bool = matches!( CanonicalExtensionType::try_from(input_field), diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 07bcfc4e5d68..c1b74fedcc32 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -297,7 +297,7 @@ pub struct ScalarFunctionArgs<'a, 'b> { /// The evaluated arguments to the function pub args: Vec, /// Field associated with each arg, if it exists - pub arg_fields: Vec>, + pub arg_fields: Vec<&'a Field>, /// The number of rows in record batch being evaluated pub number_rows: usize, /// The return field of the scalar function returned (from `return_type` diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 95441df77c13..c35a53205eb4 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -21,7 +21,6 @@ use crate::{ util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, volatility::FFI_Volatility, }; -use abi_stable::std_types::ROption; use abi_stable::{ std_types::{RResult, RString, RVec}, StableAbi, @@ -85,7 +84,7 @@ pub struct FFI_ScalarUDF { pub invoke_with_args: unsafe extern "C" fn( udf: &Self, args: RVec, - arg_fields: RVec>, + arg_fields: RVec, num_rows: usize, return_field: WrappedSchema, ) -> RResult, @@ -173,7 +172,7 @@ unsafe extern "C" fn coerce_types_fn_wrapper( unsafe extern "C" fn invoke_with_args_fn_wrapper( udf: &FFI_ScalarUDF, args: RVec, - arg_fields: RVec>, + arg_fields: RVec, number_rows: usize, return_field: WrappedSchema, ) -> RResult { @@ -193,18 +192,10 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( let arg_fields_owned = arg_fields .into_iter() - .map(|maybe_field| { - Option::from(maybe_field.as_ref().map(|wrapped_field| { - (&wrapped_field.0).try_into().map_err(DataFusionError::from) - })) - .transpose() - }) - .collect::>>>(); + .map(|wrapped_field| (&wrapped_field.0).try_into().map_err(DataFusionError::from)) + .collect::>>(); let arg_fields_owned = rresult_return!(arg_fields_owned); - let arg_fields = arg_fields_owned - .iter() - .map(|maybe_map| maybe_map.as_ref()) - .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); let args = ScalarFunctionArgs { args, @@ -366,12 +357,12 @@ impl ScalarUDFImpl for ForeignScalarUDF { let arg_fields_wrapped = arg_fields .iter() - .map(|maybe_field| maybe_field.map(FFI_ArrowSchema::try_from).transpose()) + .map(|field| FFI_ArrowSchema::try_from(*field)) .collect::, ArrowError>>()?; let arg_fields = arg_fields_wrapped .into_iter() - .map(|maybe_field| maybe_field.map(WrappedSchema).into()) + .map(WrappedSchema) .collect::>(); let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?); diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index e24965941b81..d8502e314b31 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -104,7 +104,10 @@ fn criterion_benchmark(c: &mut Criterion) { map_udf() .invoke_with_args(ScalarFunctionArgs { args: vec![keys.clone(), values.clone()], - arg_fields: vec![None; 2], + arg_fields: vec![ + &Field::new("a", keys.data_type(), true), + &Field::new("a", values.data_type(), true), + ], number_rows: 1, return_field, }) diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index 4ab741874a08..fb56b6dd5b97 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -34,13 +34,19 @@ fn criterion_benchmark(c: &mut Criterion) { for str_len in [8, 32, 128, 4096] { // StringArray ASCII only let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); + let arg_fields_owned = args_string_ascii + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( &format!("character_length_StringArray_ascii_str_len_{}", str_len), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), - arg_fields: vec![None; args_string_ascii.len()], + arg_fields: arg_fields.clone(), number_rows: n_rows, return_field: &return_field, })) @@ -50,13 +56,19 @@ fn criterion_benchmark(c: &mut Criterion) { // StringArray UTF8 let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); + let arg_fields_owned = args_string_utf8 + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( &format!("character_length_StringArray_utf8_str_len_{}", str_len), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), - arg_fields: vec![None; args_string_utf8.len()], + arg_fields: arg_fields.clone(), number_rows: n_rows, return_field: &return_field, })) @@ -66,13 +78,19 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray ASCII only let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); + let arg_fields_owned = args_string_view_ascii + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( &format!("character_length_StringViewArray_ascii_str_len_{}", str_len), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), - arg_fields: vec![None; args_string_view_ascii.len()], + arg_fields: arg_fields.clone(), number_rows: n_rows, return_field: &return_field, })) @@ -82,13 +100,19 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray UTF8 let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); + let arg_fields_owned = args_string_view_utf8 + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( &format!("character_length_StringViewArray_utf8_str_len_{}", str_len), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), - arg_fields: vec![None; args_string_view_utf8.len()], + arg_fields: arg_fields.clone(), number_rows: n_rows, return_field: &return_field, })) diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 11a6950ff5e6..874c0d642361 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -50,13 +50,20 @@ fn criterion_benchmark(c: &mut Criterion) { }; let input = Arc::new(input); let args = vec![ColumnarValue::Array(input)]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function("chr", |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), }) diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 05c4ece9b0c5..ea6a9c2eaf63 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -37,6 +37,13 @@ fn create_args(size: usize, str_len: usize) -> Vec { fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let args = create_args(size, 32); + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + let mut group = c.benchmark_group("concat function"); group.bench_function(BenchmarkId::new("concat", size), |b| { b.iter(|| { @@ -45,7 +52,7 @@ fn criterion_benchmark(c: &mut Criterion) { concat() .invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), }) diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index 508f53027c1c..e7ceeca5e9fe 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -33,13 +33,20 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; + let arg_fields_owned = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("cot f32 array: {}", size), |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), - arg_fields: vec![None; f32_args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Float32, true), }) @@ -49,13 +56,20 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; + let arg_fields_owned = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("cot f64 array: {}", size), |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), - arg_fields: vec![None; f64_args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Float64, true), }) diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 36356654b0d1..23eef1a014df 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -55,7 +55,10 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![interval.clone(), timestamps.clone()], - arg_fields: vec![None; 2], + arg_fields: vec![ + &Field::new("a", interval.data_type(), true), + &Field::new("b", timestamps.data_type(), true), + ], number_rows: batch_len, return_field: &return_field, }) diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index 7820c268b93b..780cbae615c2 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -48,6 +48,13 @@ fn criterion_benchmark(c: &mut Criterion) { let timestamps = ColumnarValue::Array(timestamps_array); let udf = date_trunc(); let args = vec![precision, timestamps]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + let return_type = udf .return_type(&args.iter().map(|arg| arg.data_type()).collect::>()) .unwrap(); @@ -56,7 +63,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: batch_len, return_field: &return_field, }) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index 56f235687451..8210de82ef49 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -17,6 +17,7 @@ extern crate criterion; +use arrow::array::Array; use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -33,19 +34,27 @@ fn criterion_benchmark(c: &mut Criterion) { let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], - arg_fields: vec![None; 2], + arg_fields: vec![ + &Field::new("a", str_array.data_type().to_owned(), true), + &Field::new("b", method.data_type().to_owned(), true), + ], number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(); + let arg_fields = vec![ + Field::new("a", encoded.data_type().to_owned(), true), + Field::new("b", method.data_type().to_owned(), true), + ]; let args = vec![encoded, method]; + b.iter(|| { black_box( decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.iter().collect(), number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), }) @@ -56,22 +65,31 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function(&format!("hex_decode/{size}"), |b| { let method = ColumnarValue::Scalar("hex".into()); + let arg_fields = vec![ + Field::new("a", str_array.data_type().to_owned(), true), + Field::new("b", method.data_type().to_owned(), true), + ]; let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(); + let arg_fields = vec![ + Field::new("a", encoded.data_type().to_owned(), true), + Field::new("b", method.data_type().to_owned(), true), + ]; let args = vec![encoded, method]; + b.iter(|| { black_box( decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.iter().collect(), number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), }) diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index f82ff775cd1c..6e60692fc581 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -153,11 +153,16 @@ fn criterion_benchmark(c: &mut Criterion) { group.measurement_time(Duration::from_secs(10)); let args = gen_args_array(n_rows, str_len, 0.1, 0.5, false); + let arg_fields_owned = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); group.bench_function(format!("string_len_{}", str_len), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: n_rows, return_field: &Field::new("f", DataType::Int32, true), })) @@ -165,11 +170,16 @@ fn criterion_benchmark(c: &mut Criterion) { }); let args = gen_args_array(n_rows, str_len, 0.1, 0.5, true); + let arg_fields_owned = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); group.bench_function(format!("string_view_len_{}", str_len), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: n_rows, return_field: &Field::new("f", DataType::Int32, true), })) @@ -181,11 +191,16 @@ fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("find_in_set_scalar"); let args = gen_args_scalar(n_rows, str_len, 0.1, false); + let arg_fields_owned = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); group.bench_function(format!("string_len_{}", str_len), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: n_rows, return_field: &Field::new("f", DataType::Int32, true), })) @@ -193,11 +208,16 @@ fn criterion_benchmark(c: &mut Criterion) { }); let args = gen_args_scalar(n_rows, str_len, 0.1, true); + let arg_fields_owned = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); group.bench_function(format!("string_view_len_{}", str_len), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: n_rows, return_field: &Field::new("f", DataType::Int32, true), })) diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index ab8be59f74fb..c56f6aae1c07 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -48,7 +48,10 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), array_b.clone()], - arg_fields: vec![None; 2], + arg_fields: vec![ + &Field::new("a", array_a.data_type(), true), + &Field::new("b", array_b.data_type(), true), + ], number_rows: 0, return_field: &Field::new("f", DataType::Int64, true), }) @@ -65,7 +68,10 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), scalar_b.clone()], - arg_fields: vec![None; 2], + arg_fields: vec![ + &Field::new("a", array_a.data_type(), true), + &Field::new("b", scalar_b.data_type(), true), + ], number_rows: 0, return_field: &Field::new("f", DataType::Int64, true), }) @@ -82,7 +88,10 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![scalar_a.clone(), scalar_b.clone()], - arg_fields: vec![None; 2], + arg_fields: vec![ + &Field::new("a", scalar_a.data_type(), true), + &Field::new("b", scalar_b.data_type(), true), + ], number_rows: 0, return_field: &Field::new("f", DataType::Int64, true), }) diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index 45ee5c435f5c..ecfafa31eec0 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -49,13 +49,20 @@ fn criterion_benchmark(c: &mut Criterion) { let initcap = unicode::initcap(); for size in [1024, 4096] { let args = create_args::(size, 8, true); + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function( format!("initcap string view shorter than 12 [size={}]", size).as_str(), |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Utf8View, true), })) @@ -70,7 +77,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Utf8View, true), })) @@ -83,7 +90,7 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), })) diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index 0bab5d00cbc8..d1b4412118fe 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -32,13 +32,20 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; + let arg_fields_owned = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("isnan f32 array: {}", size), |b| { b.iter(|| { black_box( isnan .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), - arg_fields: vec![None; f32_args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Boolean, true), }) @@ -48,13 +55,19 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; + let arg_fields_owned = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function(&format!("isnan f64 array: {}", size), |b| { b.iter(|| { black_box( isnan .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), - arg_fields: vec![None; f64_args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Boolean, true), }) diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index f9f85d6d8b7c..3abe776d0201 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -33,13 +33,20 @@ fn criterion_benchmark(c: &mut Criterion) { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); let f32_args = vec![ColumnarValue::Array(f32_array)]; + let arg_fields_owned = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("iszero f32 array: {}", size), |b| { b.iter(|| { black_box( iszero .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), - arg_fields: vec![None; f32_args.len()], + arg_fields: arg_fields.clone(), number_rows: batch_len, return_field: &Field::new("f", DataType::Boolean, true), }) @@ -50,13 +57,20 @@ fn criterion_benchmark(c: &mut Criterion) { let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; + let arg_fields_owned = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("iszero f64 array: {}", size), |b| { b.iter(|| { black_box( iszero .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), - arg_fields: vec![None; f64_args.len()], + arg_fields: arg_fields.clone(), number_rows: batch_len, return_field: &Field::new("f", DataType::Boolean, true), }) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 348c6bf2b6d9..9cab9abbbb4e 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -124,12 +124,19 @@ fn criterion_benchmark(c: &mut Criterion) { let lower = string::lower(); for size in [1024, 4096, 8192] { let args = create_args1(size, 32); + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("lower_all_values_are_ascii: {}", size), |b| { b.iter(|| { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), })) @@ -137,6 +144,13 @@ fn criterion_benchmark(c: &mut Criterion) { }); let args = create_args2(size); + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function( &format!("lower_the_first_value_is_nonascii: {}", size), |b| { @@ -144,7 +158,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), })) @@ -153,6 +167,13 @@ fn criterion_benchmark(c: &mut Criterion) { ); let args = create_args3(size); + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function( &format!("lower_the_middle_value_is_nonascii: {}", size), |b| { @@ -160,7 +181,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), })) @@ -179,6 +200,15 @@ fn criterion_benchmark(c: &mut Criterion) { for &str_len in &str_lens { for &size in &sizes { let args = create_args4(size, str_len, *null_density, mixed); + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true) + }) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function( &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", size, str_len, null_density, mixed), @@ -186,7 +216,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), })) @@ -201,7 +231,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), })) @@ -216,7 +246,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), })) diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index ebe576481bcf..b5371439c166 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -136,6 +136,12 @@ fn run_with_string_type( string_type: StringArrayType, ) { let args = create_args(size, characters, trimmed, remaining_len, string_type); + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); group.bench_function( format!( "{string_type} [size={size}, len_before={len}, len_after={remaining_len}]", @@ -145,7 +151,7 @@ fn run_with_string_type( let args_cloned = args.clone(); black_box(ltrim.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), })) diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index c24fc3bb9ad6..8e0fe5992ffc 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -69,7 +69,11 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![years.clone(), months.clone(), days.clone()], - arg_fields: vec![None; 3], + arg_fields: vec![ + &Field::new("a", years.data_type(), true), + &Field::new("a", months.data_type(), true), + &Field::new("a", days.data_type(), true), + ], number_rows: batch_len, return_field: &Field::new("f", DataType::Date32, true), }) @@ -91,7 +95,11 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), months.clone(), days.clone()], - arg_fields: vec![None; 3], + arg_fields: vec![ + &Field::new("a", year.data_type(), true), + &Field::new("a", months.data_type(), true), + &Field::new("a", days.data_type(), true), + ], number_rows: batch_len, return_field: &Field::new("f", DataType::Date32, true), }) @@ -113,7 +121,11 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), days.clone()], - arg_fields: vec![None; 3], + arg_fields: vec![ + &Field::new("a", year.data_type(), true), + &Field::new("a", month.data_type(), true), + &Field::new("a", days.data_type(), true), + ], number_rows: batch_len, return_field: &Field::new("f", DataType::Date32, true), }) @@ -132,7 +144,11 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), day.clone()], - arg_fields: vec![None; 3], + arg_fields: vec![ + &Field::new("a", year.data_type(), true), + &Field::new("a", month.data_type(), true), + &Field::new("a", day.data_type(), true), + ], number_rows: 1, return_field: &Field::new("f", DataType::Date32, true), }) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index 53f8a3684d87..0e04ad34a905 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -33,13 +33,20 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("abcd".to_string()))), ColumnarValue::Array(array), ]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("nullif scalar array: {}", size), |b| { b.iter(|| { black_box( nullif .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), }) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index d494320a1462..c484a583584c 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -21,6 +21,7 @@ use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode::{lpad, rpad}; use rand::distributions::{Distribution, Uniform}; @@ -95,22 +96,42 @@ fn create_args( } } +fn invoke_pad_with_args( + args: Vec, + number_rows: usize, + left_pad: bool, +) -> Result { + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + + let scalar_args = ScalarFunctionArgs { + args: args.clone(), + arg_fields, + number_rows, + return_field: &Field::new("f", DataType::Utf8, true), + }; + + if left_pad { + lpad().invoke_with_args(scalar_args) + } else { + rpad().invoke_with_args(scalar_args) + } +} + fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 2048] { let mut group = c.benchmark_group("lpad function"); let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("utf8 type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -119,14 +140,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::LargeUtf8, true), - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -135,14 +149,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("stringview type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -155,14 +162,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("utf8 type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); @@ -171,14 +171,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::LargeUtf8, true), - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); @@ -188,14 +181,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("stringview type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::Utf8, true), - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 94514f24af9f..a29dc494047e 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -23,6 +23,7 @@ use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; @@ -56,8 +57,26 @@ fn create_args( } } +fn invoke_repeat_with_args( + args: Vec, + repeat_times: i64, +) -> Result { + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + + string::repeat().invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: repeat_times as usize, + return_field: &Field::new("f", DataType::Utf8, true), + }) +} + fn criterion_benchmark(c: &mut Criterion) { - let repeat = string::repeat(); for size in [1024, 4096] { // REPEAT 3 TIMES let repeat_times = 3; @@ -75,12 +94,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - arg_fields: vec![None; args.len()], - number_rows: repeat_times as usize, - return_field: &Field::new("f", DataType::Utf8, true), - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -94,12 +108,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - arg_fields: vec![None; args.len()], - number_rows: repeat_times as usize, - return_field: &Field::new("f", DataType::Utf8, true), - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -113,12 +122,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - arg_fields: vec![None; args.len()], - number_rows: repeat_times as usize, - return_field: &Field::new("f", DataType::Utf8, true), - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -141,12 +145,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - arg_fields: vec![None; args.len()], - number_rows: repeat_times as usize, - return_field: &Field::new("f", DataType::Utf8, true), - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -160,12 +159,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - arg_fields: vec![None; args.len()], - number_rows: repeat_times as usize, - return_field: &Field::new("f", DataType::Utf8, true), - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -179,12 +173,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - arg_fields: vec![None; args.len()], - number_rows: repeat_times as usize, - return_field: &Field::new("f", DataType::Utf8, true), - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -207,12 +196,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - arg_fields: vec![None; args.len()], - number_rows: repeat_times as usize, - return_field: &Field::new("f", DataType::Utf8, true), - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index 6fcefbf33714..102a4396a62b 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -46,7 +46,11 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), - arg_fields: vec![None; args_string_ascii.len()], + arg_fields: vec![&Field::new( + "a", + args_string_ascii[0].data_type(), + true, + )], number_rows: N_ROWS, return_field: &Field::new("f", DataType::Utf8, true), })) @@ -66,7 +70,11 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), - arg_fields: vec![None; args_string_utf8.len()], + arg_fields: vec![&Field::new( + "a", + args_string_utf8[0].data_type(), + true, + )], number_rows: N_ROWS, return_field: &Field::new("f", DataType::Utf8, true), })) @@ -88,7 +96,11 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), - arg_fields: vec![None; args_string_view_ascii.len()], + arg_fields: vec![&Field::new( + "a", + args_string_view_ascii[0].data_type(), + true, + )], number_rows: N_ROWS, return_field: &Field::new("f", DataType::Utf8, true), })) @@ -108,7 +120,11 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), - arg_fields: vec![None; args_string_view_utf8.len()], + arg_fields: vec![&Field::new( + "a", + args_string_view_utf8[0].data_type(), + true, + )], number_rows: N_ROWS, return_field: &Field::new("f", DataType::Utf8, true), })) diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index d21acbe36414..0cd8b9451f39 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -33,13 +33,20 @@ fn criterion_benchmark(c: &mut Criterion) { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); let f32_args = vec![ColumnarValue::Array(f32_array)]; + let arg_fields_owned = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("signum f32 array: {}", size), |b| { b.iter(|| { black_box( signum .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), - arg_fields: vec![None; f32_args.len()], + arg_fields: arg_fields.clone(), number_rows: batch_len, return_field: &Field::new("f", DataType::Float32, true), }) @@ -51,13 +58,20 @@ fn criterion_benchmark(c: &mut Criterion) { let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; + let arg_fields_owned = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("signum f64 array: {}", size), |b| { b.iter(|| { black_box( signum .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), - arg_fields: vec![None; f64_args.len()], + arg_fields: arg_fields.clone(), number_rows: batch_len, return_field: &Field::new("f", DataType::Float64, true), }) diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index cd7e763bcba0..894f2cdb65eb 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -117,7 +117,11 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), - arg_fields: vec![None; args_string_ascii.len()], + arg_fields: vec![&Field::new( + "a", + args_string_ascii[0].data_type(), + true, + )], number_rows: n_rows, return_field: &Field::new("f", DataType::Int32, true), })) @@ -133,7 +137,11 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), - arg_fields: vec![None; args_string_utf8.len()], + arg_fields: vec![&Field::new( + "a", + args_string_utf8[0].data_type(), + true, + )], number_rows: n_rows, return_field: &Field::new("f", DataType::Int32, true), })) @@ -149,7 +157,11 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), - arg_fields: vec![None; args_string_view_ascii.len()], + arg_fields: vec![&Field::new( + "a", + args_string_view_ascii[0].data_type(), + true, + )], number_rows: n_rows, return_field: &Field::new("f", DataType::Int32, true), })) @@ -165,7 +177,11 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), - arg_fields: vec![None; args_string_view_utf8.len()], + arg_fields: vec![&Field::new( + "a", + args_string_view_utf8[0].data_type(), + true, + )], number_rows: n_rows, return_field: &Field::new("f", DataType::Int32, true), })) diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index d4240ef2b7af..e01e1c6d5b3b 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -23,6 +23,7 @@ use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode; use std::sync::Arc; @@ -96,8 +97,26 @@ fn create_args_with_count( } } +fn invoke_substr_with_args( + args: Vec, + number_rows: usize, +) -> Result { + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + + unicode::substr().invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields, + number_rows, + return_field: &Field::new("f", DataType::Utf8View, true), + }) +} + fn criterion_benchmark(c: &mut Criterion) { - let substr = unicode::substr(); for size in [1024, 4096] { // string_len = 12, substring_len=6 (see `create_args_without_count`) let len = 12; @@ -108,46 +127,19 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args_without_count::(size, len, true, true); group.bench_function( format!("substr_string_view [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::Utf8View, true), - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_without_count::(size, len, false, false); group.bench_function( format!("substr_string [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::Utf8View, true), - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_without_count::(size, len, true, false); group.bench_function( format!("substr_large_string [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::Utf8View, true), - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); @@ -165,16 +157,7 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_string_view [size={}, count={}, strlen={}]", size, count, len, ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::Utf8View, true), - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); @@ -183,16 +166,7 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_string [size={}, count={}, strlen={}]", size, count, len, ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::Utf8View, true), - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); @@ -201,16 +175,7 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_large_string [size={}, count={}, strlen={}]", size, count, len, ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::Utf8View, true), - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); @@ -228,16 +193,7 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_string_view [size={}, count={}, strlen={}]", size, count, len, ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::Utf8View, true), - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); @@ -246,16 +202,7 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_string [size={}, count={}, strlen={}]", size, count, len, ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::Utf8View, true), - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); @@ -264,16 +211,7 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_large_string [size={}, count={}, strlen={}]", size, count, len, ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - arg_fields: vec![None; args.len()], - number_rows: size, - return_field: &Field::new("f", DataType::Utf8View, true), - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index 106b62b1c998..5604f7b6a914 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -91,12 +91,19 @@ fn criterion_benchmark(c: &mut Criterion) { let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); let args = vec![strings, delimiters, counts]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + b.iter(|| { black_box( substr_index() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: batch_len, return_field: &Field::new("f", DataType::Utf8, true), }) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index dc7df4d6698f..9189455fd387 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -93,7 +93,10 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], - arg_fields: vec![None; 2], + arg_fields: vec![ + &Field::new("a", data.data_type(), true), + &Field::new("b", patterns.data_type(), true), + ], number_rows: batch_len, return_field: &Field::new("f", DataType::Utf8, true), }) @@ -115,7 +118,10 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], - arg_fields: vec![None; 2], + arg_fields: vec![ + &Field::new("a", data.data_type(), true), + &Field::new("b", patterns.data_type(), true), + ], number_rows: batch_len, return_field: &Field::new("f", DataType::Utf8, true), }) @@ -143,7 +149,10 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), pattern.clone()], - arg_fields: vec![None; 2], + arg_fields: vec![ + &Field::new("a", data.data_type(), true), + &Field::new("b", pattern.data_type(), true), + ], number_rows: 1, return_field: &Field::new("f", DataType::Utf8, true), }) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index 3514da83d9de..5ab49cb96c58 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -36,7 +36,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_fields: vec![None; i32_args.len()], + arg_fields: vec![&Field::new("a", DataType::Int32, false)], number_rows: batch_len, return_field: &Field::new("f", DataType::Utf8, true), }) @@ -53,7 +53,7 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_fields: vec![None; i64_args.len()], + arg_fields: vec![&Field::new("a", DataType::Int64, false)], number_rows: batch_len, return_field: &Field::new("f", DataType::Utf8, true), }) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index 47f125585b4b..ae38b3ae9df6 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -111,6 +111,8 @@ fn data_with_formats() -> (StringArray, StringArray, StringArray, StringArray) { fn criterion_benchmark(c: &mut Criterion) { let return_field = &Field::new("f", DataType::Timestamp(TimeUnit::Nanosecond, None), true); + let arg_field = Field::new("a", DataType::Utf8, false); + let arg_fields = vec![&arg_field]; c.bench_function("to_timestamp_no_formats_utf8", |b| { let arr_data = data(); let batch_len = arr_data.len(); @@ -121,7 +123,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], - arg_fields: vec![None; 1], + arg_fields: arg_fields.clone(), number_rows: batch_len, return_field, }) @@ -140,7 +142,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], - arg_fields: vec![None; 1], + arg_fields: arg_fields.clone(), number_rows: batch_len, return_field, }) @@ -159,7 +161,7 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], - arg_fields: vec![None; 1], + arg_fields: arg_fields.clone(), number_rows: batch_len, return_field, }) @@ -178,12 +180,19 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::Array(Arc::new(format2) as ArrayRef), ColumnarValue::Array(Arc::new(format3) as ArrayRef), ]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: batch_len, return_field, }) @@ -210,12 +219,19 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(cast(&format3, &DataType::LargeUtf8).unwrap()) as ArrayRef ), ]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: batch_len, return_field, }) @@ -243,12 +259,19 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(cast(&format3, &DataType::Utf8View).unwrap()) as ArrayRef ), ]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), - arg_fields: vec![None; args.len()], + arg_fields: arg_fields.clone(), number_rows: batch_len, return_field, }) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index 2102d5985637..26e4b5f234a4 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -39,7 +39,7 @@ fn criterion_benchmark(c: &mut Criterion) { trunc .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), - arg_fields: vec![None; f32_args.len()], + arg_fields: vec![&Field::new("a", DataType::Float32, false)], number_rows: size, return_field: &Field::new("f", DataType::Float32, true), }) @@ -55,7 +55,7 @@ fn criterion_benchmark(c: &mut Criterion) { trunc .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), - arg_fields: vec![None; f64_args.len()], + arg_fields: vec![&Field::new("a", DataType::Float64, false)], number_rows: size, return_field: &Field::new("f", DataType::Float64, true), }) diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index d58fbe0e16f8..e218f6d0372a 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -42,7 +42,7 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(upper.invoke_with_args(ScalarFunctionArgs { args: args_cloned, - arg_fields: vec![None; args.len()], + arg_fields: vec![&Field::new("a", DataType::Utf8, true)], number_rows: size, return_field: &Field::new("f", DataType::Utf8, true), })) diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index 5cb22d9d1b9e..b1544a9b357b 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -189,48 +189,65 @@ mod tests { ], ); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + None, + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .collect::>(); + let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - None, - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], - arg_fields: vec![None; 2], + args, + arg_fields: arg_fields.iter().collect(), number_rows: 1, return_field: &Field::new("f", DataType::Utf8, true), })?; assert_scalar(result, ScalarValue::Utf8(None)); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((3, Box::new(ScalarValue::Int32(Some(42))))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .collect::>(); + let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - Some((3, Box::new(ScalarValue::Int32(Some(42))))), - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], - arg_fields: vec![None; 2], + args, + arg_fields: arg_fields.iter().collect(), number_rows: 1, return_field: &Field::new("f", DataType::Utf8, true), })?; assert_scalar(result, ScalarValue::Utf8(None)); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((1, Box::new(ScalarValue::new_utf8("42")))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .collect::>(); let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - Some((1, Box::new(ScalarValue::new_utf8("42")))), - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], - arg_fields: vec![None; 2], + args, + arg_fields: arg_fields.iter().collect(), number_rows: 1, return_field: &Field::new("f", DataType::Utf8, true), })?; diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index b58262103d40..ea9e3d091860 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -508,89 +508,82 @@ mod tests { use arrow::datatypes::{DataType, Field, TimeUnit}; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; - use datafusion_common::ScalarValue; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use chrono::TimeDelta; + fn invoke_date_bin_with_args( + args: Vec, + number_rows: usize, + return_field: &Field, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true)) + .collect::>(); + + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields: arg_fields.iter().collect(), + number_rows, + return_field, + }; + DateBinFunc::new().invoke_with_args(args) + } + #[test] fn test_date_bin() { let return_field = - Field::new("f", DataType::Timestamp(TimeUnit::Nanosecond, None), true); - - let mut args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &return_field, - }; - let res = DateBinFunc::new().invoke_with_args(args); + &Field::new("f", DataType::Timestamp(TimeUnit::Nanosecond, None), true); + + let mut args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); let timestamps = Arc::new((1..6).map(Some).collect::()); let batch_len = timestamps.len(); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Array(timestamps), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - arg_fields: vec![None; 3], - number_rows: batch_len, - return_field: &return_field, - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Array(timestamps), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, batch_len, return_field); assert!(res.is_ok()); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - arg_fields: vec![None; 2], - number_rows: 1, - return_field: &return_field, - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // stride supports month-day-nano - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNano { - months: 0, - days: 0, - nanoseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &return_field, - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano { + months: 0, + days: 0, + nanoseconds: 1, + }, + ))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // @@ -598,35 +591,25 @@ mod tests { // // invalid number of arguments - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - )))], - arg_fields: vec![None; 1], - number_rows: 1, - return_field: &return_field, - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + IntervalDayTime { + days: 0, + milliseconds: 1, + }, + )))]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expected two or three arguments" ); // stride: invalid type - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &return_field, - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expects stride argument to be an INTERVAL but got Interval(YearMonth)" @@ -634,119 +617,83 @@ mod tests { // stride: invalid value - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 0, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &return_field, - }; + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 0, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; - let res = DateBinFunc::new().invoke_with_args(args); + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride must be non-zero" ); // stride: overflow of day-time interval - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime::MAX, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &return_field, - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + IntervalDayTime::MAX, + ))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride argument is too large" ); // stride: overflow of month-day-nano interval - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &return_field, - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride argument is too large" ); // stride: month intervals - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1, 1, 1)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &return_field, - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1, 1, 1)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN stride does not support combination of month, day and nanosecond intervals" ); // origin: invalid type - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), - ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &return_field, - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expects origin argument to be a TIMESTAMP with nanosecond precision but got Timestamp(Microsecond, None)" ); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &return_field, - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // unsupported array type for stride @@ -760,17 +707,12 @@ mod tests { }) .collect::(), ); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(intervals), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &return_field, - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Array(intervals), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN only supports literal values for the stride argument, not arrays" @@ -779,22 +721,15 @@ mod tests { // unsupported array type for origin let timestamps = Arc::new((1..6).map(Some).collect::()); let batch_len = timestamps.len(); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Array(timestamps), - ], - arg_fields: vec![None; 3], - number_rows: batch_len, - return_field: &return_field, - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Array(timestamps), + ]; + let res = invoke_date_bin_with_args(args, batch_len, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN only supports literal values for the origin argument, not arrays" @@ -910,24 +845,22 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), - ColumnarValue::Array(Arc::new(input)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - Some(string_to_timestamp_nanos(origin).unwrap()), - tz_opt.clone(), - )), - ], - arg_fields: vec![None; 3], - number_rows: batch_len, - return_field: &Field::new( - "f", - DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), - true, - ), - }; - let result = DateBinFunc::new().invoke_with_args(args).unwrap(); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), + ColumnarValue::Array(Arc::new(input)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(string_to_timestamp_nanos(origin).unwrap()), + tz_opt.clone(), + )), + ]; + let return_field = &Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + ); + let result = + invoke_date_bin_with_args(args, batch_len, return_field).unwrap(); + if let ColumnarValue::Array(result) = result { assert_eq!( result.data_type(), diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index 1b5bac7ea5f7..331c4e093e0c 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -726,12 +726,16 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", input.data_type().clone(), false), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from("day")), ColumnarValue::Array(Arc::new(input)), ], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: batch_len, return_field: &Field::new( "f", @@ -893,12 +897,16 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", input.data_type().clone(), false), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from("hour")), ColumnarValue::Array(Arc::new(input)), ], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: batch_len, return_field: &Field::new( "f", diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index 885285f2473c..1afaf14d52e6 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -170,9 +170,10 @@ mod test { #[test] fn test_without_timezone() { + let arg_field = Field::new("a", DataType::Int64, true); let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(Int64(Some(1729900800)))], - arg_fields: vec![None; 1], + arg_fields: vec![&arg_field], number_rows: 1, return_field: &Field::new("f", DataType::Timestamp(Second, None), true), }; @@ -188,6 +189,10 @@ mod test { #[test] fn test_with_timezone() { + let arg_fields = vec![ + Field::new("a", DataType::Int64, true), + Field::new("a", DataType::Utf8, true), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(Int64(Some(1729900800))), @@ -195,7 +200,7 @@ mod test { "America/New_York".to_string(), ))), ], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: 2, return_field: &Field::new( "f", diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index e6930a4bb708..ed901258cd62 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -224,25 +224,38 @@ mod tests { use crate::datetime::make_date::MakeDateFunc; use arrow::array::{Array, Date32Array, Int32Array, Int64Array, UInt32Array}; use arrow::datatypes::{DataType, Field}; - use datafusion_common::ScalarValue; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use std::sync::Arc; + fn invoke_make_date_with_args( + args: Vec, + number_rows: usize, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true)) + .collect::>(); + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields: arg_fields.iter().collect(), + number_rows, + return_field: &Field::new("f", DataType::Date32, true), + }; + MakeDateFunc::new().invoke_with_args(args) + } + #[test] fn test_make_date() { - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2024))), ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &Field::new("f", DataType::Date32, true), - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -250,19 +263,15 @@ mod tests { panic!("Expected a scalar value") } - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int64(Some(2024))), ColumnarValue::Scalar(ScalarValue::UInt64(Some(1))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &Field::new("f", DataType::Date32, true), - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -270,19 +279,15 @@ mod tests { panic!("Expected a scalar value") } - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some("2024".to_string()))), ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("1".to_string()))), ColumnarValue::Scalar(ScalarValue::Utf8(Some("14".to_string()))), ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &Field::new("f", DataType::Date32, true), - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -294,19 +299,15 @@ mod tests { let months = Arc::new((1..5).map(Some).collect::()); let days = Arc::new((11..15).map(Some).collect::()); let batch_len = years.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Array(years), ColumnarValue::Array(months), ColumnarValue::Array(days), ], - arg_fields: vec![None; 3], - number_rows: batch_len, - return_field: &Field::new("f", DataType::Date32, true), - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + batch_len, + ) + .unwrap(); if let ColumnarValue::Array(array) = res { assert_eq!(array.len(), 4); @@ -325,64 +326,52 @@ mod tests { // // invalid number of arguments - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], - arg_fields: vec![None; 1], - number_rows: 1, - return_field: &Field::new("f", DataType::Date32, true), - }; - let res = MakeDateFunc::new().invoke_with_args(args); + let res = invoke_make_date_with_args( + vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: make_date function requires 3 arguments, got 1" ); // invalid type - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &Field::new("f", DataType::Date32, true), - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Casting from Interval(YearMonth) to Int32 not supported" ); // overflow of month - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), ColumnarValue::Scalar(ScalarValue::UInt64(Some(u64::MAX))), ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &Field::new("f", DataType::Date32, true), - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Can't cast value 18446744073709551615 to type Int32" ); // overflow of day - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(u32::MAX))), ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &Field::new("f", DataType::Date32, true), - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Can't cast value 4294967295 to type Int32" diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 6d51a06d2bf7..be3917092ba9 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -303,7 +303,7 @@ mod tests { TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }; - use arrow::datatypes::{DataType, Field}; + use arrow::datatypes::{DataType, Field, TimeUnit}; use chrono::{NaiveDateTime, Timelike}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -385,9 +385,13 @@ mod tests { ]; for (value, format, expected) in scalar_data { + let arg_fields = vec![ + Field::new("a", value.data_type(), false), + Field::new("a", format.data_type(), false), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: 1, return_field: &Field::new("f", DataType::Utf8, true), }; @@ -466,12 +470,16 @@ mod tests { for (value, format, expected) in scalar_array_data { let batch_len = format.len(); + let arg_fields = vec![ + Field::new("a", value.data_type(), false), + Field::new("a", format.data_type().to_owned(), false), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: batch_len, return_field: &Field::new("f", DataType::Utf8, true), }; @@ -598,12 +606,16 @@ mod tests { for (value, format, expected) in array_scalar_data { let batch_len = value.len(); + let arg_fields = vec![ + Field::new("a", value.data_type().clone(), false), + Field::new("a", format.data_type(), false), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value as ArrayRef), ColumnarValue::Scalar(format), ], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: batch_len, return_field: &Field::new("f", DataType::Utf8, true), }; @@ -621,12 +633,16 @@ mod tests { for (value, format, expected) in array_array_data { let batch_len = value.len(); + let arg_fields = vec![ + Field::new("a", value.data_type().clone(), false), + Field::new("a", format.data_type().clone(), false), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: batch_len, return_field: &Field::new("f", DataType::Utf8, true), }; @@ -647,9 +663,10 @@ mod tests { // // invalid number of arguments + let arg_field = Field::new("a", DataType::Int32, true); let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], - arg_fields: vec![None; 1], + arg_fields: vec![&arg_field], number_rows: 1, return_field: &Field::new("f", DataType::Utf8, true), }; @@ -660,12 +677,16 @@ mod tests { ); // invalid type + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true), + Field::new("a", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: 1, return_field: &Field::new("f", DataType::Utf8, true), }; diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index a6af727291d3..09635932760d 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -165,12 +165,30 @@ mod tests { use arrow::array::{Array, Date32Array, GenericStringArray, StringViewArray}; use arrow::datatypes::{DataType, Field}; use arrow::{compute::kernels::cast_utils::Parser, datatypes::Date32Type}; - use datafusion_common::ScalarValue; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use std::sync::Arc; use super::ToDateFunc; + fn invoke_to_date_with_args( + args: Vec, + number_rows: usize, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true)) + .collect::>(); + + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields: arg_fields.iter().collect(), + number_rows, + return_field: &Field::new("f", DataType::Date32, true), + }; + ToDateFunc::new().invoke_with_args(args) + } + #[test] fn test_to_date_without_format() { struct TestCase { @@ -208,13 +226,8 @@ mod tests { } fn test_scalar(sv: ScalarValue, tc: &TestCase) { - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(sv)], - arg_fields: vec![None; 1], - number_rows: 1, - return_field: &Field::new("f", DataType::Date32, true), - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(sv)], 1); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -235,13 +248,10 @@ mod tests { { let date_array = A::from(vec![tc.date_str]); let batch_len = date_array.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Array(Arc::new(date_array))], - arg_fields: vec![None; 1], - number_rows: batch_len, - return_field: &Field::new("f", DataType::Date32, true), - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = invoke_to_date_with_args( + vec![ColumnarValue::Array(Arc::new(date_array))], + batch_len, + ); match to_date_result { Ok(ColumnarValue::Array(a)) => { @@ -330,16 +340,13 @@ mod tests { fn test_scalar(sv: ScalarValue, tc: &TestCase) { let format_scalar = ScalarValue::Utf8(Some(tc.format_str.to_string())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Scalar(sv), ColumnarValue::Scalar(format_scalar), ], - arg_fields: vec![None; 2], - number_rows: 1, - return_field: &Field::new("f", DataType::Date32, true), - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -361,16 +368,13 @@ mod tests { let format_array = A::from(vec![tc.format_str]); let batch_len = date_array.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Array(Arc::new(date_array)), ColumnarValue::Array(Arc::new(format_array)), ], - arg_fields: vec![None; 2], - number_rows: batch_len, - return_field: &Field::new("f", DataType::Date32, true), - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + batch_len, + ); match to_date_result { Ok(ColumnarValue::Array(a)) => { @@ -402,17 +406,14 @@ mod tests { let format1_scalar = ScalarValue::Utf8(Some("%Y-%m-%d".into())); let format2_scalar = ScalarValue::Utf8(Some("%Y/%m/%d".into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Scalar(formatted_date_scalar), ColumnarValue::Scalar(format1_scalar), ColumnarValue::Scalar(format2_scalar), ], - arg_fields: vec![None; 3], - number_rows: 1, - return_field: &Field::new("f", DataType::Date32, true), - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -436,13 +437,10 @@ mod tests { for date_str in test_cases { let formatted_date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(formatted_date_scalar)], - arg_fields: vec![None; 1], - number_rows: 1, - return_field: &Field::new("f", DataType::Date32, true), - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = invoke_to_date_with_args( + vec![ColumnarValue::Scalar(formatted_date_scalar)], + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -459,13 +457,8 @@ mod tests { let date_str = "20241231"; let date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(date_scalar)], - arg_fields: vec![None; 1], - number_rows: 1, - return_field: &Field::new("f", DataType::Date32, true), - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(date_scalar)], 1); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -485,13 +478,8 @@ mod tests { let date_str = "202412311"; let date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(date_scalar)], - arg_fields: vec![None; 1], - number_rows: 1, - return_field: &Field::new("f", DataType::Date32, true), - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(date_scalar)], 1); if let Ok(ColumnarValue::Scalar(ScalarValue::Date32(_))) = to_date_result { panic!( diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 3e94008cead6..5cf9b785b503 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -407,7 +407,7 @@ impl ScalarUDFImpl for ToLocalTimeFunc { mod tests { use std::sync::Arc; - use arrow::array::{types::TimestampNanosecondType, TimestampNanosecondArray}; + use arrow::array::{types::TimestampNanosecondType, Array, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; use arrow::datatypes::{DataType, Field, TimeUnit}; use chrono::NaiveDateTime; @@ -538,10 +538,11 @@ mod tests { } fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { + let arg_field = Field::new("a", input.data_type(), true); let res = ToLocalTimeFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(input)], - arg_fields: vec![None; 1], + arg_fields: vec![&arg_field], number_rows: 1, return_field: &Field::new("f", expected.data_type(), true), }) @@ -603,9 +604,10 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::(); let batch_size = input.len(); + let arg_field = Field::new("a", input.data_type().clone(), true); let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::new(input))], - arg_fields: vec![None; 1], + arg_fields: vec![&arg_field], number_rows: batch_size, return_field: &Field::new( "f", diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index d42467fcde4c..c6aab61328eb 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -1012,10 +1012,11 @@ mod tests { for udf in &udfs { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); + let arg_field = Field::new("arg", array.data_type().clone(), true); assert!(matches!(rt, Timestamp(_, Some(_)))); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], - arg_fields: vec![None; 1], + arg_fields: vec![&arg_field], number_rows: 4, return_field: &Field::new("f", rt, true), }; @@ -1061,9 +1062,10 @@ mod tests { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); assert!(matches!(rt, Timestamp(_, None))); + let arg_field = Field::new("arg", array.data_type().clone(), true); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], - arg_fields: vec![None; 1], + arg_fields: vec![&arg_field], number_rows: 5, return_field: &Field::new("f", rt, true), }; diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index 4e7d48f19d31..d1f40e3b1ad1 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -265,6 +265,10 @@ mod tests { #[test] #[should_panic] fn test_log_invalid_base_type() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false), + Field::new("a", DataType::Int64, false), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ @@ -272,7 +276,7 @@ mod tests { ]))), // num ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), ], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: 4, return_field: &Field::new("f", DataType::Float64, true), }; @@ -281,11 +285,12 @@ mod tests { #[test] fn test_log_invalid_value() { + let arg_field = Field::new("a", DataType::Int64, false); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num ], - arg_fields: vec![None; 1], + arg_fields: vec![&arg_field], number_rows: 1, return_field: &Field::new("f", DataType::Float64, true), }; @@ -296,11 +301,12 @@ mod tests { #[test] fn test_log_scalar_f32_unary() { + let arg_field = Field::new("a", DataType::Float32, false); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num ], - arg_fields: vec![None; 1], + arg_fields: vec![&arg_field], number_rows: 1, return_field: &Field::new("f", DataType::Float32, true), }; @@ -324,11 +330,12 @@ mod tests { #[test] fn test_log_scalar_f64_unary() { + let arg_field = Field::new("a", DataType::Float64, false); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num ], - arg_fields: vec![None; 1], + arg_fields: vec![&arg_field], number_rows: 1, return_field: &Field::new("f", DataType::Float64, true), }; @@ -352,12 +359,16 @@ mod tests { #[test] fn test_log_scalar_f32() { + let arg_fields = vec![ + Field::new("a", DataType::Float32, false), + Field::new("a", DataType::Float32, false), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num ], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: 1, return_field: &Field::new("f", DataType::Float32, true), }; @@ -381,12 +392,16 @@ mod tests { #[test] fn test_log_scalar_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false), + Field::new("a", DataType::Float64, false), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num ], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: 1, return_field: &Field::new("f", DataType::Float64, true), }; @@ -410,13 +425,14 @@ mod tests { #[test] fn test_log_f64_unary() { + let arg_field = Field::new("a", DataType::Float64, false); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], - arg_fields: vec![None; 1], + arg_fields: vec![&arg_field], number_rows: 4, return_field: &Field::new("f", DataType::Float64, true), }; @@ -443,13 +459,14 @@ mod tests { #[test] fn test_log_f32_unary() { + let arg_field = Field::new("a", DataType::Float32, false); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float32Array::from(vec![ 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], - arg_fields: vec![None; 1], + arg_fields: vec![&arg_field], number_rows: 4, return_field: &Field::new("f", DataType::Float32, true), }; @@ -476,6 +493,10 @@ mod tests { #[test] fn test_log_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false), + Field::new("a", DataType::Float64, false), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ @@ -485,7 +506,7 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: 4, return_field: &Field::new("f", DataType::Float64, true), }; @@ -512,6 +533,10 @@ mod tests { #[test] fn test_log_f32() { + let arg_fields = vec![ + Field::new("a", DataType::Float32, false), + Field::new("a", DataType::Float32, false), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float32Array::from(vec![ @@ -521,7 +546,7 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: 4, return_field: &Field::new("f", DataType::Float32, true), }; diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index e3bc0972e592..8876e3fe2787 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -194,6 +194,10 @@ mod tests { #[test] fn test_power_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, true), + Field::new("a", DataType::Float64, true), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ @@ -203,7 +207,7 @@ mod tests { 3.0, 2.0, 4.0, 4.0, ]))), // exponent ], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: 4, return_field: &Field::new("f", DataType::Float64, true), }; @@ -229,12 +233,16 @@ mod tests { #[test] fn test_power_i64() { + let arg_fields = vec![ + Field::new("a", DataType::Int64, true), + Field::new("a", DataType::Int64, true), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent ], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: 4, return_field: &Field::new("f", DataType::Int64, true), }; diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index 6d75af27a3e0..7414e6e138ab 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -157,9 +157,10 @@ mod test { f32::INFINITY, f32::NEG_INFINITY, ])); + let arg_fields = [Field::new("a", DataType::Float32, false)]; let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], - arg_fields: vec![None; 1], + arg_fields: arg_fields.iter().collect(), number_rows: array.len(), return_field: &Field::new("f", DataType::Float32, true), }; @@ -202,9 +203,10 @@ mod test { f64::INFINITY, f64::NEG_INFINITY, ])); + let arg_fields = [Field::new("a", DataType::Float64, false)]; let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], - arg_fields: vec![None; 1], + arg_fields: arg_fields.iter().collect(), number_rows: array.len(), return_field: &Field::new("f", DataType::Float64, true), }; diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index bfb203e2f665..d536d3531af4 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -648,6 +648,27 @@ mod tests { test_case_regexp_count_cache_check::>(); } + fn regexp_count_with_scalar_values(args: &[ScalarValue]) -> Result { + let args_values = args + .iter() + .map(|sv| ColumnarValue::Scalar(sv.clone())) + .collect(); + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + + RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: args_values, + arg_fields, + number_rows: args.len(), + return_field: &Field::new("f", Int64, true), + }) + } + fn test_case_sensitive_regexp_count_scalar() { let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; let regex = "abc"; @@ -658,12 +679,7 @@ mod tests { let v_sv = ScalarValue::Utf8(Some(v.to_string())); let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - arg_fields: vec![None; 2], - number_rows: 2, - return_field: &Field::new("f", Int64, true), - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -674,12 +690,7 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - arg_fields: vec![None; 2], - number_rows: 2, - return_field: &Field::new("f", Int64, true), - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -690,12 +701,7 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - arg_fields: vec![None; 2], - number_rows: 2, - return_field: &Field::new("f", Int64, true), - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -717,16 +723,7 @@ mod tests { let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let start_sv = ScalarValue::Int64(Some(start)); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - arg_fields: vec![None; 3], - number_rows: 3, - return_field: &Field::new("f", Int64, true), - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -737,16 +734,7 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - arg_fields: vec![None; 3], - number_rows: 3, - return_field: &Field::new("f", Int64, true), - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -757,16 +745,7 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - arg_fields: vec![None; 3], - number_rows: 3, - return_field: &Field::new("f", Int64, true), - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -790,17 +769,13 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(Some(flags.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - arg_fields: vec![None; 4], - number_rows: 4, - return_field: &Field::new("f", Int64, true), - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -812,17 +787,13 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - arg_fields: vec![None; 4], - number_rows: 4, - return_field: &Field::new("f", Int64, true), - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -834,17 +805,13 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); let flags_sv = ScalarValue::Utf8View(Some(flags.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - arg_fields: vec![None; 4], - number_rows: 4, - return_field: &Field::new("f", Int64, true), - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -917,17 +884,12 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - arg_fields: vec![None; 4], - number_rows: 4, - return_field: &Field::new("f", Int64, true), - }); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -939,17 +901,12 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - arg_fields: vec![None; 4], - number_rows: 4, - return_field: &Field::new("f", Int64, true), - }); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -961,17 +918,12 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - arg_fields: vec![None; 4], - number_rows: 4, - return_field: &Field::new("f", Int64, true), - }); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index a46b79be74dc..fe0a5915fe20 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -469,10 +469,17 @@ mod tests { None, Some("b"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + Field::new("a", Utf8View, true), + Field::new("a", Utf8View, true), + ]; let args = ScalarFunctionArgs { args: vec![c0, c1, c2, c3, c4], - arg_fields: vec![None; 5], + arg_fields: arg_fields.iter().collect(), number_rows: 3, return_field: &Field::new("f", Utf8, true), }; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index 0de16ad4fec8..79a5d34fb4c4 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -481,9 +481,14 @@ mod tests { Some("z"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + ]; let args = ScalarFunctionArgs { args: vec![c0, c1, c2], - arg_fields: vec![None; 3], + arg_fields: arg_fields.iter().collect(), number_rows: 3, return_field: &Field::new("f", Utf8, true), }; @@ -512,9 +517,14 @@ mod tests { Some("z"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + ]; let args = ScalarFunctionArgs { args: vec![c0, c1, c2], - arg_fields: vec![None; 3], + arg_fields: arg_fields.iter().collect(), number_rows: 3, return_field: &Field::new("f", Utf8, true), }; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index f60514427431..00fd20ff3479 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -164,10 +164,14 @@ mod test { Some("yyy?()"), ]))); let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string()))); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true), + Field::new("a", DataType::Utf8, true), + ]; let args = ScalarFunctionArgs { args: vec![array, scalar], - arg_fields: vec![None; 2], + arg_fields: arg_fields.iter().collect(), number_rows: 2, return_field: &Field::new("f", DataType::Boolean, true), }; diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index df1d378d0c16..1dc6e9d28367 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -104,11 +104,12 @@ mod tests { fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = LowerFunc::new(); + let arg_fields = [Field::new("a", input.data_type().clone(), true)]; let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - arg_fields: vec![None; 1], + arg_fields: arg_fields.iter().collect(), return_field: &Field::new("f", Utf8, true), }; diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index acff5049ab14..06a9bd9720d6 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -104,10 +104,11 @@ mod tests { fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = UpperFunc::new(); + let arg_field = Field::new("a", input.data_type().clone(), true); let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - arg_fields: vec![None; 1], + arg_fields: vec![&arg_field], return_field: &Field::new("f", Utf8, true), }; diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index 8e80dd843511..6b5df89e860f 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -471,7 +471,12 @@ mod tests { }) .unwrap_or(1); let return_type = fis.return_type(&type_array)?; - let arg_fields = vec![None; args.len()]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); let result = fis.invoke_with_args(ScalarFunctionArgs { args, arg_fields, diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index a81ea96951d8..0d6367565921 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -161,6 +161,10 @@ pub mod test { arg_fields: &field_array, scalar_arguments: &scalar_arguments_refs, }); + let arg_fields_owned = $ARGS.iter() + .enumerate() + .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true)) + .collect::>(); match expected { Ok(expected) => { @@ -169,7 +173,7 @@ pub mod test { let return_type = return_field.data_type(); assert_eq!(return_type, &$EXPECTED_DATA_TYPE); - let arg_fields = vec![None; $ARGS.len()]; + let arg_fields = arg_fields_owned.iter().collect::>(); let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field: &return_field}); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); @@ -193,7 +197,7 @@ pub mod test { else { let return_field = return_field.unwrap(); - let arg_fields = vec![None; $ARGS.len()]; + let arg_fields = arg_fields_owned.iter().collect::>(); // invoke is expected error - cannot use .expect_err() due to Debug not being implemented match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field: &return_field}) { Ok(_) => assert!(false, "expected error"), diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 09d372db2425..b937534a6da8 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -71,13 +71,23 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; /// Get the data type of this expression, given the schema of the input - fn data_type(&self, input_schema: &Schema) -> Result; + fn data_type(&self, input_schema: &Schema) -> Result { + Ok(self.output_field(input_schema)?.data_type().to_owned()) + } /// Determine whether this expression is nullable, given the schema of the input - fn nullable(&self, input_schema: &Schema) -> Result; + fn nullable(&self, input_schema: &Schema) -> Result { + Ok(self.output_field(input_schema)?.is_nullable()) + } /// Evaluate an expression against a RecordBatch fn evaluate(&self, batch: &RecordBatch) -> Result; /// The output field associated with this expression - fn output_field(&self, input_schema: &Schema) -> Result>; + fn output_field(&self, input_schema: &Schema) -> Result { + Ok(Field::new( + format!("{self}"), + self.data_type(input_schema)?, + self.nullable(input_schema)?, + )) + } /// Evaluate an expression against a RecordBatch after first applying a /// validity array fn evaluate_selection( @@ -469,7 +479,7 @@ where /// # fn data_type(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn nullable(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result { unimplemented!() } -/// # fn output_field(&self, input_schema: &Schema) -> Result> { unimplemented!() } +/// # fn output_field(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn children(&self) -> Vec<&Arc>{ unimplemented!() } /// # fn with_new_children(self: Arc, children: Vec>) -> Result> { unimplemented!() } /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE a > b THEN 1 ELSE 0 END") } diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 04c7719ecfa1..ce9550ada7e7 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -57,7 +57,7 @@ use itertools::Itertools; /// # fn data_type(&self, input_schema: &Schema) -> Result {todo!()} /// # fn nullable(&self, input_schema: &Schema) -> Result {todo!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result {todo!() } -/// # fn output_field(&self, input_schema: &Schema) -> Result> { unimplemented!() } +/// # fn output_field(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn children(&self) -> Vec<&Arc> {todo!()} /// # fn with_new_children(self: Arc, children: Vec>) -> Result> {todo!()} /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { todo!() } diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 0b0b25e68ef8..3e59ed5a07d4 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -432,10 +432,6 @@ impl PhysicalExpr for BinaryExpr { .map(ColumnarValue::Array) } - fn output_field(&self, _input_schema: &Schema) -> Result> { - Ok(None) - } - fn children(&self) -> Vec<&Arc> { vec![&self.left, &self.right] } diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 291dc7788e3b..1a74e78f1075 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -24,7 +24,7 @@ use std::{any::Any, sync::Arc}; use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Schema}; use datafusion_common::cast::as_boolean_array; use datafusion_common::{ exec_err, internal_datafusion_err, internal_err, DataFusionError, Result, ScalarValue, @@ -513,10 +513,6 @@ impl PhysicalExpr for CaseExpr { } } - fn output_field(&self, _input_schema: &Schema) -> Result> { - Ok(None) - } - fn children(&self) -> Vec<&Arc> { let mut children = vec![]; if let Some(expr) = &self.expr { @@ -606,7 +602,7 @@ mod tests { use crate::expressions::{binary, cast, col, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; - + use arrow::datatypes::Field; use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 8183a36ca52a..d937647e5556 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -144,11 +144,11 @@ impl PhysicalExpr for CastExpr { value.cast_to(&self.cast_type, Some(&self.cast_options)) } - fn output_field(&self, input_schema: &Schema) -> Result> { + fn output_field(&self, input_schema: &Schema) -> Result { Ok(self .expr .output_field(input_schema)? - .map(|f| f.with_data_type(self.cast_type.clone()))) + .with_data_type(self.cast_type.clone())) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index b88880ea7670..039924b95e71 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -127,8 +127,8 @@ impl PhysicalExpr for Column { Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } - fn output_field(&self, input_schema: &Schema) -> Result> { - Ok(Some(input_schema.field(self.index).clone())) + fn output_field(&self, input_schema: &Schema) -> Result { + Ok(input_schema.field(self.index).clone()) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/dynamic_filters.rs b/datafusion/physical-expr/src/expressions/dynamic_filters.rs index 2cf17eb72a86..c0a3285f0e78 100644 --- a/datafusion/physical-expr/src/expressions/dynamic_filters.rs +++ b/datafusion/physical-expr/src/expressions/dynamic_filters.rs @@ -23,7 +23,7 @@ use std::{ }; use crate::PhysicalExpr; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Schema}; use datafusion_common::{ tree_node::{Transformed, TransformedResult, TreeNode}, Result, @@ -291,10 +291,6 @@ impl PhysicalExpr for DynamicFilterPhysicalExpr { // Return the current expression as a snapshot. Ok(Some(self.current()?)) } - - fn output_field(&self, _input_schema: &Schema) -> Result> { - Ok(None) - } } #[cfg(test)] diff --git a/datafusion/physical-expr/src/expressions/in_list.rs b/datafusion/physical-expr/src/expressions/in_list.rs index 704cf47879f9..469f7bbee317 100644 --- a/datafusion/physical-expr/src/expressions/in_list.rs +++ b/datafusion/physical-expr/src/expressions/in_list.rs @@ -379,10 +379,6 @@ impl PhysicalExpr for InListExpr { Ok(ColumnarValue::Array(Arc::new(r))) } - fn output_field(&self, _input_schema: &Schema) -> Result> { - Ok(None) - } - fn children(&self) -> Vec<&Arc> { let mut children = vec![]; children.push(&self.expr); diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index f16114259c0f..5cecadd40c41 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -94,7 +94,7 @@ impl PhysicalExpr for IsNotNullExpr { } } - fn output_field(&self, input_schema: &Schema) -> Result> { + fn output_field(&self, input_schema: &Schema) -> Result { self.arg.output_field(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 32ef64fe0230..befd3a2ea57a 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -92,7 +92,7 @@ impl PhysicalExpr for IsNullExpr { } } - fn output_field(&self, input_schema: &Schema) -> Result> { + fn output_field(&self, input_schema: &Schema) -> Result { self.arg.output_field(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index fa51def9eb09..e86c778d5161 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -16,7 +16,7 @@ // under the License. use crate::PhysicalExpr; -use arrow::datatypes::{DataType, Field, Schema}; +use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; @@ -129,10 +129,6 @@ impl PhysicalExpr for LikeExpr { } } - fn output_field(&self, _input_schema: &Schema) -> Result> { - Ok(None) - } - fn children(&self) -> Vec<&Arc> { vec![&self.expr, &self.pattern] } diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 08c4bd38b383..6f7caaea8d45 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -23,7 +23,6 @@ use std::sync::Arc; use crate::physical_expr::PhysicalExpr; -use arrow::datatypes::Field; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -76,10 +75,6 @@ impl PhysicalExpr for Literal { Ok(ColumnarValue::Scalar(self.value.clone())) } - fn output_field(&self, _input_schema: &Schema) -> Result> { - Ok(None) - } - fn children(&self) -> Vec<&Arc> { vec![] } @@ -117,7 +112,7 @@ mod tests { use super::*; use arrow::array::Int32Array; - + use arrow::datatypes::Field; use datafusion_common::cast::as_int32_array; use datafusion_physical_expr_common::physical_expr::fmt_sql; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index fe7b7cab429f..5097ce0f6560 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -104,7 +104,7 @@ impl PhysicalExpr for NegativeExpr { } } - fn output_field(&self, input_schema: &Schema) -> Result> { + fn output_field(&self, input_schema: &Schema) -> Result { self.arg.output_field(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index c918aa5f23f2..94610996c6b0 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -22,7 +22,6 @@ use std::hash::Hash; use std::sync::Arc; use crate::PhysicalExpr; -use arrow::datatypes::Field; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -67,10 +66,6 @@ impl PhysicalExpr for NoOp { internal_err!("NoOp::evaluate() should not be called") } - fn output_field(&self, _input_schema: &Schema) -> Result> { - Ok(None) - } - fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index d23b708efd36..852b4e4fa780 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -101,7 +101,7 @@ impl PhysicalExpr for NotExpr { } } - fn output_field(&self, input_schema: &Schema) -> Result> { + fn output_field(&self, input_schema: &Schema) -> Result { self.arg.output_field(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index e12f5af94360..e56ba376d3c6 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -110,7 +110,7 @@ impl PhysicalExpr for TryCastExpr { } } - fn output_field(&self, input_schema: &Schema) -> Result> { + fn output_field(&self, input_schema: &Schema) -> Result { self.expr.output_field(input_schema) } diff --git a/datafusion/physical-expr/src/expressions/unknown_column.rs b/datafusion/physical-expr/src/expressions/unknown_column.rs index f0c18e785f28..2face4eb6bdb 100644 --- a/datafusion/physical-expr/src/expressions/unknown_column.rs +++ b/datafusion/physical-expr/src/expressions/unknown_column.rs @@ -23,7 +23,6 @@ use std::sync::Arc; use crate::PhysicalExpr; -use arrow::datatypes::Field; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -77,10 +76,6 @@ impl PhysicalExpr for UnKnownColumn { internal_err!("UnKnownColumn::evaluate() should not be called") } - fn output_field(&self, _input_schema: &Schema) -> Result> { - Ok(None) - } - fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index b3a66dafb8bc..fb3b0ca9d179 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -92,18 +92,7 @@ impl ScalarFunctionExpr { let name = fun.name().to_string(); let arg_fields = args .iter() - .enumerate() - .map(|(idx, e)| { - e.output_field(schema).and_then(|maybe_field| { - Ok(maybe_field.unwrap_or({ - Field::new( - format!("field_{idx}"), - e.data_type(schema)?, - e.nullable(schema)?, - ) - })) - }) - }) + .map(|e| e.output_field(schema)) .collect::>>()?; // verify that input data types is consistent with function's `TypeSignature` @@ -196,10 +185,7 @@ impl PhysicalExpr for ScalarFunctionExpr { .iter() .map(|e| e.output_field(batch.schema_ref())) .collect::>>()?; - let arg_fields = arg_fields_owned - .iter() - .map(|opt_map| opt_map.as_ref()) - .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); let input_empty = args.is_empty(); let input_all_scalar = args @@ -231,8 +217,8 @@ impl PhysicalExpr for ScalarFunctionExpr { Ok(output) } - fn output_field(&self, _input_schema: &Schema) -> Result> { - Ok(Some(self.return_field.clone())) + fn output_field(&self, _input_schema: &Schema) -> Result { + Ok(self.return_field.clone()) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 785752c4cd08..6a24a5732147 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -284,11 +284,7 @@ impl PhysicalGroupBy { expr.data_type(input_schema)?, group_expr_nullable || expr.nullable(input_schema)?, ) - .with_metadata( - expr.output_field(input_schema)? - .map(|f| f.metadata().clone()) - .unwrap_or_default(), - ), + .with_metadata(expr.output_field(input_schema)?.metadata().clone()), ); } if !self.is_single() { diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index b5be14427bd0..a6698a4e4728 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -79,10 +79,7 @@ impl ProjectionExec { let fields: Result> = expr .iter() .map(|(e, name)| { - let metadata = e - .output_field(&input_schema)? - .map(|field| field.metadata().clone()) - .unwrap_or_default(); + let metadata = e.output_field(&input_schema)?.metadata().clone(); let field = Field::new( name, diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index 7b03939d754c..c26fdaa9fca8 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -864,10 +864,6 @@ fn roundtrip_parquet_exec_with_custom_predicate_expr() -> Result<()> { unreachable!() } - fn output_field(&self, _input_schema: &Schema) -> Result> { - Ok(None) - } - fn children(&self) -> Vec<&Arc> { vec![&self.inner] } From 317ed22445c36ec7da476b8e452672a058746793 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 21 Apr 2025 15:00:51 -0400 Subject: [PATCH 22/25] Change name from output_field to return_field to be more consistent --- datafusion/physical-expr-common/src/physical_expr.rs | 8 ++++---- datafusion/physical-expr/src/expressions/cast.rs | 4 ++-- datafusion/physical-expr/src/expressions/column.rs | 2 +- datafusion/physical-expr/src/expressions/is_not_null.rs | 4 ++-- datafusion/physical-expr/src/expressions/is_null.rs | 4 ++-- datafusion/physical-expr/src/expressions/negative.rs | 4 ++-- datafusion/physical-expr/src/expressions/not.rs | 4 ++-- datafusion/physical-expr/src/expressions/try_cast.rs | 6 ++++-- datafusion/physical-expr/src/scalar_function.rs | 6 +++--- datafusion/physical-plan/src/aggregates/mod.rs | 2 +- datafusion/physical-plan/src/projection.rs | 2 +- 11 files changed, 24 insertions(+), 22 deletions(-) diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index b937534a6da8..685398c352e2 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -72,16 +72,16 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { fn as_any(&self) -> &dyn Any; /// Get the data type of this expression, given the schema of the input fn data_type(&self, input_schema: &Schema) -> Result { - Ok(self.output_field(input_schema)?.data_type().to_owned()) + Ok(self.return_field(input_schema)?.data_type().to_owned()) } /// Determine whether this expression is nullable, given the schema of the input fn nullable(&self, input_schema: &Schema) -> Result { - Ok(self.output_field(input_schema)?.is_nullable()) + Ok(self.return_field(input_schema)?.is_nullable()) } /// Evaluate an expression against a RecordBatch fn evaluate(&self, batch: &RecordBatch) -> Result; /// The output field associated with this expression - fn output_field(&self, input_schema: &Schema) -> Result { + fn return_field(&self, input_schema: &Schema) -> Result { Ok(Field::new( format!("{self}"), self.data_type(input_schema)?, @@ -479,7 +479,7 @@ where /// # fn data_type(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn nullable(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result { unimplemented!() } -/// # fn output_field(&self, input_schema: &Schema) -> Result { unimplemented!() } +/// # fn return_field(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn children(&self) -> Vec<&Arc>{ unimplemented!() } /// # fn with_new_children(self: Arc, children: Vec>) -> Result> { unimplemented!() } /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE a > b THEN 1 ELSE 0 END") } diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index d937647e5556..88923d9c6cee 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -144,10 +144,10 @@ impl PhysicalExpr for CastExpr { value.cast_to(&self.cast_type, Some(&self.cast_options)) } - fn output_field(&self, input_schema: &Schema) -> Result { + fn return_field(&self, input_schema: &Schema) -> Result { Ok(self .expr - .output_field(input_schema)? + .return_field(input_schema)? .with_data_type(self.cast_type.clone())) } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index 039924b95e71..80af0b84c5d1 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -127,7 +127,7 @@ impl PhysicalExpr for Column { Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } - fn output_field(&self, input_schema: &Schema) -> Result { + fn return_field(&self, input_schema: &Schema) -> Result { Ok(input_schema.field(self.index).clone()) } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 5cecadd40c41..1de8c17a373a 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -94,8 +94,8 @@ impl PhysicalExpr for IsNotNullExpr { } } - fn output_field(&self, input_schema: &Schema) -> Result { - self.arg.output_field(input_schema) + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index befd3a2ea57a..7707075ce653 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -92,8 +92,8 @@ impl PhysicalExpr for IsNullExpr { } } - fn output_field(&self, input_schema: &Schema) -> Result { - self.arg.output_field(input_schema) + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 5097ce0f6560..597cbf1dac9e 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -104,8 +104,8 @@ impl PhysicalExpr for NegativeExpr { } } - fn output_field(&self, input_schema: &Schema) -> Result { - self.arg.output_field(input_schema) + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 852b4e4fa780..1f3ae9e25ffb 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -101,8 +101,8 @@ impl PhysicalExpr for NotExpr { } } - fn output_field(&self, input_schema: &Schema) -> Result { - self.arg.output_field(input_schema) + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index e56ba376d3c6..e4fe027c7918 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -110,8 +110,10 @@ impl PhysicalExpr for TryCastExpr { } } - fn output_field(&self, input_schema: &Schema) -> Result { - self.expr.output_field(input_schema) + fn return_field(&self, input_schema: &Schema) -> Result { + self.expr + .return_field(input_schema) + .map(|f| f.with_data_type(self.cast_type.clone())) } fn children(&self) -> Vec<&Arc> { diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index fb3b0ca9d179..d6e070e38948 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -92,7 +92,7 @@ impl ScalarFunctionExpr { let name = fun.name().to_string(); let arg_fields = args .iter() - .map(|e| e.output_field(schema)) + .map(|e| e.return_field(schema)) .collect::>>()?; // verify that input data types is consistent with function's `TypeSignature` @@ -183,7 +183,7 @@ impl PhysicalExpr for ScalarFunctionExpr { let arg_fields_owned = self .args .iter() - .map(|e| e.output_field(batch.schema_ref())) + .map(|e| e.return_field(batch.schema_ref())) .collect::>>()?; let arg_fields = arg_fields_owned.iter().collect::>(); @@ -217,7 +217,7 @@ impl PhysicalExpr for ScalarFunctionExpr { Ok(output) } - fn output_field(&self, _input_schema: &Schema) -> Result { + fn return_field(&self, _input_schema: &Schema) -> Result { Ok(self.return_field.clone()) } diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 6a24a5732147..628506a17b82 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -284,7 +284,7 @@ impl PhysicalGroupBy { expr.data_type(input_schema)?, group_expr_nullable || expr.nullable(input_schema)?, ) - .with_metadata(expr.output_field(input_schema)?.metadata().clone()), + .with_metadata(expr.return_field(input_schema)?.metadata().clone()), ); } if !self.is_single() { diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index a6698a4e4728..534dbd71b40a 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -79,7 +79,7 @@ impl ProjectionExec { let fields: Result> = expr .iter() .map(|(e, name)| { - let metadata = e.output_field(&input_schema)?.metadata().clone(); + let metadata = e.return_field(&input_schema)?.metadata().clone(); let field = Field::new( name, From 6481b3e278e5a73da8603ad7c30080186609e4e7 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 21 Apr 2025 15:12:55 -0400 Subject: [PATCH 23/25] Update migration guide for DF48 with user defined functions --- docs/source/library-user-guide/upgrading.md | 27 +++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index 11fd49566522..4bf2206dbd8f 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -19,6 +19,33 @@ # Upgrade Guides +## DataFusion `48.0.0` + +### Processing `Field` instead of `DataType` for user defined functions + +In order to support metadata handling and extension types, user defined functions are +now switching to traits which use `Field` rather than a `DataType` and nullability. +This gives a single interface to both of these parameters and additionally allows +access to metadata fields, which can be used for extension types. + +To upgrade structs which implement `ScalarUDFImpl`, if you have implemented +`return_type_from_args` you need instead to implement `return_field_from_args`. +If your functions do not need to handle metadata, this should be straightforward +repackaging of the output data into a `Field`. The name you specify on the +field is not important. It will be overwritten during planning. `ReturnInfo` +has been removed, so you will need to remove all references to it. + +`ScalarFunctionArgs` now contains a field called `arg_fields`. You can use this +to access the metadata associated with the columnar values during invocation. + +### Physical Expression return field + +To support the changes to user defined functions processing metadata, the +`PhysicalExpr` trait, which now must specify a return `Field` based on the input +schema. To upgrade structs which implement `PhysicalExpr` you need to implement +the `return_field` function. There are numerous examples in the `physical-expr` +crate. + ## DataFusion `46.0.0` ### Use `invoke_with_args` instead of `invoke()` and `invoke_batch()` From cddb52afb4bd65d068151752a89278ab46a9b746 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 21 Apr 2025 15:29:05 -0400 Subject: [PATCH 24/25] Whitespace --- docs/source/library-user-guide/upgrading.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index c4b4c454cc98..35f0577b0c82 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -31,7 +31,7 @@ access to metadata fields, which can be used for extension types. To upgrade structs which implement `ScalarUDFImpl`, if you have implemented `return_type_from_args` you need instead to implement `return_field_from_args`. If your functions do not need to handle metadata, this should be straightforward -repackaging of the output data into a `Field`. The name you specify on the +repackaging of the output data into a `Field`. The name you specify on the field is not important. It will be overwritten during planning. `ReturnInfo` has been removed, so you will need to remove all references to it. From e4d5846dddefed5bf6934acc63da7f4bd4644868 Mon Sep 17 00:00:00 2001 From: Tim Saucer Date: Mon, 21 Apr 2025 16:38:42 -0400 Subject: [PATCH 25/25] Docstring correction --- datafusion/physical-expr-common/src/sort_expr.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index ce9550ada7e7..ee575603683a 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -57,7 +57,7 @@ use itertools::Itertools; /// # fn data_type(&self, input_schema: &Schema) -> Result {todo!()} /// # fn nullable(&self, input_schema: &Schema) -> Result {todo!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result {todo!() } -/// # fn output_field(&self, input_schema: &Schema) -> Result { unimplemented!() } +/// # fn return_field(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn children(&self) -> Vec<&Arc> {todo!()} /// # fn with_new_children(self: Arc, children: Vec>) -> Result> {todo!()} /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { todo!() }