diff --git a/datafusion/functions/src/crypto/basic.rs b/datafusion/functions/src/crypto/basic.rs index 191154b8f8ff..eaa688c1c335 100644 --- a/datafusion/functions/src/crypto/basic.rs +++ b/datafusion/functions/src/crypto/basic.rs @@ -17,7 +17,10 @@ //! "crypto" DataFusion functions -use arrow::array::{Array, ArrayRef, BinaryArray, OffsetSizeTrait}; +use arrow::array::{ + Array, ArrayRef, BinaryArray, BinaryArrayType, BinaryViewArray, GenericBinaryArray, + OffsetSizeTrait, +}; use arrow::array::{AsArray, GenericStringArray, StringArray, StringViewArray}; use arrow::datatypes::DataType; use blake2::{Blake2b512, Blake2s256, Digest}; @@ -26,8 +29,8 @@ use datafusion_common::cast::as_binary_array; use arrow::compute::StringArrayType; use datafusion_common::{ - cast::as_generic_binary_array, exec_err, internal_err, plan_err, - utils::take_function_args, DataFusionError, Result, ScalarValue, + exec_err, internal_err, plan_err, utils::take_function_args, DataFusionError, Result, + ScalarValue, }; use datafusion_expr::ColumnarValue; use md5::Md5; @@ -203,6 +206,7 @@ pub fn utf8_or_binary_to_binary_type( | DataType::LargeUtf8 | DataType::Utf8 | DataType::Binary + | DataType::BinaryView | DataType::LargeBinary => DataType::Binary, DataType::Null => DataType::Null, _ => { @@ -251,27 +255,17 @@ impl DigestAlgorithm { where T: OffsetSizeTrait, { - let input_value = as_generic_binary_array::(value)?; - let array: ArrayRef = match self { - Self::Md5 => digest_to_array!(Md5, input_value), - Self::Sha224 => digest_to_array!(Sha224, input_value), - Self::Sha256 => digest_to_array!(Sha256, input_value), - Self::Sha384 => digest_to_array!(Sha384, input_value), - Self::Sha512 => digest_to_array!(Sha512, input_value), - Self::Blake2b => digest_to_array!(Blake2b512, input_value), - Self::Blake2s => digest_to_array!(Blake2s256, input_value), - Self::Blake3 => { - let binary_array: BinaryArray = input_value - .iter() - .map(|opt| { - opt.map(|x| { - let mut digest = Blake3::default(); - digest.update(x); - Blake3::finalize(&digest).as_bytes().to_vec() - }) - }) - .collect(); - Arc::new(binary_array) + let array = match value.data_type() { + DataType::Binary | DataType::LargeBinary => { + let v = value.as_binary::(); + self.digest_binary_array_impl::<&GenericBinaryArray>(v) + } + DataType::BinaryView => { + let v = value.as_binary_view(); + self.digest_binary_array_impl::<&BinaryViewArray>(v) + } + other => { + return exec_err!("unsupported type for digest_utf_array: {other:?}") } }; Ok(ColumnarValue::Array(array)) @@ -328,6 +322,37 @@ impl DigestAlgorithm { } } } + + pub fn digest_binary_array_impl<'a, BinaryArrType>( + self, + input_value: BinaryArrType, + ) -> ArrayRef + where + BinaryArrType: BinaryArrayType<'a>, + { + match self { + Self::Md5 => digest_to_array!(Md5, input_value), + Self::Sha224 => digest_to_array!(Sha224, input_value), + Self::Sha256 => digest_to_array!(Sha256, input_value), + Self::Sha384 => digest_to_array!(Sha384, input_value), + Self::Sha512 => digest_to_array!(Sha512, input_value), + Self::Blake2b => digest_to_array!(Blake2b512, input_value), + Self::Blake2s => digest_to_array!(Blake2s256, input_value), + Self::Blake3 => { + let binary_array: BinaryArray = input_value + .iter() + .map(|opt| { + opt.map(|x| { + let mut digest = Blake3::default(); + digest.update(x); + Blake3::finalize(&digest).as_bytes().to_vec() + }) + }) + .collect(); + Arc::new(binary_array) + } + } + } } pub fn digest_process( value: &ColumnarValue, @@ -342,22 +367,27 @@ pub fn digest_process( DataType::LargeBinary => { digest_algorithm.digest_binary_array::(a.as_ref()) } - other => exec_err!( - "Unsupported data type {other:?} for function {digest_algorithm}" - ), - }, - ColumnarValue::Scalar(scalar) => match scalar { - ScalarValue::Utf8View(a) - | ScalarValue::Utf8(a) - | ScalarValue::LargeUtf8(a) => { - Ok(digest_algorithm - .digest_scalar(a.as_ref().map(|s: &String| s.as_bytes()))) + DataType::BinaryView => { + digest_algorithm.digest_binary_array::(a.as_ref()) } - ScalarValue::Binary(a) | ScalarValue::LargeBinary(a) => Ok(digest_algorithm - .digest_scalar(a.as_ref().map(|v: &Vec| v.as_slice()))), other => exec_err!( "Unsupported data type {other:?} for function {digest_algorithm}" ), }, + ColumnarValue::Scalar(scalar) => { + match scalar { + ScalarValue::Utf8View(a) + | ScalarValue::Utf8(a) + | ScalarValue::LargeUtf8(a) => Ok(digest_algorithm + .digest_scalar(a.as_ref().map(|s: &String| s.as_bytes()))), + ScalarValue::Binary(a) + | ScalarValue::LargeBinary(a) + | ScalarValue::BinaryView(a) => Ok(digest_algorithm + .digest_scalar(a.as_ref().map(|v: &Vec| v.as_slice()))), + other => exec_err!( + "Unsupported data type {other:?} for function {digest_algorithm}" + ), + } + } } } diff --git a/datafusion/functions/src/crypto/digest.rs b/datafusion/functions/src/crypto/digest.rs index 4f9d4605fe07..2840006169be 100644 --- a/datafusion/functions/src/crypto/digest.rs +++ b/datafusion/functions/src/crypto/digest.rs @@ -18,11 +18,15 @@ //! "crypto" DataFusion functions use super::basic::{digest, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; -use datafusion_common::Result; +use datafusion_common::{ + types::{logical_binary, logical_string}, + Result, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - TypeSignature::*, Volatility, + TypeSignature, Volatility, }; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; use std::any::Any; @@ -64,15 +68,17 @@ impl Default for DigestFunc { impl DigestFunc { pub fn new() -> Self { - use DataType::*; Self { signature: Signature::one_of( vec![ - Exact(vec![Utf8View, Utf8View]), - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Binary, Utf8]), - Exact(vec![LargeBinary, Utf8]), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ]), + TypeSignature::Coercible(vec![ + Coercion::new_exact(TypeSignatureClass::Native(logical_binary())), + Coercion::new_exact(TypeSignatureClass::Native(logical_string())), + ]), ], Volatility::Immutable, ), diff --git a/datafusion/functions/src/crypto/md5.rs b/datafusion/functions/src/crypto/md5.rs index 18ad0d6a7ded..c1540450029c 100644 --- a/datafusion/functions/src/crypto/md5.rs +++ b/datafusion/functions/src/crypto/md5.rs @@ -18,11 +18,16 @@ //! "crypto" DataFusion functions use crate::crypto::basic::md5; use arrow::datatypes::DataType; -use datafusion_common::{plan_err, Result}; +use datafusion_common::{ + plan_err, + types::{logical_binary, logical_string, NativeType}, + Result, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - Volatility, + TypeSignature, Volatility, }; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; use std::any::Any; @@ -52,11 +57,20 @@ impl Default for Md5Func { impl Md5Func { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_binary()), + vec![TypeSignatureClass::Native(logical_string())], + NativeType::String, + )]), + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_binary()), + vec![TypeSignatureClass::Native(logical_binary())], + NativeType::Binary, + )]), + ], Volatility::Immutable, ), } @@ -79,11 +93,11 @@ impl ScalarUDFImpl for Md5Func { use DataType::*; Ok(match &arg_types[0] { LargeUtf8 | LargeBinary => Utf8, - Utf8View | Utf8 | Binary => Utf8, + Utf8View | Utf8 | Binary | BinaryView => Utf8, Null => Null, Dictionary(_, t) => match **t { LargeUtf8 | LargeBinary => Utf8, - Utf8 | Binary => Utf8, + Utf8 | Binary | BinaryView => Utf8, Null => Null, _ => { return plan_err!( diff --git a/datafusion/functions/src/crypto/sha224.rs b/datafusion/functions/src/crypto/sha224.rs index 24fe5e119df3..a64a3ef80319 100644 --- a/datafusion/functions/src/crypto/sha224.rs +++ b/datafusion/functions/src/crypto/sha224.rs @@ -18,11 +18,15 @@ //! "crypto" DataFusion functions use super::basic::{sha224, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; -use datafusion_common::Result; +use datafusion_common::{ + types::{logical_binary, logical_string, NativeType}, + Result, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - Volatility, + TypeSignature, Volatility, }; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; use std::any::Any; @@ -53,11 +57,20 @@ impl Default for SHA224Func { impl SHA224Func { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_binary()), + vec![TypeSignatureClass::Native(logical_string())], + NativeType::String, + )]), + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_binary()), + vec![TypeSignatureClass::Native(logical_binary())], + NativeType::Binary, + )]), + ], Volatility::Immutable, ), } diff --git a/datafusion/functions/src/crypto/sha256.rs b/datafusion/functions/src/crypto/sha256.rs index c48dda19cbc5..94f3ea3b49fa 100644 --- a/datafusion/functions/src/crypto/sha256.rs +++ b/datafusion/functions/src/crypto/sha256.rs @@ -18,11 +18,15 @@ //! "crypto" DataFusion functions use super::basic::{sha256, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; -use datafusion_common::Result; +use datafusion_common::{ + types::{logical_binary, logical_string, NativeType}, + Result, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - Volatility, + TypeSignature, Volatility, }; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; use std::any::Any; @@ -52,11 +56,20 @@ impl Default for SHA256Func { impl SHA256Func { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_binary()), + vec![TypeSignatureClass::Native(logical_string())], + NativeType::String, + )]), + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_binary()), + vec![TypeSignatureClass::Native(logical_binary())], + NativeType::Binary, + )]), + ], Volatility::Immutable, ), } diff --git a/datafusion/functions/src/crypto/sha384.rs b/datafusion/functions/src/crypto/sha384.rs index 11d1d130e929..023730469c7b 100644 --- a/datafusion/functions/src/crypto/sha384.rs +++ b/datafusion/functions/src/crypto/sha384.rs @@ -18,11 +18,15 @@ //! "crypto" DataFusion functions use super::basic::{sha384, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; -use datafusion_common::Result; +use datafusion_common::{ + types::{logical_binary, logical_string, NativeType}, + Result, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - Volatility, + TypeSignature, Volatility, }; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; use std::any::Any; @@ -52,11 +56,20 @@ impl Default for SHA384Func { impl SHA384Func { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_binary()), + vec![TypeSignatureClass::Native(logical_string())], + NativeType::String, + )]), + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_binary()), + vec![TypeSignatureClass::Native(logical_binary())], + NativeType::Binary, + )]), + ], Volatility::Immutable, ), } diff --git a/datafusion/functions/src/crypto/sha512.rs b/datafusion/functions/src/crypto/sha512.rs index 26fa85a5da3a..f48737e5751f 100644 --- a/datafusion/functions/src/crypto/sha512.rs +++ b/datafusion/functions/src/crypto/sha512.rs @@ -18,11 +18,15 @@ //! "crypto" DataFusion functions use super::basic::{sha512, utf8_or_binary_to_binary_type}; use arrow::datatypes::DataType; -use datafusion_common::Result; +use datafusion_common::{ + types::{logical_binary, logical_string, NativeType}, + Result, +}; use datafusion_expr::{ ColumnarValue, Documentation, ScalarFunctionArgs, ScalarUDFImpl, Signature, - Volatility, + TypeSignature, Volatility, }; +use datafusion_expr_common::signature::{Coercion, TypeSignatureClass}; use datafusion_macros::user_doc; use std::any::Any; @@ -52,11 +56,20 @@ impl Default for SHA512Func { impl SHA512Func { pub fn new() -> Self { - use DataType::*; Self { - signature: Signature::uniform( - 1, - vec![Utf8View, Utf8, LargeUtf8, Binary, LargeBinary], + signature: Signature::one_of( + vec![ + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_binary()), + vec![TypeSignatureClass::Native(logical_string())], + NativeType::String, + )]), + TypeSignature::Coercible(vec![Coercion::new_implicit( + TypeSignatureClass::Native(logical_binary()), + vec![TypeSignatureClass::Native(logical_binary())], + NativeType::Binary, + )]), + ], Volatility::Immutable, ), } diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index 352064fbe5c8..cb56686b6437 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -5973,13 +5973,13 @@ true false true false false false true true false false true false true # NB that `col in (a, b, c)` is simplified to OR if there are <= 3 elements, so we make 4-element haystack lists query I -with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'); ---- 1 query TT -explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +explain with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE needle IN ('7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'); ---- logical_plan @@ -6002,13 +6002,13 @@ physical_plan 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I -with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE needle = ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']); ---- 1 query TT -explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +explain with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE needle = ANY(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c']); ---- logical_plan @@ -6031,13 +6031,13 @@ physical_plan 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I -with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE array_has(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], needle); ---- 1 query TT -explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +explain with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE array_has(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], needle); ---- logical_plan @@ -6068,7 +6068,7 @@ physical_plan # FIXME: array_has with large list haystack not currently rewritten to InList query TT -explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +explain with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'LargeList(Utf8View)'), needle); ---- logical_plan @@ -6091,13 +6091,13 @@ physical_plan 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I -with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'FixedSizeList(4, Utf8View)'), needle); ---- 1 query TT -explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +explain with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE array_has(arrow_cast(['7f4b18de3cfeb9b4ac78c381ee2ad278', 'a', 'b', 'c'], 'FixedSizeList(4, Utf8View)'), needle); ---- logical_plan @@ -6120,14 +6120,14 @@ physical_plan 09)----------------LazyMemoryExec: partitions=1, batch_generators=[generate_series: start=1, end=100000, batch_size=8192] query I -with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE array_has([needle], needle); ---- 100000 # TODO: this should probably be possible to completely remove the filter as always true? query TT -explain with test AS (SELECT substr(md5(i)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) +explain with test AS (SELECT substr(md5(i::text)::text, 1, 32) as needle FROM generate_series(1, 100000) t(i)) select count(*) from test WHERE array_has([needle], needle); ---- logical_plan diff --git a/datafusion/sqllogictest/test_files/string/string_view.slt b/datafusion/sqllogictest/test_files/string/string_view.slt index 96fb2477598c..a72c8f574484 100644 --- a/datafusion/sqllogictest/test_files/string/string_view.slt +++ b/datafusion/sqllogictest/test_files/string/string_view.slt @@ -1068,7 +1068,7 @@ EXPLAIN SELECT FROM test; ---- logical_plan -01)Projection: digest(test.column1_utf8view, Utf8View("md5")) AS c +01)Projection: digest(test.column1_utf8view, Utf8("md5")) AS c 02)--TableScan: test projection=[column1_utf8view] ## Ensure no unexpected casts for string_to_array