@@ -19,10 +19,10 @@ use std::any::Any;
19
19
use std:: sync:: Arc ;
20
20
21
21
use crate :: strings:: make_and_append_view;
22
- use crate :: utils:: { make_scalar_function, utf8_to_str_type } ;
22
+ use crate :: utils:: make_scalar_function;
23
23
use arrow:: array:: {
24
- Array , ArrayIter , ArrayRef , AsArray , GenericStringBuilder , Int64Array ,
25
- NullBufferBuilder , OffsetSizeTrait , StringArrayType , StringViewArray ,
24
+ Array , ArrayIter , ArrayRef , AsArray , Int64Array , NullBufferBuilder , StringArrayType ,
25
+ StringViewArray , StringViewBuilder ,
26
26
} ;
27
27
use arrow:: buffer:: ScalarBuffer ;
28
28
use arrow:: datatypes:: DataType ;
@@ -90,12 +90,9 @@ impl ScalarUDFImpl for SubstrFunc {
90
90
& self . signature
91
91
}
92
92
93
- fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
94
- if arg_types[ 0 ] == DataType :: Utf8View {
95
- Ok ( DataType :: Utf8View )
96
- } else {
97
- utf8_to_str_type ( & arg_types[ 0 ] , "substr" )
98
- }
93
+ // `SubstrFunc` always generates `Utf8View` output for its efficiency.
94
+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
95
+ Ok ( DataType :: Utf8View )
99
96
}
100
97
101
98
fn invoke_batch (
@@ -189,11 +186,11 @@ pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
189
186
match args[ 0 ] . data_type ( ) {
190
187
DataType :: Utf8 => {
191
188
let string_array = args[ 0 ] . as_string :: < i32 > ( ) ;
192
- string_substr :: < _ , i32 > ( string_array, & args[ 1 ..] )
189
+ string_substr :: < _ > ( string_array, & args[ 1 ..] )
193
190
}
194
191
DataType :: LargeUtf8 => {
195
192
let string_array = args[ 0 ] . as_string :: < i64 > ( ) ;
196
- string_substr :: < _ , i64 > ( string_array, & args[ 1 ..] )
193
+ string_substr :: < _ > ( string_array, & args[ 1 ..] )
197
194
}
198
195
DataType :: Utf8View => {
199
196
let string_array = args[ 0 ] . as_string_view ( ) ;
@@ -429,10 +426,9 @@ fn string_view_substr(
429
426
}
430
427
}
431
428
432
- fn string_substr < ' a , V , T > ( string_array : V , args : & [ ArrayRef ] ) -> Result < ArrayRef >
429
+ fn string_substr < ' a , V > ( string_array : V , args : & [ ArrayRef ] ) -> Result < ArrayRef >
433
430
where
434
431
V : StringArrayType < ' a > ,
435
- T : OffsetSizeTrait ,
436
432
{
437
433
let start_array = as_int64_array ( & args[ 0 ] ) ?;
438
434
let count_array_opt = if args. len ( ) == 2 {
@@ -447,7 +443,7 @@ where
447
443
match args. len ( ) {
448
444
1 => {
449
445
let iter = ArrayIter :: new ( string_array) ;
450
- let mut result_builder = GenericStringBuilder :: < T > :: new ( ) ;
446
+ let mut result_builder = StringViewBuilder :: new ( ) ;
451
447
for ( string, start) in iter. zip ( start_array. iter ( ) ) {
452
448
match ( string, start) {
453
449
( Some ( string) , Some ( start) ) => {
@@ -470,7 +466,7 @@ where
470
466
2 => {
471
467
let iter = ArrayIter :: new ( string_array) ;
472
468
let count_array = count_array_opt. unwrap ( ) ;
473
- let mut result_builder = GenericStringBuilder :: < T > :: new ( ) ;
469
+ let mut result_builder = StringViewBuilder :: new ( ) ;
474
470
475
471
for ( ( string, start) , count) in
476
472
iter. zip ( start_array. iter ( ) ) . zip ( count_array. iter ( ) )
@@ -512,8 +508,8 @@ where
512
508
513
509
#[ cfg( test) ]
514
510
mod tests {
515
- use arrow:: array:: { Array , StringArray , StringViewArray } ;
516
- use arrow:: datatypes:: DataType :: { Utf8 , Utf8View } ;
511
+ use arrow:: array:: { Array , StringViewArray } ;
512
+ use arrow:: datatypes:: DataType :: Utf8View ;
517
513
518
514
use datafusion_common:: { exec_err, Result , ScalarValue } ;
519
515
use datafusion_expr:: { ColumnarValue , ScalarUDFImpl } ;
@@ -623,8 +619,8 @@ mod tests {
623
619
] ,
624
620
Ok ( Some ( "alphabet" ) ) ,
625
621
& str ,
626
- Utf8 ,
627
- StringArray
622
+ Utf8View ,
623
+ StringViewArray
628
624
) ;
629
625
test_function ! (
630
626
SubstrFunc :: new( ) ,
@@ -634,8 +630,8 @@ mod tests {
634
630
] ,
635
631
Ok ( Some ( "ésoj" ) ) ,
636
632
& str ,
637
- Utf8 ,
638
- StringArray
633
+ Utf8View ,
634
+ StringViewArray
639
635
) ;
640
636
test_function ! (
641
637
SubstrFunc :: new( ) ,
@@ -645,8 +641,8 @@ mod tests {
645
641
] ,
646
642
Ok ( Some ( "joséésoj" ) ) ,
647
643
& str ,
648
- Utf8 ,
649
- StringArray
644
+ Utf8View ,
645
+ StringViewArray
650
646
) ;
651
647
test_function ! (
652
648
SubstrFunc :: new( ) ,
@@ -656,8 +652,8 @@ mod tests {
656
652
] ,
657
653
Ok ( Some ( "alphabet" ) ) ,
658
654
& str ,
659
- Utf8 ,
660
- StringArray
655
+ Utf8View ,
656
+ StringViewArray
661
657
) ;
662
658
test_function ! (
663
659
SubstrFunc :: new( ) ,
@@ -667,8 +663,8 @@ mod tests {
667
663
] ,
668
664
Ok ( Some ( "lphabet" ) ) ,
669
665
& str ,
670
- Utf8 ,
671
- StringArray
666
+ Utf8View ,
667
+ StringViewArray
672
668
) ;
673
669
test_function ! (
674
670
SubstrFunc :: new( ) ,
@@ -678,8 +674,8 @@ mod tests {
678
674
] ,
679
675
Ok ( Some ( "phabet" ) ) ,
680
676
& str ,
681
- Utf8 ,
682
- StringArray
677
+ Utf8View ,
678
+ StringViewArray
683
679
) ;
684
680
test_function ! (
685
681
SubstrFunc :: new( ) ,
@@ -689,8 +685,8 @@ mod tests {
689
685
] ,
690
686
Ok ( Some ( "alphabet" ) ) ,
691
687
& str ,
692
- Utf8 ,
693
- StringArray
688
+ Utf8View ,
689
+ StringViewArray
694
690
) ;
695
691
test_function ! (
696
692
SubstrFunc :: new( ) ,
@@ -700,8 +696,8 @@ mod tests {
700
696
] ,
701
697
Ok ( Some ( "" ) ) ,
702
698
& str ,
703
- Utf8 ,
704
- StringArray
699
+ Utf8View ,
700
+ StringViewArray
705
701
) ;
706
702
test_function ! (
707
703
SubstrFunc :: new( ) ,
@@ -711,8 +707,8 @@ mod tests {
711
707
] ,
712
708
Ok ( None ) ,
713
709
& str ,
714
- Utf8 ,
715
- StringArray
710
+ Utf8View ,
711
+ StringViewArray
716
712
) ;
717
713
test_function ! (
718
714
SubstrFunc :: new( ) ,
@@ -723,8 +719,8 @@ mod tests {
723
719
] ,
724
720
Ok ( Some ( "ph" ) ) ,
725
721
& str ,
726
- Utf8 ,
727
- StringArray
722
+ Utf8View ,
723
+ StringViewArray
728
724
) ;
729
725
test_function ! (
730
726
SubstrFunc :: new( ) ,
@@ -735,8 +731,8 @@ mod tests {
735
731
] ,
736
732
Ok ( Some ( "phabet" ) ) ,
737
733
& str ,
738
- Utf8 ,
739
- StringArray
734
+ Utf8View ,
735
+ StringViewArray
740
736
) ;
741
737
test_function ! (
742
738
SubstrFunc :: new( ) ,
@@ -747,8 +743,8 @@ mod tests {
747
743
] ,
748
744
Ok ( Some ( "alph" ) ) ,
749
745
& str ,
750
- Utf8 ,
751
- StringArray
746
+ Utf8View ,
747
+ StringViewArray
752
748
) ;
753
749
// starting from 5 (10 + -5)
754
750
test_function ! (
@@ -760,8 +756,8 @@ mod tests {
760
756
] ,
761
757
Ok ( Some ( "alph" ) ) ,
762
758
& str ,
763
- Utf8 ,
764
- StringArray
759
+ Utf8View ,
760
+ StringViewArray
765
761
) ;
766
762
// starting from -1 (4 + -5)
767
763
test_function ! (
@@ -773,8 +769,8 @@ mod tests {
773
769
] ,
774
770
Ok ( Some ( "" ) ) ,
775
771
& str ,
776
- Utf8 ,
777
- StringArray
772
+ Utf8View ,
773
+ StringViewArray
778
774
) ;
779
775
// starting from 0 (5 + -5)
780
776
test_function ! (
@@ -786,8 +782,8 @@ mod tests {
786
782
] ,
787
783
Ok ( Some ( "" ) ) ,
788
784
& str ,
789
- Utf8 ,
790
- StringArray
785
+ Utf8View ,
786
+ StringViewArray
791
787
) ;
792
788
test_function ! (
793
789
SubstrFunc :: new( ) ,
@@ -798,8 +794,8 @@ mod tests {
798
794
] ,
799
795
Ok ( None ) ,
800
796
& str ,
801
- Utf8 ,
802
- StringArray
797
+ Utf8View ,
798
+ StringViewArray
803
799
) ;
804
800
test_function ! (
805
801
SubstrFunc :: new( ) ,
@@ -810,8 +806,8 @@ mod tests {
810
806
] ,
811
807
Ok ( None ) ,
812
808
& str ,
813
- Utf8 ,
814
- StringArray
809
+ Utf8View ,
810
+ StringViewArray
815
811
) ;
816
812
test_function ! (
817
813
SubstrFunc :: new( ) ,
@@ -822,8 +818,8 @@ mod tests {
822
818
] ,
823
819
exec_err!( "negative substring length not allowed: substr(<str>, 1, -1)" ) ,
824
820
& str ,
825
- Utf8 ,
826
- StringArray
821
+ Utf8View ,
822
+ StringViewArray
827
823
) ;
828
824
test_function ! (
829
825
SubstrFunc :: new( ) ,
@@ -834,8 +830,8 @@ mod tests {
834
830
] ,
835
831
Ok ( Some ( "és" ) ) ,
836
832
& str ,
837
- Utf8 ,
838
- StringArray
833
+ Utf8View ,
834
+ StringViewArray
839
835
) ;
840
836
#[ cfg( not( feature = "unicode_expressions" ) ) ]
841
837
test_function ! (
@@ -848,8 +844,8 @@ mod tests {
848
844
"function substr requires compilation with feature flag: unicode_expressions."
849
845
) ,
850
846
& str ,
851
- Utf8 ,
852
- StringArray
847
+ Utf8View ,
848
+ StringViewArray
853
849
) ;
854
850
test_function ! (
855
851
SubstrFunc :: new( ) ,
@@ -859,8 +855,8 @@ mod tests {
859
855
] ,
860
856
Ok ( Some ( "abc" ) ) ,
861
857
& str ,
862
- Utf8 ,
863
- StringArray
858
+ Utf8View ,
859
+ StringViewArray
864
860
) ;
865
861
test_function ! (
866
862
SubstrFunc :: new( ) ,
@@ -871,8 +867,8 @@ mod tests {
871
867
] ,
872
868
exec_err!( "negative overflow when calculating skip value" ) ,
873
869
& str ,
874
- Utf8 ,
875
- StringArray
870
+ Utf8View ,
871
+ StringViewArray
876
872
) ;
877
873
878
874
Ok ( ( ) )
0 commit comments