@@ -34,9 +34,9 @@ def df():
34
34
# create a RecordBatch and a new DataFrame from it
35
35
batch = pa .RecordBatch .from_arrays (
36
36
[
37
- pa .array (["Hello" , "World" , "!" ]),
37
+ pa .array (["Hello" , "World" , "!" ], type = pa . string_view () ),
38
38
pa .array ([4 , 5 , 6 ]),
39
- pa .array (["hello " , " world " , " !" ]),
39
+ pa .array (["hello " , " world " , " !" ], type = pa . string_view () ),
40
40
pa .array (
41
41
[
42
42
datetime (2022 , 12 , 31 ),
@@ -88,16 +88,18 @@ def test_literal(df):
88
88
assert len (result ) == 1
89
89
result = result [0 ]
90
90
assert result .column (0 ) == pa .array ([1 ] * 3 )
91
- assert result .column (1 ) == pa .array (["1" ] * 3 )
92
- assert result .column (2 ) == pa .array (["OK" ] * 3 )
91
+ assert result .column (1 ) == pa .array (["1" ] * 3 , type = pa . string_view () )
92
+ assert result .column (2 ) == pa .array (["OK" ] * 3 , type = pa . string_view () )
93
93
assert result .column (3 ) == pa .array ([3.14 ] * 3 )
94
94
assert result .column (4 ) == pa .array ([True ] * 3 )
95
95
assert result .column (5 ) == pa .array ([b"hello world" ] * 3 )
96
96
97
97
98
98
def test_lit_arith (df ):
99
99
"""Test literals with arithmetic operations"""
100
- df = df .select (literal (1 ) + column ("b" ), f .concat (column ("a" ), literal ("!" )))
100
+ df = df .select (
101
+ literal (1 ) + column ("b" ), f .concat (column ("a" ).cast (pa .string ()), literal ("!" ))
102
+ )
101
103
result = df .collect ()
102
104
assert len (result ) == 1
103
105
result = result [0 ]
@@ -600,21 +602,33 @@ def test_array_function_obj_tests(stmt, py_expr):
600
602
f .ascii (column ("a" )),
601
603
pa .array ([72 , 87 , 33 ], type = pa .int32 ()),
602
604
), # H = 72; W = 87; ! = 33
603
- (f .bit_length (column ("a" )), pa .array ([40 , 40 , 8 ], type = pa .int32 ())),
604
- (f .btrim (literal (" World " )), pa .array (["World" , "World" , "World" ])),
605
+ (
606
+ f .bit_length (column ("a" ).cast (pa .string ())),
607
+ pa .array ([40 , 40 , 8 ], type = pa .int32 ()),
608
+ ),
609
+ (
610
+ f .btrim (literal (" World " )),
611
+ pa .array (["World" , "World" , "World" ], type = pa .string_view ()),
612
+ ),
605
613
(f .character_length (column ("a" )), pa .array ([5 , 5 , 1 ], type = pa .int32 ())),
606
614
(f .chr (literal (68 )), pa .array (["D" , "D" , "D" ])),
607
615
(
608
616
f .concat_ws ("-" , column ("a" ), literal ("test" )),
609
617
pa .array (["Hello-test" , "World-test" , "!-test" ]),
610
618
),
611
- (f .concat (column ("a" ), literal ("?" )), pa .array (["Hello?" , "World?" , "!?" ])),
619
+ (
620
+ f .concat (column ("a" ).cast (pa .string ()), literal ("?" )),
621
+ pa .array (["Hello?" , "World?" , "!?" ]),
622
+ ),
612
623
(f .initcap (column ("c" )), pa .array (["Hello " , " World " , " !" ])),
613
624
(f .left (column ("a" ), literal (3 )), pa .array (["Hel" , "Wor" , "!" ])),
614
625
(f .length (column ("c" )), pa .array ([6 , 7 , 2 ], type = pa .int32 ())),
615
626
(f .lower (column ("a" )), pa .array (["hello" , "world" , "!" ])),
616
627
(f .lpad (column ("a" ), literal (7 )), pa .array ([" Hello" , " World" , " !" ])),
617
- (f .ltrim (column ("c" )), pa .array (["hello " , "world " , "!" ])),
628
+ (
629
+ f .ltrim (column ("c" )),
630
+ pa .array (["hello " , "world " , "!" ], type = pa .string_view ()),
631
+ ),
618
632
(
619
633
f .md5 (column ("a" )),
620
634
pa .array (
@@ -640,19 +654,25 @@ def test_array_function_obj_tests(stmt, py_expr):
640
654
f .rpad (column ("a" ), literal (8 )),
641
655
pa .array (["Hello " , "World " , "! " ]),
642
656
),
643
- (f .rtrim (column ("c" )), pa .array (["hello" , " world" , " !" ])),
657
+ (
658
+ f .rtrim (column ("c" )),
659
+ pa .array (["hello" , " world" , " !" ], type = pa .string_view ()),
660
+ ),
644
661
(
645
662
f .split_part (column ("a" ), literal ("l" ), literal (1 )),
646
663
pa .array (["He" , "Wor" , "!" ]),
647
664
),
648
665
(f .starts_with (column ("a" ), literal ("Wor" )), pa .array ([False , True , False ])),
649
666
(f .strpos (column ("a" ), literal ("o" )), pa .array ([5 , 2 , 0 ], type = pa .int32 ())),
650
- (f .substr (column ("a" ), literal (3 )), pa .array (["llo" , "rld" , "" ])),
667
+ (
668
+ f .substr (column ("a" ), literal (3 )),
669
+ pa .array (["llo" , "rld" , "" ], type = pa .string_view ()),
670
+ ),
651
671
(
652
672
f .translate (column ("a" ), literal ("or" ), literal ("ld" )),
653
673
pa .array (["Helll" , "Wldld" , "!" ]),
654
674
),
655
- (f .trim (column ("c" )), pa .array (["hello" , "world" , "!" ])),
675
+ (f .trim (column ("c" )), pa .array (["hello" , "world" , "!" ], type = pa . string_view () )),
656
676
(f .upper (column ("c" )), pa .array (["HELLO " , " WORLD " , " !" ])),
657
677
(f .ends_with (column ("a" ), literal ("llo" )), pa .array ([True , False , False ])),
658
678
(
@@ -794,9 +814,9 @@ def test_temporal_functions(df):
794
814
f .date_trunc (literal ("month" ), column ("d" )),
795
815
f .datetrunc (literal ("day" ), column ("d" )),
796
816
f .date_bin (
797
- literal ("15 minutes" ),
817
+ literal ("15 minutes" ). cast ( pa . string ()) ,
798
818
column ("d" ),
799
- literal ("2001-01-01 00:02:30" ),
819
+ literal ("2001-01-01 00:02:30" ). cast ( pa . string ()) ,
800
820
),
801
821
f .from_unixtime (literal (1673383974 )),
802
822
f .to_timestamp (literal ("2023-09-07 05:06:14.523952" )),
@@ -858,8 +878,8 @@ def test_case(df):
858
878
result = df .collect ()
859
879
result = result [0 ]
860
880
assert result .column (0 ) == pa .array ([10 , 8 , 8 ])
861
- assert result .column (1 ) == pa .array (["Hola" , "Mundo" , "!!" ])
862
- assert result .column (2 ) == pa .array (["Hola" , "Mundo" , None ])
881
+ assert result .column (1 ) == pa .array (["Hola" , "Mundo" , "!!" ], type = pa . string_view () )
882
+ assert result .column (2 ) == pa .array (["Hola" , "Mundo" , None ], type = pa . string_view () )
863
883
864
884
865
885
def test_when_with_no_base (df ):
@@ -877,8 +897,10 @@ def test_when_with_no_base(df):
877
897
result = df .collect ()
878
898
result = result [0 ]
879
899
assert result .column (0 ) == pa .array ([4 , 5 , 6 ])
880
- assert result .column (1 ) == pa .array (["too small" , "just right" , "too big" ])
881
- assert result .column (2 ) == pa .array (["Hello" , None , None ])
900
+ assert result .column (1 ) == pa .array (
901
+ ["too small" , "just right" , "too big" ], type = pa .string_view ()
902
+ )
903
+ assert result .column (2 ) == pa .array (["Hello" , None , None ], type = pa .string_view ())
882
904
883
905
884
906
def test_regr_funcs_sql (df ):
@@ -1021,8 +1043,13 @@ def test_regr_funcs_df(func, expected):
1021
1043
1022
1044
def test_binary_string_functions (df ):
1023
1045
df = df .select (
1024
- f .encode (column ("a" ), literal ("base64" )),
1025
- f .decode (f .encode (column ("a" ), literal ("base64" )), literal ("base64" )),
1046
+ f .encode (column ("a" ).cast (pa .string ()), literal ("base64" ).cast (pa .string ())),
1047
+ f .decode (
1048
+ f .encode (
1049
+ column ("a" ).cast (pa .string ()), literal ("base64" ).cast (pa .string ())
1050
+ ),
1051
+ literal ("base64" ).cast (pa .string ()),
1052
+ ),
1026
1053
)
1027
1054
result = df .collect ()
1028
1055
assert len (result ) == 1
0 commit comments