diff --git a/Cargo.lock b/Cargo.lock index d392e88d3459..9ac159486a76 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -456,6 +456,7 @@ checksum = "7450c76ab7c5a6805be3440dc2e2096010da58f7cab301fdc996a4ee3ee74e49" dependencies = [ "bitflags 2.8.0", "serde", + "serde_json", ] [[package]] diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 66a26a18c0dc..a4fa502189f8 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -472,7 +472,7 @@ impl DFSchema { let matches = self.qualified_fields_with_unqualified_name(name); match matches.len() { 0 => Err(unqualified_field_not_found(name, self)), - 1 => Ok((matches[0].0, (matches[0].1))), + 1 => Ok((matches[0].0, matches[0].1)), _ => { // When `matches` size > 1, it doesn't necessarily mean an `ambiguous name` problem. // Because name may generate from Alias/... . It means that it don't own qualifier. @@ -515,14 +515,6 @@ impl DFSchema { Ok(self.field(idx)) } - /// Find the field with the given qualified column - pub fn field_from_column(&self, column: &Column) -> Result<&Field> { - match &column.relation { - Some(r) => self.field_with_qualified_name(r, &column.name), - None => self.field_with_unqualified_name(&column.name), - } - } - /// Find the field with the given qualified column pub fn qualified_field_from_column( &self, @@ -969,16 +961,28 @@ impl Display for DFSchema { /// widely used in the DataFusion codebase. pub trait ExprSchema: std::fmt::Debug { /// Is this column reference nullable? - fn nullable(&self, col: &Column) -> Result; + fn nullable(&self, col: &Column) -> Result { + Ok(self.field_from_column(col)?.is_nullable()) + } /// What is the datatype of this column? - fn data_type(&self, col: &Column) -> Result<&DataType>; + fn data_type(&self, col: &Column) -> Result<&DataType> { + Ok(self.field_from_column(col)?.data_type()) + } /// Returns the column's optional metadata. - fn metadata(&self, col: &Column) -> Result<&HashMap>; + fn metadata(&self, col: &Column) -> Result<&HashMap> { + Ok(self.field_from_column(col)?.metadata()) + } /// Return the column's datatype and nullability - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)>; + fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { + let field = self.field_from_column(col)?; + Ok((field.data_type(), field.is_nullable())) + } + + // Return the column's field + fn field_from_column(&self, col: &Column) -> Result<&Field>; } // Implement `ExprSchema` for `Arc` @@ -998,24 +1002,18 @@ impl + std::fmt::Debug> ExprSchema for P { fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { self.as_ref().data_type_and_nullable(col) } -} - -impl ExprSchema for DFSchema { - fn nullable(&self, col: &Column) -> Result { - Ok(self.field_from_column(col)?.is_nullable()) - } - - fn data_type(&self, col: &Column) -> Result<&DataType> { - Ok(self.field_from_column(col)?.data_type()) - } - fn metadata(&self, col: &Column) -> Result<&HashMap> { - Ok(self.field_from_column(col)?.metadata()) + fn field_from_column(&self, col: &Column) -> Result<&Field> { + self.as_ref().field_from_column(col) } +} - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { - let field = self.field_from_column(col)?; - Ok((field.data_type(), field.is_nullable())) +impl ExprSchema for DFSchema { + fn field_from_column(&self, col: &Column) -> Result<&Field> { + match &col.relation { + Some(r) => self.field_with_qualified_name(r, &col.name), + None => self.field_with_unqualified_name(&col.name), + } } } diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index edc0d34b539a..3ace3e14ec25 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -95,7 +95,7 @@ extended_tests = [] [dependencies] arrow = { workspace = true } arrow-ipc = { workspace = true } -arrow-schema = { workspace = true } +arrow-schema = { workspace = true, features = ["canonical_extension_types"] } async-trait = { workspace = true } bytes = { workspace = true } bzip2 = { version = "0.5.2", optional = true } diff --git a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs index 911d2c0cee05..e00a44188e57 100644 --- a/datafusion/core/tests/physical_optimizer/projection_pushdown.rs +++ b/datafusion/core/tests/physical_optimizer/projection_pushdown.rs @@ -128,7 +128,7 @@ fn test_update_matching_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -193,7 +193,7 @@ fn test_update_matching_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 3))), @@ -261,7 +261,7 @@ fn test_update_projected_exprs() -> Result<()> { Arc::new(Column::new("b", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d", 2))), @@ -326,7 +326,7 @@ fn test_update_projected_exprs() -> Result<()> { Arc::new(Column::new("b_new", 1)), )), ], - DataType::Int32, + Field::new("f", DataType::Int32, true), )), Arc::new(CaseExpr::try_new( Some(Arc::new(Column::new("d_new", 3))), diff --git a/datafusion/core/tests/tpc-ds/49.sql b/datafusion/core/tests/tpc-ds/49.sql index 090e9746c0d8..219877719f22 100644 --- a/datafusion/core/tests/tpc-ds/49.sql +++ b/datafusion/core/tests/tpc-ds/49.sql @@ -110,7 +110,7 @@ select channel, item, return_ratio, return_rank, currency_rank from where sr.sr_return_amt > 10000 and sts.ss_net_profit > 1 - and sts.ss_net_paid > 0 + and sts.ss_net_paid > 0 and sts.ss_quantity > 0 and ss_sold_date_sk = d_date_sk and d_year = 2000 diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 264bd6b66a60..fbef524440e4 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -16,16 +16,19 @@ // under the License. use std::any::Any; +use std::collections::HashMap; use std::hash::{DefaultHasher, Hash, Hasher}; use std::sync::Arc; -use arrow::array::as_string_array; +use arrow::array::{as_string_array, record_batch, Int8Array, UInt64Array}; use arrow::array::{ builder::BooleanBuilder, cast::AsArray, Array, ArrayRef, Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, }; use arrow::compute::kernels::numeric::add; use arrow::datatypes::{DataType, Field, Schema}; +use arrow_schema::extension::{Bool8, CanonicalExtensionType, ExtensionType}; +use arrow_schema::ArrowError; use datafusion::common::test_util::batches_to_string; use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; @@ -35,13 +38,13 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::utils::take_function_args; use datafusion_common::{ assert_batches_eq, assert_batches_sorted_eq, assert_contains, exec_err, not_impl_err, - plan_err, DFSchema, DataFusionError, HashMap, Result, ScalarValue, + plan_err, DFSchema, DataFusionError, Result, ScalarValue, }; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ Accumulator, ColumnarValue, CreateFunction, CreateFunctionBody, LogicalPlanBuilder, - OperateFunctionArg, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, Signature, Volatility, + OperateFunctionArg, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, + Signature, Volatility, }; use datafusion_functions_nested::range::range_udf; use parking_lot::Mutex; @@ -57,7 +60,7 @@ async fn csv_query_custom_udf_with_cast() -> Result<()> { let ctx = create_udf_context(); register_aggregate_csv(&ctx).await?; let sql = "SELECT avg(custom_sqrt(c11)) FROM aggregate_test_100"; - let actual = plan_and_collect(&ctx, sql).await.unwrap(); + let actual = plan_and_collect(&ctx, sql).await?; insta::assert_snapshot!(batches_to_string(&actual), @r###" +------------------------------------------+ @@ -76,7 +79,7 @@ async fn csv_query_avg_sqrt() -> Result<()> { register_aggregate_csv(&ctx).await?; // Note it is a different column (c12) than above (c11) let sql = "SELECT avg(custom_sqrt(c12)) FROM aggregate_test_100"; - let actual = plan_and_collect(&ctx, sql).await.unwrap(); + let actual = plan_and_collect(&ctx, sql).await?; insta::assert_snapshot!(batches_to_string(&actual), @r###" +------------------------------------------+ @@ -389,7 +392,7 @@ async fn udaf_as_window_func() -> Result<()> { WindowAggr: windowExpr=[[my_acc(my_table.b) PARTITION BY [my_table.a] ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING]] TableScan: my_table"#; - let dataframe = context.sql(sql).await.unwrap(); + let dataframe = context.sql(sql).await?; assert_eq!(format!("{}", dataframe.logical_plan()), expected); Ok(()) } @@ -399,7 +402,7 @@ async fn case_sensitive_identifiers_user_defined_functions() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch)?; let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { @@ -443,7 +446,7 @@ async fn test_user_defined_functions_with_alias() -> Result<()> { let ctx = SessionContext::new(); let arr = Int32Array::from(vec![1]); let batch = RecordBatch::try_from_iter(vec![("i", Arc::new(arr) as _)])?; - ctx.register_batch("t", batch).unwrap(); + ctx.register_batch("t", batch)?; let myfunc = Arc::new(|args: &[ColumnarValue]| { let ColumnarValue::Array(array) = &args[0] else { @@ -803,7 +806,7 @@ impl ScalarUDFImpl for TakeUDF { &self.signature } fn return_type(&self, _args: &[DataType]) -> Result { - not_impl_err!("Not called because the return_type_from_args is implemented") + not_impl_err!("Not called because the return_field_from_args is implemented") } /// This function returns the type of the first or second argument based on @@ -811,9 +814,9 @@ impl ScalarUDFImpl for TakeUDF { /// /// 1. If the third argument is '0', return the type of the first argument /// 2. If the third argument is '1', return the type of the second argument - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - if args.arg_types.len() != 3 { - return plan_err!("Expected 3 arguments, got {}.", args.arg_types.len()); + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + if args.arg_fields.len() != 3 { + return plan_err!("Expected 3 arguments, got {}.", args.arg_fields.len()); } let take_idx = if let Some(take_idx) = args.scalar_arguments.get(2) { @@ -838,8 +841,10 @@ impl ScalarUDFImpl for TakeUDF { ); }; - Ok(ReturnInfo::new_nullable( - args.arg_types[take_idx].to_owned(), + Ok(Field::new( + self.name(), + args.arg_fields[take_idx].data_type().to_owned(), + true, )) } @@ -1367,3 +1372,342 @@ async fn register_alltypes_parquet(ctx: &SessionContext) -> Result<()> { async fn plan_and_collect(ctx: &SessionContext, sql: &str) -> Result> { ctx.sql(sql).await?.collect().await } + +#[derive(Debug)] +struct MetadataBasedUdf { + name: String, + signature: Signature, + metadata: HashMap, +} + +impl MetadataBasedUdf { + fn new(metadata: HashMap) -> Self { + // The name we return must be unique. Otherwise we will not call distinct + // instances of this UDF. This is a small hack for the unit tests to get unique + // names, but you could do something more elegant with the metadata. + let name = format!("metadata_based_udf_{}", metadata.len()); + Self { + name, + signature: Signature::exact(vec![DataType::UInt64], Volatility::Immutable), + metadata, + } + } +} + +impl ScalarUDFImpl for MetadataBasedUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> Result { + unimplemented!( + "this should never be called since return_field_from_args is implemented" + ); + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new(self.name(), DataType::UInt64, true) + .with_metadata(self.metadata.clone())) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert_eq!(args.arg_fields.len(), 1); + let should_double = args.arg_fields[0] + .metadata() + .get("modify_values") + .map(|v| v == "double_output") + .unwrap_or(false); + let mulitplier = if should_double { 2 } else { 1 }; + + match &args.args[0] { + ColumnarValue::Array(array) => { + let array_values: Vec<_> = array + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.map(|x| x * mulitplier)) + .collect(); + let array_ref = Arc::new(UInt64Array::from(array_values)) as ArrayRef; + Ok(ColumnarValue::Array(array_ref)) + } + ColumnarValue::Scalar(value) => { + let ScalarValue::UInt64(value) = value else { + return exec_err!("incorrect data type"); + }; + + Ok(ColumnarValue::Scalar(ScalarValue::UInt64( + value.map(|v| v * mulitplier), + ))) + } + } + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name == other.name() + } +} + +#[tokio::test] +async fn test_metadata_based_udf() -> Result<()> { + let data_array = Arc::new(UInt64Array::from(vec![0, 5, 10, 15, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_metadata", DataType::UInt64, true), + Field::new("with_metadata", DataType::UInt64, true).with_metadata( + [("modify_values".to_string(), "double_output".to_string())] + .into_iter() + .collect(), + ), + ])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let no_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new(HashMap::new())); + let with_output_meta_udf = ScalarUDF::from(MetadataBasedUdf::new( + [("output_metatype".to_string(), "custom_value".to_string())] + .into_iter() + .collect(), + )); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .project(vec![ + no_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_no_out"), + no_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_no_out"), + with_output_meta_udf + .call(vec![col("no_metadata")]) + .alias("meta_no_in_with_out"), + with_output_meta_udf + .call(vec![col("with_metadata")]) + .alias("meta_with_in_with_out"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + // To test for output metadata handling, we set the expected values on the result + // To test for input metadata handling, we check the numbers returned + let mut output_meta = HashMap::new(); + let _ = output_meta.insert("output_metatype".to_string(), "custom_value".to_string()); + let expected_schema = Schema::new(vec![ + Field::new("meta_no_in_no_out", DataType::UInt64, true), + Field::new("meta_with_in_no_out", DataType::UInt64, true), + Field::new("meta_no_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + Field::new("meta_with_in_with_out", DataType::UInt64, true) + .with_metadata(output_meta.clone()), + ]); + + let expected = record_batch!( + ("meta_no_in_no_out", UInt64, [0, 5, 10, 15, 20]), + ("meta_with_in_no_out", UInt64, [0, 10, 20, 30, 40]), + ("meta_no_in_with_out", UInt64, [0, 5, 10, 15, 20]), + ("meta_with_in_with_out", UInt64, [0, 10, 20, 30, 40]) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + ctx.deregister_table("t")?; + Ok(()) +} + +/// This UDF is to test extension handling, both on the input and output +/// sides. For the input, we will handle the data differently if there is +/// the canonical extension type Bool8. For the output we will add a +/// user defined extension type. +#[derive(Debug)] +struct ExtensionBasedUdf { + name: String, + signature: Signature, +} + +impl Default for ExtensionBasedUdf { + fn default() -> Self { + Self { + name: "canonical_extension_udf".to_string(), + signature: Signature::exact(vec![DataType::Int8], Volatility::Immutable), + } + } +} +impl ScalarUDFImpl for ExtensionBasedUdf { + fn as_any(&self) -> &dyn Any { + self + } + + fn name(&self) -> &str { + &self.name + } + + fn signature(&self) -> &Signature { + &self.signature + } + + fn return_type(&self, _args: &[DataType]) -> Result { + Ok(DataType::Utf8) + } + + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new("canonical_extension_udf", DataType::Utf8, true) + .with_extension_type(MyUserExtentionType {})) + } + + fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { + assert_eq!(args.arg_fields.len(), 1); + let input_field = args.arg_fields[0]; + + let output_as_bool = matches!( + CanonicalExtensionType::try_from(input_field), + Ok(CanonicalExtensionType::Bool8(_)) + ); + + // If we have the extension type set, we are outputting a boolean value. + // Otherwise we output a string representation of the numeric value. + fn print_value(v: Option, as_bool: bool) -> Option { + v.map(|x| match as_bool { + true => format!("{}", x != 0), + false => format!("{x}"), + }) + } + + match &args.args[0] { + ColumnarValue::Array(array) => { + let array_values: Vec<_> = array + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| print_value(v, output_as_bool)) + .collect(); + let array_ref = Arc::new(StringArray::from(array_values)) as ArrayRef; + Ok(ColumnarValue::Array(array_ref)) + } + ColumnarValue::Scalar(value) => { + let ScalarValue::Int8(value) = value else { + return exec_err!("incorrect data type"); + }; + + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(print_value( + *value, + output_as_bool, + )))) + } + } + } + + fn equals(&self, other: &dyn ScalarUDFImpl) -> bool { + self.name == other.name() + } +} + +struct MyUserExtentionType {} + +impl ExtensionType for MyUserExtentionType { + const NAME: &'static str = "my_user_extention_type"; + type Metadata = (); + + fn metadata(&self) -> &Self::Metadata { + &() + } + + fn serialize_metadata(&self) -> Option { + None + } + + fn deserialize_metadata( + _metadata: Option<&str>, + ) -> std::result::Result { + Ok(()) + } + + fn supports_data_type( + &self, + data_type: &DataType, + ) -> std::result::Result<(), ArrowError> { + if let DataType::Utf8 = data_type { + Ok(()) + } else { + Err(ArrowError::InvalidArgumentError( + "only utf8 supported".to_string(), + )) + } + } + + fn try_new( + _data_type: &DataType, + _metadata: Self::Metadata, + ) -> std::result::Result { + Ok(Self {}) + } +} + +#[tokio::test] +async fn test_extension_based_udf() -> Result<()> { + let data_array = Arc::new(Int8Array::from(vec![0, 0, 10, 20])) as ArrayRef; + let schema = Arc::new(Schema::new(vec![ + Field::new("no_extension", DataType::Int8, true), + Field::new("with_extension", DataType::Int8, true).with_extension_type(Bool8), + ])); + let batch = RecordBatch::try_new( + schema, + vec![Arc::clone(&data_array), Arc::clone(&data_array)], + )?; + + let ctx = SessionContext::new(); + ctx.register_batch("t", batch)?; + let t = ctx.table("t").await?; + let extension_based_udf = ScalarUDF::from(ExtensionBasedUdf::default()); + + let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) + .project(vec![ + extension_based_udf + .call(vec![col("no_extension")]) + .alias("without_bool8_extension"), + extension_based_udf + .call(vec![col("with_extension")]) + .alias("with_bool8_extension"), + ])? + .build()?; + + let actual = DataFrame::new(ctx.state(), plan).collect().await?; + + // To test for output extension handling, we set the expected values on the result + // To test for input extensions handling, we check the strings returned + let expected_schema = Schema::new(vec![ + Field::new("without_bool8_extension", DataType::Utf8, true) + .with_extension_type(MyUserExtentionType {}), + Field::new("with_bool8_extension", DataType::Utf8, true) + .with_extension_type(MyUserExtentionType {}), + ]); + + let expected = record_batch!( + ("without_bool8_extension", Utf8, ["0", "0", "10", "20"]), + ( + "with_bool8_extension", + Utf8, + ["false", "false", "true", "true"] + ) + )? + .with_schema(Arc::new(expected_schema))?; + + assert_eq!(expected, actual[0]); + + ctx.deregister_table("t")?; + Ok(()) +} diff --git a/datafusion/expr/src/expr_schema.rs b/datafusion/expr/src/expr_schema.rs index a349c83a4934..3786180e2cfa 100644 --- a/datafusion/expr/src/expr_schema.rs +++ b/datafusion/expr/src/expr_schema.rs @@ -24,7 +24,7 @@ use crate::expr::{ use crate::type_coercion::functions::{ data_types_with_aggregate_udf, data_types_with_scalar_udf, data_types_with_window_udf, }; -use crate::udf::ReturnTypeArgs; +use crate::udf::ReturnFieldArgs; use crate::{utils, LogicalPlan, Projection, Subquery, WindowFunctionDefinition}; use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field}; @@ -341,21 +341,8 @@ impl ExprSchemable for Expr { } fn metadata(&self, schema: &dyn ExprSchema) -> Result> { - match self { - Expr::Column(c) => Ok(schema.metadata(c)?.clone()), - Expr::Alias(Alias { expr, metadata, .. }) => { - let mut ret = expr.metadata(schema)?; - if let Some(metadata) = metadata { - if !metadata.is_empty() { - ret.extend(metadata.clone()); - return Ok(ret); - } - } - Ok(ret) - } - Expr::Cast(Cast { expr, .. }) => expr.metadata(schema), - _ => Ok(HashMap::new()), - } + self.to_field(schema) + .map(|(_, field)| field.metadata().clone()) } /// Returns the datatype and nullability of the expression based on [ExprSchema]. @@ -372,23 +359,62 @@ impl ExprSchemable for Expr { &self, schema: &dyn ExprSchema, ) -> Result<(DataType, bool)> { - match self { - Expr::Alias(Alias { expr, name, .. }) => match &**expr { - Expr::Placeholder(Placeholder { data_type, .. }) => match &data_type { - None => schema - .data_type_and_nullable(&Column::from_name(name)) - .map(|(d, n)| (d.clone(), n)), - Some(dt) => Ok((dt.clone(), expr.nullable(schema)?)), - }, - _ => expr.data_type_and_nullable(schema), - }, - Expr::Negative(expr) => expr.data_type_and_nullable(schema), - Expr::Column(c) => schema - .data_type_and_nullable(c) - .map(|(d, n)| (d.clone(), n)), - Expr::OuterReferenceColumn(ty, _) => Ok((ty.clone(), true)), - Expr::ScalarVariable(ty, _) => Ok((ty.clone(), true)), - Expr::Literal(l) => Ok((l.data_type(), l.is_null())), + let field = self.to_field(schema)?.1; + + Ok((field.data_type().clone(), field.is_nullable())) + } + + /// Returns a [arrow::datatypes::Field] compatible with this expression. + /// + /// So for example, a projected expression `col(c1) + col(c2)` is + /// placed in an output field **named** col("c1 + c2") + fn to_field( + &self, + schema: &dyn ExprSchema, + ) -> Result<(Option, Arc)> { + let (relation, schema_name) = self.qualified_name(); + #[allow(deprecated)] + let field = match self { + Expr::Alias(Alias { + expr, + name, + metadata, + .. + }) => { + let field = match &**expr { + Expr::Placeholder(Placeholder { data_type, .. }) => { + match &data_type { + None => schema + .data_type_and_nullable(&Column::from_name(name)) + .map(|(d, n)| Field::new(&schema_name, d.clone(), n)), + Some(dt) => Ok(Field::new( + &schema_name, + dt.clone(), + expr.nullable(schema)?, + )), + } + } + _ => expr.to_field(schema).map(|(_, f)| f.as_ref().clone()), + }?; + + let mut combined_metadata = expr.metadata(schema)?; + if let Some(metadata) = metadata { + if !metadata.is_empty() { + combined_metadata.extend(metadata.clone()); + } + } + + Ok(field.with_metadata(combined_metadata)) + } + Expr::Negative(expr) => { + expr.to_field(schema).map(|(_, f)| f.as_ref().clone()) + } + Expr::Column(c) => schema.field_from_column(c).cloned(), + Expr::OuterReferenceColumn(ty, _) => { + Ok(Field::new(&schema_name, ty.clone(), true)) + } + Expr::ScalarVariable(ty, _) => Ok(Field::new(&schema_name, ty.clone(), true)), + Expr::Literal(l) => Ok(Field::new(&schema_name, l.data_type(), l.is_null())), Expr::IsNull(_) | Expr::IsNotNull(_) | Expr::IsTrue(_) @@ -397,11 +423,12 @@ impl ExprSchemable for Expr { | Expr::IsNotTrue(_) | Expr::IsNotFalse(_) | Expr::IsNotUnknown(_) - | Expr::Exists { .. } => Ok((DataType::Boolean, false)), - Expr::ScalarSubquery(subquery) => Ok(( - subquery.subquery.schema().field(0).data_type().clone(), - subquery.subquery.schema().field(0).is_nullable(), - )), + | Expr::Exists { .. } => { + Ok(Field::new(&schema_name, DataType::Boolean, false)) + } + Expr::ScalarSubquery(subquery) => { + Ok(subquery.subquery.schema().field(0).clone()) + } Expr::BinaryExpr(BinaryExpr { ref left, ref right, @@ -412,17 +439,26 @@ impl ExprSchemable for Expr { let mut coercer = BinaryTypeCoercer::new(&lhs_type, op, &rhs_type); coercer.set_lhs_spans(left.spans().cloned().unwrap_or_default()); coercer.set_rhs_spans(right.spans().cloned().unwrap_or_default()); - Ok((coercer.get_result_type()?, lhs_nullable || rhs_nullable)) + Ok(Field::new( + &schema_name, + coercer.get_result_type()?, + lhs_nullable || rhs_nullable, + )) } Expr::WindowFunction(window_function) => { - self.data_type_and_nullable_with_window_function(schema, window_function) + let (dt, nullable) = self.data_type_and_nullable_with_window_function( + schema, + window_function, + )?; + Ok(Field::new(&schema_name, dt, nullable)) } Expr::ScalarFunction(ScalarFunction { func, args }) => { - let (arg_types, nullables): (Vec, Vec) = args + let (arg_types, fields): (Vec, Vec>) = args .iter() - .map(|e| e.data_type_and_nullable(schema)) + .map(|e| e.to_field(schema).map(|(_, f)| f)) .collect::>>()? .into_iter() + .map(|f| (f.data_type().clone(), f)) .unzip(); // Verify that function is invoked with correct number and type of arguments as defined in `TypeSignature` let new_data_types = data_types_with_scalar_udf(&arg_types, func) @@ -440,6 +476,11 @@ impl ExprSchemable for Expr { ) ) })?; + let new_fields = fields + .into_iter() + .zip(new_data_types) + .map(|(f, d)| f.as_ref().clone().with_data_type(d)) + .collect::>(); let arguments = args .iter() @@ -448,34 +489,37 @@ impl ExprSchemable for Expr { _ => None, }) .collect::>(); - let args = ReturnTypeArgs { - arg_types: &new_data_types, + let args = ReturnFieldArgs { + arg_fields: &new_fields, scalar_arguments: &arguments, - nullables: &nullables, }; - let (return_type, nullable) = - func.return_type_from_args(args)?.into_parts(); - Ok((return_type, nullable)) + func.return_field_from_args(args) } - _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), - } - } + // _ => Ok((self.get_type(schema)?, self.nullable(schema)?)), + Expr::Cast(Cast { expr, data_type }) => expr + .to_field(schema) + .map(|(_, f)| f.as_ref().clone().with_data_type(data_type.clone())), + Expr::Like(_) + | Expr::SimilarTo(_) + | Expr::Not(_) + | Expr::Between(_) + | Expr::Case(_) + | Expr::TryCast(_) + | Expr::AggregateFunction(_) + | Expr::InList(_) + | Expr::InSubquery(_) + | Expr::Wildcard { .. } + | Expr::GroupingSet(_) + | Expr::Placeholder(_) + | Expr::Unnest(_) => Ok(Field::new( + &schema_name, + self.get_type(schema)?, + self.nullable(schema)?, + )), + }?; - /// Returns a [arrow::datatypes::Field] compatible with this expression. - /// - /// So for example, a projected expression `col(c1) + col(c2)` is - /// placed in an output field **named** col("c1 + c2") - fn to_field( - &self, - input_schema: &dyn ExprSchema, - ) -> Result<(Option, Arc)> { - let (relation, schema_name) = self.qualified_name(); - let (data_type, nullable) = self.data_type_and_nullable(input_schema)?; - let field = Field::new(schema_name, data_type, nullable) - .with_metadata(self.metadata(input_schema)?) - .into(); - Ok((relation, field)) + Ok((relation, Arc::new(field.with_name(schema_name)))) } /// Wraps this expression in a cast to a target [arrow::datatypes::DataType]. @@ -762,29 +806,25 @@ mod tests { #[derive(Debug)] struct MockExprSchema { - nullable: bool, - data_type: DataType, + field: Field, error_on_nullable: bool, - metadata: HashMap, } impl MockExprSchema { fn new() -> Self { Self { - nullable: false, - data_type: DataType::Null, + field: Field::new("mock_field", DataType::Null, false), error_on_nullable: false, - metadata: HashMap::new(), } } fn with_nullable(mut self, nullable: bool) -> Self { - self.nullable = nullable; + self.field = self.field.with_nullable(nullable); self } fn with_data_type(mut self, data_type: DataType) -> Self { - self.data_type = data_type; + self.field = self.field.with_data_type(data_type); self } @@ -794,7 +834,7 @@ mod tests { } fn with_metadata(mut self, metadata: HashMap) -> Self { - self.metadata = metadata; + self.field = self.field.with_metadata(metadata); self } } @@ -804,20 +844,12 @@ mod tests { if self.error_on_nullable { internal_err!("nullable error") } else { - Ok(self.nullable) + Ok(self.field.is_nullable()) } } - fn data_type(&self, _col: &Column) -> Result<&DataType> { - Ok(&self.data_type) - } - - fn metadata(&self, _col: &Column) -> Result<&HashMap> { - Ok(&self.metadata) - } - - fn data_type_and_nullable(&self, col: &Column) -> Result<(&DataType, bool)> { - Ok((self.data_type(col)?, self.nullable(col)?)) + fn field_from_column(&self, _col: &Column) -> Result<&Field> { + Ok(&self.field) } } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index d3cc881af361..48931d6525af 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -104,8 +104,7 @@ pub use udaf::{ SetMonotonicity, StatisticsArgs, }; pub use udf::{ - scalar_doc_sections, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, - ScalarUDFImpl, + scalar_doc_sections, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, }; pub use udwf::{window_doc_sections, ReversedUDWF, WindowUDF, WindowUDFImpl}; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 05a43444d4ae..706770d5a66f 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -1122,12 +1122,24 @@ impl LogicalPlanBuilder { for (l, r) in &on { if self.plan.schema().has_column(l) && right.schema().has_column(r) - && can_hash(self.plan.schema().field_from_column(l)?.data_type()) + && can_hash( + datafusion_common::ExprSchema::field_from_column( + self.plan.schema(), + l, + )? + .data_type(), + ) { join_on.push((Expr::Column(l.clone()), Expr::Column(r.clone()))); } else if self.plan.schema().has_column(l) && right.schema().has_column(r) - && can_hash(self.plan.schema().field_from_column(r)?.data_type()) + && can_hash( + datafusion_common::ExprSchema::field_from_column( + self.plan.schema(), + r, + )? + .data_type(), + ) { join_on.push((Expr::Column(r.clone()), Expr::Column(l.clone()))); } else { diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 9b2400774a3d..c1b74fedcc32 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -21,7 +21,7 @@ use crate::expr::schema_name_from_exprs_comma_separated_without_space; use crate::simplify::{ExprSimplifyResult, SimplifyInfo}; use crate::sort_properties::{ExprProperties, SortProperties}; use crate::{ColumnarValue, Documentation, Expr, Signature}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{not_impl_err, ExprSchema, Result, ScalarValue}; use datafusion_expr_common::interval_arithmetic::Interval; use std::any::Any; @@ -170,7 +170,7 @@ impl ScalarUDF { /// /// # Notes /// - /// If a function implement [`ScalarUDFImpl::return_type_from_args`], + /// If a function implement [`ScalarUDFImpl::return_field_from_args`], /// its [`ScalarUDFImpl::return_type`] should raise an error. /// /// See [`ScalarUDFImpl::return_type`] for more details. @@ -180,9 +180,9 @@ impl ScalarUDF { /// Return the datatype this function returns given the input argument types. /// - /// See [`ScalarUDFImpl::return_type_from_args`] for more details. - pub fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - self.inner.return_type_from_args(args) + /// See [`ScalarUDFImpl::return_field_from_args`] for more details. + pub fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.inner.return_field_from_args(args) } /// Do the function rewrite @@ -293,14 +293,17 @@ where /// Arguments passed to [`ScalarUDFImpl::invoke_with_args`] when invoking a /// scalar function. -pub struct ScalarFunctionArgs<'a> { +pub struct ScalarFunctionArgs<'a, 'b> { /// The evaluated arguments to the function pub args: Vec, + /// Field associated with each arg, if it exists + pub arg_fields: Vec<&'a Field>, /// The number of rows in record batch being evaluated pub number_rows: usize, - /// The return type of the scalar function returned (from `return_type` or `return_type_from_args`) - /// when creating the physical expression from the logical expression - pub return_type: &'a DataType, + /// The return field of the scalar function returned (from `return_type` + /// or `return_field_from_args`) when creating the physical expression + /// from the logical expression + pub return_field: &'b Field, } /// Information about arguments passed to the function @@ -309,11 +312,11 @@ pub struct ScalarFunctionArgs<'a> { /// such as the type of the arguments, any scalar arguments and if the /// arguments can (ever) be null /// -/// See [`ScalarUDFImpl::return_type_from_args`] for more information +/// See [`ScalarUDFImpl::return_field_from_args`] for more information #[derive(Debug)] -pub struct ReturnTypeArgs<'a> { +pub struct ReturnFieldArgs<'a> { /// The data types of the arguments to the function - pub arg_types: &'a [DataType], + pub arg_fields: &'a [Field], /// Is argument `i` to the function a scalar (constant) /// /// If argument `i` is not a scalar, it will be None @@ -321,52 +324,6 @@ pub struct ReturnTypeArgs<'a> { /// For example, if a function is called like `my_function(column_a, 5)` /// this field will be `[None, Some(ScalarValue::Int32(Some(5)))]` pub scalar_arguments: &'a [Option<&'a ScalarValue>], - /// Can argument `i` (ever) null? - pub nullables: &'a [bool], -} - -/// Return metadata for this function. -/// -/// See [`ScalarUDFImpl::return_type_from_args`] for more information -#[derive(Debug)] -pub struct ReturnInfo { - return_type: DataType, - nullable: bool, -} - -impl ReturnInfo { - pub fn new(return_type: DataType, nullable: bool) -> Self { - Self { - return_type, - nullable, - } - } - - pub fn new_nullable(return_type: DataType) -> Self { - Self { - return_type, - nullable: true, - } - } - - pub fn new_non_nullable(return_type: DataType) -> Self { - Self { - return_type, - nullable: false, - } - } - - pub fn return_type(&self) -> &DataType { - &self.return_type - } - - pub fn nullable(&self) -> bool { - self.nullable - } - - pub fn into_parts(self) -> (DataType, bool) { - (self.return_type, self.nullable) - } } /// Trait for implementing user defined scalar functions. @@ -480,7 +437,7 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// /// # Notes /// - /// If you provide an implementation for [`Self::return_type_from_args`], + /// If you provide an implementation for [`Self::return_field_from_args`], /// DataFusion will not call `return_type` (this function). In such cases /// is recommended to return [`DataFusionError::Internal`]. /// @@ -518,14 +475,20 @@ pub trait ScalarUDFImpl: Debug + Send + Sync { /// This function **must** consistently return the same type for the same /// logical input even if the input is simplified (e.g. it must return the same /// value for `('foo' | 'bar')` as it does for ('foobar'). - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let return_type = self.return_type(args.arg_types)?; - Ok(ReturnInfo::new_nullable(return_type)) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let data_types = args + .arg_fields + .iter() + .map(|f| f.data_type()) + .cloned() + .collect::>(); + let return_type = self.return_type(&data_types)?; + Ok(Field::new(self.name(), return_type, true)) } #[deprecated( since = "45.0.0", - note = "Use `return_type_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_type_from_args`, you might have error" + note = "Use `return_field_from_args` instead. if you use `is_nullable` that returns non-nullable with `return_type`, you would need to switch to `return_field_from_args`, you might have error" )] fn is_nullable(&self, _args: &[Expr], _schema: &dyn ExprSchema) -> bool { true @@ -765,18 +728,18 @@ impl ScalarUDFImpl for AliasedScalarUDFImpl { self.inner.return_type(arg_types) } - fn aliases(&self) -> &[String] { - &self.aliases - } - - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - self.inner.return_type_from_args(args) + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + self.inner.return_field_from_args(args) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { self.inner.invoke_with_args(args) } + fn aliases(&self) -> &[String] { + &self.aliases + } + fn simplify( &self, args: Vec, diff --git a/datafusion/ffi/src/udf/mod.rs b/datafusion/ffi/src/udf/mod.rs index 706b9fabedcb..c35a53205eb4 100644 --- a/datafusion/ffi/src/udf/mod.rs +++ b/datafusion/ffi/src/udf/mod.rs @@ -15,23 +15,26 @@ // specific language governing permissions and limitations // under the License. -use std::{ffi::c_void, sync::Arc}; - +use crate::{ + arrow_wrappers::{WrappedArray, WrappedSchema}, + df_result, rresult, rresult_return, + util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, + volatility::FFI_Volatility, +}; use abi_stable::{ std_types::{RResult, RString, RVec}, StableAbi, }; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ array::ArrayRef, error::ArrowError, ffi::{from_ffi, to_ffi, FFI_ArrowSchema}, }; +use datafusion::logical_expr::ReturnFieldArgs; use datafusion::{ error::DataFusionError, - logical_expr::{ - type_coercion::functions::data_types_with_scalar_udf, ReturnInfo, ReturnTypeArgs, - }, + logical_expr::type_coercion::functions::data_types_with_scalar_udf, }; use datafusion::{ error::Result, @@ -39,19 +42,11 @@ use datafusion::{ ColumnarValue, ScalarFunctionArgs, ScalarUDF, ScalarUDFImpl, Signature, }, }; -use return_info::FFI_ReturnInfo; use return_type_args::{ - FFI_ReturnTypeArgs, ForeignReturnTypeArgs, ForeignReturnTypeArgsOwned, -}; - -use crate::{ - arrow_wrappers::{WrappedArray, WrappedSchema}, - df_result, rresult, rresult_return, - util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, - volatility::FFI_Volatility, + FFI_ReturnFieldArgs, ForeignReturnFieldArgs, ForeignReturnFieldArgsOwned, }; +use std::{ffi::c_void, sync::Arc}; -pub mod return_info; pub mod return_type_args; /// A stable struct for sharing a [`ScalarUDF`] across FFI boundaries. @@ -77,19 +72,21 @@ pub struct FFI_ScalarUDF { /// Determines the return info of the underlying [`ScalarUDF`]. Either this /// or return_type may be implemented on a UDF. - pub return_type_from_args: unsafe extern "C" fn( + pub return_field_from_args: unsafe extern "C" fn( udf: &Self, - args: FFI_ReturnTypeArgs, + args: FFI_ReturnFieldArgs, ) - -> RResult, + -> RResult, /// Execute the underlying [`ScalarUDF`] and return the result as a `FFI_ArrowArray` /// within an AbiStable wrapper. + #[allow(clippy::type_complexity)] pub invoke_with_args: unsafe extern "C" fn( udf: &Self, args: RVec, + arg_fields: RVec, num_rows: usize, - return_type: WrappedSchema, + return_field: WrappedSchema, ) -> RResult, /// See [`ScalarUDFImpl`] for details on short_circuits @@ -140,19 +137,20 @@ unsafe extern "C" fn return_type_fn_wrapper( rresult!(return_type) } -unsafe extern "C" fn return_type_from_args_fn_wrapper( +unsafe extern "C" fn return_field_from_args_fn_wrapper( udf: &FFI_ScalarUDF, - args: FFI_ReturnTypeArgs, -) -> RResult { + args: FFI_ReturnFieldArgs, +) -> RResult { let private_data = udf.private_data as *const ScalarUDFPrivateData; let udf = &(*private_data).udf; - let args: ForeignReturnTypeArgsOwned = rresult_return!((&args).try_into()); - let args_ref: ForeignReturnTypeArgs = (&args).into(); + let args: ForeignReturnFieldArgsOwned = rresult_return!((&args).try_into()); + let args_ref: ForeignReturnFieldArgs = (&args).into(); let return_type = udf - .return_type_from_args((&args_ref).into()) - .and_then(FFI_ReturnInfo::try_from); + .return_field_from_args((&args_ref).into()) + .and_then(|f| FFI_ArrowSchema::try_from(f).map_err(DataFusionError::from)) + .map(WrappedSchema); rresult!(return_type) } @@ -174,8 +172,9 @@ unsafe extern "C" fn coerce_types_fn_wrapper( unsafe extern "C" fn invoke_with_args_fn_wrapper( udf: &FFI_ScalarUDF, args: RVec, + arg_fields: RVec, number_rows: usize, - return_type: WrappedSchema, + return_field: WrappedSchema, ) -> RResult { let private_data = udf.private_data as *const ScalarUDFPrivateData; let udf = &(*private_data).udf; @@ -189,12 +188,20 @@ unsafe extern "C" fn invoke_with_args_fn_wrapper( .collect::>(); let args = rresult_return!(args); - let return_type = rresult_return!(DataType::try_from(&return_type.0)); + let return_field = rresult_return!(Field::try_from(&return_field.0)); + + let arg_fields_owned = arg_fields + .into_iter() + .map(|wrapped_field| (&wrapped_field.0).try_into().map_err(DataFusionError::from)) + .collect::>>(); + let arg_fields_owned = rresult_return!(arg_fields_owned); + let arg_fields = arg_fields_owned.iter().collect::>(); let args = ScalarFunctionArgs { args, + arg_fields, number_rows, - return_type: &return_type, + return_field: &return_field, }; let result = rresult_return!(udf @@ -243,7 +250,7 @@ impl From> for FFI_ScalarUDF { short_circuits, invoke_with_args: invoke_with_args_fn_wrapper, return_type: return_type_fn_wrapper, - return_type_from_args: return_type_from_args_fn_wrapper, + return_field_from_args: return_field_from_args_fn_wrapper, coerce_types: coerce_types_fn_wrapper, clone: clone_fn_wrapper, release: release_fn_wrapper, @@ -316,21 +323,22 @@ impl ScalarUDFImpl for ForeignScalarUDF { result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let args: FFI_ReturnTypeArgs = args.try_into()?; + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let args: FFI_ReturnFieldArgs = args.try_into()?; - let result = unsafe { (self.udf.return_type_from_args)(&self.udf, args) }; + let result = unsafe { (self.udf.return_field_from_args)(&self.udf, args) }; let result = df_result!(result); - result.and_then(|r| r.try_into()) + result.and_then(|r| (&r.0).try_into().map_err(DataFusionError::from)) } fn invoke_with_args(&self, invoke_args: ScalarFunctionArgs) -> Result { let ScalarFunctionArgs { args, + arg_fields, number_rows, - return_type, + return_field, } = invoke_args; let args = args @@ -347,10 +355,26 @@ impl ScalarUDFImpl for ForeignScalarUDF { .collect::, ArrowError>>()? .into(); - let return_type = WrappedSchema(FFI_ArrowSchema::try_from(return_type)?); + let arg_fields_wrapped = arg_fields + .iter() + .map(|field| FFI_ArrowSchema::try_from(*field)) + .collect::, ArrowError>>()?; + + let arg_fields = arg_fields_wrapped + .into_iter() + .map(WrappedSchema) + .collect::>(); + + let return_field = WrappedSchema(FFI_ArrowSchema::try_from(return_field)?); let result = unsafe { - (self.udf.invoke_with_args)(&self.udf, args, number_rows, return_type) + (self.udf.invoke_with_args)( + &self.udf, + args, + arg_fields, + number_rows, + return_field, + ) }; let result = df_result!(result)?; @@ -389,7 +413,7 @@ mod tests { let foreign_udf: ForeignScalarUDF = (&local_udf).try_into()?; - assert!(original_udf.name() == foreign_udf.name()); + assert_eq!(original_udf.name(), foreign_udf.name()); Ok(()) } diff --git a/datafusion/ffi/src/udf/return_info.rs b/datafusion/ffi/src/udf/return_info.rs deleted file mode 100644 index cf76ddd1db76..000000000000 --- a/datafusion/ffi/src/udf/return_info.rs +++ /dev/null @@ -1,53 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -use abi_stable::StableAbi; -use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; -use datafusion::{error::DataFusionError, logical_expr::ReturnInfo}; - -use crate::arrow_wrappers::WrappedSchema; - -/// A stable struct for sharing a [`ReturnInfo`] across FFI boundaries. -#[repr(C)] -#[derive(Debug, StableAbi)] -#[allow(non_camel_case_types)] -pub struct FFI_ReturnInfo { - return_type: WrappedSchema, - nullable: bool, -} - -impl TryFrom for FFI_ReturnInfo { - type Error = DataFusionError; - - fn try_from(value: ReturnInfo) -> Result { - let return_type = WrappedSchema(FFI_ArrowSchema::try_from(value.return_type())?); - Ok(Self { - return_type, - nullable: value.nullable(), - }) - } -} - -impl TryFrom for ReturnInfo { - type Error = DataFusionError; - - fn try_from(value: FFI_ReturnInfo) -> Result { - let return_type = DataType::try_from(&value.return_type.0)?; - - Ok(ReturnInfo::new(return_type, value.nullable)) - } -} diff --git a/datafusion/ffi/src/udf/return_type_args.rs b/datafusion/ffi/src/udf/return_type_args.rs index a0897630e2ea..40e577591c34 100644 --- a/datafusion/ffi/src/udf/return_type_args.rs +++ b/datafusion/ffi/src/udf/return_type_args.rs @@ -19,33 +19,30 @@ use abi_stable::{ std_types::{ROption, RVec}, StableAbi, }; -use arrow::datatypes::DataType; +use arrow::datatypes::Field; use datafusion::{ - common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnTypeArgs, + common::exec_datafusion_err, error::DataFusionError, logical_expr::ReturnFieldArgs, scalar::ScalarValue, }; -use crate::{ - arrow_wrappers::WrappedSchema, - util::{rvec_wrapped_to_vec_datatype, vec_datatype_to_rvec_wrapped}, -}; +use crate::arrow_wrappers::WrappedSchema; +use crate::util::{rvec_wrapped_to_vec_field, vec_field_to_rvec_wrapped}; use prost::Message; -/// A stable struct for sharing a [`ReturnTypeArgs`] across FFI boundaries. +/// A stable struct for sharing a [`ReturnFieldArgs`] across FFI boundaries. #[repr(C)] #[derive(Debug, StableAbi)] #[allow(non_camel_case_types)] -pub struct FFI_ReturnTypeArgs { - arg_types: RVec, +pub struct FFI_ReturnFieldArgs { + arg_fields: RVec, scalar_arguments: RVec>>, - nullables: RVec, } -impl TryFrom> for FFI_ReturnTypeArgs { +impl TryFrom> for FFI_ReturnFieldArgs { type Error = DataFusionError; - fn try_from(value: ReturnTypeArgs) -> Result { - let arg_types = vec_datatype_to_rvec_wrapped(value.arg_types)?; + fn try_from(value: ReturnFieldArgs) -> Result { + let arg_fields = vec_field_to_rvec_wrapped(value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments .iter() @@ -62,35 +59,31 @@ impl TryFrom> for FFI_ReturnTypeArgs { .collect(); let scalar_arguments = scalar_arguments?.into_iter().map(ROption::from).collect(); - let nullables = value.nullables.into(); Ok(Self { - arg_types, + arg_fields, scalar_arguments, - nullables, }) } } // TODO(tsaucer) It would be good to find a better way around this, but it // appears a restriction based on the need to have a borrowed ScalarValue -// in the arguments when converted to ReturnTypeArgs -pub struct ForeignReturnTypeArgsOwned { - arg_types: Vec, +// in the arguments when converted to ReturnFieldArgs +pub struct ForeignReturnFieldArgsOwned { + arg_fields: Vec, scalar_arguments: Vec>, - nullables: Vec, } -pub struct ForeignReturnTypeArgs<'a> { - arg_types: &'a [DataType], +pub struct ForeignReturnFieldArgs<'a> { + arg_fields: &'a [Field], scalar_arguments: Vec>, - nullables: &'a [bool], } -impl TryFrom<&FFI_ReturnTypeArgs> for ForeignReturnTypeArgsOwned { +impl TryFrom<&FFI_ReturnFieldArgs> for ForeignReturnFieldArgsOwned { type Error = DataFusionError; - fn try_from(value: &FFI_ReturnTypeArgs) -> Result { - let arg_types = rvec_wrapped_to_vec_datatype(&value.arg_types)?; + fn try_from(value: &FFI_ReturnFieldArgs) -> Result { + let arg_fields = rvec_wrapped_to_vec_field(&value.arg_fields)?; let scalar_arguments: Result, Self::Error> = value .scalar_arguments .iter() @@ -107,36 +100,31 @@ impl TryFrom<&FFI_ReturnTypeArgs> for ForeignReturnTypeArgsOwned { .collect(); let scalar_arguments = scalar_arguments?.into_iter().collect(); - let nullables = value.nullables.iter().cloned().collect(); - Ok(Self { - arg_types, + arg_fields, scalar_arguments, - nullables, }) } } -impl<'a> From<&'a ForeignReturnTypeArgsOwned> for ForeignReturnTypeArgs<'a> { - fn from(value: &'a ForeignReturnTypeArgsOwned) -> Self { +impl<'a> From<&'a ForeignReturnFieldArgsOwned> for ForeignReturnFieldArgs<'a> { + fn from(value: &'a ForeignReturnFieldArgsOwned) -> Self { Self { - arg_types: &value.arg_types, + arg_fields: &value.arg_fields, scalar_arguments: value .scalar_arguments .iter() .map(|opt| opt.as_ref()) .collect(), - nullables: &value.nullables, } } } -impl<'a> From<&'a ForeignReturnTypeArgs<'a>> for ReturnTypeArgs<'a> { - fn from(value: &'a ForeignReturnTypeArgs) -> Self { - ReturnTypeArgs { - arg_types: value.arg_types, +impl<'a> From<&'a ForeignReturnFieldArgs<'a>> for ReturnFieldArgs<'a> { + fn from(value: &'a ForeignReturnFieldArgs) -> Self { + ReturnFieldArgs { + arg_fields: value.arg_fields, scalar_arguments: &value.scalar_arguments, - nullables: value.nullables, } } } diff --git a/datafusion/ffi/src/util.rs b/datafusion/ffi/src/util.rs index 9d5f2aefe324..6b9f373939ea 100644 --- a/datafusion/ffi/src/util.rs +++ b/datafusion/ffi/src/util.rs @@ -15,11 +15,11 @@ // specific language governing permissions and limitations // under the License. +use crate::arrow_wrappers::WrappedSchema; use abi_stable::std_types::RVec; +use arrow::datatypes::Field; use arrow::{datatypes::DataType, ffi::FFI_ArrowSchema}; -use crate::arrow_wrappers::WrappedSchema; - /// This macro is a helpful conversion utility to conver from an abi_stable::RResult to a /// DataFusion result. #[macro_export] @@ -64,6 +64,28 @@ macro_rules! rresult_return { }; } +/// This is a utility function to convert a slice of [`Field`] to its equivalent +/// FFI friendly counterpart, [`WrappedSchema`] +pub fn vec_field_to_rvec_wrapped( + fields: &[Field], +) -> Result, arrow::error::ArrowError> { + Ok(fields + .iter() + .map(FFI_ArrowSchema::try_from) + .collect::, arrow::error::ArrowError>>()? + .into_iter() + .map(WrappedSchema) + .collect()) +} + +/// This is a utility function to convert an FFI friendly vector of [`WrappedSchema`] +/// to their equivalent [`Field`]. +pub fn rvec_wrapped_to_vec_field( + fields: &RVec, +) -> Result, arrow::error::ArrowError> { + fields.iter().map(|d| Field::try_from(&d.0)).collect() +} + /// This is a utility function to convert a slice of [`DataType`] to its equivalent /// FFI friendly counterpart, [`WrappedSchema`] pub fn vec_datatype_to_rvec_wrapped( diff --git a/datafusion/functions-nested/benches/map.rs b/datafusion/functions-nested/benches/map.rs index 2774b24b902a..d8502e314b31 100644 --- a/datafusion/functions-nested/benches/map.rs +++ b/datafusion/functions-nested/benches/map.rs @@ -94,17 +94,22 @@ fn criterion_benchmark(c: &mut Criterion) { let keys = ColumnarValue::Scalar(ScalarValue::List(Arc::new(key_list))); let values = ColumnarValue::Scalar(ScalarValue::List(Arc::new(value_list))); - let return_type = &map_udf() + let return_type = map_udf() .return_type(&[DataType::Utf8, DataType::Int32]) .expect("should get return type"); + let return_field = &Field::new("f", return_type, true); b.iter(|| { black_box( map_udf() .invoke_with_args(ScalarFunctionArgs { args: vec![keys.clone(), values.clone()], + arg_fields: vec![ + &Field::new("a", keys.data_type(), true), + &Field::new("a", values.data_type(), true), + ], number_rows: 1, - return_type, + return_field, }) .expect("map should work on valid values"), ); diff --git a/datafusion/functions/benches/character_length.rs b/datafusion/functions/benches/character_length.rs index bbcfed021064..fb56b6dd5b97 100644 --- a/datafusion/functions/benches/character_length.rs +++ b/datafusion/functions/benches/character_length.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ScalarFunctionArgs; use helper::gen_string_array; @@ -28,20 +28,27 @@ fn criterion_benchmark(c: &mut Criterion) { // All benches are single batch run with 8192 rows let character_length = datafusion_functions::unicode::character_length(); - let return_type = DataType::Utf8; + let return_field = Field::new("f", DataType::Utf8, true); let n_rows = 8192; for str_len in [8, 32, 128, 4096] { // StringArray ASCII only let args_string_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, false); + let arg_fields_owned = args_string_ascii + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( &format!("character_length_StringArray_ascii_str_len_{}", str_len), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: &return_field, })) }) }, @@ -49,14 +56,21 @@ fn criterion_benchmark(c: &mut Criterion) { // StringArray UTF8 let args_string_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, false); + let arg_fields_owned = args_string_utf8 + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( &format!("character_length_StringArray_utf8_str_len_{}", str_len), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: &return_field, })) }) }, @@ -64,14 +78,21 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray ASCII only let args_string_view_ascii = gen_string_array(n_rows, str_len, 0.1, 0.0, true); + let arg_fields_owned = args_string_view_ascii + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( &format!("character_length_StringViewArray_ascii_str_len_{}", str_len), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: &return_field, })) }) }, @@ -79,14 +100,21 @@ fn criterion_benchmark(c: &mut Criterion) { // StringViewArray UTF8 let args_string_view_utf8 = gen_string_array(n_rows, str_len, 0.1, 0.5, true); + let arg_fields_owned = args_string_view_utf8 + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function( &format!("character_length_StringViewArray_utf8_str_len_{}", str_len), |b| { b.iter(|| { black_box(character_length.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &return_type, + return_field: &return_field, })) }) }, diff --git a/datafusion/functions/benches/chr.rs b/datafusion/functions/benches/chr.rs index 8575809c21c8..874c0d642361 100644 --- a/datafusion/functions/benches/chr.rs +++ b/datafusion/functions/benches/chr.rs @@ -23,7 +23,7 @@ use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string::chr; use rand::{Rng, SeedableRng}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use rand::rngs::StdRng; use std::sync::Arc; @@ -50,14 +50,22 @@ fn criterion_benchmark(c: &mut Criterion) { }; let input = Arc::new(input); let args = vec![ColumnarValue::Array(input)]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function("chr", |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/concat.rs b/datafusion/functions/benches/concat.rs index 45ca076e754f..ea6a9c2eaf63 100644 --- a/datafusion/functions/benches/concat.rs +++ b/datafusion/functions/benches/concat.rs @@ -16,7 +16,7 @@ // under the License. use arrow::array::ArrayRef; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use datafusion_common::ScalarValue; @@ -37,6 +37,13 @@ fn create_args(size: usize, str_len: usize) -> Vec { fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let args = create_args(size, 32); + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + let mut group = c.benchmark_group("concat function"); group.bench_function(BenchmarkId::new("concat", size), |b| { b.iter(|| { @@ -45,8 +52,9 @@ fn criterion_benchmark(c: &mut Criterion) { concat() .invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/cot.rs b/datafusion/functions/benches/cot.rs index b2a9ca0b9f47..e7ceeca5e9fe 100644 --- a/datafusion/functions/benches/cot.rs +++ b/datafusion/functions/benches/cot.rs @@ -25,7 +25,7 @@ use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::math::cot; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use std::sync::Arc; fn criterion_benchmark(c: &mut Criterion) { @@ -33,14 +33,22 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; + let arg_fields_owned = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("cot f32 array: {}", size), |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }) .unwrap(), ) @@ -48,14 +56,22 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; + let arg_fields_owned = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("cot f64 array: {}", size), |b| { b.iter(|| { black_box( cot_fn .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/date_bin.rs b/datafusion/functions/benches/date_bin.rs index 7ea5fdcb2be2..23eef1a014df 100644 --- a/datafusion/functions/benches/date_bin.rs +++ b/datafusion/functions/benches/date_bin.rs @@ -20,6 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, TimestampSecondArray}; +use arrow::datatypes::Field; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use rand::rngs::ThreadRng; @@ -48,13 +49,18 @@ fn criterion_benchmark(c: &mut Criterion) { let return_type = udf .return_type(&[interval.data_type(), timestamps.data_type()]) .unwrap(); + let return_field = Field::new("f", return_type, true); b.iter(|| { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![interval.clone(), timestamps.clone()], + arg_fields: vec![ + &Field::new("a", interval.data_type(), true), + &Field::new("b", timestamps.data_type(), true), + ], number_rows: batch_len, - return_type: &return_type, + return_field: &return_field, }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/date_trunc.rs b/datafusion/functions/benches/date_trunc.rs index e7e96fb7a9fa..780cbae615c2 100644 --- a/datafusion/functions/benches/date_trunc.rs +++ b/datafusion/functions/benches/date_trunc.rs @@ -20,6 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, TimestampSecondArray}; +use arrow::datatypes::Field; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; use rand::rngs::ThreadRng; @@ -47,15 +48,24 @@ fn criterion_benchmark(c: &mut Criterion) { let timestamps = ColumnarValue::Array(timestamps_array); let udf = date_trunc(); let args = vec![precision, timestamps]; - let return_type = &udf + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + + let return_type = udf .return_type(&args.iter().map(|arg| arg.data_type()).collect::>()) .unwrap(); + let return_field = Field::new("f", return_type, true); b.iter(|| { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field: &return_field, }) .expect("date_trunc should work on valid values"), ) diff --git a/datafusion/functions/benches/encoding.rs b/datafusion/functions/benches/encoding.rs index cf8f8d2fd62c..8210de82ef49 100644 --- a/datafusion/functions/benches/encoding.rs +++ b/datafusion/functions/benches/encoding.rs @@ -17,7 +17,8 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::array::Array; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -33,19 +34,29 @@ fn criterion_benchmark(c: &mut Criterion) { let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + arg_fields: vec![ + &Field::new("a", str_array.data_type().to_owned(), true), + &Field::new("b", method.data_type().to_owned(), true), + ], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(); + let arg_fields = vec![ + Field::new("a", encoded.data_type().to_owned(), true), + Field::new("b", method.data_type().to_owned(), true), + ]; let args = vec![encoded, method]; + b.iter(|| { black_box( decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.iter().collect(), number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) @@ -54,22 +65,33 @@ fn criterion_benchmark(c: &mut Criterion) { c.bench_function(&format!("hex_decode/{size}"), |b| { let method = ColumnarValue::Scalar("hex".into()); + let arg_fields = vec![ + Field::new("a", str_array.data_type().to_owned(), true), + Field::new("b", method.data_type().to_owned(), true), + ]; let encoded = encoding::encode() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Array(str_array.clone()), method.clone()], + arg_fields: arg_fields.iter().collect(), number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(); + let arg_fields = vec![ + Field::new("a", encoded.data_type().to_owned(), true), + Field::new("b", method.data_type().to_owned(), true), + ]; let args = vec![encoded, method]; + b.iter(|| { black_box( decode .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.iter().collect(), number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/find_in_set.rs b/datafusion/functions/benches/find_in_set.rs index 9307525482c2..6e60692fc581 100644 --- a/datafusion/functions/benches/find_in_set.rs +++ b/datafusion/functions/benches/find_in_set.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::{StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; @@ -153,23 +153,35 @@ fn criterion_benchmark(c: &mut Criterion) { group.measurement_time(Duration::from_secs(10)); let args = gen_args_array(n_rows, str_len, 0.1, 0.5, false); + let arg_fields_owned = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); group.bench_function(format!("string_len_{}", str_len), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }); let args = gen_args_array(n_rows, str_len, 0.1, 0.5, true); + let arg_fields_owned = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); group.bench_function(format!("string_view_len_{}", str_len), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }); @@ -179,23 +191,35 @@ fn criterion_benchmark(c: &mut Criterion) { let mut group = c.benchmark_group("find_in_set_scalar"); let args = gen_args_scalar(n_rows, str_len, 0.1, false); + let arg_fields_owned = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); group.bench_function(format!("string_len_{}", str_len), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }); let args = gen_args_scalar(n_rows, str_len, 0.1, true); + let arg_fields_owned = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); group.bench_function(format!("string_view_len_{}", str_len), |b| { b.iter(|| { black_box(find_in_set.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }); diff --git a/datafusion/functions/benches/gcd.rs b/datafusion/functions/benches/gcd.rs index f8c855c82ad4..c56f6aae1c07 100644 --- a/datafusion/functions/benches/gcd.rs +++ b/datafusion/functions/benches/gcd.rs @@ -17,6 +17,7 @@ extern crate criterion; +use arrow::datatypes::Field; use arrow::{ array::{ArrayRef, Int64Array}, datatypes::DataType, @@ -47,8 +48,12 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), array_b.clone()], + arg_fields: vec![ + &Field::new("a", array_a.data_type(), true), + &Field::new("b", array_b.data_type(), true), + ], number_rows: 0, - return_type: &DataType::Int64, + return_field: &Field::new("f", DataType::Int64, true), }) .expect("date_bin should work on valid values"), ) @@ -63,8 +68,12 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![array_a.clone(), scalar_b.clone()], + arg_fields: vec![ + &Field::new("a", array_a.data_type(), true), + &Field::new("b", scalar_b.data_type(), true), + ], number_rows: 0, - return_type: &DataType::Int64, + return_field: &Field::new("f", DataType::Int64, true), }) .expect("date_bin should work on valid values"), ) @@ -79,8 +88,12 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( udf.invoke_with_args(ScalarFunctionArgs { args: vec![scalar_a.clone(), scalar_b.clone()], + arg_fields: vec![ + &Field::new("a", scalar_a.data_type(), true), + &Field::new("b", scalar_b.data_type(), true), + ], number_rows: 0, - return_type: &DataType::Int64, + return_field: &Field::new("f", DataType::Int64, true), }) .expect("date_bin should work on valid values"), ) diff --git a/datafusion/functions/benches/initcap.rs b/datafusion/functions/benches/initcap.rs index 97c76831b33c..ecfafa31eec0 100644 --- a/datafusion/functions/benches/initcap.rs +++ b/datafusion/functions/benches/initcap.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::OffsetSizeTrait; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; @@ -49,14 +49,22 @@ fn criterion_benchmark(c: &mut Criterion) { let initcap = unicode::initcap(); for size in [1024, 4096] { let args = create_args::(size, 8, true); + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function( format!("initcap string view shorter than 12 [size={}]", size).as_str(), |b| { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8View, + return_field: &Field::new("f", DataType::Utf8View, true), })) }) }, @@ -69,8 +77,9 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8View, + return_field: &Field::new("f", DataType::Utf8View, true), })) }) }, @@ -81,8 +90,9 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(initcap.invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }); diff --git a/datafusion/functions/benches/isnan.rs b/datafusion/functions/benches/isnan.rs index 42004cc24f69..d1b4412118fe 100644 --- a/datafusion/functions/benches/isnan.rs +++ b/datafusion/functions/benches/isnan.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, @@ -32,14 +32,22 @@ fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 4096, 8192] { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let f32_args = vec![ColumnarValue::Array(f32_array)]; + let arg_fields_owned = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("isnan f32 array: {}", size), |b| { b.iter(|| { black_box( isnan .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Boolean, + return_field: &Field::new("f", DataType::Boolean, true), }) .unwrap(), ) @@ -47,14 +55,21 @@ fn criterion_benchmark(c: &mut Criterion) { }); let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let f64_args = vec![ColumnarValue::Array(f64_array)]; + let arg_fields_owned = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); c.bench_function(&format!("isnan f64 array: {}", size), |b| { b.iter(|| { black_box( isnan .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Boolean, + return_field: &Field::new("f", DataType::Boolean, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/iszero.rs b/datafusion/functions/benches/iszero.rs index 9e5f6a84804b..3abe776d0201 100644 --- a/datafusion/functions/benches/iszero.rs +++ b/datafusion/functions/benches/iszero.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::{ datatypes::{Float32Type, Float64Type}, util::bench_util::create_primitive_array, @@ -33,14 +33,22 @@ fn criterion_benchmark(c: &mut Criterion) { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); let f32_args = vec![ColumnarValue::Array(f32_array)]; + let arg_fields_owned = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("iszero f32 array: {}", size), |b| { b.iter(|| { black_box( iszero .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Boolean, + return_field: &Field::new("f", DataType::Boolean, true), }) .unwrap(), ) @@ -49,14 +57,22 @@ fn criterion_benchmark(c: &mut Criterion) { let f64_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; + let arg_fields_owned = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("iszero f64 array: {}", size), |b| { b.iter(|| { black_box( iszero .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Boolean, + return_field: &Field::new("f", DataType::Boolean, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/lower.rs b/datafusion/functions/benches/lower.rs index 534e5739225d..9cab9abbbb4e 100644 --- a/datafusion/functions/benches/lower.rs +++ b/datafusion/functions/benches/lower.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::{ArrayRef, StringArray, StringViewBuilder}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; @@ -124,18 +124,33 @@ fn criterion_benchmark(c: &mut Criterion) { let lower = string::lower(); for size in [1024, 4096, 8192] { let args = create_args1(size, 32); + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("lower_all_values_are_ascii: {}", size), |b| { b.iter(|| { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }); let args = create_args2(size); + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function( &format!("lower_the_first_value_is_nonascii: {}", size), |b| { @@ -143,14 +158,22 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, ); let args = create_args3(size); + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function( &format!("lower_the_middle_value_is_nonascii: {}", size), |b| { @@ -158,8 +181,9 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -176,6 +200,15 @@ fn criterion_benchmark(c: &mut Criterion) { for &str_len in &str_lens { for &size in &sizes { let args = create_args4(size, str_len, *null_density, mixed); + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| { + Field::new(format!("arg_{idx}"), arg.data_type(), true) + }) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function( &format!("lower_all_values_are_ascii_string_views: size: {}, str_len: {}, null_density: {}, mixed: {}", size, str_len, null_density, mixed), @@ -183,8 +216,9 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }), ); @@ -197,8 +231,9 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }), ); @@ -211,8 +246,9 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(lower.invoke_with_args(ScalarFunctionArgs{ args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }), ); diff --git a/datafusion/functions/benches/ltrim.rs b/datafusion/functions/benches/ltrim.rs index 457fb499f5a1..b5371439c166 100644 --- a/datafusion/functions/benches/ltrim.rs +++ b/datafusion/functions/benches/ltrim.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{ black_box, criterion_group, criterion_main, measurement::Measurement, BenchmarkGroup, Criterion, SamplingMode, @@ -136,6 +136,12 @@ fn run_with_string_type( string_type: StringArrayType, ) { let args = create_args(size, characters, trimmed, remaining_len, string_type); + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); group.bench_function( format!( "{string_type} [size={size}, len_before={len}, len_after={remaining_len}]", @@ -145,8 +151,9 @@ fn run_with_string_type( let args_cloned = args.clone(); black_box(ltrim.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, diff --git a/datafusion/functions/benches/make_date.rs b/datafusion/functions/benches/make_date.rs index 8dd7a7a59773..8e0fe5992ffc 100644 --- a/datafusion/functions/benches/make_date.rs +++ b/datafusion/functions/benches/make_date.rs @@ -20,7 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{Array, ArrayRef, Int32Array}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::rngs::ThreadRng; use rand::Rng; @@ -69,8 +69,13 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![years.clone(), months.clone(), days.clone()], + arg_fields: vec![ + &Field::new("a", years.data_type(), true), + &Field::new("a", months.data_type(), true), + &Field::new("a", days.data_type(), true), + ], number_rows: batch_len, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }) .expect("make_date should work on valid values"), ) @@ -90,8 +95,13 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), months.clone(), days.clone()], + arg_fields: vec![ + &Field::new("a", year.data_type(), true), + &Field::new("a", months.data_type(), true), + &Field::new("a", days.data_type(), true), + ], number_rows: batch_len, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }) .expect("make_date should work on valid values"), ) @@ -111,8 +121,13 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), days.clone()], + arg_fields: vec![ + &Field::new("a", year.data_type(), true), + &Field::new("a", month.data_type(), true), + &Field::new("a", days.data_type(), true), + ], number_rows: batch_len, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }) .expect("make_date should work on valid values"), ) @@ -129,8 +144,13 @@ fn criterion_benchmark(c: &mut Criterion) { make_date() .invoke_with_args(ScalarFunctionArgs { args: vec![year.clone(), month.clone(), day.clone()], + arg_fields: vec![ + &Field::new("a", year.data_type(), true), + &Field::new("a", month.data_type(), true), + &Field::new("a", day.data_type(), true), + ], number_rows: 1, - return_type: &DataType::Date32, + return_field: &Field::new("f", DataType::Date32, true), }) .expect("make_date should work on valid values"), ) diff --git a/datafusion/functions/benches/nullif.rs b/datafusion/functions/benches/nullif.rs index 9096c976bf31..0e04ad34a905 100644 --- a/datafusion/functions/benches/nullif.rs +++ b/datafusion/functions/benches/nullif.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_common::ScalarValue; @@ -33,14 +33,22 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::Scalar(ScalarValue::Utf8(Some("abcd".to_string()))), ColumnarValue::Array(array), ]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("nullif scalar array: {}", size), |b| { b.iter(|| { black_box( nullif .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/pad.rs b/datafusion/functions/benches/pad.rs index f78a53fbee19..c484a583584c 100644 --- a/datafusion/functions/benches/pad.rs +++ b/datafusion/functions/benches/pad.rs @@ -16,11 +16,12 @@ // under the License. use arrow::array::{ArrayRef, ArrowPrimitiveType, OffsetSizeTrait, PrimitiveArray}; -use arrow::datatypes::{DataType, Int64Type}; +use arrow::datatypes::{DataType, Field, Int64Type}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode::{lpad, rpad}; use rand::distributions::{Distribution, Uniform}; @@ -95,21 +96,42 @@ fn create_args( } } +fn invoke_pad_with_args( + args: Vec, + number_rows: usize, + left_pad: bool, +) -> Result { + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + + let scalar_args = ScalarFunctionArgs { + args: args.clone(), + arg_fields, + number_rows, + return_field: &Field::new("f", DataType::Utf8, true), + }; + + if left_pad { + lpad().invoke_with_args(scalar_args) + } else { + rpad().invoke_with_args(scalar_args) + } +} + fn criterion_benchmark(c: &mut Criterion) { for size in [1024, 2048] { let mut group = c.benchmark_group("lpad function"); let args = create_args::(size, 32, false); + group.bench_function(BenchmarkId::new("utf8 type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -118,13 +140,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::LargeUtf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -133,13 +149,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("stringview type", size), |b| { b.iter(|| { criterion::black_box( - lpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, true).unwrap(), ) }) }); @@ -152,13 +162,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("utf8 type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); @@ -167,13 +171,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("largeutf8 type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::LargeUtf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); @@ -183,13 +181,7 @@ fn criterion_benchmark(c: &mut Criterion) { group.bench_function(BenchmarkId::new("stringview type", size), |b| { b.iter(|| { criterion::black_box( - rpad() - .invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8, - }) - .unwrap(), + invoke_pad_with_args(args.clone(), size, false).unwrap(), ) }) }); diff --git a/datafusion/functions/benches/random.rs b/datafusion/functions/benches/random.rs index 78ebf23e02e0..5a80bbd1f11b 100644 --- a/datafusion/functions/benches/random.rs +++ b/datafusion/functions/benches/random.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ScalarFunctionArgs, ScalarUDFImpl}; use datafusion_functions::math::random::RandomFunc; @@ -34,8 +34,9 @@ fn criterion_benchmark(c: &mut Criterion) { random_func .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 8192, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }) .unwrap(), ); @@ -52,8 +53,9 @@ fn criterion_benchmark(c: &mut Criterion) { random_func .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 128, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }) .unwrap(), ); diff --git a/datafusion/functions/benches/repeat.rs b/datafusion/functions/benches/repeat.rs index 5cc6a177d9d9..a29dc494047e 100644 --- a/datafusion/functions/benches/repeat.rs +++ b/datafusion/functions/benches/repeat.rs @@ -18,11 +18,12 @@ extern crate criterion; use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::string; use std::sync::Arc; @@ -56,8 +57,26 @@ fn create_args( } } +fn invoke_repeat_with_args( + args: Vec, + repeat_times: i64, +) -> Result { + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + + string::repeat().invoke_with_args(ScalarFunctionArgs { + args, + arg_fields, + number_rows: repeat_times as usize, + return_field: &Field::new("f", DataType::Utf8, true), + }) +} + fn criterion_benchmark(c: &mut Criterion) { - let repeat = string::repeat(); for size in [1024, 4096] { // REPEAT 3 TIMES let repeat_times = 3; @@ -75,11 +94,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -93,11 +108,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -111,11 +122,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -138,11 +145,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -156,11 +159,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -174,11 +173,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); @@ -201,11 +196,7 @@ fn criterion_benchmark(c: &mut Criterion) { |b| { b.iter(|| { let args_cloned = args.clone(); - black_box(repeat.invoke_with_args(ScalarFunctionArgs { - args: args_cloned, - number_rows: repeat_times as usize, - return_type: &DataType::Utf8, - })) + black_box(invoke_repeat_with_args(args_cloned, repeat_times)) }) }, ); diff --git a/datafusion/functions/benches/reverse.rs b/datafusion/functions/benches/reverse.rs index d61f8fb80517..102a4396a62b 100644 --- a/datafusion/functions/benches/reverse.rs +++ b/datafusion/functions/benches/reverse.rs @@ -18,7 +18,7 @@ extern crate criterion; mod helper; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ScalarFunctionArgs; use helper::gen_string_array; @@ -46,8 +46,13 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_fields: vec![&Field::new( + "a", + args_string_ascii[0].data_type(), + true, + )], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -65,8 +70,13 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), + arg_fields: vec![&Field::new( + "a", + args_string_utf8[0].data_type(), + true, + )], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -86,8 +96,13 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_fields: vec![&Field::new( + "a", + args_string_view_ascii[0].data_type(), + true, + )], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, @@ -105,8 +120,13 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(reverse.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_fields: vec![&Field::new( + "a", + args_string_view_utf8[0].data_type(), + true, + )], number_rows: N_ROWS, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }, diff --git a/datafusion/functions/benches/signum.rs b/datafusion/functions/benches/signum.rs index 01939fad5f34..0cd8b9451f39 100644 --- a/datafusion/functions/benches/signum.rs +++ b/datafusion/functions/benches/signum.rs @@ -19,7 +19,7 @@ extern crate criterion; use arrow::datatypes::DataType; use arrow::{ - datatypes::{Float32Type, Float64Type}, + datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -33,14 +33,22 @@ fn criterion_benchmark(c: &mut Criterion) { let f32_array = Arc::new(create_primitive_array::(size, 0.2)); let batch_len = f32_array.len(); let f32_args = vec![ColumnarValue::Array(f32_array)]; + let arg_fields_owned = f32_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("signum f32 array: {}", size), |b| { b.iter(|| { black_box( signum .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }) .unwrap(), ) @@ -50,14 +58,22 @@ fn criterion_benchmark(c: &mut Criterion) { let batch_len = f64_array.len(); let f64_args = vec![ColumnarValue::Array(f64_array)]; + let arg_fields_owned = f64_args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + c.bench_function(&format!("signum f64 array: {}", size), |b| { b.iter(|| { black_box( signum .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/strpos.rs b/datafusion/functions/benches/strpos.rs index df57c229e0ad..894f2cdb65eb 100644 --- a/datafusion/functions/benches/strpos.rs +++ b/datafusion/functions/benches/strpos.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::array::{StringArray, StringViewArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use rand::distributions::Alphanumeric; @@ -117,8 +117,13 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_ascii.clone(), + arg_fields: vec![&Field::new( + "a", + args_string_ascii[0].data_type(), + true, + )], number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }, @@ -132,8 +137,13 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_utf8.clone(), + arg_fields: vec![&Field::new( + "a", + args_string_utf8[0].data_type(), + true, + )], number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }, @@ -147,8 +157,13 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_ascii.clone(), + arg_fields: vec![&Field::new( + "a", + args_string_view_ascii[0].data_type(), + true, + )], number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }, @@ -162,8 +177,13 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(strpos.invoke_with_args(ScalarFunctionArgs { args: args_string_view_utf8.clone(), + arg_fields: vec![&Field::new( + "a", + args_string_view_utf8[0].data_type(), + true, + )], number_rows: n_rows, - return_type: &DataType::Int32, + return_field: &Field::new("f", DataType::Int32, true), })) }) }, diff --git a/datafusion/functions/benches/substr.rs b/datafusion/functions/benches/substr.rs index 80ab70ef71b0..e01e1c6d5b3b 100644 --- a/datafusion/functions/benches/substr.rs +++ b/datafusion/functions/benches/substr.rs @@ -18,11 +18,12 @@ extern crate criterion; use arrow::array::{ArrayRef, Int64Array, OffsetSizeTrait}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::{ create_string_array_with_len, create_string_view_array_with_len, }; use criterion::{black_box, criterion_group, criterion_main, Criterion, SamplingMode}; +use datafusion_common::DataFusionError; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; use datafusion_functions::unicode; use std::sync::Arc; @@ -96,8 +97,26 @@ fn create_args_with_count( } } +fn invoke_substr_with_args( + args: Vec, + number_rows: usize, +) -> Result { + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + + unicode::substr().invoke_with_args(ScalarFunctionArgs { + args: args.clone(), + arg_fields, + number_rows, + return_field: &Field::new("f", DataType::Utf8View, true), + }) +} + fn criterion_benchmark(c: &mut Criterion) { - let substr = unicode::substr(); for size in [1024, 4096] { // string_len = 12, substring_len=6 (see `create_args_without_count`) let len = 12; @@ -108,43 +127,19 @@ fn criterion_benchmark(c: &mut Criterion) { let args = create_args_without_count::(size, len, true, true); group.bench_function( format!("substr_string_view [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_without_count::(size, len, false, false); group.bench_function( format!("substr_string [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_without_count::(size, len, true, false); group.bench_function( format!("substr_large_string [size={}, strlen={}]", size, len), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); @@ -162,15 +157,7 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_string_view [size={}, count={}, strlen={}]", size, count, len, ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); @@ -179,15 +166,7 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_string [size={}, count={}, strlen={}]", size, count, len, ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); @@ -196,15 +175,7 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_large_string [size={}, count={}, strlen={}]", size, count, len, ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); @@ -222,15 +193,7 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_string_view [size={}, count={}, strlen={}]", size, count, len, ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); @@ -239,15 +202,7 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_string [size={}, count={}, strlen={}]", size, count, len, ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); let args = create_args_with_count::(size, len, count, false); @@ -256,15 +211,7 @@ fn criterion_benchmark(c: &mut Criterion) { "substr_large_string [size={}, count={}, strlen={}]", size, count, len, ), - |b| { - b.iter(|| { - black_box(substr.invoke_with_args(ScalarFunctionArgs { - args: args.clone(), - number_rows: size, - return_type: &DataType::Utf8View, - })) - }) - }, + |b| b.iter(|| black_box(invoke_substr_with_args(args.clone(), size))), ); group.finish(); diff --git a/datafusion/functions/benches/substr_index.rs b/datafusion/functions/benches/substr_index.rs index b1c1c3c34a95..5604f7b6a914 100644 --- a/datafusion/functions/benches/substr_index.rs +++ b/datafusion/functions/benches/substr_index.rs @@ -20,7 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{ArrayRef, Int64Array, StringArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use rand::distributions::{Alphanumeric, Uniform}; use rand::prelude::Distribution; @@ -91,13 +91,21 @@ fn criterion_benchmark(c: &mut Criterion) { let counts = ColumnarValue::Array(Arc::new(counts) as ArrayRef); let args = vec![strings, delimiters, counts]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + b.iter(|| { black_box( substr_index() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .expect("substr_index should work on valid values"), ) diff --git a/datafusion/functions/benches/to_char.rs b/datafusion/functions/benches/to_char.rs index 6f20a20dc219..9189455fd387 100644 --- a/datafusion/functions/benches/to_char.rs +++ b/datafusion/functions/benches/to_char.rs @@ -20,7 +20,7 @@ extern crate criterion; use std::sync::Arc; use arrow::array::{ArrayRef, Date32Array, StringArray}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use chrono::prelude::*; use chrono::TimeDelta; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -93,8 +93,12 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], + arg_fields: vec![ + &Field::new("a", data.data_type(), true), + &Field::new("b", patterns.data_type(), true), + ], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .expect("to_char should work on valid values"), ) @@ -114,8 +118,12 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), patterns.clone()], + arg_fields: vec![ + &Field::new("a", data.data_type(), true), + &Field::new("b", patterns.data_type(), true), + ], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .expect("to_char should work on valid values"), ) @@ -141,8 +149,12 @@ fn criterion_benchmark(c: &mut Criterion) { to_char() .invoke_with_args(ScalarFunctionArgs { args: vec![data.clone(), pattern.clone()], + arg_fields: vec![ + &Field::new("a", data.data_type(), true), + &Field::new("b", pattern.data_type(), true), + ], number_rows: 1, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .expect("to_char should work on valid values"), ) diff --git a/datafusion/functions/benches/to_hex.rs b/datafusion/functions/benches/to_hex.rs index a45d936c0a52..5ab49cb96c58 100644 --- a/datafusion/functions/benches/to_hex.rs +++ b/datafusion/functions/benches/to_hex.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::{DataType, Int32Type, Int64Type}; +use arrow::datatypes::{DataType, Field, Int32Type, Int64Type}; use arrow::util::bench_util::create_primitive_array; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -36,8 +36,9 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: vec![&Field::new("a", DataType::Int32, false)], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) @@ -52,8 +53,9 @@ fn criterion_benchmark(c: &mut Criterion) { black_box( hex.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: vec![&Field::new("a", DataType::Int64, false)], number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/to_timestamp.rs b/datafusion/functions/benches/to_timestamp.rs index aec56697691f..ae38b3ae9df6 100644 --- a/datafusion/functions/benches/to_timestamp.rs +++ b/datafusion/functions/benches/to_timestamp.rs @@ -22,7 +22,7 @@ use std::sync::Arc; use arrow::array::builder::StringBuilder; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::compute::cast; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::{DataType, Field, TimeUnit}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -109,7 +109,10 @@ fn data_with_formats() -> (StringArray, StringArray, StringArray, StringArray) { ) } fn criterion_benchmark(c: &mut Criterion) { - let return_type = &DataType::Timestamp(TimeUnit::Nanosecond, None); + let return_field = + &Field::new("f", DataType::Timestamp(TimeUnit::Nanosecond, None), true); + let arg_field = Field::new("a", DataType::Utf8, false); + let arg_fields = vec![&arg_field]; c.bench_function("to_timestamp_no_formats_utf8", |b| { let arr_data = data(); let batch_len = arr_data.len(); @@ -120,8 +123,9 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field, }) .expect("to_timestamp should work on valid values"), ) @@ -138,8 +142,9 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field, }) .expect("to_timestamp should work on valid values"), ) @@ -156,8 +161,9 @@ fn criterion_benchmark(c: &mut Criterion) { to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: vec![string_array.clone()], + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field, }) .expect("to_timestamp should work on valid values"), ) @@ -174,13 +180,21 @@ fn criterion_benchmark(c: &mut Criterion) { ColumnarValue::Array(Arc::new(format2) as ArrayRef), ColumnarValue::Array(Arc::new(format3) as ArrayRef), ]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field, }) .expect("to_timestamp should work on valid values"), ) @@ -205,13 +219,21 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(cast(&format3, &DataType::LargeUtf8).unwrap()) as ArrayRef ), ]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field, }) .expect("to_timestamp should work on valid values"), ) @@ -237,13 +259,21 @@ fn criterion_benchmark(c: &mut Criterion) { Arc::new(cast(&format3, &DataType::Utf8View).unwrap()) as ArrayRef ), ]; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, arg)| Field::new(format!("arg_{idx}"), arg.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + b.iter(|| { black_box( to_timestamp() .invoke_with_args(ScalarFunctionArgs { args: args.clone(), + arg_fields: arg_fields.clone(), number_rows: batch_len, - return_type, + return_field, }) .expect("to_timestamp should work on valid values"), ) diff --git a/datafusion/functions/benches/trunc.rs b/datafusion/functions/benches/trunc.rs index 7fc93921d2e7..26e4b5f234a4 100644 --- a/datafusion/functions/benches/trunc.rs +++ b/datafusion/functions/benches/trunc.rs @@ -18,7 +18,7 @@ extern crate criterion; use arrow::{ - datatypes::{Float32Type, Float64Type}, + datatypes::{Field, Float32Type, Float64Type}, util::bench_util::create_primitive_array, }; use criterion::{black_box, criterion_group, criterion_main, Criterion}; @@ -39,8 +39,9 @@ fn criterion_benchmark(c: &mut Criterion) { trunc .invoke_with_args(ScalarFunctionArgs { args: f32_args.clone(), + arg_fields: vec![&Field::new("a", DataType::Float32, false)], number_rows: size, - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }) .unwrap(), ) @@ -54,8 +55,9 @@ fn criterion_benchmark(c: &mut Criterion) { trunc .invoke_with_args(ScalarFunctionArgs { args: f64_args.clone(), + arg_fields: vec![&Field::new("a", DataType::Float64, false)], number_rows: size, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }) .unwrap(), ) diff --git a/datafusion/functions/benches/upper.rs b/datafusion/functions/benches/upper.rs index f0bee89c7d37..e218f6d0372a 100644 --- a/datafusion/functions/benches/upper.rs +++ b/datafusion/functions/benches/upper.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::util::bench_util::create_string_array_with_len; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs}; @@ -42,8 +42,9 @@ fn criterion_benchmark(c: &mut Criterion) { let args_cloned = args.clone(); black_box(upper.invoke_with_args(ScalarFunctionArgs { args: args_cloned, + arg_fields: vec![&Field::new("a", DataType::Utf8, true)], number_rows: size, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }); diff --git a/datafusion/functions/benches/uuid.rs b/datafusion/functions/benches/uuid.rs index 7b8d156fec21..dfed6871d6c2 100644 --- a/datafusion/functions/benches/uuid.rs +++ b/datafusion/functions/benches/uuid.rs @@ -17,7 +17,7 @@ extern crate criterion; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use datafusion_expr::ScalarFunctionArgs; use datafusion_functions::string; @@ -28,8 +28,9 @@ fn criterion_benchmark(c: &mut Criterion) { b.iter(|| { black_box(uuid.invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 1024, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })) }) }); diff --git a/datafusion/functions/src/core/arrow_cast.rs b/datafusion/functions/src/core/arrow_cast.rs index 2686dbf8be3c..0e18ec180cef 100644 --- a/datafusion/functions/src/core/arrow_cast.rs +++ b/datafusion/functions/src/core/arrow_cast.rs @@ -17,7 +17,7 @@ //! [`ArrowCastFunc`]: Implementation of the `arrow_cast` -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow::error::ArrowError; use datafusion_common::{ arrow_datafusion_err, exec_err, internal_err, Result, ScalarValue, @@ -29,7 +29,7 @@ use std::any::Any; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; @@ -113,11 +113,11 @@ impl ScalarUDFImpl for ArrowCastFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - let nullable = args.nullables.iter().any(|&nullable| nullable); + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + let nullable = args.arg_fields.iter().any(|f| f.is_nullable()); let [_, type_arg] = take_function_args(self.name(), args.scalar_arguments)?; @@ -131,7 +131,7 @@ impl ScalarUDFImpl for ArrowCastFunc { ) }, |casted_type| match casted_type.parse::() { - Ok(data_type) => Ok(ReturnInfo::new(data_type, nullable)), + Ok(data_type) => Ok(Field::new(self.name(), data_type, nullable)), Err(ArrowError::ParseError(e)) => Err(exec_datafusion_err!("{e}")), Err(e) => Err(arrow_datafusion_err!(e)), }, diff --git a/datafusion/functions/src/core/coalesce.rs b/datafusion/functions/src/core/coalesce.rs index ba20c23828eb..b2ca3692c1d3 100644 --- a/datafusion/functions/src/core/coalesce.rs +++ b/datafusion/functions/src/core/coalesce.rs @@ -18,11 +18,11 @@ use arrow::array::{new_null_array, BooleanArray}; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, is_not_null, is_null}; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::binary::try_type_union_resolution; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -79,19 +79,20 @@ impl ScalarUDFImpl for CoalesceFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // If any the arguments in coalesce is non-null, the result is non-null - let nullable = args.nullables.iter().all(|&nullable| nullable); + let nullable = args.arg_fields.iter().all(|f| f.is_nullable()); let return_type = args - .arg_types + .arg_fields .iter() + .map(|f| f.data_type()) .find_or_first(|d| !d.is_null()) .unwrap() .clone(); - Ok(ReturnInfo::new(return_type, nullable)) + Ok(Field::new(self.name(), return_type, nullable)) } /// coalesce evaluates to the first value which is not NULL diff --git a/datafusion/functions/src/core/getfield.rs b/datafusion/functions/src/core/getfield.rs index 3ac26b98359b..97df76eaac58 100644 --- a/datafusion/functions/src/core/getfield.rs +++ b/datafusion/functions/src/core/getfield.rs @@ -20,7 +20,7 @@ use arrow::array::{ Scalar, }; use arrow::compute::SortOptions; -use arrow::datatypes::DataType; +use arrow::datatypes::{DataType, Field}; use arrow_buffer::NullBuffer; use datafusion_common::cast::{as_map_array, as_struct_array}; use datafusion_common::{ @@ -28,7 +28,7 @@ use datafusion_common::{ ScalarValue, }; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarFunctionArgs, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -130,14 +130,14 @@ impl ScalarUDFImpl for GetFieldFunc { } fn return_type(&self, _: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // Length check handled in the signature debug_assert_eq!(args.scalar_arguments.len(), 2); - match (&args.arg_types[0], args.scalar_arguments[1].as_ref()) { + match (&args.arg_fields[0].data_type(), args.scalar_arguments[1].as_ref()) { (DataType::Map(fields, _), _) => { match fields.data_type() { DataType::Struct(fields) if fields.len() == 2 => { @@ -146,7 +146,8 @@ impl ScalarUDFImpl for GetFieldFunc { // instead, we assume that the second column is the "value" column both here and in // execution. let value_field = fields.get(1).expect("fields should have exactly two members"); - Ok(ReturnInfo::new_nullable(value_field.data_type().clone())) + + Ok(value_field.as_ref().clone().with_nullable(true)) }, _ => exec_err!("Map fields must contain a Struct with exactly 2 fields"), } @@ -158,10 +159,20 @@ impl ScalarUDFImpl for GetFieldFunc { |field_name| { fields.iter().find(|f| f.name() == field_name) .ok_or(plan_datafusion_err!("Field {field_name} not found in struct")) - .map(|f| ReturnInfo::new_nullable(f.data_type().to_owned())) + .map(|f| { + let mut child_field = f.as_ref().clone(); + + // If the parent is nullable, then getting the child must be nullable, + // so potentially override the return value + + if args.arg_fields[0].is_nullable() { + child_field = child_field.with_nullable(true); + } + child_field + }) }) }, - (DataType::Null, _) => Ok(ReturnInfo::new_nullable(DataType::Null)), + (DataType::Null, _) => Ok(Field::new(self.name(), DataType::Null, true)), (other, _) => exec_err!("The expression to get an indexed field is only valid for `Struct`, `Map` or `Null` types, got {other}"), } } diff --git a/datafusion/functions/src/core/named_struct.rs b/datafusion/functions/src/core/named_struct.rs index bba884d96483..5fb118a8a2fa 100644 --- a/datafusion/functions/src/core/named_struct.rs +++ b/datafusion/functions/src/core/named_struct.rs @@ -19,7 +19,7 @@ use arrow::array::StructArray; use arrow::datatypes::{DataType, Field, Fields}; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs, }; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -91,10 +91,12 @@ impl ScalarUDFImpl for NamedStructFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("named_struct: return_type called instead of return_type_from_args") + internal_err!( + "named_struct: return_type called instead of return_field_from_args" + ) } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // do not accept 0 arguments. if args.scalar_arguments.is_empty() { return exec_err!( @@ -126,7 +128,13 @@ impl ScalarUDFImpl for NamedStructFunc { ) ) .collect::>>()?; - let types = args.arg_types.iter().skip(1).step_by(2).collect::>(); + let types = args + .arg_fields + .iter() + .skip(1) + .step_by(2) + .map(|f| f.data_type()) + .collect::>(); let return_fields = names .into_iter() @@ -134,13 +142,15 @@ impl ScalarUDFImpl for NamedStructFunc { .map(|(name, data_type)| Ok(Field::new(name, data_type.to_owned(), true))) .collect::>>()?; - Ok(ReturnInfo::new_nullable(DataType::Struct(Fields::from( - return_fields, - )))) + Ok(Field::new( + self.name(), + DataType::Struct(Fields::from(return_fields)), + true, + )) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let DataType::Struct(fields) = args.return_type else { + let DataType::Struct(fields) = args.return_field.data_type() else { return internal_err!("incorrect named_struct return type"); }; diff --git a/datafusion/functions/src/core/struct.rs b/datafusion/functions/src/core/struct.rs index 8792bf1bd1b9..416ad288fc38 100644 --- a/datafusion/functions/src/core/struct.rs +++ b/datafusion/functions/src/core/struct.rs @@ -117,7 +117,7 @@ impl ScalarUDFImpl for StructFunc { } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { - let DataType::Struct(fields) = args.return_type else { + let DataType::Struct(fields) = args.return_field.data_type() else { return internal_err!("incorrect struct return type"); }; diff --git a/datafusion/functions/src/core/union_extract.rs b/datafusion/functions/src/core/union_extract.rs index 420eeed42cc3..b1544a9b357b 100644 --- a/datafusion/functions/src/core/union_extract.rs +++ b/datafusion/functions/src/core/union_extract.rs @@ -16,14 +16,14 @@ // under the License. use arrow::array::Array; -use arrow::datatypes::{DataType, FieldRef, UnionFields}; +use arrow::datatypes::{DataType, Field, FieldRef, UnionFields}; use datafusion_common::cast::as_union_array; use datafusion_common::utils::take_function_args; use datafusion_common::{ exec_datafusion_err, exec_err, internal_err, Result, ScalarValue, }; use datafusion_doc::Documentation; -use datafusion_expr::{ColumnarValue, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs}; +use datafusion_expr::{ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs}; use datafusion_expr::{ScalarUDFImpl, Signature, Volatility}; use datafusion_macros::user_doc; @@ -82,35 +82,35 @@ impl ScalarUDFImpl for UnionExtractFun { } fn return_type(&self, _: &[DataType]) -> Result { - // should be using return_type_from_args and not calling the default implementation + // should be using return_field_from_args and not calling the default implementation internal_err!("union_extract should return type from args") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { - if args.arg_types.len() != 2 { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { + if args.arg_fields.len() != 2 { return exec_err!( "union_extract expects 2 arguments, got {} instead", - args.arg_types.len() + args.arg_fields.len() ); } - let DataType::Union(fields, _) = &args.arg_types[0] else { + let DataType::Union(fields, _) = &args.arg_fields[0].data_type() else { return exec_err!( "union_extract first argument must be a union, got {} instead", - args.arg_types[0] + args.arg_fields[0].data_type() ); }; let Some(ScalarValue::Utf8(Some(field_name))) = &args.scalar_arguments[1] else { return exec_err!( "union_extract second argument must be a non-null string literal, got {} instead", - args.arg_types[1] + args.arg_fields[1].data_type() ); }; let field = find_field(fields, field_name)?.1; - Ok(ReturnInfo::new_nullable(field.data_type().clone())) + Ok(Field::new(self.name(), field.data_type().clone(), true)) } fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result { @@ -189,47 +189,67 @@ mod tests { ], ); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + None, + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .collect::>(); + let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - None, - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], + args, + arg_fields: arg_fields.iter().collect(), number_rows: 1, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })?; assert_scalar(result, ScalarValue::Utf8(None)); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((3, Box::new(ScalarValue::Int32(Some(42))))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .collect::>(); + let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - Some((3, Box::new(ScalarValue::Int32(Some(42))))), - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], + args, + arg_fields: arg_fields.iter().collect(), number_rows: 1, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })?; assert_scalar(result, ScalarValue::Utf8(None)); + let args = vec![ + ColumnarValue::Scalar(ScalarValue::Union( + Some((1, Box::new(ScalarValue::new_utf8("42")))), + fields.clone(), + UnionMode::Dense, + )), + ColumnarValue::Scalar(ScalarValue::new_utf8("str")), + ]; + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type().clone(), true)) + .collect::>(); let result = fun.invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::Union( - Some((1, Box::new(ScalarValue::new_utf8("42")))), - fields.clone(), - UnionMode::Dense, - )), - ColumnarValue::Scalar(ScalarValue::new_utf8("str")), - ], + args, + arg_fields: arg_fields.iter().collect(), number_rows: 1, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), })?; assert_scalar(result, ScalarValue::new_utf8("42")); diff --git a/datafusion/functions/src/core/version.rs b/datafusion/functions/src/core/version.rs index 34038022f2dc..9e243dd0adb8 100644 --- a/datafusion/functions/src/core/version.rs +++ b/datafusion/functions/src/core/version.rs @@ -97,6 +97,7 @@ impl ScalarUDFImpl for VersionFunc { #[cfg(test)] mod test { use super::*; + use arrow::datatypes::Field; use datafusion_expr::ScalarUDF; #[tokio::test] @@ -105,8 +106,9 @@ mod test { let version = version_udf .invoke_with_args(ScalarFunctionArgs { args: vec![], + arg_fields: vec![], number_rows: 0, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }) .unwrap(); diff --git a/datafusion/functions/src/datetime/date_bin.rs b/datafusion/functions/src/datetime/date_bin.rs index 5ffae46dde48..ea9e3d091860 100644 --- a/datafusion/functions/src/datetime/date_bin.rs +++ b/datafusion/functions/src/datetime/date_bin.rs @@ -505,85 +505,85 @@ mod tests { use arrow::array::types::TimestampNanosecondType; use arrow::array::{Array, IntervalDayTimeArray, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, TimeUnit}; use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano}; - use datafusion_common::ScalarValue; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use chrono::TimeDelta; + fn invoke_date_bin_with_args( + args: Vec, + number_rows: usize, + return_field: &Field, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true)) + .collect::>(); + + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields: arg_fields.iter().collect(), + number_rows, + return_field, + }; + DateBinFunc::new().invoke_with_args(args) + } + #[test] fn test_date_bin() { - let mut args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + let return_field = + &Field::new("f", DataType::Timestamp(TimeUnit::Nanosecond, None), true); + + let mut args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); let timestamps = Arc::new((1..6).map(Some).collect::()); let batch_len = timestamps.len(); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Array(timestamps), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Array(timestamps), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, batch_len, return_field); assert!(res.is_ok()); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // stride supports month-day-nano - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some( - IntervalMonthDayNano { - months: 0, - days: 0, - nanoseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalMonthDayNano(Some( + IntervalMonthDayNano { + months: 0, + days: 0, + nanoseconds: 1, + }, + ))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // @@ -591,33 +591,25 @@ mod tests { // // invalid number of arguments - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - )))], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + IntervalDayTime { + days: 0, + milliseconds: 1, + }, + )))]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expected two or three arguments" ); // stride: invalid type - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expects stride argument to be an INTERVAL but got Interval(YearMonth)" @@ -625,113 +617,83 @@ mod tests { // stride: invalid value - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 0, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 0, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; - let res = DateBinFunc::new().invoke_with_args(args); + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride must be non-zero" ); // stride: overflow of day-time interval - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime::MAX, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( + IntervalDayTime::MAX, + ))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride argument is too large" ); // stride: overflow of month-day-nano interval - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(0, i32::MAX, 1)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN stride argument is too large" ); // stride: month intervals - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1, 1, 1)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_mdn(1, 1, 1)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN stride does not support combination of month, day and nanosecond intervals" ); // origin: invalid type - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: DATE_BIN expects origin argument to be a TIMESTAMP with nanosecond precision but got Timestamp(Microsecond, None)" ); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert!(res.is_ok()); // unsupported array type for stride @@ -745,16 +707,12 @@ mod tests { }) .collect::(), ); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Array(intervals), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ], - number_rows: 1, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Array(intervals), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ]; + let res = invoke_date_bin_with_args(args, 1, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN only supports literal values for the stride argument, not arrays" @@ -763,21 +721,15 @@ mod tests { // unsupported array type for origin let timestamps = Arc::new((1..6).map(Some).collect::()); let batch_len = timestamps.len(); - args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some( - IntervalDayTime { - days: 0, - milliseconds: 1, - }, - ))), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), - ColumnarValue::Array(timestamps), - ], - number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), - }; - let res = DateBinFunc::new().invoke_with_args(args); + args = vec![ + ColumnarValue::Scalar(ScalarValue::IntervalDayTime(Some(IntervalDayTime { + days: 0, + milliseconds: 1, + }))), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), + ColumnarValue::Array(timestamps), + ]; + let res = invoke_date_bin_with_args(args, batch_len, return_field); assert_eq!( res.err().unwrap().strip_backtrace(), "This feature is not implemented: DATE_BIN only supports literal values for the origin argument, not arrays" @@ -893,22 +845,22 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), - ColumnarValue::Array(Arc::new(input)), - ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( - Some(string_to_timestamp_nanos(origin).unwrap()), - tz_opt.clone(), - )), - ], - number_rows: batch_len, - return_type: &DataType::Timestamp( - TimeUnit::Nanosecond, + let args = vec![ + ColumnarValue::Scalar(ScalarValue::new_interval_dt(1, 0)), + ColumnarValue::Array(Arc::new(input)), + ColumnarValue::Scalar(ScalarValue::TimestampNanosecond( + Some(string_to_timestamp_nanos(origin).unwrap()), tz_opt.clone(), - ), - }; - let result = DateBinFunc::new().invoke_with_args(args).unwrap(); + )), + ]; + let return_field = &Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + ); + let result = + invoke_date_bin_with_args(args, batch_len, return_field).unwrap(); + if let ColumnarValue::Array(result) = result { assert_eq!( result.data_type(), diff --git a/datafusion/functions/src/datetime/date_part.rs b/datafusion/functions/src/datetime/date_part.rs index bfd06b39d206..91f983e0acc3 100644 --- a/datafusion/functions/src/datetime/date_part.rs +++ b/datafusion/functions/src/datetime/date_part.rs @@ -26,7 +26,7 @@ use arrow::datatypes::DataType::{ Date32, Date64, Duration, Interval, Time32, Time64, Timestamp, }; use arrow::datatypes::TimeUnit::{Microsecond, Millisecond, Nanosecond, Second}; -use arrow::datatypes::{DataType, TimeUnit}; +use arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion_common::types::{logical_date, NativeType}; use datafusion_common::{ @@ -42,7 +42,7 @@ use datafusion_common::{ Result, ScalarValue, }; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, TypeSignature, Volatility, }; use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; @@ -142,10 +142,10 @@ impl ScalarUDFImpl for DatePartFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { let [field, _] = take_function_args(self.name(), args.scalar_arguments)?; field @@ -155,9 +155,9 @@ impl ScalarUDFImpl for DatePartFunc { .filter(|s| !s.is_empty()) .map(|part| { if is_epoch(part) { - ReturnInfo::new_nullable(DataType::Float64) + Field::new(self.name(), DataType::Float64, true) } else { - ReturnInfo::new_nullable(DataType::Int32) + Field::new(self.name(), DataType::Int32, true) } }) }) diff --git a/datafusion/functions/src/datetime/date_trunc.rs b/datafusion/functions/src/datetime/date_trunc.rs index ed3eb228bf03..331c4e093e0c 100644 --- a/datafusion/functions/src/datetime/date_trunc.rs +++ b/datafusion/functions/src/datetime/date_trunc.rs @@ -487,7 +487,7 @@ mod tests { use arrow::array::types::TimestampNanosecondType; use arrow::array::{Array, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, TimeUnit}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -726,13 +726,22 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", input.data_type().clone(), false), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from("day")), ColumnarValue::Array(Arc::new(input)), ], + arg_fields: arg_fields.iter().collect(), number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + return_field: &Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + ), }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { @@ -888,13 +897,22 @@ mod tests { .collect::() .with_timezone_opt(tz_opt.clone()); let batch_len = input.len(); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, false), + Field::new("b", input.data_type().clone(), false), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::from("hour")), ColumnarValue::Array(Arc::new(input)), ], + arg_fields: arg_fields.iter().collect(), number_rows: batch_len, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + return_field: &Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, tz_opt.clone()), + true, + ), }; let result = DateTruncFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/from_unixtime.rs b/datafusion/functions/src/datetime/from_unixtime.rs index ed8181452dbd..1afaf14d52e6 100644 --- a/datafusion/functions/src/datetime/from_unixtime.rs +++ b/datafusion/functions/src/datetime/from_unixtime.rs @@ -18,14 +18,13 @@ use std::any::Any; use std::sync::Arc; -use arrow::datatypes::DataType; use arrow::datatypes::DataType::{Int64, Timestamp, Utf8}; use arrow::datatypes::TimeUnit::Second; +use arrow::datatypes::{DataType, Field}; use datafusion_common::{exec_err, internal_err, Result, ScalarValue}; use datafusion_expr::TypeSignature::Exact; use datafusion_expr::{ - ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, Signature, - Volatility, + ColumnarValue, Documentation, ReturnFieldArgs, ScalarUDFImpl, Signature, Volatility, }; use datafusion_macros::user_doc; @@ -82,12 +81,12 @@ impl ScalarUDFImpl for FromUnixtimeFunc { &self.signature } - fn return_type_from_args(&self, args: ReturnTypeArgs) -> Result { + fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result { // Length check handled in the signature debug_assert!(matches!(args.scalar_arguments.len(), 1 | 2)); if args.scalar_arguments.len() == 1 { - Ok(ReturnInfo::new_nullable(Timestamp(Second, None))) + Ok(Field::new(self.name(), Timestamp(Second, None), true)) } else { args.scalar_arguments[1] .and_then(|sv| { @@ -95,10 +94,11 @@ impl ScalarUDFImpl for FromUnixtimeFunc { .flatten() .filter(|s| !s.is_empty()) .map(|tz| { - ReturnInfo::new_nullable(Timestamp( - Second, - Some(Arc::from(tz.to_string())), - )) + Field::new( + self.name(), + Timestamp(Second, Some(Arc::from(tz.to_string()))), + true, + ) }) }) .map_or_else( @@ -114,7 +114,7 @@ impl ScalarUDFImpl for FromUnixtimeFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("call return_type_from_args instead") + internal_err!("call return_field_from_args instead") } fn invoke_with_args( @@ -161,8 +161,8 @@ impl ScalarUDFImpl for FromUnixtimeFunc { #[cfg(test)] mod test { use crate::datetime::from_unixtime::FromUnixtimeFunc; - use arrow::datatypes::DataType; use arrow::datatypes::TimeUnit::Second; + use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_common::ScalarValue::Int64; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -170,10 +170,12 @@ mod test { #[test] fn test_without_timezone() { + let arg_field = Field::new("a", DataType::Int64, true); let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(Int64(Some(1729900800)))], + arg_fields: vec![&arg_field], number_rows: 1, - return_type: &DataType::Timestamp(Second, None), + return_field: &Field::new("f", DataType::Timestamp(Second, None), true), }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); @@ -187,6 +189,10 @@ mod test { #[test] fn test_with_timezone() { + let arg_fields = vec![ + Field::new("a", DataType::Int64, true), + Field::new("a", DataType::Utf8, true), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(Int64(Some(1729900800))), @@ -194,10 +200,12 @@ mod test { "America/New_York".to_string(), ))), ], + arg_fields: arg_fields.iter().collect(), number_rows: 2, - return_type: &DataType::Timestamp( - Second, - Some(Arc::from("America/New_York")), + return_field: &Field::new( + "f", + DataType::Timestamp(Second, Some(Arc::from("America/New_York"))), + true, ), }; let result = FromUnixtimeFunc::new().invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/datetime/make_date.rs b/datafusion/functions/src/datetime/make_date.rs index 929fa601f107..ed901258cd62 100644 --- a/datafusion/functions/src/datetime/make_date.rs +++ b/datafusion/functions/src/datetime/make_date.rs @@ -223,25 +223,39 @@ fn make_date_inner( mod tests { use crate::datetime::make_date::MakeDateFunc; use arrow::array::{Array, Date32Array, Int32Array, Int64Array, UInt32Array}; - use arrow::datatypes::DataType; - use datafusion_common::ScalarValue; + use arrow::datatypes::{DataType, Field}; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use std::sync::Arc; + fn invoke_make_date_with_args( + args: Vec, + number_rows: usize, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true)) + .collect::>(); + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields: arg_fields.iter().collect(), + number_rows, + return_field: &Field::new("f", DataType::Date32, true), + }; + MakeDateFunc::new().invoke_with_args(args) + } + #[test] fn test_make_date() { - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2024))), ColumnarValue::Scalar(ScalarValue::Int64(Some(1))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -249,18 +263,15 @@ mod tests { panic!("Expected a scalar value") } - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int64(Some(2024))), ColumnarValue::Scalar(ScalarValue::UInt64(Some(1))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(14))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -268,18 +279,15 @@ mod tests { panic!("Expected a scalar value") } - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Utf8(Some("2024".to_string()))), ColumnarValue::Scalar(ScalarValue::LargeUtf8(Some("1".to_string()))), ColumnarValue::Scalar(ScalarValue::Utf8(Some("14".to_string()))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + 1, + ) + .expect("that make_date parsed values without error"); if let ColumnarValue::Scalar(ScalarValue::Date32(date)) = res { assert_eq!(19736, date.unwrap()); @@ -291,18 +299,15 @@ mod tests { let months = Arc::new((1..5).map(Some).collect::()); let days = Arc::new((11..15).map(Some).collect::()); let batch_len = years.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Array(years), ColumnarValue::Array(months), ColumnarValue::Array(days), ], - number_rows: batch_len, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new() - .invoke_with_args(args) - .expect("that make_date parsed values without error"); + batch_len, + ) + .unwrap(); if let ColumnarValue::Array(array) = res { assert_eq!(array.len(), 4); @@ -321,60 +326,52 @@ mod tests { // // invalid number of arguments - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + let res = invoke_make_date_with_args( + vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Execution error: make_date function requires 3 arguments, got 1" ); // invalid type - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::IntervalYearMonth(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Casting from Interval(YearMonth) to Int32 not supported" ); // overflow of month - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), ColumnarValue::Scalar(ScalarValue::UInt64(Some(u64::MAX))), ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Can't cast value 18446744073709551615 to type Int32" ); // overflow of day - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let res = invoke_make_date_with_args( + vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(2023))), ColumnarValue::Scalar(ScalarValue::Int32(Some(22))), ColumnarValue::Scalar(ScalarValue::UInt32(Some(u32::MAX))), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let res = MakeDateFunc::new().invoke_with_args(args); + 1, + ); assert_eq!( res.err().unwrap().strip_backtrace(), "Arrow error: Cast error: Can't cast value 4294967295 to type Int32" diff --git a/datafusion/functions/src/datetime/now.rs b/datafusion/functions/src/datetime/now.rs index b26dc52cee4d..867442df45ad 100644 --- a/datafusion/functions/src/datetime/now.rs +++ b/datafusion/functions/src/datetime/now.rs @@ -15,16 +15,16 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::DataType; use arrow::datatypes::DataType::Timestamp; use arrow::datatypes::TimeUnit::Nanosecond; +use arrow::datatypes::{DataType, Field}; use std::any::Any; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::simplify::{ExprSimplifyResult, SimplifyInfo}; use datafusion_expr::{ - ColumnarValue, Documentation, Expr, ReturnInfo, ReturnTypeArgs, ScalarUDFImpl, - Signature, Volatility, + ColumnarValue, Documentation, Expr, ReturnFieldArgs, ScalarUDFImpl, Signature, + Volatility, }; use datafusion_macros::user_doc; @@ -77,15 +77,16 @@ impl ScalarUDFImpl for NowFunc { &self.signature } - fn return_type_from_args(&self, _args: ReturnTypeArgs) -> Result { - Ok(ReturnInfo::new_non_nullable(Timestamp( - Nanosecond, - Some("+00:00".into()), - ))) + fn return_field_from_args(&self, _args: ReturnFieldArgs) -> Result { + Ok(Field::new( + self.name(), + Timestamp(Nanosecond, Some("+00:00".into())), + false, + )) } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be called instead") + internal_err!("return_field_from_args should be called instead") } fn invoke_with_args( diff --git a/datafusion/functions/src/datetime/to_char.rs b/datafusion/functions/src/datetime/to_char.rs index 8b2e5ad87471..be3917092ba9 100644 --- a/datafusion/functions/src/datetime/to_char.rs +++ b/datafusion/functions/src/datetime/to_char.rs @@ -303,7 +303,7 @@ mod tests { TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, TimestampSecondArray, }; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field, TimeUnit}; use chrono::{NaiveDateTime, Timelike}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -385,10 +385,15 @@ mod tests { ]; for (value, format, expected) in scalar_data { + let arg_fields = vec![ + Field::new("a", value.data_type(), false), + Field::new("a", format.data_type(), false), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(value), ColumnarValue::Scalar(format)], + arg_fields: arg_fields.iter().collect(), number_rows: 1, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -465,13 +470,18 @@ mod tests { for (value, format, expected) in scalar_array_data { let batch_len = format.len(); + let arg_fields = vec![ + Field::new("a", value.data_type(), false), + Field::new("a", format.data_type().to_owned(), false), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], + arg_fields: arg_fields.iter().collect(), number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -596,13 +606,18 @@ mod tests { for (value, format, expected) in array_scalar_data { let batch_len = value.len(); + let arg_fields = vec![ + Field::new("a", value.data_type().clone(), false), + Field::new("a", format.data_type(), false), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value as ArrayRef), ColumnarValue::Scalar(format), ], + arg_fields: arg_fields.iter().collect(), number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -618,13 +633,18 @@ mod tests { for (value, format, expected) in array_array_data { let batch_len = value.len(); + let arg_fields = vec![ + Field::new("a", value.data_type().clone(), false), + Field::new("a", format.data_type().clone(), false), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Array(value), ColumnarValue::Array(Arc::new(format) as ArrayRef), ], + arg_fields: arg_fields.iter().collect(), number_rows: batch_len, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }; let result = ToCharFunc::new() .invoke_with_args(args) @@ -643,10 +663,12 @@ mod tests { // // invalid number of arguments + let arg_field = Field::new("a", DataType::Int32, true); let args = datafusion_expr::ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(ScalarValue::Int32(Some(1)))], + arg_fields: vec![&arg_field], number_rows: 1, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( @@ -655,13 +677,18 @@ mod tests { ); // invalid type + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true), + Field::new("a", DataType::Timestamp(TimeUnit::Nanosecond, None), true), + ]; let args = datafusion_expr::ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Int32(Some(1))), ColumnarValue::Scalar(ScalarValue::TimestampNanosecond(Some(1), None)), ], + arg_fields: arg_fields.iter().collect(), number_rows: 1, - return_type: &DataType::Utf8, + return_field: &Field::new("f", DataType::Utf8, true), }; let result = ToCharFunc::new().invoke_with_args(args); assert_eq!( diff --git a/datafusion/functions/src/datetime/to_date.rs b/datafusion/functions/src/datetime/to_date.rs index 91740b2c31c1..09635932760d 100644 --- a/datafusion/functions/src/datetime/to_date.rs +++ b/datafusion/functions/src/datetime/to_date.rs @@ -163,14 +163,32 @@ impl ScalarUDFImpl for ToDateFunc { #[cfg(test)] mod tests { use arrow::array::{Array, Date32Array, GenericStringArray, StringViewArray}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use arrow::{compute::kernels::cast_utils::Parser, datatypes::Date32Type}; - use datafusion_common::ScalarValue; + use datafusion_common::{DataFusionError, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; use std::sync::Arc; use super::ToDateFunc; + fn invoke_to_date_with_args( + args: Vec, + number_rows: usize, + ) -> Result { + let arg_fields = args + .iter() + .map(|arg| Field::new("a", arg.data_type(), true)) + .collect::>(); + + let args = datafusion_expr::ScalarFunctionArgs { + args, + arg_fields: arg_fields.iter().collect(), + number_rows, + return_field: &Field::new("f", DataType::Date32, true), + }; + ToDateFunc::new().invoke_with_args(args) + } + #[test] fn test_to_date_without_format() { struct TestCase { @@ -208,12 +226,8 @@ mod tests { } fn test_scalar(sv: ScalarValue, tc: &TestCase) { - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(sv)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(sv)], 1); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -234,12 +248,10 @@ mod tests { { let date_array = A::from(vec![tc.date_str]); let batch_len = date_array.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Array(Arc::new(date_array))], - number_rows: batch_len, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = invoke_to_date_with_args( + vec![ColumnarValue::Array(Arc::new(date_array))], + batch_len, + ); match to_date_result { Ok(ColumnarValue::Array(a)) => { @@ -328,15 +340,13 @@ mod tests { fn test_scalar(sv: ScalarValue, tc: &TestCase) { let format_scalar = ScalarValue::Utf8(Some(tc.format_str.to_string())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Scalar(sv), ColumnarValue::Scalar(format_scalar), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -358,15 +368,13 @@ mod tests { let format_array = A::from(vec![tc.format_str]); let batch_len = date_array.len(); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Array(Arc::new(date_array)), ColumnarValue::Array(Arc::new(format_array)), ], - number_rows: batch_len, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + batch_len, + ); match to_date_result { Ok(ColumnarValue::Array(a)) => { @@ -398,16 +406,14 @@ mod tests { let format1_scalar = ScalarValue::Utf8(Some("%Y-%m-%d".into())); let format2_scalar = ScalarValue::Utf8(Some("%Y/%m/%d".into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ + let to_date_result = invoke_to_date_with_args( + vec![ ColumnarValue::Scalar(formatted_date_scalar), ColumnarValue::Scalar(format1_scalar), ColumnarValue::Scalar(format2_scalar), ], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -431,12 +437,10 @@ mod tests { for date_str in test_cases { let formatted_date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(formatted_date_scalar)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = invoke_to_date_with_args( + vec![ColumnarValue::Scalar(formatted_date_scalar)], + 1, + ); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -453,12 +457,8 @@ mod tests { let date_str = "20241231"; let date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(date_scalar)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(date_scalar)], 1); match to_date_result { Ok(ColumnarValue::Scalar(ScalarValue::Date32(date_val))) => { @@ -478,12 +478,8 @@ mod tests { let date_str = "202412311"; let date_scalar = ScalarValue::Utf8(Some(date_str.into())); - let args = datafusion_expr::ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(date_scalar)], - number_rows: 1, - return_type: &DataType::Date32, - }; - let to_date_result = ToDateFunc::new().invoke_with_args(args); + let to_date_result = + invoke_to_date_with_args(vec![ColumnarValue::Scalar(date_scalar)], 1); if let Ok(ColumnarValue::Scalar(ScalarValue::Date32(_))) = to_date_result { panic!( diff --git a/datafusion/functions/src/datetime/to_local_time.rs b/datafusion/functions/src/datetime/to_local_time.rs index 8dbef90cdc3f..5cf9b785b503 100644 --- a/datafusion/functions/src/datetime/to_local_time.rs +++ b/datafusion/functions/src/datetime/to_local_time.rs @@ -407,9 +407,9 @@ impl ScalarUDFImpl for ToLocalTimeFunc { mod tests { use std::sync::Arc; - use arrow::array::{types::TimestampNanosecondType, TimestampNanosecondArray}; + use arrow::array::{types::TimestampNanosecondType, Array, TimestampNanosecondArray}; use arrow::compute::kernels::cast_utils::string_to_timestamp_nanos; - use arrow::datatypes::{DataType, TimeUnit}; + use arrow::datatypes::{DataType, Field, TimeUnit}; use chrono::NaiveDateTime; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; @@ -538,11 +538,13 @@ mod tests { } fn test_to_local_time_helper(input: ScalarValue, expected: ScalarValue) { + let arg_field = Field::new("a", input.data_type(), true); let res = ToLocalTimeFunc::new() .invoke_with_args(ScalarFunctionArgs { args: vec![ColumnarValue::Scalar(input)], + arg_fields: vec![&arg_field], number_rows: 1, - return_type: &expected.data_type(), + return_field: &Field::new("f", expected.data_type(), true), }) .unwrap(); match res { @@ -602,10 +604,16 @@ mod tests { .map(|s| Some(string_to_timestamp_nanos(s).unwrap())) .collect::(); let batch_size = input.len(); + let arg_field = Field::new("a", input.data_type().clone(), true); let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::new(input))], + arg_fields: vec![&arg_field], number_rows: batch_size, - return_type: &DataType::Timestamp(TimeUnit::Nanosecond, None), + return_field: &Field::new( + "f", + DataType::Timestamp(TimeUnit::Nanosecond, None), + true, + ), }; let result = ToLocalTimeFunc::new().invoke_with_args(args).unwrap(); if let ColumnarValue::Array(result) = result { diff --git a/datafusion/functions/src/datetime/to_timestamp.rs b/datafusion/functions/src/datetime/to_timestamp.rs index 52c86733f332..c6aab61328eb 100644 --- a/datafusion/functions/src/datetime/to_timestamp.rs +++ b/datafusion/functions/src/datetime/to_timestamp.rs @@ -639,7 +639,7 @@ mod tests { TimestampNanosecondArray, TimestampSecondArray, }; use arrow::array::{ArrayRef, Int64Array, StringBuilder}; - use arrow::datatypes::TimeUnit; + use arrow::datatypes::{Field, TimeUnit}; use chrono::Utc; use datafusion_common::{assert_contains, DataFusionError, ScalarValue}; use datafusion_expr::ScalarFunctionImplementation; @@ -1012,11 +1012,13 @@ mod tests { for udf in &udfs { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); + let arg_field = Field::new("arg", array.data_type().clone(), true); assert!(matches!(rt, Timestamp(_, Some(_)))); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], + arg_fields: vec![&arg_field], number_rows: 4, - return_type: &rt, + return_field: &Field::new("f", rt, true), }; let res = udf .invoke_with_args(args) @@ -1060,10 +1062,12 @@ mod tests { for array in arrays { let rt = udf.return_type(&[array.data_type()]).unwrap(); assert!(matches!(rt, Timestamp(_, None))); + let arg_field = Field::new("arg", array.data_type().clone(), true); let args = datafusion_expr::ScalarFunctionArgs { args: vec![array.clone()], + arg_fields: vec![&arg_field], number_rows: 5, - return_type: &rt, + return_field: &Field::new("f", rt, true), }; let res = udf .invoke_with_args(args) diff --git a/datafusion/functions/src/math/log.rs b/datafusion/functions/src/math/log.rs index fd135f4c5ec0..d1f40e3b1ad1 100644 --- a/datafusion/functions/src/math/log.rs +++ b/datafusion/functions/src/math/log.rs @@ -256,6 +256,7 @@ mod tests { use arrow::array::{Float32Array, Float64Array, Int64Array}; use arrow::compute::SortOptions; + use arrow::datatypes::Field; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_common::DFSchema; use datafusion_expr::execution_props::ExecutionProps; @@ -264,6 +265,10 @@ mod tests { #[test] #[should_panic] fn test_log_invalid_base_type() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false), + Field::new("a", DataType::Int64, false), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ @@ -271,20 +276,23 @@ mod tests { ]))), // num ColumnarValue::Array(Arc::new(Int64Array::from(vec![5, 10, 15, 20]))), ], + arg_fields: arg_fields.iter().collect(), number_rows: 4, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let _ = LogFunc::new().invoke_with_args(args); } #[test] fn test_log_invalid_value() { + let arg_field = Field::new("a", DataType::Int64, false); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Int64Array::from(vec![10]))), // num ], + arg_fields: vec![&arg_field], number_rows: 1, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let result = LogFunc::new().invoke_with_args(args); @@ -293,12 +301,14 @@ mod tests { #[test] fn test_log_scalar_f32_unary() { + let arg_field = Field::new("a", DataType::Float32, false); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float32(Some(10.0))), // num ], + arg_fields: vec![&arg_field], number_rows: 1, - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }; let result = LogFunc::new() .invoke_with_args(args) @@ -320,12 +330,14 @@ mod tests { #[test] fn test_log_scalar_f64_unary() { + let arg_field = Field::new("a", DataType::Float64, false); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float64(Some(10.0))), // num ], + arg_fields: vec![&arg_field], number_rows: 1, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let result = LogFunc::new() .invoke_with_args(args) @@ -347,13 +359,18 @@ mod tests { #[test] fn test_log_scalar_f32() { + let arg_fields = vec![ + Field::new("a", DataType::Float32, false), + Field::new("a", DataType::Float32, false), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float32(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float32(Some(32.0))), // num ], + arg_fields: arg_fields.iter().collect(), number_rows: 1, - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }; let result = LogFunc::new() .invoke_with_args(args) @@ -375,13 +392,18 @@ mod tests { #[test] fn test_log_scalar_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false), + Field::new("a", DataType::Float64, false), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Scalar(ScalarValue::Float64(Some(2.0))), // num ColumnarValue::Scalar(ScalarValue::Float64(Some(64.0))), // num ], + arg_fields: arg_fields.iter().collect(), number_rows: 1, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let result = LogFunc::new() .invoke_with_args(args) @@ -403,14 +425,16 @@ mod tests { #[test] fn test_log_f64_unary() { + let arg_field = Field::new("a", DataType::Float64, false); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], + arg_fields: vec![&arg_field], number_rows: 4, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let result = LogFunc::new() .invoke_with_args(args) @@ -435,14 +459,16 @@ mod tests { #[test] fn test_log_f32_unary() { + let arg_field = Field::new("a", DataType::Float32, false); let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float32Array::from(vec![ 10.0, 100.0, 1000.0, 10000.0, ]))), // num ], + arg_fields: vec![&arg_field], number_rows: 4, - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }; let result = LogFunc::new() .invoke_with_args(args) @@ -467,6 +493,10 @@ mod tests { #[test] fn test_log_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, false), + Field::new("a", DataType::Float64, false), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ @@ -476,8 +506,9 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], + arg_fields: arg_fields.iter().collect(), number_rows: 4, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let result = LogFunc::new() .invoke_with_args(args) @@ -502,6 +533,10 @@ mod tests { #[test] fn test_log_f32() { + let arg_fields = vec![ + Field::new("a", DataType::Float32, false), + Field::new("a", DataType::Float32, false), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float32Array::from(vec![ @@ -511,8 +546,9 @@ mod tests { 8.0, 4.0, 81.0, 625.0, ]))), // num ], + arg_fields: arg_fields.iter().collect(), number_rows: 4, - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }; let result = LogFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/power.rs b/datafusion/functions/src/math/power.rs index 028ec2fef793..8876e3fe2787 100644 --- a/datafusion/functions/src/math/power.rs +++ b/datafusion/functions/src/math/power.rs @@ -187,12 +187,17 @@ fn is_log(func: &ScalarUDF) -> bool { #[cfg(test)] mod tests { use arrow::array::Float64Array; + use arrow::datatypes::Field; use datafusion_common::cast::{as_float64_array, as_int64_array}; use super::*; #[test] fn test_power_f64() { + let arg_fields = vec![ + Field::new("a", DataType::Float64, true), + Field::new("a", DataType::Float64, true), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Float64Array::from(vec![ @@ -202,8 +207,9 @@ mod tests { 3.0, 2.0, 4.0, 4.0, ]))), // exponent ], + arg_fields: arg_fields.iter().collect(), number_rows: 4, - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let result = PowerFunc::new() .invoke_with_args(args) @@ -227,13 +233,18 @@ mod tests { #[test] fn test_power_i64() { + let arg_fields = vec![ + Field::new("a", DataType::Int64, true), + Field::new("a", DataType::Int64, true), + ]; let args = ScalarFunctionArgs { args: vec![ ColumnarValue::Array(Arc::new(Int64Array::from(vec![2, 2, 3, 5]))), // base ColumnarValue::Array(Arc::new(Int64Array::from(vec![3, 2, 4, 4]))), // exponent ], + arg_fields: arg_fields.iter().collect(), number_rows: 4, - return_type: &DataType::Int64, + return_field: &Field::new("f", DataType::Int64, true), }; let result = PowerFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/math/signum.rs b/datafusion/functions/src/math/signum.rs index ba5422afa768..7414e6e138ab 100644 --- a/datafusion/functions/src/math/signum.rs +++ b/datafusion/functions/src/math/signum.rs @@ -138,7 +138,7 @@ mod test { use std::sync::Arc; use arrow::array::{ArrayRef, Float32Array, Float64Array}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::cast::{as_float32_array, as_float64_array}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; @@ -157,10 +157,12 @@ mod test { f32::INFINITY, f32::NEG_INFINITY, ])); + let arg_fields = [Field::new("a", DataType::Float32, false)]; let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_fields: arg_fields.iter().collect(), number_rows: array.len(), - return_type: &DataType::Float32, + return_field: &Field::new("f", DataType::Float32, true), }; let result = SignumFunc::new() .invoke_with_args(args) @@ -201,10 +203,12 @@ mod test { f64::INFINITY, f64::NEG_INFINITY, ])); + let arg_fields = [Field::new("a", DataType::Float64, false)]; let args = ScalarFunctionArgs { args: vec![ColumnarValue::Array(Arc::clone(&array) as ArrayRef)], + arg_fields: arg_fields.iter().collect(), number_rows: array.len(), - return_type: &DataType::Float64, + return_field: &Field::new("f", DataType::Float64, true), }; let result = SignumFunc::new() .invoke_with_args(args) diff --git a/datafusion/functions/src/regex/regexpcount.rs b/datafusion/functions/src/regex/regexpcount.rs index 8cb1a4ff3d60..d536d3531af4 100644 --- a/datafusion/functions/src/regex/regexpcount.rs +++ b/datafusion/functions/src/regex/regexpcount.rs @@ -619,6 +619,7 @@ fn count_matches( mod tests { use super::*; use arrow::array::{GenericStringArray, StringViewArray}; + use arrow::datatypes::Field; use datafusion_expr::ScalarFunctionArgs; #[test] @@ -647,6 +648,27 @@ mod tests { test_case_regexp_count_cache_check::>(); } + fn regexp_count_with_scalar_values(args: &[ScalarValue]) -> Result { + let args_values = args + .iter() + .map(|sv| ColumnarValue::Scalar(sv.clone())) + .collect(); + + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); + + RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { + args: args_values, + arg_fields, + number_rows: args.len(), + return_field: &Field::new("f", Int64, true), + }) + } + fn test_case_sensitive_regexp_count_scalar() { let values = ["", "aabca", "abcabc", "abcAbcab", "abcabcabc"]; let regex = "abc"; @@ -657,11 +679,7 @@ mod tests { let v_sv = ScalarValue::Utf8(Some(v.to_string())); let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - number_rows: 2, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -672,11 +690,7 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - number_rows: 2, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -687,11 +701,7 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ColumnarValue::Scalar(v_sv), ColumnarValue::Scalar(regex_sv)], - number_rows: 2, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -713,15 +723,7 @@ mod tests { let regex_sv = ScalarValue::Utf8(Some(regex.to_string())); let start_sv = ScalarValue::Int64(Some(start)); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - number_rows: 3, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -732,15 +734,7 @@ mod tests { // largeutf8 let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - number_rows: 3, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -751,15 +745,7 @@ mod tests { // utf8view let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ], - number_rows: 3, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[v_sv, regex_sv, start_sv.clone()]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -783,16 +769,13 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(Some(flags.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -804,16 +787,13 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(Some(regex.to_string())); let flags_sv = ScalarValue::LargeUtf8(Some(flags.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -825,16 +805,13 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(Some(regex.to_string())); let flags_sv = ScalarValue::Utf8View(Some(flags.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -907,16 +884,12 @@ mod tests { let start_sv = ScalarValue::Int64(Some(start)); let flags_sv = ScalarValue::Utf8(flags.get(pos).map(|f| f.to_string())); let expected = expected.get(pos).cloned(); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -928,16 +901,12 @@ mod tests { let v_sv = ScalarValue::LargeUtf8(Some(v.to_string())); let regex_sv = ScalarValue::LargeUtf8(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::LargeUtf8(flags.get(pos).map(|f| f.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); @@ -949,16 +918,12 @@ mod tests { let v_sv = ScalarValue::Utf8View(Some(v.to_string())); let regex_sv = ScalarValue::Utf8View(regex.get(pos).map(|s| s.to_string())); let flags_sv = ScalarValue::Utf8View(flags.get(pos).map(|f| f.to_string())); - let re = RegexpCountFunc::new().invoke_with_args(ScalarFunctionArgs { - args: vec![ - ColumnarValue::Scalar(v_sv), - ColumnarValue::Scalar(regex_sv), - ColumnarValue::Scalar(start_sv.clone()), - ColumnarValue::Scalar(flags_sv.clone()), - ], - number_rows: 4, - return_type: &Int64, - }); + let re = regexp_count_with_scalar_values(&[ + v_sv, + regex_sv, + start_sv.clone(), + flags_sv.clone(), + ]); match re { Ok(ColumnarValue::Scalar(ScalarValue::Int64(v))) => { assert_eq!(v, expected, "regexp_count scalar test failed"); diff --git a/datafusion/functions/src/string/concat.rs b/datafusion/functions/src/string/concat.rs index c47d08d579e4..fe0a5915fe20 100644 --- a/datafusion/functions/src/string/concat.rs +++ b/datafusion/functions/src/string/concat.rs @@ -376,6 +376,7 @@ mod tests { use crate::utils::test::test_function; use arrow::array::{Array, LargeStringArray, StringViewArray}; use arrow::array::{ArrayRef, StringArray}; + use arrow::datatypes::Field; use DataType::*; #[test] @@ -468,11 +469,19 @@ mod tests { None, Some("b"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + Field::new("a", Utf8View, true), + Field::new("a", Utf8View, true), + ]; let args = ScalarFunctionArgs { args: vec![c0, c1, c2, c3, c4], + arg_fields: arg_fields.iter().collect(), number_rows: 3, - return_type: &Utf8, + return_field: &Field::new("f", Utf8, true), }; let result = ConcatFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/concat_ws.rs b/datafusion/functions/src/string/concat_ws.rs index c2bad206db15..79a5d34fb4c4 100644 --- a/datafusion/functions/src/string/concat_ws.rs +++ b/datafusion/functions/src/string/concat_ws.rs @@ -403,10 +403,10 @@ fn is_null(expr: &Expr) -> bool { mod tests { use std::sync::Arc; + use crate::string::concat_ws::ConcatWsFunc; use arrow::array::{Array, ArrayRef, StringArray}; use arrow::datatypes::DataType::Utf8; - - use crate::string::concat_ws::ConcatWsFunc; + use arrow::datatypes::Field; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; @@ -481,10 +481,16 @@ mod tests { Some("z"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + ]; let args = ScalarFunctionArgs { args: vec![c0, c1, c2], + arg_fields: arg_fields.iter().collect(), number_rows: 3, - return_type: &Utf8, + return_field: &Field::new("f", Utf8, true), }; let result = ConcatWsFunc::new().invoke_with_args(args)?; @@ -511,10 +517,16 @@ mod tests { Some("z"), ]))); + let arg_fields = vec![ + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + Field::new("a", Utf8, true), + ]; let args = ScalarFunctionArgs { args: vec![c0, c1, c2], + arg_fields: arg_fields.iter().collect(), number_rows: 3, - return_type: &Utf8, + return_field: &Field::new("f", Utf8, true), }; let result = ConcatWsFunc::new().invoke_with_args(args)?; diff --git a/datafusion/functions/src/string/contains.rs b/datafusion/functions/src/string/contains.rs index 05a3edf61c5a..00fd20ff3479 100644 --- a/datafusion/functions/src/string/contains.rs +++ b/datafusion/functions/src/string/contains.rs @@ -151,7 +151,7 @@ fn contains(args: &[ArrayRef]) -> Result { mod test { use super::ContainsFunc; use arrow::array::{BooleanArray, StringArray}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::ScalarValue; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; @@ -164,11 +164,16 @@ mod test { Some("yyy?()"), ]))); let scalar = ColumnarValue::Scalar(ScalarValue::Utf8(Some("x?(".to_string()))); + let arg_fields = vec![ + Field::new("a", DataType::Utf8, true), + Field::new("a", DataType::Utf8, true), + ]; let args = ScalarFunctionArgs { args: vec![array, scalar], + arg_fields: arg_fields.iter().collect(), number_rows: 2, - return_type: &DataType::Boolean, + return_field: &Field::new("f", DataType::Boolean, true), }; let actual = udf.invoke_with_args(args).unwrap(); diff --git a/datafusion/functions/src/string/lower.rs b/datafusion/functions/src/string/lower.rs index 226275b13999..1dc6e9d28367 100644 --- a/datafusion/functions/src/string/lower.rs +++ b/datafusion/functions/src/string/lower.rs @@ -98,15 +98,19 @@ impl ScalarUDFImpl for LowerFunc { mod tests { use super::*; use arrow::array::{Array, ArrayRef, StringArray}; + use arrow::datatypes::DataType::Utf8; + use arrow::datatypes::Field; use std::sync::Arc; fn to_lower(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = LowerFunc::new(); + let arg_fields = [Field::new("a", input.data_type().clone(), true)]; let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - return_type: &DataType::Utf8, + arg_fields: arg_fields.iter().collect(), + return_field: &Field::new("f", Utf8, true), }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/string/upper.rs b/datafusion/functions/src/string/upper.rs index 2fec7305d183..06a9bd9720d6 100644 --- a/datafusion/functions/src/string/upper.rs +++ b/datafusion/functions/src/string/upper.rs @@ -97,15 +97,19 @@ impl ScalarUDFImpl for UpperFunc { mod tests { use super::*; use arrow::array::{Array, ArrayRef, StringArray}; + use arrow::datatypes::DataType::Utf8; + use arrow::datatypes::Field; use std::sync::Arc; fn to_upper(input: ArrayRef, expected: ArrayRef) -> Result<()> { let func = UpperFunc::new(); + let arg_field = Field::new("a", input.data_type().clone(), true); let args = ScalarFunctionArgs { number_rows: input.len(), args: vec![ColumnarValue::Array(input)], - return_type: &DataType::Utf8, + arg_fields: vec![&arg_field], + return_field: &Field::new("f", Utf8, true), }; let result = match func.invoke_with_args(args)? { diff --git a/datafusion/functions/src/unicode/find_in_set.rs b/datafusion/functions/src/unicode/find_in_set.rs index c4a9f067e9f4..6b5df89e860f 100644 --- a/datafusion/functions/src/unicode/find_in_set.rs +++ b/datafusion/functions/src/unicode/find_in_set.rs @@ -348,7 +348,7 @@ mod tests { use crate::unicode::find_in_set::FindInSetFunc; use crate::utils::test::test_function; use arrow::array::{Array, Int32Array, StringArray}; - use arrow::datatypes::DataType::Int32; + use arrow::datatypes::{DataType::Int32, Field}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl}; use std::sync::Arc; @@ -471,10 +471,17 @@ mod tests { }) .unwrap_or(1); let return_type = fis.return_type(&type_array)?; + let arg_fields_owned = args + .iter() + .enumerate() + .map(|(idx, a)| Field::new(format!("arg_{idx}"), a.data_type(), true)) + .collect::>(); + let arg_fields = arg_fields_owned.iter().collect::>(); let result = fis.invoke_with_args(ScalarFunctionArgs { args, + arg_fields, number_rows: cardinality, - return_type: &return_type, + return_field: &Field::new("f", return_type, true), }); assert!(result.is_ok()); diff --git a/datafusion/functions/src/unicode/strpos.rs b/datafusion/functions/src/unicode/strpos.rs index b3bc73a29585..b33a1ca7713a 100644 --- a/datafusion/functions/src/unicode/strpos.rs +++ b/datafusion/functions/src/unicode/strpos.rs @@ -22,7 +22,7 @@ use crate::utils::{make_scalar_function, utf8_to_int_type}; use arrow::array::{ ArrayRef, ArrowPrimitiveType, AsArray, PrimitiveArray, StringArrayType, }; -use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type}; +use arrow::datatypes::{ArrowNativeType, DataType, Field, Int32Type, Int64Type}; use datafusion_common::types::logical_string; use datafusion_common::{exec_err, internal_err, Result}; use datafusion_expr::{ @@ -88,16 +88,22 @@ impl ScalarUDFImpl for StrposFunc { } fn return_type(&self, _arg_types: &[DataType]) -> Result { - internal_err!("return_type_from_args should be used instead") + internal_err!("return_field_from_args should be used instead") } - fn return_type_from_args( + fn return_field_from_args( &self, - args: datafusion_expr::ReturnTypeArgs, - ) -> Result { - utf8_to_int_type(&args.arg_types[0], "strpos/instr/position").map(|data_type| { - datafusion_expr::ReturnInfo::new(data_type, args.nullables.iter().any(|x| *x)) - }) + args: datafusion_expr::ReturnFieldArgs, + ) -> Result { + utf8_to_int_type(args.arg_fields[0].data_type(), "strpos/instr/position").map( + |data_type| { + Field::new( + self.name(), + data_type, + args.arg_fields.iter().any(|x| x.is_nullable()), + ) + }, + ) } fn invoke_with_args( @@ -228,7 +234,7 @@ mod tests { use arrow::array::{Array, Int32Array, Int64Array}; use arrow::datatypes::DataType::{Int32, Int64}; - use arrow::datatypes::DataType; + use arrow::datatypes::{DataType, Field}; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::{ColumnarValue, ScalarUDFImpl}; @@ -321,15 +327,15 @@ mod tests { fn nullable_return_type() { fn get_nullable(string_array_nullable: bool, substring_nullable: bool) -> bool { let strpos = StrposFunc::new(); - let args = datafusion_expr::ReturnTypeArgs { - arg_types: &[DataType::Utf8, DataType::Utf8], - nullables: &[string_array_nullable, substring_nullable], + let args = datafusion_expr::ReturnFieldArgs { + arg_fields: &[ + Field::new("f1", DataType::Utf8, string_array_nullable), + Field::new("f2", DataType::Utf8, substring_nullable), + ], scalar_arguments: &[None::<&ScalarValue>, None::<&ScalarValue>], }; - let (_, nullable) = strpos.return_type_from_args(args).unwrap().into_parts(); - - nullable + strpos.return_field_from_args(args).unwrap().is_nullable() } assert!(!get_nullable(false, false)); diff --git a/datafusion/functions/src/utils.rs b/datafusion/functions/src/utils.rs index 47f3121ba2ce..0d6367565921 100644 --- a/datafusion/functions/src/utils.rs +++ b/datafusion/functions/src/utils.rs @@ -133,7 +133,7 @@ pub mod test { let expected: Result> = $EXPECTED; let func = $FUNC; - let type_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); + let data_array = $ARGS.iter().map(|arg| arg.data_type()).collect::>(); let cardinality = $ARGS .iter() .fold(Option::::None, |acc, arg| match arg { @@ -153,19 +153,28 @@ pub mod test { ColumnarValue::Array(a) => a.null_count() > 0, }).collect::>(); - let return_info = func.return_type_from_args(datafusion_expr::ReturnTypeArgs { - arg_types: &type_array, + let field_array = data_array.into_iter().zip(nullables).enumerate() + .map(|(idx, (data_type, nullable))| arrow::datatypes::Field::new(format!("field_{idx}"), data_type, nullable)) + .collect::>(); + + let return_field = func.return_field_from_args(datafusion_expr::ReturnFieldArgs { + arg_fields: &field_array, scalar_arguments: &scalar_arguments_refs, - nullables: &nullables }); + let arg_fields_owned = $ARGS.iter() + .enumerate() + .map(|(idx, arg)| arrow::datatypes::Field::new(format!("f_{idx}"), arg.data_type(), true)) + .collect::>(); match expected { Ok(expected) => { - assert_eq!(return_info.is_ok(), true); - let (return_type, _nullable) = return_info.unwrap().into_parts(); - assert_eq!(return_type, $EXPECTED_DATA_TYPE); + assert_eq!(return_field.is_ok(), true); + let return_field = return_field.unwrap(); + let return_type = return_field.data_type(); + assert_eq!(return_type, &$EXPECTED_DATA_TYPE); - let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}); + let arg_fields = arg_fields_owned.iter().collect::>(); + let result = func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field: &return_field}); assert_eq!(result.is_ok(), true, "function returned an error: {}", result.unwrap_err()); let result = result.unwrap().to_array(cardinality).expect("Failed to convert to array"); @@ -179,17 +188,18 @@ pub mod test { }; } Err(expected_error) => { - if return_info.is_err() { - match return_info { + if return_field.is_err() { + match return_field { Ok(_) => assert!(false, "expected error"), Err(error) => { datafusion_common::assert_contains!(expected_error.strip_backtrace(), error.strip_backtrace()); } } } else { - let (return_type, _nullable) = return_info.unwrap().into_parts(); + let return_field = return_field.unwrap(); + let arg_fields = arg_fields_owned.iter().collect::>(); // invoke is expected error - cannot use .expect_err() due to Debug not being implemented - match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, number_rows: cardinality, return_type: &return_type}) { + match func.invoke_with_args(datafusion_expr::ScalarFunctionArgs{args: $ARGS, arg_fields, number_rows: cardinality, return_field: &return_field}) { Ok(_) => assert!(false, "expected error"), Err(error) => { assert!(expected_error.strip_backtrace().starts_with(&error.strip_backtrace())); diff --git a/datafusion/physical-expr-common/src/physical_expr.rs b/datafusion/physical-expr-common/src/physical_expr.rs index 3bc41d2652d9..685398c352e2 100644 --- a/datafusion/physical-expr-common/src/physical_expr.rs +++ b/datafusion/physical-expr-common/src/physical_expr.rs @@ -25,7 +25,7 @@ use crate::utils::scatter; use arrow::array::BooleanArray; use arrow::compute::filter_record_batch; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; use datafusion_common::{internal_err, not_impl_err, Result, ScalarValue}; @@ -71,11 +71,23 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + DynEq + DynHash { /// downcast to a specific implementation. fn as_any(&self) -> &dyn Any; /// Get the data type of this expression, given the schema of the input - fn data_type(&self, input_schema: &Schema) -> Result; + fn data_type(&self, input_schema: &Schema) -> Result { + Ok(self.return_field(input_schema)?.data_type().to_owned()) + } /// Determine whether this expression is nullable, given the schema of the input - fn nullable(&self, input_schema: &Schema) -> Result; + fn nullable(&self, input_schema: &Schema) -> Result { + Ok(self.return_field(input_schema)?.is_nullable()) + } /// Evaluate an expression against a RecordBatch fn evaluate(&self, batch: &RecordBatch) -> Result; + /// The output field associated with this expression + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(Field::new( + format!("{self}"), + self.data_type(input_schema)?, + self.nullable(input_schema)?, + )) + } /// Evaluate an expression against a RecordBatch after first applying a /// validity array fn evaluate_selection( @@ -453,19 +465,21 @@ where /// ``` /// # // The boiler plate needed to create a `PhysicalExpr` for the example /// # use std::any::Any; +/// use std::collections::HashMap; /// # use std::fmt::Formatter; /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; -/// # use arrow::datatypes::{DataType, Schema}; +/// # use arrow::datatypes::{DataType, Field, Schema}; /// # use datafusion_common::Result; /// # use datafusion_expr_common::columnar_value::ColumnarValue; /// # use datafusion_physical_expr_common::physical_expr::{fmt_sql, DynEq, PhysicalExpr}; /// # #[derive(Debug, Hash, PartialOrd, PartialEq)] -/// # struct MyExpr {}; +/// # struct MyExpr {} /// # impl PhysicalExpr for MyExpr {fn as_any(&self) -> &dyn Any { unimplemented!() } /// # fn data_type(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn nullable(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result { unimplemented!() } +/// # fn return_field(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn children(&self) -> Vec<&Arc>{ unimplemented!() } /// # fn with_new_children(self: Arc, children: Vec>) -> Result> { unimplemented!() } /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "CASE a > b THEN 1 ELSE 0 END") } diff --git a/datafusion/physical-expr-common/src/sort_expr.rs b/datafusion/physical-expr-common/src/sort_expr.rs index 3a54b5b40399..ee575603683a 100644 --- a/datafusion/physical-expr-common/src/sort_expr.rs +++ b/datafusion/physical-expr-common/src/sort_expr.rs @@ -37,13 +37,14 @@ use itertools::Itertools; /// Example: /// ``` /// # use std::any::Any; +/// # use std::collections::HashMap; /// # use std::fmt::{Display, Formatter}; /// # use std::hash::Hasher; /// # use std::sync::Arc; /// # use arrow::array::RecordBatch; /// # use datafusion_common::Result; /// # use arrow::compute::SortOptions; -/// # use arrow::datatypes::{DataType, Schema}; +/// # use arrow::datatypes::{DataType, Field, Schema}; /// # use datafusion_expr_common::columnar_value::ColumnarValue; /// # use datafusion_physical_expr_common::physical_expr::PhysicalExpr; /// # use datafusion_physical_expr_common::sort_expr::PhysicalSortExpr; @@ -56,6 +57,7 @@ use itertools::Itertools; /// # fn data_type(&self, input_schema: &Schema) -> Result {todo!()} /// # fn nullable(&self, input_schema: &Schema) -> Result {todo!() } /// # fn evaluate(&self, batch: &RecordBatch) -> Result {todo!() } +/// # fn return_field(&self, input_schema: &Schema) -> Result { unimplemented!() } /// # fn children(&self) -> Vec<&Arc> {todo!()} /// # fn with_new_children(self: Arc, children: Vec>) -> Result> {todo!()} /// # fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result { todo!() } diff --git a/datafusion/physical-expr/src/equivalence/properties/dependency.rs b/datafusion/physical-expr/src/equivalence/properties/dependency.rs index 9eba295e562e..8b17db04844d 100644 --- a/datafusion/physical-expr/src/equivalence/properties/dependency.rs +++ b/datafusion/physical-expr/src/equivalence/properties/dependency.rs @@ -1224,7 +1224,7 @@ mod tests { "concat", concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], - DataType::Utf8, + Field::new("f", DataType::Utf8, true), )); // Assume existing ordering is [c ASC, a ASC, b ASC] @@ -1315,7 +1315,7 @@ mod tests { "concat", concat(), vec![Arc::clone(&col_a), Arc::clone(&col_b)], - DataType::Utf8, + Field::new("f", DataType::Utf8, true), )); // Assume existing ordering is [concat(a, b) ASC, a ASC, b ASC] diff --git a/datafusion/physical-expr/src/expressions/binary.rs b/datafusion/physical-expr/src/expressions/binary.rs index 6c68d11e2c94..9f14511d52dc 100644 --- a/datafusion/physical-expr/src/expressions/binary.rs +++ b/datafusion/physical-expr/src/expressions/binary.rs @@ -17,12 +17,11 @@ mod kernels; -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::expressions::binary::kernels::concat_elements_utf8view; use crate::intervals::cp_solver::{propagate_arithmetic, propagate_comparison}; use crate::PhysicalExpr; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; use arrow::array::*; use arrow::compute::kernels::boolean::{and_kleene, not, or_kleene}; diff --git a/datafusion/physical-expr/src/expressions/case.rs b/datafusion/physical-expr/src/expressions/case.rs index 854c715eb0a2..1a74e78f1075 100644 --- a/datafusion/physical-expr/src/expressions/case.rs +++ b/datafusion/physical-expr/src/expressions/case.rs @@ -15,13 +15,12 @@ // specific language governing permissions and limitations // under the License. +use crate::expressions::try_cast; +use crate::PhysicalExpr; use std::borrow::Cow; use std::hash::Hash; use std::{any::Any, sync::Arc}; -use crate::expressions::try_cast; -use crate::PhysicalExpr; - use arrow::array::*; use arrow::compute::kernels::zip::zip; use arrow::compute::{and, and_not, is_null, not, nullif, or, prep_null_mask_filter}; @@ -603,7 +602,7 @@ mod tests { use crate::expressions::{binary, cast, col, lit, BinaryExpr}; use arrow::buffer::Buffer; use arrow::datatypes::DataType::Float64; - use arrow::datatypes::*; + use arrow::datatypes::Field; use datafusion_common::cast::{as_float64_array, as_int32_array}; use datafusion_common::plan_err; use datafusion_common::tree_node::{Transformed, TransformedResult, TreeNode}; diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index a6766687a881..88923d9c6cee 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::physical_expr::PhysicalExpr; use arrow::compute::{can_cast_types, CastOptions}; -use arrow::datatypes::{DataType, DataType::*, Schema}; +use arrow::datatypes::{DataType, DataType::*, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; use datafusion_common::{not_impl_err, Result}; @@ -144,6 +144,13 @@ impl PhysicalExpr for CastExpr { value.cast_to(&self.cast_type, Some(&self.cast_options)) } + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(self + .expr + .return_field(input_schema)? + .with_data_type(self.cast_type.clone())) + } + fn children(&self) -> Vec<&Arc> { vec![&self.expr] } diff --git a/datafusion/physical-expr/src/expressions/column.rs b/datafusion/physical-expr/src/expressions/column.rs index ab5b35984753..80af0b84c5d1 100644 --- a/datafusion/physical-expr/src/expressions/column.rs +++ b/datafusion/physical-expr/src/expressions/column.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::physical_expr::PhysicalExpr; use arrow::{ - datatypes::{DataType, Schema, SchemaRef}, + datatypes::{DataType, Field, Schema, SchemaRef}, record_batch::RecordBatch, }; use datafusion_common::tree_node::{Transformed, TreeNode}; @@ -127,6 +127,10 @@ impl PhysicalExpr for Column { Ok(ColumnarValue::Array(Arc::clone(batch.column(self.index)))) } + fn return_field(&self, input_schema: &Schema) -> Result { + Ok(input_schema.field(self.index).clone()) + } + fn children(&self) -> Vec<&Arc> { vec![] } diff --git a/datafusion/physical-expr/src/expressions/is_not_null.rs b/datafusion/physical-expr/src/expressions/is_not_null.rs index 0619e7248858..1de8c17a373a 100644 --- a/datafusion/physical-expr/src/expressions/is_not_null.rs +++ b/datafusion/physical-expr/src/expressions/is_not_null.rs @@ -17,10 +17,8 @@ //! IS NOT NULL expression -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; +use arrow::datatypes::Field; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, @@ -28,6 +26,8 @@ use arrow::{ use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; /// IS NOT NULL expression #[derive(Debug, Eq)] @@ -94,6 +94,10 @@ impl PhysicalExpr for IsNotNullExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/is_null.rs b/datafusion/physical-expr/src/expressions/is_null.rs index 4c6081f35cad..7707075ce653 100644 --- a/datafusion/physical-expr/src/expressions/is_null.rs +++ b/datafusion/physical-expr/src/expressions/is_null.rs @@ -17,17 +17,16 @@ //! IS NULL expression -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; use arrow::{ - datatypes::{DataType, Schema}, + datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, }; use datafusion_common::Result; use datafusion_common::ScalarValue; use datafusion_expr::ColumnarValue; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; /// IS NULL expression #[derive(Debug, Eq)] @@ -93,6 +92,10 @@ impl PhysicalExpr for IsNullExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/like.rs b/datafusion/physical-expr/src/expressions/like.rs index ebf9882665ba..e86c778d5161 100644 --- a/datafusion/physical-expr/src/expressions/like.rs +++ b/datafusion/physical-expr/src/expressions/like.rs @@ -15,15 +15,14 @@ // specific language governing permissions and limitations // under the License. -use std::hash::Hash; -use std::{any::Any, sync::Arc}; - use crate::PhysicalExpr; use arrow::datatypes::{DataType, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; use datafusion_physical_expr_common::datum::apply_cmp; +use std::hash::Hash; +use std::{any::Any, sync::Arc}; // Like expression #[derive(Debug, Eq)] diff --git a/datafusion/physical-expr/src/expressions/literal.rs b/datafusion/physical-expr/src/expressions/literal.rs index 0d0c0ecc62c7..6f7caaea8d45 100644 --- a/datafusion/physical-expr/src/expressions/literal.rs +++ b/datafusion/physical-expr/src/expressions/literal.rs @@ -112,7 +112,7 @@ mod tests { use super::*; use arrow::array::Int32Array; - use arrow::datatypes::*; + use arrow::datatypes::Field; use datafusion_common::cast::as_int32_array; use datafusion_physical_expr_common::physical_expr::fmt_sql; diff --git a/datafusion/physical-expr/src/expressions/negative.rs b/datafusion/physical-expr/src/expressions/negative.rs index 33a1bae14d42..597cbf1dac9e 100644 --- a/datafusion/physical-expr/src/expressions/negative.rs +++ b/datafusion/physical-expr/src/expressions/negative.rs @@ -23,6 +23,7 @@ use std::sync::Arc; use crate::PhysicalExpr; +use arrow::datatypes::Field; use arrow::{ compute::kernels::numeric::neg_wrapping, datatypes::{DataType, Schema}, @@ -103,6 +104,10 @@ impl PhysicalExpr for NegativeExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/no_op.rs b/datafusion/physical-expr/src/expressions/no_op.rs index 24d2f4d9e074..94610996c6b0 100644 --- a/datafusion/physical-expr/src/expressions/no_op.rs +++ b/datafusion/physical-expr/src/expressions/no_op.rs @@ -21,12 +21,11 @@ use std::any::Any; use std::hash::Hash; use std::sync::Arc; +use crate::PhysicalExpr; use arrow::{ datatypes::{DataType, Schema}, record_batch::RecordBatch, }; - -use crate::PhysicalExpr; use datafusion_common::{internal_err, Result}; use datafusion_expr::ColumnarValue; diff --git a/datafusion/physical-expr/src/expressions/not.rs b/datafusion/physical-expr/src/expressions/not.rs index 8a3348b43d20..1f3ae9e25ffb 100644 --- a/datafusion/physical-expr/src/expressions/not.rs +++ b/datafusion/physical-expr/src/expressions/not.rs @@ -24,7 +24,7 @@ use std::sync::Arc; use crate::PhysicalExpr; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::{cast::as_boolean_array, internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; @@ -101,6 +101,10 @@ impl PhysicalExpr for NotExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.arg.return_field(input_schema) + } + fn children(&self) -> Vec<&Arc> { vec![&self.arg] } diff --git a/datafusion/physical-expr/src/expressions/try_cast.rs b/datafusion/physical-expr/src/expressions/try_cast.rs index e49815cd8b64..e4fe027c7918 100644 --- a/datafusion/physical-expr/src/expressions/try_cast.rs +++ b/datafusion/physical-expr/src/expressions/try_cast.rs @@ -23,7 +23,7 @@ use std::sync::Arc; use crate::PhysicalExpr; use arrow::compute; use arrow::compute::{cast_with_options, CastOptions}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use arrow::record_batch::RecordBatch; use compute::can_cast_types; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; @@ -110,6 +110,12 @@ impl PhysicalExpr for TryCastExpr { } } + fn return_field(&self, input_schema: &Schema) -> Result { + self.expr + .return_field(input_schema) + .map(|f| f.with_data_type(self.cast_type.clone())) + } + fn children(&self) -> Vec<&Arc> { vec![&self.expr] } diff --git a/datafusion/physical-expr/src/scalar_function.rs b/datafusion/physical-expr/src/scalar_function.rs index 44bbcc4928c6..d6e070e38948 100644 --- a/datafusion/physical-expr/src/scalar_function.rs +++ b/datafusion/physical-expr/src/scalar_function.rs @@ -38,13 +38,13 @@ use crate::expressions::Literal; use crate::PhysicalExpr; use arrow::array::{Array, RecordBatch}; -use arrow::datatypes::{DataType, Schema}; +use arrow::datatypes::{DataType, Field, Schema}; use datafusion_common::{internal_err, Result, ScalarValue}; use datafusion_expr::interval_arithmetic::Interval; use datafusion_expr::sort_properties::ExprProperties; use datafusion_expr::type_coercion::functions::data_types_with_scalar_udf; use datafusion_expr::{ - expr_vec_fmt, ColumnarValue, ReturnTypeArgs, ScalarFunctionArgs, ScalarUDF, + expr_vec_fmt, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF, }; /// Physical expression of a scalar function @@ -53,8 +53,7 @@ pub struct ScalarFunctionExpr { fun: Arc, name: String, args: Vec>, - return_type: DataType, - nullable: bool, + return_field: Field, } impl Debug for ScalarFunctionExpr { @@ -63,7 +62,7 @@ impl Debug for ScalarFunctionExpr { .field("fun", &"") .field("name", &self.name) .field("args", &self.args) - .field("return_type", &self.return_type) + .field("return_field", &self.return_field) .finish() } } @@ -74,14 +73,13 @@ impl ScalarFunctionExpr { name: &str, fun: Arc, args: Vec>, - return_type: DataType, + return_field: Field, ) -> Self { Self { fun, name: name.to_owned(), args, - return_type, - nullable: true, + return_field, } } @@ -92,18 +90,17 @@ impl ScalarFunctionExpr { schema: &Schema, ) -> Result { let name = fun.name().to_string(); - let arg_types = args + let arg_fields = args .iter() - .map(|e| e.data_type(schema)) + .map(|e| e.return_field(schema)) .collect::>>()?; // verify that input data types is consistent with function's `TypeSignature` - data_types_with_scalar_udf(&arg_types, &fun)?; - - let nullables = args + let arg_types = arg_fields .iter() - .map(|e| e.nullable(schema)) - .collect::>>()?; + .map(|f| f.data_type().clone()) + .collect::>(); + data_types_with_scalar_udf(&arg_types, &fun)?; let arguments = args .iter() @@ -113,18 +110,16 @@ impl ScalarFunctionExpr { .map(|literal| literal.value()) }) .collect::>(); - let ret_args = ReturnTypeArgs { - arg_types: &arg_types, + let ret_args = ReturnFieldArgs { + arg_fields: &arg_fields, scalar_arguments: &arguments, - nullables: &nullables, }; - let (return_type, nullable) = fun.return_type_from_args(ret_args)?.into_parts(); + let return_field = fun.return_field_from_args(ret_args)?; Ok(Self { fun, name, args, - return_type, - nullable, + return_field, }) } @@ -145,16 +140,16 @@ impl ScalarFunctionExpr { /// Data type produced by this expression pub fn return_type(&self) -> &DataType { - &self.return_type + self.return_field.data_type() } pub fn with_nullable(mut self, nullable: bool) -> Self { - self.nullable = nullable; + self.return_field = self.return_field.with_nullable(nullable); self } pub fn nullable(&self) -> bool { - self.nullable + self.return_field.is_nullable() } } @@ -171,11 +166,11 @@ impl PhysicalExpr for ScalarFunctionExpr { } fn data_type(&self, _input_schema: &Schema) -> Result { - Ok(self.return_type.clone()) + Ok(self.return_field.data_type().clone()) } fn nullable(&self, _input_schema: &Schema) -> Result { - Ok(self.nullable) + Ok(self.return_field.is_nullable()) } fn evaluate(&self, batch: &RecordBatch) -> Result { @@ -185,6 +180,13 @@ impl PhysicalExpr for ScalarFunctionExpr { .map(|e| e.evaluate(batch)) .collect::>>()?; + let arg_fields_owned = self + .args + .iter() + .map(|e| e.return_field(batch.schema_ref())) + .collect::>>()?; + let arg_fields = arg_fields_owned.iter().collect::>(); + let input_empty = args.is_empty(); let input_all_scalar = args .iter() @@ -193,8 +195,9 @@ impl PhysicalExpr for ScalarFunctionExpr { // evaluate the function let output = self.fun.invoke_with_args(ScalarFunctionArgs { args, + arg_fields, number_rows: batch.num_rows(), - return_type: &self.return_type, + return_field: &self.return_field, })?; if let ColumnarValue::Array(array) = &output { @@ -214,6 +217,10 @@ impl PhysicalExpr for ScalarFunctionExpr { Ok(output) } + fn return_field(&self, _input_schema: &Schema) -> Result { + Ok(self.return_field.clone()) + } + fn children(&self) -> Vec<&Arc> { self.args.iter().collect() } @@ -222,15 +229,12 @@ impl PhysicalExpr for ScalarFunctionExpr { self: Arc, children: Vec>, ) -> Result> { - Ok(Arc::new( - ScalarFunctionExpr::new( - &self.name, - Arc::clone(&self.fun), - children, - self.return_type().clone(), - ) - .with_nullable(self.nullable), - )) + Ok(Arc::new(ScalarFunctionExpr::new( + &self.name, + Arc::clone(&self.fun), + children, + self.return_field.clone(), + ))) } fn evaluate_bounds(&self, children: &[&Interval]) -> Result { diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 8906468f68db..628506a17b82 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -27,7 +27,6 @@ use crate::aggregates::{ }; use crate::execution_plan::{CardinalityEffect, EmissionType}; use crate::metrics::{ExecutionPlanMetricsSet, MetricsSet}; -use crate::projection::get_field_metadata; use crate::windows::get_ordered_partition_by_indices; use crate::{ DisplayFormatType, Distribution, ExecutionPlan, InputOrderMode, @@ -285,9 +284,7 @@ impl PhysicalGroupBy { expr.data_type(input_schema)?, group_expr_nullable || expr.nullable(input_schema)?, ) - .with_metadata( - get_field_metadata(expr, input_schema).unwrap_or_default(), - ), + .with_metadata(expr.return_field(input_schema)?.metadata().clone()), ); } if !self.is_single() { diff --git a/datafusion/physical-plan/src/projection.rs b/datafusion/physical-plan/src/projection.rs index 72934c74446e..534dbd71b40a 100644 --- a/datafusion/physical-plan/src/projection.rs +++ b/datafusion/physical-plan/src/projection.rs @@ -26,7 +26,7 @@ use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use super::expressions::{CastExpr, Column, Literal}; +use super::expressions::{Column, Literal}; use super::metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}; use super::{ DisplayAs, ExecutionPlanProperties, PlanProperties, RecordBatchStream, @@ -79,14 +79,14 @@ impl ProjectionExec { let fields: Result> = expr .iter() .map(|(e, name)| { - let mut field = Field::new( + let metadata = e.return_field(&input_schema)?.metadata().clone(); + + let field = Field::new( name, e.data_type(&input_schema)?, e.nullable(&input_schema)?, - ); - field.set_metadata( - get_field_metadata(e, &input_schema).unwrap_or_default(), - ); + ) + .with_metadata(metadata); Ok(field) }) @@ -198,23 +198,11 @@ impl ExecutionPlan for ProjectionExec { &self.cache } - fn children(&self) -> Vec<&Arc> { - vec![&self.input] - } - fn maintains_input_order(&self) -> Vec { // Tell optimizer this operator doesn't reorder its input vec![true] } - fn with_new_children( - self: Arc, - mut children: Vec>, - ) -> Result> { - ProjectionExec::try_new(self.expr.clone(), children.swap_remove(0)) - .map(|p| Arc::new(p) as _) - } - fn benefits_from_input_partitioning(&self) -> Vec { let all_simple_exprs = self .expr @@ -225,6 +213,18 @@ impl ExecutionPlan for ProjectionExec { vec![!all_simple_exprs] } + fn children(&self) -> Vec<&Arc> { + vec![&self.input] + } + + fn with_new_children( + self: Arc, + mut children: Vec>, + ) -> Result> { + ProjectionExec::try_new(self.expr.clone(), children.swap_remove(0)) + .map(|p| Arc::new(p) as _) + } + fn execute( &self, partition: usize, @@ -273,24 +273,6 @@ impl ExecutionPlan for ProjectionExec { } } -/// If 'e' is a direct column reference, returns the field level -/// metadata for that field, if any. Otherwise returns None -pub(crate) fn get_field_metadata( - e: &Arc, - input_schema: &Schema, -) -> Option> { - if let Some(cast) = e.as_any().downcast_ref::() { - return get_field_metadata(cast.expr(), input_schema); - } - - // Look up field by index in schema (not NAME as there can be more than one - // column with the same name) - e.as_any() - .downcast_ref::() - .map(|column| input_schema.field(column.index()).metadata()) - .cloned() -} - fn stats_projection( mut stats: Statistics, exprs: impl Iterator>, @@ -1093,13 +1075,11 @@ mod tests { let task_ctx = Arc::new(TaskContext::default()); let exec = test::scan_partitioned(1); - let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?) - .await - .unwrap(); + let expected = collect(exec.execute(0, Arc::clone(&task_ctx))?).await?; let projection = ProjectionExec::try_new(vec![], exec)?; let stream = projection.execute(0, Arc::clone(&task_ctx))?; - let output = collect(stream).await.unwrap(); + let output = collect(stream).await?; assert_eq!(output.len(), expected.len()); Ok(()) diff --git a/datafusion/proto/src/physical_plan/from_proto.rs b/datafusion/proto/src/physical_plan/from_proto.rs index a886fc242545..c412dfed5d03 100644 --- a/datafusion/proto/src/physical_plan/from_proto.rs +++ b/datafusion/proto/src/physical_plan/from_proto.rs @@ -20,6 +20,7 @@ use std::sync::Arc; use arrow::compute::SortOptions; +use arrow::datatypes::Field; use chrono::{TimeZone, Utc}; use datafusion_expr::dml::InsertOp; use object_store::path::Path; @@ -365,7 +366,7 @@ pub fn parse_physical_expr( e.name.as_str(), scalar_fun_def, args, - convert_required!(e.return_type)?, + Field::new("f", convert_required!(e.return_type)?, true), ) .with_nullable(e.nullable), ) diff --git a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs index be90497a6e21..c26fdaa9fca8 100644 --- a/datafusion/proto/tests/cases/roundtrip_physical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_physical_plan.rs @@ -968,7 +968,7 @@ fn roundtrip_scalar_udf() -> Result<()> { "dummy", fun_def, vec![col("a", &schema)?], - DataType::Int64, + Field::new("f", DataType::Int64, true), ); let project = @@ -1096,7 +1096,7 @@ fn roundtrip_scalar_udf_extension_codec() -> Result<()> { "regex_udf", Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], - DataType::Int64, + Field::new("f", DataType::Int64, true), )); let filter = Arc::new(FilterExec::try_new( @@ -1198,7 +1198,7 @@ fn roundtrip_aggregate_udf_extension_codec() -> Result<()> { "regex_udf", Arc::new(ScalarUDF::from(MyRegexUdf::new(".*".to_string()))), vec![col("text", &schema)?], - DataType::Int64, + Field::new("f", DataType::Int64, true), )); let udaf = Arc::new(AggregateUDF::from(MyAggregateUDF::new( diff --git a/docs/source/library-user-guide/upgrading.md b/docs/source/library-user-guide/upgrading.md index db5078d603f2..35f0577b0c82 100644 --- a/docs/source/library-user-guide/upgrading.md +++ b/docs/source/library-user-guide/upgrading.md @@ -19,6 +19,33 @@ # Upgrade Guides +## DataFusion `48.0.0` + +### Processing `Field` instead of `DataType` for user defined functions + +In order to support metadata handling and extension types, user defined functions are +now switching to traits which use `Field` rather than a `DataType` and nullability. +This gives a single interface to both of these parameters and additionally allows +access to metadata fields, which can be used for extension types. + +To upgrade structs which implement `ScalarUDFImpl`, if you have implemented +`return_type_from_args` you need instead to implement `return_field_from_args`. +If your functions do not need to handle metadata, this should be straightforward +repackaging of the output data into a `Field`. The name you specify on the +field is not important. It will be overwritten during planning. `ReturnInfo` +has been removed, so you will need to remove all references to it. + +`ScalarFunctionArgs` now contains a field called `arg_fields`. You can use this +to access the metadata associated with the columnar values during invocation. + +### Physical Expression return field + +To support the changes to user defined functions processing metadata, the +`PhysicalExpr` trait, which now must specify a return `Field` based on the input +schema. To upgrade structs which implement `PhysicalExpr` you need to implement +the `return_field` function. There are numerous examples in the `physical-expr` +crate. + ## DataFusion `47.0.0` This section calls out some of the major changes in the `47.0.0` release of DataFusion.