diff --git a/Cargo.toml b/Cargo.toml index 023dc6c6fc4f..47638abe0515 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,7 @@ members = [ "datafusion/core", "datafusion/expr", "datafusion/execution", + "datafusion/functions", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/physical-plan", @@ -62,6 +63,7 @@ ctor = "0.2.0" datafusion = { path = "datafusion/core", version = "34.0.0" } datafusion-common = { path = "datafusion/common", version = "34.0.0" } datafusion-expr = { path = "datafusion/expr", version = "34.0.0" } +datafusion-functions = { path = "datafusion/functions", version = "34.0.0" } datafusion-sql = { path = "datafusion/sql", version = "34.0.0" } datafusion-optimizer = { path = "datafusion/optimizer", version = "34.0.0" } datafusion-physical-expr = { path = "datafusion/physical-expr", version = "34.0.0" } diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 19ad6709362d..6475accc8d9f 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -384,7 +384,7 @@ checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -1069,12 +1069,12 @@ dependencies = [ [[package]] name = "ctor" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "37e366bff8cd32dd8754b0991fb66b279dc48f598c3a18914852a6673deef583" +checksum = "30d2b3721e861707777e3195b0158f950ae6dc4a27e4d02ff9f67e3eb3de199e" dependencies = [ "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -1114,6 +1114,7 @@ dependencies = [ "datafusion-common", "datafusion-execution", "datafusion-expr", + "datafusion-functions", "datafusion-optimizer", "datafusion-physical-expr", "datafusion-physical-plan", @@ -1222,6 +1223,19 @@ dependencies = [ "strum_macros", ] +[[package]] +name = "datafusion-functions" +version = "34.0.0" +dependencies = [ + "arrow", + "base64", + "datafusion-common", + "datafusion-execution", + "datafusion-expr", + "hex", + "log", +] + [[package]] name = "datafusion-optimizer" version = "34.0.0" @@ -1576,7 +1590,7 @@ checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -2496,7 +2510,7 @@ checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3020,7 +3034,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3196,7 +3210,7 @@ dependencies = [ "proc-macro2", "quote", "rustversion", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3218,9 +3232,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.40" +version = "2.0.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13fa70a4ee923979ffb522cacce59d34421ebdea5625e1073c4326ef9d2dd42e" +checksum = "44c8b28c477cc3bf0e7966561e3460130e1255f7a1cf71931075f1c5e7a7e269" dependencies = [ "proc-macro2", "quote", @@ -3299,7 +3313,7 @@ checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3391,7 +3405,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3488,7 +3502,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3533,7 +3547,7 @@ checksum = "f03ca4cb38206e2bef0700092660bb74d696f808514dae47fa1467cbfe26e96e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] @@ -3687,7 +3701,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", "wasm-bindgen-shared", ] @@ -3721,7 +3735,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -3970,22 +3984,22 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.30" +version = "0.7.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "306dca4455518f1f31635ec308b6b3e4eb1b11758cefafc782827d0aa7acb5c7" +checksum = "1c4061bedbb353041c12f413700357bec76df2c7e2ca8e4df8bac24c6bf68e3d" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.30" +version = "0.7.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be912bf68235a88fbefd1b73415cb218405958d1655b2ece9035a19920bdf6ba" +checksum = "b3c129550b3e6de3fd0ba67ba5c81818f9805e58b8d7fee80a3a59d2c9fc601a" dependencies = [ "proc-macro2", "quote", - "syn 2.0.40", + "syn 2.0.41", ] [[package]] diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 0ee83e756745..c2205fcff880 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -40,7 +40,7 @@ backtrace = ["datafusion-common/backtrace"] compression = ["xz2", "bzip2", "flate2", "zstd", "async-compression"] crypto_expressions = ["datafusion-physical-expr/crypto_expressions", "datafusion-optimizer/crypto_expressions"] default = ["crypto_expressions", "encoding_expressions", "regex_expressions", "unicode_expressions", "compression", "parquet"] -encoding_expressions = ["datafusion-physical-expr/encoding_expressions"] +encoding_expressions = ["datafusion-functions/encoding_expressions"] # Used for testing ONLY: causes all values to hash to the same value (test for collisions) force_hash_collisions = [] parquet = ["datafusion-common/parquet", "dep:parquet"] @@ -65,6 +65,7 @@ dashmap = { workspace = true } datafusion-common = { path = "../common", version = "34.0.0", features = ["object_store"], default-features = false } datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } +datafusion-functions = { path = "../functions", version = "34.0.0"} datafusion-optimizer = { path = "../optimizer", version = "34.0.0", default-features = false } datafusion-physical-expr = { path = "../physical-expr", version = "34.0.0", default-features = false } datafusion-physical-plan = { workspace = true } diff --git a/datafusion/core/src/dataframe/mod.rs b/datafusion/core/src/dataframe/mod.rs index 4b8a9c5b7d79..94cd50123295 100644 --- a/datafusion/core/src/dataframe/mod.rs +++ b/datafusion/core/src/dataframe/mod.rs @@ -40,7 +40,6 @@ use crate::physical_plan::{ collect, collect_partitioned, execute_stream, execute_stream_partitioned, ExecutionPlan, SendableRecordBatchStream, }; -use crate::prelude::SessionContext; use arrow::array::{Array, ArrayRef, Int64Array, StringArray}; use arrow::compute::{cast, concat}; @@ -59,6 +58,7 @@ use datafusion_expr::{ TableProviderFilterPushDown, UNNAMED_TABLE, }; +use crate::prelude::SessionContext; use async_trait::async_trait; /// Contains options that control how data is diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 58a4f08341d6..de3a294ce3df 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1343,7 +1343,7 @@ impl SessionState { ); } - SessionState { + let mut new_self = SessionState { session_id, analyzer: Analyzer::new(), optimizer: Optimizer::new(), @@ -1359,7 +1359,13 @@ impl SessionState { execution_props: ExecutionProps::new(), runtime_env: runtime, table_factories, - } + }; + + // register built in functions + datafusion_functions::register_all(&mut new_self) + .expect("can not register built in functions"); + + new_self } /// Returns new [`SessionState`] using the provided /// [`SessionConfig`] and [`RuntimeEnv`]. @@ -1968,6 +1974,10 @@ impl FunctionRegistry for SessionState { plan_datafusion_err!("There is no UDWF named \"{name}\" in the registry") }) } + + fn register_udf(&mut self, udf: Arc) -> Result>> { + Ok(self.scalar_functions.insert(udf.name().into(), udf)) + } } impl OptimizerConfig for SessionState { diff --git a/datafusion/core/src/prelude.rs b/datafusion/core/src/prelude.rs index 5cd8b3870f81..69c33355402b 100644 --- a/datafusion/core/src/prelude.rs +++ b/datafusion/core/src/prelude.rs @@ -38,6 +38,7 @@ pub use datafusion_expr::{ logical_plan::{JoinType, Partitioning}, Expr, }; +pub use datafusion_functions::expr_fn::*; pub use std::ops::Not; pub use std::ops::{Add, Div, Mul, Neg, Rem, Sub}; diff --git a/datafusion/execution/src/registry.rs b/datafusion/execution/src/registry.rs index 9ba487e715b3..ef7486f1870c 100644 --- a/datafusion/execution/src/registry.rs +++ b/datafusion/execution/src/registry.rs @@ -17,7 +17,7 @@ //! FunctionRegistry trait -use datafusion_common::Result; +use datafusion_common::{not_impl_err, DataFusionError, Result}; use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF}; use std::{collections::HashSet, sync::Arc}; @@ -34,6 +34,17 @@ pub trait FunctionRegistry { /// Returns a reference to the udwf named `name`. fn udwf(&self, name: &str) -> Result>; + + /// Registers a new `ScalarUDF`, returning any previously registered + /// implementation. + /// + /// Returns an error (default) if the function can not be registered, for + /// example because the registry doesn't support new functions + fn register_udf(&mut self, _udf: Arc) -> Result>> { + not_impl_err!("Registering ScalarUDF") + } + + // TODO add register_udaf and register_udwf } /// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode]. diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index fd899289ac82..f3d2ffd6da34 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -69,14 +69,10 @@ pub enum BuiltinScalarFunction { Cos, /// cos Cosh, - /// Decode - Decode, /// degrees Degrees, /// Digest Digest, - /// Encode - Encode, /// exp Exp, /// factorial @@ -373,9 +369,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Coalesce => Volatility::Immutable, BuiltinScalarFunction::Cos => Volatility::Immutable, BuiltinScalarFunction::Cosh => Volatility::Immutable, - BuiltinScalarFunction::Decode => Volatility::Immutable, BuiltinScalarFunction::Degrees => Volatility::Immutable, - BuiltinScalarFunction::Encode => Volatility::Immutable, BuiltinScalarFunction::Exp => Volatility::Immutable, BuiltinScalarFunction::Factorial => Volatility::Immutable, BuiltinScalarFunction::Floor => Volatility::Immutable, @@ -746,30 +740,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Digest => { utf8_or_binary_to_binary_type(&input_expr_types[0], "digest") } - BuiltinScalarFunction::Encode => Ok(match input_expr_types[0] { - Utf8 => Utf8, - LargeUtf8 => LargeUtf8, - Binary => Utf8, - LargeBinary => LargeUtf8, - Null => Null, - _ => { - return plan_err!( - "The encode function can only accept utf8 or binary." - ); - } - }), - BuiltinScalarFunction::Decode => Ok(match input_expr_types[0] { - Utf8 => Binary, - LargeUtf8 => LargeBinary, - Binary => Binary, - LargeBinary => LargeBinary, - Null => Null, - _ => { - return plan_err!( - "The decode function can only accept utf8 or binary." - ); - } - }), BuiltinScalarFunction::SplitPart => { utf8_to_str_type(&input_expr_types[0], "split_part") } @@ -1100,24 +1070,6 @@ impl BuiltinScalarFunction { ], self.volatility(), ), - BuiltinScalarFunction::Encode => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Binary, Utf8]), - Exact(vec![LargeBinary, Utf8]), - ], - self.volatility(), - ), - BuiltinScalarFunction::Decode => Signature::one_of( - vec![ - Exact(vec![Utf8, Utf8]), - Exact(vec![LargeUtf8, Utf8]), - Exact(vec![Binary, Utf8]), - Exact(vec![LargeBinary, Utf8]), - ], - self.volatility(), - ), BuiltinScalarFunction::DateTrunc => Signature::one_of( vec![ Exact(vec![Utf8, Timestamp(Nanosecond, None)]), @@ -1552,10 +1504,6 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::SHA384 => &["sha384"], BuiltinScalarFunction::SHA512 => &["sha512"], - // encode/decode - BuiltinScalarFunction::Encode => &["encode"], - BuiltinScalarFunction::Decode => &["decode"], - // other functions BuiltinScalarFunction::ArrowTypeof => &["arrow_typeof"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index cedf1d845137..f04965702f52 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -785,8 +785,6 @@ scalar_expr!( "converts the Unicode code point to a UTF8 character" ); scalar_expr!(Digest, digest, input algorithm, "compute the binary hash of `input`, using the `algorithm`"); -scalar_expr!(Encode, encode, input encoding, "encode the `input`, using the `encoding`. encoding can be base64 or hex"); -scalar_expr!(Decode, decode, input encoding, "decode the`input`, using the `encoding`. encoding can be base64 or hex"); scalar_expr!(InitCap, initcap, string, "converts the first letter of each word in `string` in uppercase and the remaining characters in lowercase"); scalar_expr!(Left, left, string n, "returns the first `n` characters in the `string`"); scalar_expr!(Lower, lower, string, "convert the string to lower case"); @@ -1135,8 +1133,6 @@ mod test { test_scalar_expr!(CharacterLength, character_length, string); test_scalar_expr!(Chr, chr, string); test_scalar_expr!(Digest, digest, string, algorithm); - test_scalar_expr!(Encode, encode, string, encoding); - test_scalar_expr!(Decode, decode, string, encoding); test_scalar_expr!(Gcd, gcd, arg_1, arg_2); test_scalar_expr!(Lcm, lcm, arg_1, arg_2); test_scalar_expr!(InitCap, initcap, string); @@ -1249,34 +1245,4 @@ mod test { unreachable!(); } } - - #[test] - fn encode_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(fun), - args, - }) = encode(col("tableA.a"), lit("base64")) - { - let name = BuiltinScalarFunction::Encode; - assert_eq!(name, fun); - assert_eq!(2, args.len()); - } else { - unreachable!(); - } - } - - #[test] - fn decode_function_definitions() { - if let Expr::ScalarFunction(ScalarFunction { - func_def: ScalarFunctionDefinition::BuiltIn(fun), - args, - }) = decode(col("tableA.a"), lit("hex")) - { - let name = BuiltinScalarFunction::Decode; - assert_eq!(name, fun); - assert_eq!(2, args.len()); - } else { - unreachable!(); - } - } } diff --git a/datafusion/expr/src/lib.rs b/datafusion/expr/src/lib.rs index 48532e13dcd7..0a15fcaea9aa 100644 --- a/datafusion/expr/src/lib.rs +++ b/datafusion/expr/src/lib.rs @@ -80,7 +80,7 @@ pub use signature::{ }; pub use table_source::{TableProviderFilterPushDown, TableSource, TableType}; pub use udaf::AggregateUDF; -pub use udf::ScalarUDF; +pub use udf::{FunctionImplementation, ScalarUDF}; pub use udwf::WindowUDF; pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits}; pub use window_function::{BuiltInWindowFunction, WindowFunction}; diff --git a/datafusion/expr/src/udf.rs b/datafusion/expr/src/udf.rs index 3a18ca2d25e8..390821afdf23 100644 --- a/datafusion/expr/src/udf.rs +++ b/datafusion/expr/src/udf.rs @@ -17,7 +17,9 @@ //! [`ScalarUDF`]: Scalar User Defined Functions -use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature}; +use crate::{ + ColumnarValue, Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature, +}; use arrow::datatypes::DataType; use datafusion_common::Result; use std::fmt; @@ -95,6 +97,24 @@ impl ScalarUDF { } } + /// Create a new `ScalarUDF` from a `FuncImpl` + pub fn new_from_impl( + fun: impl FunctionImplementation + Send + Sync + 'static, + ) -> ScalarUDF { + let arc_fun = Arc::new(fun); + let captured_self = arc_fun.clone(); + let return_type: ReturnTypeFunction = Arc::new(move |arg_types| { + let return_type = captured_self.return_type(arg_types)?; + Ok(Arc::new(return_type)) + }); + + let captured_self = arc_fun.clone(); + let func: ScalarFunctionImplementation = + Arc::new(move |args| captured_self.invoke(args)); + + ScalarUDF::new(arc_fun.name(), arc_fun.signature(), &return_type, &func) + } + /// Adds additional names that can be used to invoke this function, in addition to `name` pub fn with_aliases( mut self, @@ -140,6 +160,19 @@ impl ScalarUDF { pub fn fun(&self) -> ScalarFunctionImplementation { self.fun.clone() } +} + +/// Convenience trait for implementing ScalarUDF. See [`ScalarUDF::new_from_impl()`] +pub trait FunctionImplementation { + /// Returns this function's name + fn name(&self) -> &str; + + /// Returns this function's signature + fn signature(&self) -> &Signature; + + /// return the return type of this function given the types of the arguments + fn return_type(&self, arg_types: &[DataType]) -> Result; - // TODO maybe add an invoke() method that runs the actual function? + /// Invoke the function on `args`, returning the appropriate result + fn invoke(&self, args: &[ColumnarValue]) -> Result; } diff --git a/datafusion/functions/Cargo.toml b/datafusion/functions/Cargo.toml new file mode 100644 index 000000000000..af7831c54d58 --- /dev/null +++ b/datafusion/functions/Cargo.toml @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +[package] +name = "datafusion-functions" +description = "Function packages for the datafusion query engine" +keywords = ["datafusion", "logical", "plan", "expressions"] +readme = "README.md" +version = { workspace = true } +edition = { workspace = true } +homepage = { workspace = true } +repository = { workspace = true } +license = { workspace = true } +authors = { workspace = true } +rust-version = { workspace = true } + +[features] +# enable the encode/decode functions +encoding_expressions = ["base64", "hex"] + + +[lib] +name = "datafusion_functions" +path = "src/lib.rs" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +arrow = { workspace = true } +base64 = { version = "0.21", optional = true } +datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } +hex = { version = "0.4", optional = true } +log = "0.4.20" diff --git a/datafusion/functions/README.md b/datafusion/functions/README.md new file mode 100644 index 000000000000..925769be18f8 --- /dev/null +++ b/datafusion/functions/README.md @@ -0,0 +1,26 @@ + + +# DataFusion Function Library + +[DataFusion][df] is an extensible query execution framework, written in Rust, that uses Apache Arrow as its in-memory format. + +This crate contains several "built in" function packages that can be used with DataFusion. + +[df]: https://crates.io/crates/datafusion diff --git a/datafusion/physical-expr/src/encoding_expressions.rs b/datafusion/functions/src/encoding/inner.rs similarity index 82% rename from datafusion/physical-expr/src/encoding_expressions.rs rename to datafusion/functions/src/encoding/inner.rs index b74310485fb7..e8e15903d881 100644 --- a/datafusion/physical-expr/src/encoding_expressions.rs +++ b/datafusion/functions/src/encoding/inner.rs @@ -32,12 +32,108 @@ use datafusion_expr::ColumnarValue; use std::sync::Arc; use std::{fmt, str::FromStr}; +use datafusion_expr::TypeSignature::*; +use datafusion_expr::{FunctionImplementation, Signature, Volatility}; +use std::sync::OnceLock; + +#[derive(Default, Debug)] +pub(super) struct EncodeFunc {} + +static ENCODE_SIGNATURE: OnceLock = OnceLock::new(); + +impl FunctionImplementation for EncodeFunc { + fn name(&self) -> &str { + "encode" + } + + fn signature(&self) -> &Signature { + use DataType::*; + ENCODE_SIGNATURE.get_or_init(|| { + Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Binary, Utf8]), + Exact(vec![LargeBinary, Utf8]), + ], + Volatility::Immutable, + ) + }) + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(match arg_types[0] { + Utf8 => Utf8, + LargeUtf8 => LargeUtf8, + Binary => Utf8, + LargeBinary => LargeUtf8, + Null => Null, + _ => { + return plan_err!("The encode function can only accept utf8 or binary."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + // Put a feature flag here to make sure this is only compiled when the feature is activated + super::inner::encode(args) + } +} + +#[derive(Default, Debug)] +pub(super) struct DecodeFunc {} + +static DECODE_SIGNATURE: OnceLock = OnceLock::new(); + +impl FunctionImplementation for DecodeFunc { + fn name(&self) -> &str { + "decode" + } + + fn signature(&self) -> &Signature { + use DataType::*; + + DECODE_SIGNATURE.get_or_init(|| { + Signature::one_of( + vec![ + Exact(vec![Utf8, Utf8]), + Exact(vec![LargeUtf8, Utf8]), + Exact(vec![Binary, Utf8]), + Exact(vec![LargeBinary, Utf8]), + ], + Volatility::Immutable, + ) + }) + } + + fn return_type(&self, arg_types: &[DataType]) -> Result { + use DataType::*; + + Ok(match arg_types[0] { + Utf8 => Binary, + LargeUtf8 => LargeBinary, + Binary => Binary, + LargeBinary => LargeBinary, + Null => Null, + _ => { + return plan_err!("The decode function can only accept utf8 or binary."); + } + }) + } + + fn invoke(&self, args: &[ColumnarValue]) -> Result { + // Put a feature flag here to make sure this is only compiled when the feature is activated + super::inner::decode(args) + } +} + #[derive(Debug, Copy, Clone)] enum Encoding { Base64, Hex, } - fn encode_process(value: &ColumnarValue, encoding: Encoding) -> Result { match value { ColumnarValue::Array(a) => match a.data_type() { diff --git a/datafusion/functions/src/encoding/mod.rs b/datafusion/functions/src/encoding/mod.rs new file mode 100644 index 000000000000..00ef24adbc57 --- /dev/null +++ b/datafusion/functions/src/encoding/mod.rs @@ -0,0 +1,28 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod inner; + +use datafusion_expr::{Expr, ScalarUDF}; +use std::sync::{Arc, OnceLock}; + +make_function!(inner::EncodeFunc, ENCODE, encode); +make_function!(inner::DecodeFunc, DECODE, decode); + +// Export the functions out of this package, both as expr_fn as well as in functions +export_functions!(encode, decode); + diff --git a/datafusion/functions/src/lib.rs b/datafusion/functions/src/lib.rs new file mode 100644 index 000000000000..d4099b629bbd --- /dev/null +++ b/datafusion/functions/src/lib.rs @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Built in optional function packages for DataFusion +//! +//! Each module should implement a "function package" that should have a function with the signature: +//! +//! ``` +//! # use std::sync::Arc; +//! # use datafusion_expr::FunctionImplementation; +//! // return a list of functions or stubs +//! fn functions() -> Vec> { +//! todo!() +//! } +//! ``` +//! +//! Which returns: +//! +//! 1. The list of actual function implementation when the relevant +//! feature is activated, +//! +//! 2. A list of stub function when the feature is not activated that produce +//! a runtime error (and explain what feature flag is needed to activate them). +//! +//! The rationale for providing stub functions is to help users to configure datafusion +//! properly (so they get an error telling them why a function is not available) +//! instead of getting a cryptic "no function found" message at runtime. +use datafusion_common::Result; +use datafusion_execution::FunctionRegistry; +use log::debug; + +pub mod stub; +#[macro_use] +mod macros; + +make_package!(encoding, "encoding_expressions"); + +/// reexports of all expr_fn APIs +pub mod expr_fn { + #[cfg(feature = "encoding_expressions")] + pub use super::encoding::expr_fn::*; +} + +pub fn register_all(registry: &mut dyn FunctionRegistry) -> Result<()> { + encoding::functions().into_iter().try_for_each(|udf| { + let existing_udf = registry.register_udf(udf)?; + if let Some(existing_udf) = existing_udf { + debug!("Overwrite existing UDF: {}", existing_udf.name()); + } + Ok(()) as Result<()> + })?; + Ok(()) +} diff --git a/datafusion/functions/src/macros.rs b/datafusion/functions/src/macros.rs new file mode 100644 index 000000000000..eca3d3355579 --- /dev/null +++ b/datafusion/functions/src/macros.rs @@ -0,0 +1,96 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +/// macro that exports listed names as individual functions in an `expr_fn` module +/// +/// Equivalent to +/// ```text +/// pub mod expr_fn { +/// use super::*; +/// /// Return encode(arg) +/// pub fn encode(args: Vec) -> Expr { +/// super::encode().call(args) +/// } +/// ... +/// /// Return a list of all functions in this package +/// pub(crate) fn functions() -> Vec> { +/// vec![ +/// encode(), +/// decode() +/// ] +/// } +/// ``` +macro_rules! export_functions { + ($($name:ident),*) => { + pub mod expr_fn { + use super::*; + $( + /// Return $name(arg) + pub fn $name(args: Vec) -> Expr { + super::$name().call(args) + } + )* + } + + /// Return a list of all functions in this package + pub(crate) fn functions() -> Vec> { + vec![ + $( + $name(), + )* + ] + } + }; +} + +/// Create a singleton instance of the function named $GNAME and return it from the +/// function named $NAME +macro_rules! make_function { + ($UDF:ty, $GNAME:ident, $NAME:ident) => { + /// Singleton instance of the function + static $GNAME: OnceLock> = OnceLock::new(); + + /// Return the function implementation + fn $NAME() -> Arc { + $GNAME + .get_or_init(|| Arc::new(ScalarUDF::new_from_impl(<$UDF>::default()))) + .clone() + } + }; +} + +// Macro creates the named module if the feature is enabled +// otherwise creates a stub +macro_rules! make_package { + ($name:ident, $feature:literal) => { + #[cfg(feature = $feature)] + pub mod $name; + + #[cfg(not(feature = $feature))] + /// Stub module when feature is not enabled + mod $name { + use datafusion_expr::ScalarUDF; + use log::debug; + use std::sync::Arc; + + pub(crate) fn functions() -> Vec> { + debug!("{} functions disabled", stringify!($name)); + vec![] + } + } + }; +} diff --git a/datafusion/functions/src/stub.rs b/datafusion/functions/src/stub.rs new file mode 100644 index 000000000000..f8cb4155a0f1 --- /dev/null +++ b/datafusion/functions/src/stub.rs @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::datatypes::DataType; +use datafusion_common::{exec_err, plan_err, DataFusionError, Result}; +use datafusion_expr::{ColumnarValue, FunctionImplementation, Signature, Volatility}; +use std::sync::OnceLock; + +/// A scalar function that always errors with a hint. This is used to stub out +/// functions that are not enabled with the current set of crate features. +pub struct StubFunc { + name: &'static str, + hint: &'static str, +} + +impl StubFunc { + /// Create a new stub function + pub fn new(name: &'static str, hint: &'static str) -> Self { + Self { name, hint } + } +} + +static STUB_SIGNATURE: OnceLock = OnceLock::new(); + +impl FunctionImplementation for StubFunc { + fn name(&self) -> &str { + self.name + } + + fn signature(&self) -> &Signature { + STUB_SIGNATURE.get_or_init(|| Signature::variadic_any(Volatility::Volatile)) + } + + fn return_type(&self, _args: &[DataType]) -> Result { + plan_err!("function {} not available. {}", self.name, self.hint) + } + fn invoke(&self, _args: &[ColumnarValue]) -> Result { + exec_err!("function {} not available. {}", self.name, self.hint) + } +} diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 53de85843919..8b08e6f92793 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -83,26 +83,6 @@ pub fn create_physical_expr( ))) } -#[cfg(feature = "encoding_expressions")] -macro_rules! invoke_if_encoding_expressions_feature_flag { - ($FUNC:ident, $NAME:expr) => {{ - use crate::encoding_expressions; - encoding_expressions::$FUNC - }}; -} - -#[cfg(not(feature = "encoding_expressions"))] -macro_rules! invoke_if_encoding_expressions_feature_flag { - ($FUNC:ident, $NAME:expr) => { - |_: &[ColumnarValue]| -> Result { - internal_err!( - "function {} requires compilation with feature flag: encoding_expressions.", - $NAME - ) - } - }; -} - #[cfg(feature = "crypto_expressions")] macro_rules! invoke_if_crypto_expressions_feature_flag { ($FUNC:ident, $NAME:expr) => {{ @@ -578,12 +558,6 @@ pub fn create_physical_fun( BuiltinScalarFunction::Digest => { Arc::new(invoke_if_crypto_expressions_feature_flag!(digest, "digest")) } - BuiltinScalarFunction::Decode => Arc::new( - invoke_if_encoding_expressions_feature_flag!(decode, "decode"), - ), - BuiltinScalarFunction::Encode => Arc::new( - invoke_if_encoding_expressions_feature_flag!(encode, "encode"), - ), BuiltinScalarFunction::NullIf => Arc::new(nullif_func), BuiltinScalarFunction::OctetLength => Arc::new(|args| match &args[0] { ColumnarValue::Array(v) => Ok(ColumnarValue::Array(length(v.as_ref())?)), diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index fffa8f602d87..208a57486ea3 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -22,8 +22,6 @@ pub mod conditional_expressions; #[cfg(feature = "crypto_expressions")] pub mod crypto_expressions; pub mod datetime_expressions; -#[cfg(feature = "encoding_expressions")] -pub mod encoding_expressions; pub mod equivalence; pub mod execution_props; pub mod expressions; diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index bd8053c817e7..aa1ce1e5bcf0 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -617,8 +617,6 @@ enum ScalarFunction { Cardinality = 98; ArrayElement = 99; ArraySlice = 100; - Encode = 101; - Decode = 102; Cot = 103; ArrayHas = 104; ArrayHasAny = 105; diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 88310be0318a..2d74e031c278 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -21090,8 +21090,6 @@ impl serde::Serialize for ScalarFunction { Self::Cardinality => "Cardinality", Self::ArrayElement => "ArrayElement", Self::ArraySlice => "ArraySlice", - Self::Encode => "Encode", - Self::Decode => "Decode", Self::Cot => "Cot", Self::ArrayHas => "ArrayHas", Self::ArrayHasAny => "ArrayHasAny", @@ -21231,8 +21229,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Cardinality", "ArrayElement", "ArraySlice", - "Encode", - "Decode", "Cot", "ArrayHas", "ArrayHasAny", @@ -21401,8 +21397,6 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "Cardinality" => Ok(ScalarFunction::Cardinality), "ArrayElement" => Ok(ScalarFunction::ArrayElement), "ArraySlice" => Ok(ScalarFunction::ArraySlice), - "Encode" => Ok(ScalarFunction::Encode), - "Decode" => Ok(ScalarFunction::Decode), "Cot" => Ok(ScalarFunction::Cot), "ArrayHas" => Ok(ScalarFunction::ArrayHas), "ArrayHasAny" => Ok(ScalarFunction::ArrayHasAny), diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 3dfd3938615f..ed223b7f2603 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2590,8 +2590,6 @@ pub enum ScalarFunction { Cardinality = 98, ArrayElement = 99, ArraySlice = 100, - Encode = 101, - Decode = 102, Cot = 103, ArrayHas = 104, ArrayHasAny = 105, @@ -2728,8 +2726,6 @@ impl ScalarFunction { ScalarFunction::Cardinality => "Cardinality", ScalarFunction::ArrayElement => "ArrayElement", ScalarFunction::ArraySlice => "ArraySlice", - ScalarFunction::Encode => "Encode", - ScalarFunction::Decode => "Decode", ScalarFunction::Cot => "Cot", ScalarFunction::ArrayHas => "ArrayHas", ScalarFunction::ArrayHasAny => "ArrayHasAny", @@ -2863,8 +2859,6 @@ impl ScalarFunction { "Cardinality" => Some(Self::Cardinality), "ArrayElement" => Some(Self::ArrayElement), "ArraySlice" => Some(Self::ArraySlice), - "Encode" => Some(Self::Encode), - "Decode" => Some(Self::Decode), "Cot" => Some(Self::Cot), "ArrayHas" => Some(Self::ArrayHas), "ArrayHasAny" => Some(Self::ArrayHasAny), diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index 193e0947d6d9..5ced98aabb8c 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -49,7 +49,7 @@ use datafusion_expr::{ array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part, - date_trunc, decode, degrees, digest, encode, exp, + date_trunc, degrees, digest, exp, expr::{self, InList, Sort, WindowFunction}, factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero, lcm, left, levenshtein, ln, log, log10, log2, @@ -518,8 +518,6 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::Sha384 => Self::SHA384, ScalarFunction::Sha512 => Self::SHA512, ScalarFunction::Digest => Self::Digest, - ScalarFunction::Encode => Self::Encode, - ScalarFunction::Decode => Self::Decode, ScalarFunction::ToTimestampMillis => Self::ToTimestampMillis, ScalarFunction::Log2 => Self::Log2, ScalarFunction::Signum => Self::Signum, @@ -1548,14 +1546,6 @@ pub fn parse_expr( ScalarFunction::Sha384 => Ok(sha384(parse_expr(&args[0], registry)?)), ScalarFunction::Sha512 => Ok(sha512(parse_expr(&args[0], registry)?)), ScalarFunction::Md5 => Ok(md5(parse_expr(&args[0], registry)?)), - ScalarFunction::Encode => Ok(encode( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - )), - ScalarFunction::Decode => Ok(decode( - parse_expr(&args[0], registry)?, - parse_expr(&args[1], registry)?, - )), ScalarFunction::NullIf => Ok(nullif( parse_expr(&args[0], registry)?, parse_expr(&args[1], registry)?, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 2997d147424d..3ac793c859cd 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1546,8 +1546,6 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::SHA384 => Self::Sha384, BuiltinScalarFunction::SHA512 => Self::Sha512, BuiltinScalarFunction::Digest => Self::Digest, - BuiltinScalarFunction::Decode => Self::Decode, - BuiltinScalarFunction::Encode => Self::Encode, BuiltinScalarFunction::ToTimestampMillis => Self::ToTimestampMillis, BuiltinScalarFunction::Log2 => Self::Log2, BuiltinScalarFunction::Signum => Self::Signum, diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index 8e15b5d0d480..1edc9f830b07 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -32,7 +32,9 @@ use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionState; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::physical_plan::functions::make_scalar_function; -use datafusion::prelude::{create_udf, CsvReadOptions, SessionConfig, SessionContext}; +use datafusion::prelude::{ + create_udf, decode, encode, CsvReadOptions, SessionConfig, SessionContext, +}; use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::Result; use datafusion_common::{internal_err, not_impl_err, plan_err}; @@ -45,8 +47,9 @@ use datafusion_expr::logical_plan::{Extension, UserDefinedLogicalNodeCore}; use datafusion_expr::{ col, create_udaf, lit, Accumulator, AggregateFunction, BuiltinScalarFunction::{Sqrt, Substr}, - Expr, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, Volatility, - WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, WindowUDF, + Expr, ExprSchemable, LogicalPlan, Operator, PartitionEvaluator, Signature, TryCast, + Volatility, WindowFrame, WindowFrameBound, WindowFrameUnits, WindowFunction, + WindowUDF, }; use datafusion_proto::bytes::{ logical_plan_from_bytes, logical_plan_from_bytes_with_extension_codec, @@ -365,6 +368,30 @@ async fn roundtrip_logical_plan_with_extension() -> Result<()> { Ok(()) } +#[tokio::test] +async fn roundtrip_expr_api() -> Result<()> { + let ctx = SessionContext::new(); + ctx.register_csv("t1", "tests/testdata/test.csv", CsvReadOptions::default()) + .await?; + let table = ctx.table("t1").await?; + let schema = table.schema().clone(); + + // ensure expressions created with the expr api can be round tripped + let plan = table + .select(vec![ + encode(vec![ + col("a").cast_to(&DataType::Utf8, &schema)?, + lit("hex"), + ]), + decode(vec![lit("1234"), lit("hex")]), + ])? + .into_optimized_plan()?; + let bytes = logical_plan_to_bytes(&plan)?; + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); + Ok(()) +} + #[tokio::test] async fn roundtrip_logical_plan_with_view_scan() -> Result<()> { let ctx = SessionContext::new();