Skip to content

Commit 8730466

Browse files
authored
Move concat, concat_ws, ends_with, initcap to datafusion-functions (#10089)
1 parent 1395adf commit 8730466

File tree

31 files changed

+1409
-1271
lines changed

31 files changed

+1409
-1271
lines changed

datafusion/core/src/execution/context/mod.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,10 @@ use crate::{
5858
physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner},
5959
variable::{VarProvider, VarType},
6060
};
61-
use crate::{functions, functions_aggregate, functions_array};
61+
62+
#[cfg(feature = "array_expressions")]
63+
use crate::functions_array;
64+
use crate::{functions, functions_aggregate};
6265

6366
use arrow::datatypes::{DataType, SchemaRef};
6467
use arrow::record_batch::RecordBatch;

datafusion/core/tests/dataframe/dataframe_functions.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ use datafusion::assert_batches_eq;
3737
use datafusion_common::DFSchema;
3838
use datafusion_expr::expr::Alias;
3939
use datafusion_expr::{approx_median, cast, ExprSchemable};
40-
use datafusion_functions::unicode::expr_fn::character_length;
40+
use datafusion_functions_array::expr_fn::array_to_string;
4141

4242
fn test_schema() -> SchemaRef {
4343
Arc::new(Schema::new(vec![

datafusion/core/tests/optimizer_integration.rs

+27-1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,30 @@ fn timestamp_nano_ts_utc_predicates() {
8383
assert_eq!(expected, format!("{plan:?}"));
8484
}
8585

86+
#[test]
87+
fn concat_literals() -> Result<()> {
88+
let sql = "SELECT concat(true, col_int32, false, null, 'hello', col_utf8, 12, 3.4) \
89+
AS col
90+
FROM test";
91+
let expected =
92+
"Projection: concat(Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"falsehello\"), test.col_utf8, Utf8(\"123.4\")) AS col\
93+
\n TableScan: test projection=[col_int32, col_utf8]";
94+
quick_test(sql, expected);
95+
Ok(())
96+
}
97+
98+
#[test]
99+
fn concat_ws_literals() -> Result<()> {
100+
let sql = "SELECT concat_ws('-', true, col_int32, false, null, 'hello', col_utf8, 12, '', 3.4) \
101+
AS col
102+
FROM test";
103+
let expected =
104+
"Projection: concat_ws(Utf8(\"-\"), Utf8(\"true\"), CAST(test.col_int32 AS Utf8), Utf8(\"false-hello\"), test.col_utf8, Utf8(\"12--3.4\")) AS col\
105+
\n TableScan: test projection=[col_int32, col_utf8]";
106+
quick_test(sql, expected);
107+
Ok(())
108+
}
109+
86110
fn quick_test(sql: &str, expected_plan: &str) {
87111
let plan = test_sql(sql).unwrap();
88112
assert_eq!(expected_plan, format!("{:?}", plan));
@@ -97,7 +121,9 @@ fn test_sql(sql: &str) -> Result<LogicalPlan> {
97121
// create a logical query plan
98122
let context_provider = MyContextProvider::default()
99123
.with_udf(datetime::now())
100-
.with_udf(datafusion_functions::core::arrow_cast());
124+
.with_udf(datafusion_functions::core::arrow_cast())
125+
.with_udf(datafusion_functions::string::concat())
126+
.with_udf(datafusion_functions::string::concat_ws());
101127
let sql_to_rel = SqlToRel::new(&context_provider);
102128
let plan = sql_to_rel.sql_statement_to_plan(statement.clone()).unwrap();
103129

datafusion/core/tests/simplification.rs

+94-5
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ use datafusion_expr::{
3131
expr, table_scan, Cast, ColumnarValue, Expr, ExprSchemable, LogicalPlan,
3232
LogicalPlanBuilder, ScalarUDF, Volatility,
3333
};
34-
use datafusion_functions::math;
34+
use datafusion_functions::{math, string};
3535
use datafusion_optimizer::optimizer::Optimizer;
3636
use datafusion_optimizer::simplify_expressions::{ExprSimplifier, SimplifyExpressions};
3737
use datafusion_optimizer::{OptimizerContext, OptimizerRule};
@@ -217,7 +217,7 @@ fn fold_and_simplify() {
217217
let info: MyInfo = schema().into();
218218

219219
// What will it do with the expression `concat('foo', 'bar') == 'foobar')`?
220-
let expr = concat(&[lit("foo"), lit("bar")]).eq(lit("foobar"));
220+
let expr = concat(vec![lit("foo"), lit("bar")]).eq(lit("foobar"));
221221

222222
// Since datafusion applies both simplification *and* rewriting
223223
// some expressions can be entirely simplified
@@ -364,13 +364,13 @@ fn test_const_evaluator() {
364364
#[test]
365365
fn test_const_evaluator_scalar_functions() {
366366
// concat("foo", "bar") --> "foobar"
367-
let expr = call_fn("concat", vec![lit("foo"), lit("bar")]).unwrap();
367+
let expr = string::expr_fn::concat(vec![lit("foo"), lit("bar")]);
368368
test_evaluate(expr, lit("foobar"));
369369

370370
// ensure arguments are also constant folded
371371
// concat("foo", concat("bar", "baz")) --> "foobarbaz"
372-
let concat1 = call_fn("concat", vec![lit("bar"), lit("baz")]).unwrap();
373-
let expr = call_fn("concat", vec![lit("foo"), concat1]).unwrap();
372+
let concat1 = string::expr_fn::concat(vec![lit("bar"), lit("baz")]);
373+
let expr = string::expr_fn::concat(vec![lit("foo"), concat1]);
374374
test_evaluate(expr, lit("foobarbaz"));
375375

376376
// Check non string arguments
@@ -569,3 +569,92 @@ fn test_simplify_power() {
569569
test_simplify(expr, expected)
570570
}
571571
}
572+
573+
#[test]
574+
fn test_simplify_concat_ws() {
575+
let null = lit(ScalarValue::Utf8(None));
576+
// the delimiter is not a literal
577+
{
578+
let expr = concat_ws(col("c"), vec![lit("a"), null.clone(), lit("b")]);
579+
let expected = concat_ws(col("c"), vec![lit("a"), lit("b")]);
580+
test_simplify(expr, expected);
581+
}
582+
583+
// the delimiter is an empty string
584+
{
585+
let expr = concat_ws(lit(""), vec![col("a"), lit("c"), lit("b")]);
586+
let expected = concat(vec![col("a"), lit("cb")]);
587+
test_simplify(expr, expected);
588+
}
589+
590+
// the delimiter is a not-empty string
591+
{
592+
let expr = concat_ws(
593+
lit("-"),
594+
vec![
595+
null.clone(),
596+
col("c0"),
597+
lit("hello"),
598+
null.clone(),
599+
lit("rust"),
600+
col("c1"),
601+
lit(""),
602+
lit(""),
603+
null,
604+
],
605+
);
606+
let expected = concat_ws(
607+
lit("-"),
608+
vec![col("c0"), lit("hello-rust"), col("c1"), lit("-")],
609+
);
610+
test_simplify(expr, expected)
611+
}
612+
}
613+
614+
#[test]
615+
fn test_simplify_concat_ws_with_null() {
616+
let null = lit(ScalarValue::Utf8(None));
617+
// null delimiter -> null
618+
{
619+
let expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]);
620+
test_simplify(expr, null.clone());
621+
}
622+
623+
// filter out null args
624+
{
625+
let expr = concat_ws(lit("|"), vec![col("c1"), null.clone(), col("c2")]);
626+
let expected = concat_ws(lit("|"), vec![col("c1"), col("c2")]);
627+
test_simplify(expr, expected);
628+
}
629+
630+
// nested test
631+
{
632+
let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]);
633+
let expr = concat_ws(lit("|"), vec![sub_expr, col("c3")]);
634+
test_simplify(expr, concat_ws(lit("|"), vec![col("c3")]));
635+
}
636+
637+
// null delimiter (nested)
638+
{
639+
let sub_expr = concat_ws(null.clone(), vec![col("c1"), col("c2")]);
640+
let expr = concat_ws(sub_expr, vec![col("c3"), col("c4")]);
641+
test_simplify(expr, null);
642+
}
643+
}
644+
645+
#[test]
646+
fn test_simplify_concat() {
647+
let null = lit(ScalarValue::Utf8(None));
648+
let expr = concat(vec![
649+
null.clone(),
650+
col("c0"),
651+
lit("hello "),
652+
null.clone(),
653+
lit("rust"),
654+
col("c1"),
655+
lit(""),
656+
null,
657+
]);
658+
let expected = concat(vec![col("c0"), lit("hello rust"), col("c1")]);
659+
test_simplify(expr, expected)
660+
}

datafusion/expr/src/built_in_function.rs

+1-89
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use std::str::FromStr;
2323
use std::sync::OnceLock;
2424

2525
use crate::type_coercion::functions::data_types;
26-
use crate::{FuncMonotonicity, Signature, TypeSignature, Volatility};
26+
use crate::{FuncMonotonicity, Signature, Volatility};
2727

2828
use arrow::datatypes::DataType;
2929
use datafusion_common::{plan_err, DataFusionError, Result};
@@ -39,15 +39,6 @@ pub enum BuiltinScalarFunction {
3939
// math functions
4040
/// coalesce
4141
Coalesce,
42-
// string functions
43-
/// concat
44-
Concat,
45-
/// concat_ws
46-
ConcatWithSeparator,
47-
/// ends_with
48-
EndsWith,
49-
/// initcap
50-
InitCap,
5142
}
5243

5344
/// Maps the sql function name to `BuiltinScalarFunction`
@@ -101,10 +92,6 @@ impl BuiltinScalarFunction {
10192
match self {
10293
// Immutable scalar builtins
10394
BuiltinScalarFunction::Coalesce => Volatility::Immutable,
104-
BuiltinScalarFunction::Concat => Volatility::Immutable,
105-
BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable,
106-
BuiltinScalarFunction::EndsWith => Volatility::Immutable,
107-
BuiltinScalarFunction::InitCap => Volatility::Immutable,
10895
}
10996
}
11097

@@ -117,8 +104,6 @@ impl BuiltinScalarFunction {
117104
/// 1. Perform additional checks on `input_expr_types` that are beyond the scope of `TypeSignature` validation.
118105
/// 2. Deduce the output `DataType` based on the provided `input_expr_types`.
119106
pub fn return_type(self, input_expr_types: &[DataType]) -> Result<DataType> {
120-
use DataType::*;
121-
122107
// Note that this function *must* return the same type that the respective physical expression returns
123108
// or the execution panics.
124109

@@ -130,43 +115,18 @@ impl BuiltinScalarFunction {
130115
let coerced_types = data_types(input_expr_types, &self.signature());
131116
coerced_types.map(|types| types[0].clone())
132117
}
133-
BuiltinScalarFunction::Concat => Ok(Utf8),
134-
BuiltinScalarFunction::ConcatWithSeparator => Ok(Utf8),
135-
BuiltinScalarFunction::InitCap => {
136-
utf8_to_str_type(&input_expr_types[0], "initcap")
137-
}
138-
BuiltinScalarFunction::EndsWith => Ok(Boolean),
139118
}
140119
}
141120

142121
/// Return the argument [`Signature`] supported by this function
143122
pub fn signature(&self) -> Signature {
144-
use DataType::*;
145-
use TypeSignature::*;
146123
// note: the physical expression must accept the type returned by this function or the execution panics.
147124

148125
// for now, the list is small, as we do not have many built-in functions.
149126
match self {
150-
BuiltinScalarFunction::Concat
151-
| BuiltinScalarFunction::ConcatWithSeparator => {
152-
Signature::variadic(vec![Utf8], self.volatility())
153-
}
154127
BuiltinScalarFunction::Coalesce => {
155128
Signature::variadic_equal(self.volatility())
156129
}
157-
BuiltinScalarFunction::InitCap => {
158-
Signature::uniform(1, vec![Utf8, LargeUtf8], self.volatility())
159-
}
160-
161-
BuiltinScalarFunction::EndsWith => Signature::one_of(
162-
vec![
163-
Exact(vec![Utf8, Utf8]),
164-
Exact(vec![Utf8, LargeUtf8]),
165-
Exact(vec![LargeUtf8, Utf8]),
166-
Exact(vec![LargeUtf8, LargeUtf8]),
167-
],
168-
self.volatility(),
169-
),
170130
}
171131
}
172132

@@ -182,11 +142,6 @@ impl BuiltinScalarFunction {
182142
match self {
183143
// conditional functions
184144
BuiltinScalarFunction::Coalesce => &["coalesce"],
185-
186-
BuiltinScalarFunction::Concat => &["concat"],
187-
BuiltinScalarFunction::ConcatWithSeparator => &["concat_ws"],
188-
BuiltinScalarFunction::EndsWith => &["ends_with"],
189-
BuiltinScalarFunction::InitCap => &["initcap"],
190145
}
191146
}
192147
}
@@ -208,49 +163,6 @@ impl FromStr for BuiltinScalarFunction {
208163
}
209164
}
210165

211-
/// Creates a function to identify the optimal return type of a string function given
212-
/// the type of its first argument.
213-
///
214-
/// If the input type is `LargeUtf8` or `LargeBinary` the return type is
215-
/// `$largeUtf8Type`,
216-
///
217-
/// If the input type is `Utf8` or `Binary` the return type is `$utf8Type`,
218-
macro_rules! get_optimal_return_type {
219-
($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
220-
fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {
221-
Ok(match arg_type {
222-
// LargeBinary inputs are automatically coerced to Utf8
223-
DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
224-
// Binary inputs are automatically coerced to Utf8
225-
DataType::Utf8 | DataType::Binary => $utf8Type,
226-
DataType::Null => DataType::Null,
227-
DataType::Dictionary(_, value_type) => match **value_type {
228-
DataType::LargeUtf8 | DataType::LargeBinary => $largeUtf8Type,
229-
DataType::Utf8 | DataType::Binary => $utf8Type,
230-
DataType::Null => DataType::Null,
231-
_ => {
232-
return plan_err!(
233-
"The {} function can only accept strings, but got {:?}.",
234-
name.to_uppercase(),
235-
**value_type
236-
);
237-
}
238-
},
239-
data_type => {
240-
return plan_err!(
241-
"The {} function can only accept strings, but got {:?}.",
242-
name.to_uppercase(),
243-
data_type
244-
);
245-
}
246-
})
247-
}
248-
};
249-
}
250-
251-
// `utf8_to_str_type`: returns either a Utf8 or LargeUtf8 based on the input type size.
252-
get_optimal_return_type!(utf8_to_str_type, DataType::LargeUtf8, DataType::Utf8);
253-
254166
#[cfg(test)]
255167
mod tests {
256168
use super::*;

0 commit comments

Comments
 (0)