Skip to content

Commit 91c0975

Browse files
authored
Always use StringViewArray as output of substr (#14498)
* change substr return type to utf8view (bin/sqllogicaltests.rs to fix) * fix sqllogictest/string_view * fix clippy * fix tcph benchmark result
1 parent 7ccc6d7 commit 91c0975

File tree

3 files changed

+64
-68
lines changed

3 files changed

+64
-68
lines changed

datafusion/functions/src/unicode/substr.rs

+57-61
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ use std::any::Any;
1919
use std::sync::Arc;
2020

2121
use crate::strings::make_and_append_view;
22-
use crate::utils::{make_scalar_function, utf8_to_str_type};
22+
use crate::utils::make_scalar_function;
2323
use arrow::array::{
24-
Array, ArrayIter, ArrayRef, AsArray, GenericStringBuilder, Int64Array,
25-
NullBufferBuilder, OffsetSizeTrait, StringArrayType, StringViewArray,
24+
Array, ArrayIter, ArrayRef, AsArray, Int64Array, NullBufferBuilder, StringArrayType,
25+
StringViewArray, StringViewBuilder,
2626
};
2727
use arrow::buffer::ScalarBuffer;
2828
use arrow::datatypes::DataType;
@@ -90,12 +90,9 @@ impl ScalarUDFImpl for SubstrFunc {
9090
&self.signature
9191
}
9292

93-
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
94-
if arg_types[0] == DataType::Utf8View {
95-
Ok(DataType::Utf8View)
96-
} else {
97-
utf8_to_str_type(&arg_types[0], "substr")
98-
}
93+
// `SubstrFunc` always generates `Utf8View` output for its efficiency.
94+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
95+
Ok(DataType::Utf8View)
9996
}
10097

10198
fn invoke_batch(
@@ -189,11 +186,11 @@ pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
189186
match args[0].data_type() {
190187
DataType::Utf8 => {
191188
let string_array = args[0].as_string::<i32>();
192-
string_substr::<_, i32>(string_array, &args[1..])
189+
string_substr::<_>(string_array, &args[1..])
193190
}
194191
DataType::LargeUtf8 => {
195192
let string_array = args[0].as_string::<i64>();
196-
string_substr::<_, i64>(string_array, &args[1..])
193+
string_substr::<_>(string_array, &args[1..])
197194
}
198195
DataType::Utf8View => {
199196
let string_array = args[0].as_string_view();
@@ -429,10 +426,9 @@ fn string_view_substr(
429426
}
430427
}
431428

432-
fn string_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
429+
fn string_substr<'a, V>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
433430
where
434431
V: StringArrayType<'a>,
435-
T: OffsetSizeTrait,
436432
{
437433
let start_array = as_int64_array(&args[0])?;
438434
let count_array_opt = if args.len() == 2 {
@@ -447,7 +443,7 @@ where
447443
match args.len() {
448444
1 => {
449445
let iter = ArrayIter::new(string_array);
450-
let mut result_builder = GenericStringBuilder::<T>::new();
446+
let mut result_builder = StringViewBuilder::new();
451447
for (string, start) in iter.zip(start_array.iter()) {
452448
match (string, start) {
453449
(Some(string), Some(start)) => {
@@ -470,7 +466,7 @@ where
470466
2 => {
471467
let iter = ArrayIter::new(string_array);
472468
let count_array = count_array_opt.unwrap();
473-
let mut result_builder = GenericStringBuilder::<T>::new();
469+
let mut result_builder = StringViewBuilder::new();
474470

475471
for ((string, start), count) in
476472
iter.zip(start_array.iter()).zip(count_array.iter())
@@ -512,8 +508,8 @@ where
512508

513509
#[cfg(test)]
514510
mod tests {
515-
use arrow::array::{Array, StringArray, StringViewArray};
516-
use arrow::datatypes::DataType::{Utf8, Utf8View};
511+
use arrow::array::{Array, StringViewArray};
512+
use arrow::datatypes::DataType::Utf8View;
517513

518514
use datafusion_common::{exec_err, Result, ScalarValue};
519515
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
@@ -623,8 +619,8 @@ mod tests {
623619
],
624620
Ok(Some("alphabet")),
625621
&str,
626-
Utf8,
627-
StringArray
622+
Utf8View,
623+
StringViewArray
628624
);
629625
test_function!(
630626
SubstrFunc::new(),
@@ -634,8 +630,8 @@ mod tests {
634630
],
635631
Ok(Some("ésoj")),
636632
&str,
637-
Utf8,
638-
StringArray
633+
Utf8View,
634+
StringViewArray
639635
);
640636
test_function!(
641637
SubstrFunc::new(),
@@ -645,8 +641,8 @@ mod tests {
645641
],
646642
Ok(Some("joséésoj")),
647643
&str,
648-
Utf8,
649-
StringArray
644+
Utf8View,
645+
StringViewArray
650646
);
651647
test_function!(
652648
SubstrFunc::new(),
@@ -656,8 +652,8 @@ mod tests {
656652
],
657653
Ok(Some("alphabet")),
658654
&str,
659-
Utf8,
660-
StringArray
655+
Utf8View,
656+
StringViewArray
661657
);
662658
test_function!(
663659
SubstrFunc::new(),
@@ -667,8 +663,8 @@ mod tests {
667663
],
668664
Ok(Some("lphabet")),
669665
&str,
670-
Utf8,
671-
StringArray
666+
Utf8View,
667+
StringViewArray
672668
);
673669
test_function!(
674670
SubstrFunc::new(),
@@ -678,8 +674,8 @@ mod tests {
678674
],
679675
Ok(Some("phabet")),
680676
&str,
681-
Utf8,
682-
StringArray
677+
Utf8View,
678+
StringViewArray
683679
);
684680
test_function!(
685681
SubstrFunc::new(),
@@ -689,8 +685,8 @@ mod tests {
689685
],
690686
Ok(Some("alphabet")),
691687
&str,
692-
Utf8,
693-
StringArray
688+
Utf8View,
689+
StringViewArray
694690
);
695691
test_function!(
696692
SubstrFunc::new(),
@@ -700,8 +696,8 @@ mod tests {
700696
],
701697
Ok(Some("")),
702698
&str,
703-
Utf8,
704-
StringArray
699+
Utf8View,
700+
StringViewArray
705701
);
706702
test_function!(
707703
SubstrFunc::new(),
@@ -711,8 +707,8 @@ mod tests {
711707
],
712708
Ok(None),
713709
&str,
714-
Utf8,
715-
StringArray
710+
Utf8View,
711+
StringViewArray
716712
);
717713
test_function!(
718714
SubstrFunc::new(),
@@ -723,8 +719,8 @@ mod tests {
723719
],
724720
Ok(Some("ph")),
725721
&str,
726-
Utf8,
727-
StringArray
722+
Utf8View,
723+
StringViewArray
728724
);
729725
test_function!(
730726
SubstrFunc::new(),
@@ -735,8 +731,8 @@ mod tests {
735731
],
736732
Ok(Some("phabet")),
737733
&str,
738-
Utf8,
739-
StringArray
734+
Utf8View,
735+
StringViewArray
740736
);
741737
test_function!(
742738
SubstrFunc::new(),
@@ -747,8 +743,8 @@ mod tests {
747743
],
748744
Ok(Some("alph")),
749745
&str,
750-
Utf8,
751-
StringArray
746+
Utf8View,
747+
StringViewArray
752748
);
753749
// starting from 5 (10 + -5)
754750
test_function!(
@@ -760,8 +756,8 @@ mod tests {
760756
],
761757
Ok(Some("alph")),
762758
&str,
763-
Utf8,
764-
StringArray
759+
Utf8View,
760+
StringViewArray
765761
);
766762
// starting from -1 (4 + -5)
767763
test_function!(
@@ -773,8 +769,8 @@ mod tests {
773769
],
774770
Ok(Some("")),
775771
&str,
776-
Utf8,
777-
StringArray
772+
Utf8View,
773+
StringViewArray
778774
);
779775
// starting from 0 (5 + -5)
780776
test_function!(
@@ -786,8 +782,8 @@ mod tests {
786782
],
787783
Ok(Some("")),
788784
&str,
789-
Utf8,
790-
StringArray
785+
Utf8View,
786+
StringViewArray
791787
);
792788
test_function!(
793789
SubstrFunc::new(),
@@ -798,8 +794,8 @@ mod tests {
798794
],
799795
Ok(None),
800796
&str,
801-
Utf8,
802-
StringArray
797+
Utf8View,
798+
StringViewArray
803799
);
804800
test_function!(
805801
SubstrFunc::new(),
@@ -810,8 +806,8 @@ mod tests {
810806
],
811807
Ok(None),
812808
&str,
813-
Utf8,
814-
StringArray
809+
Utf8View,
810+
StringViewArray
815811
);
816812
test_function!(
817813
SubstrFunc::new(),
@@ -822,8 +818,8 @@ mod tests {
822818
],
823819
exec_err!("negative substring length not allowed: substr(<str>, 1, -1)"),
824820
&str,
825-
Utf8,
826-
StringArray
821+
Utf8View,
822+
StringViewArray
827823
);
828824
test_function!(
829825
SubstrFunc::new(),
@@ -834,8 +830,8 @@ mod tests {
834830
],
835831
Ok(Some("és")),
836832
&str,
837-
Utf8,
838-
StringArray
833+
Utf8View,
834+
StringViewArray
839835
);
840836
#[cfg(not(feature = "unicode_expressions"))]
841837
test_function!(
@@ -848,8 +844,8 @@ mod tests {
848844
"function substr requires compilation with feature flag: unicode_expressions."
849845
),
850846
&str,
851-
Utf8,
852-
StringArray
847+
Utf8View,
848+
StringViewArray
853849
);
854850
test_function!(
855851
SubstrFunc::new(),
@@ -859,8 +855,8 @@ mod tests {
859855
],
860856
Ok(Some("abc")),
861857
&str,
862-
Utf8,
863-
StringArray
858+
Utf8View,
859+
StringViewArray
864860
);
865861
test_function!(
866862
SubstrFunc::new(),
@@ -871,8 +867,8 @@ mod tests {
871867
],
872868
exec_err!("negative overflow when calculating skip value"),
873869
&str,
874-
Utf8,
875-
StringArray
870+
Utf8View,
871+
StringViewArray
876872
);
877873

878874
Ok(())

datafusion/sqllogictest/test_files/string/string_view.slt

+1-1
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,7 @@ EXPLAIN SELECT
382382
FROM test;
383383
----
384384
logical_plan
385-
01)Projection: starts_with(test.column1_utf8, substr(test.column1_utf8, Int64(1), Int64(2))) AS c1, starts_with(test.column1_large_utf8, substr(test.column1_large_utf8, Int64(1), Int64(2))) AS c2, starts_with(test.column1_utf8view, substr(test.column1_utf8view, Int64(1), Int64(2))) AS c3
385+
01)Projection: starts_with(CAST(test.column1_utf8 AS Utf8View), substr(test.column1_utf8, Int64(1), Int64(2))) AS c1, starts_with(CAST(test.column1_large_utf8 AS Utf8View), substr(test.column1_large_utf8, Int64(1), Int64(2))) AS c2, starts_with(test.column1_utf8view, substr(test.column1_utf8view, Int64(1), Int64(2))) AS c3
386386
02)--TableScan: test projection=[column1_utf8, column1_large_utf8, column1_utf8view]
387387

388388
query BBB

datafusion/sqllogictest/test_files/tpch/plans/q22.slt.part

+6-6
Original file line numberDiff line numberDiff line change
@@ -64,15 +64,15 @@ logical_plan
6464
06)----------Inner Join: Filter: CAST(customer.c_acctbal AS Decimal128(19, 6)) > __scalar_sq_2.avg(customer.c_acctbal)
6565
07)------------Projection: customer.c_phone, customer.c_acctbal
6666
08)--------------LeftAnti Join: customer.c_custkey = __correlated_sq_1.o_custkey
67-
09)----------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
68-
10)------------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]
67+
09)----------------Filter: substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])
68+
10)------------------TableScan: customer projection=[c_custkey, c_phone, c_acctbal], partial_filters=[substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])]
6969
11)----------------SubqueryAlias: __correlated_sq_1
7070
12)------------------TableScan: orders projection=[o_custkey]
7171
13)------------SubqueryAlias: __scalar_sq_2
7272
14)--------------Aggregate: groupBy=[[]], aggr=[[avg(customer.c_acctbal)]]
7373
15)----------------Projection: customer.c_acctbal
74-
16)------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])
75-
17)--------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8("13"), Utf8("31"), Utf8("23"), Utf8("29"), Utf8("30"), Utf8("18"), Utf8("17")])]
74+
16)------------------Filter: customer.c_acctbal > Decimal128(Some(0),15,2) AND substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])
75+
17)--------------------TableScan: customer projection=[c_phone, c_acctbal], partial_filters=[customer.c_acctbal > Decimal128(Some(0),15,2), substr(customer.c_phone, Int64(1), Int64(2)) IN ([Utf8View("13"), Utf8View("31"), Utf8View("23"), Utf8View("29"), Utf8View("30"), Utf8View("18"), Utf8View("17")])]
7676
physical_plan
7777
01)SortPreservingMergeExec: [cntrycode@0 ASC NULLS LAST]
7878
02)--SortExec: expr=[cntrycode@0 ASC NULLS LAST], preserve_partitioning=[true]
@@ -90,7 +90,7 @@ physical_plan
9090
14)--------------------------CoalesceBatchesExec: target_batch_size=8192
9191
15)----------------------------RepartitionExec: partitioning=Hash([c_custkey@0], 4), input_partitions=4
9292
16)------------------------------CoalesceBatchesExec: target_batch_size=8192
93-
17)--------------------------------FilterExec: Use substr(c_phone@1, 1, 2) IN (SET) ([Literal { value: Utf8("13") }, Literal { value: Utf8("31") }, Literal { value: Utf8("23") }, Literal { value: Utf8("29") }, Literal { value: Utf8("30") }, Literal { value: Utf8("18") }, Literal { value: Utf8("17") }])
93+
17)--------------------------------FilterExec: substr(c_phone@1, 1, 2) IN ([Literal { value: Utf8View("13") }, Literal { value: Utf8View("31") }, Literal { value: Utf8View("23") }, Literal { value: Utf8View("29") }, Literal { value: Utf8View("30") }, Literal { value: Utf8View("18") }, Literal { value: Utf8View("17") }])
9494
18)----------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
9595
19)------------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_custkey, c_phone, c_acctbal], file_type=csv, has_header=false
9696
20)--------------------------CoalesceBatchesExec: target_batch_size=8192
@@ -100,6 +100,6 @@ physical_plan
100100
24)----------------------CoalescePartitionsExec
101101
25)------------------------AggregateExec: mode=Partial, gby=[], aggr=[avg(customer.c_acctbal)]
102102
26)--------------------------CoalesceBatchesExec: target_batch_size=8192
103-
27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND Use substr(c_phone@0, 1, 2) IN (SET) ([Literal { value: Utf8("13") }, Literal { value: Utf8("31") }, Literal { value: Utf8("23") }, Literal { value: Utf8("29") }, Literal { value: Utf8("30") }, Literal { value: Utf8("18") }, Literal { value: Utf8("17") }]), projection=[c_acctbal@1]
103+
27)----------------------------FilterExec: c_acctbal@1 > Some(0),15,2 AND substr(c_phone@0, 1, 2) IN ([Literal { value: Utf8View("13") }, Literal { value: Utf8View("31") }, Literal { value: Utf8View("23") }, Literal { value: Utf8View("29") }, Literal { value: Utf8View("30") }, Literal { value: Utf8View("18") }, Literal { value: Utf8View("17") }]), projection=[c_acctbal@1]
104104
28)------------------------------RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1
105105
29)--------------------------------DataSourceExec: file_groups={1 group: [[WORKSPACE_ROOT/datafusion/sqllogictest/test_files/tpch/data/customer.tbl]]}, projection=[c_phone, c_acctbal], file_type=csv, has_header=false

0 commit comments

Comments
 (0)