@@ -567,45 +567,86 @@ def test_array_function_obj_tests(stmt, py_expr):
567
567
assert a == b
568
568
569
569
570
- @pytest .mark .parametrize ("function, expected_result" , [
571
- (f .ascii (column ("a" )), pa .array ([72 , 87 , 33 ], type = pa .int32 ())), # H = 72; W = 87; ! = 33
572
- (f .bit_length (column ("a" )), pa .array ([40 , 40 , 8 ], type = pa .int32 ())),
573
- (f .btrim (literal (" World " )), pa .array (["World" , "World" , "World" ])),
574
- (f .character_length (column ("a" )), pa .array ([5 , 5 , 1 ], type = pa .int32 ())),
575
- (f .chr (literal (68 )), pa .array (["D" , "D" , "D" ])),
576
- (f .concat_ws ("-" , column ("a" ), literal ("test" )), pa .array (["Hello-test" , "World-test" , "!-test" ])),
577
- (f .concat (column ("a" ), literal ("?" )), pa .array (["Hello?" , "World?" , "!?" ])),
578
- (f .initcap (column ("c" )), pa .array (["Hello " , " World " , " !" ])),
579
- (f .left (column ("a" ), literal (3 )), pa .array (["Hel" , "Wor" , "!" ])),
580
- (f .length (column ("c" )), pa .array ([6 , 7 , 2 ], type = pa .int32 ())),
581
- (f .lower (column ("a" )), pa .array (["hello" , "world" , "!" ])),
582
- (f .lpad (column ("a" ), literal (7 )), pa .array ([" Hello" , " World" , " !" ])),
583
- (f .ltrim (column ("c" )), pa .array (["hello " , "world " , "!" ])),
584
- (f .md5 (column ("a" )), pa .array ([
585
- "8b1a9953c4611296a827abf8c47804d7" ,
586
- "f5a7924e621e84c9280a9a27e1bcb7f6" ,
587
- "9033e0e305f247c0c3c80d0c7848c8b3" ,
588
- ])),
589
- (f .octet_length (column ("a" )), pa .array ([5 , 5 , 1 ], type = pa .int32 ())),
590
- (f .repeat (column ("a" ), literal (2 )), pa .array (["HelloHello" , "WorldWorld" , "!!" ])),
591
- (f .replace (column ("a" ), literal ("l" ), literal ("?" )), pa .array (["He??o" , "Wor?d" , "!" ])),
592
- (f .reverse (column ("a" )), pa .array (["olleH" , "dlroW" , "!" ])),
593
- (f .right (column ("a" ), literal (4 )), pa .array (["ello" , "orld" , "!" ])),
594
- (f .rpad (column ("a" ), literal (8 )), pa .array (["Hello " , "World " , "! " ])),
595
- (f .rtrim (column ("c" )), pa .array (["hello" , " world" , " !" ])),
596
- (f .split_part (column ("a" ), literal ("l" ), literal (1 )), pa .array (["He" , "Wor" , "!" ])),
597
- (f .starts_with (column ("a" ), literal ("Wor" )), pa .array ([False , True , False ])),
598
- (f .strpos (column ("a" ), literal ("o" )), pa .array ([5 , 2 , 0 ], type = pa .int32 ())),
599
- (f .substr (column ("a" ), literal (3 )), pa .array (["llo" , "rld" , "" ])),
600
- (f .translate (column ("a" ), literal ("or" ), literal ("ld" )), pa .array (["Helll" , "Wldld" , "!" ])),
601
- (f .trim (column ("c" )), pa .array (["hello" , "world" , "!" ])),
602
- (f .upper (column ("c" )), pa .array (["HELLO " , " WORLD " , " !" ])),
603
- (f .ends_with (column ("a" ), literal ("llo" )), pa .array ([True , False , False ])),
604
- (f .overlay (column ("a" ), literal ("--" ), literal (2 )), pa .array (["H--lo" , "W--ld" , "--" ])),
605
- (f .regexp_like (column ("a" ), literal ("(ell|orl)" )), pa .array ([True , True , False ])),
606
- (f .regexp_match (column ("a" ), literal ("(ell|orl)" )), pa .array ([["ell" ], ["orl" ], None ])),
607
- (f .regexp_replace (column ("a" ), literal ("(ell|orl)" ), literal ("-" )), pa .array (["H-o" , "W-d" , "!" ])),
608
- ])
570
+ @pytest .mark .parametrize (
571
+ "function, expected_result" ,
572
+ [
573
+ (
574
+ f .ascii (column ("a" )),
575
+ pa .array ([72 , 87 , 33 ], type = pa .int32 ()),
576
+ ), # H = 72; W = 87; ! = 33
577
+ (f .bit_length (column ("a" )), pa .array ([40 , 40 , 8 ], type = pa .int32 ())),
578
+ (f .btrim (literal (" World " )), pa .array (["World" , "World" , "World" ])),
579
+ (f .character_length (column ("a" )), pa .array ([5 , 5 , 1 ], type = pa .int32 ())),
580
+ (f .chr (literal (68 )), pa .array (["D" , "D" , "D" ])),
581
+ (
582
+ f .concat_ws ("-" , column ("a" ), literal ("test" )),
583
+ pa .array (["Hello-test" , "World-test" , "!-test" ]),
584
+ ),
585
+ (f .concat (column ("a" ), literal ("?" )), pa .array (["Hello?" , "World?" , "!?" ])),
586
+ (f .initcap (column ("c" )), pa .array (["Hello " , " World " , " !" ])),
587
+ (f .left (column ("a" ), literal (3 )), pa .array (["Hel" , "Wor" , "!" ])),
588
+ (f .length (column ("c" )), pa .array ([6 , 7 , 2 ], type = pa .int32 ())),
589
+ (f .lower (column ("a" )), pa .array (["hello" , "world" , "!" ])),
590
+ (f .lpad (column ("a" ), literal (7 )), pa .array ([" Hello" , " World" , " !" ])),
591
+ (f .ltrim (column ("c" )), pa .array (["hello " , "world " , "!" ])),
592
+ (
593
+ f .md5 (column ("a" )),
594
+ pa .array (
595
+ [
596
+ "8b1a9953c4611296a827abf8c47804d7" ,
597
+ "f5a7924e621e84c9280a9a27e1bcb7f6" ,
598
+ "9033e0e305f247c0c3c80d0c7848c8b3" ,
599
+ ]
600
+ ),
601
+ ),
602
+ (f .octet_length (column ("a" )), pa .array ([5 , 5 , 1 ], type = pa .int32 ())),
603
+ (
604
+ f .repeat (column ("a" ), literal (2 )),
605
+ pa .array (["HelloHello" , "WorldWorld" , "!!" ]),
606
+ ),
607
+ (
608
+ f .replace (column ("a" ), literal ("l" ), literal ("?" )),
609
+ pa .array (["He??o" , "Wor?d" , "!" ]),
610
+ ),
611
+ (f .reverse (column ("a" )), pa .array (["olleH" , "dlroW" , "!" ])),
612
+ (f .right (column ("a" ), literal (4 )), pa .array (["ello" , "orld" , "!" ])),
613
+ (
614
+ f .rpad (column ("a" ), literal (8 )),
615
+ pa .array (["Hello " , "World " , "! " ]),
616
+ ),
617
+ (f .rtrim (column ("c" )), pa .array (["hello" , " world" , " !" ])),
618
+ (
619
+ f .split_part (column ("a" ), literal ("l" ), literal (1 )),
620
+ pa .array (["He" , "Wor" , "!" ]),
621
+ ),
622
+ (f .starts_with (column ("a" ), literal ("Wor" )), pa .array ([False , True , False ])),
623
+ (f .strpos (column ("a" ), literal ("o" )), pa .array ([5 , 2 , 0 ], type = pa .int32 ())),
624
+ (f .substr (column ("a" ), literal (3 )), pa .array (["llo" , "rld" , "" ])),
625
+ (
626
+ f .translate (column ("a" ), literal ("or" ), literal ("ld" )),
627
+ pa .array (["Helll" , "Wldld" , "!" ]),
628
+ ),
629
+ (f .trim (column ("c" )), pa .array (["hello" , "world" , "!" ])),
630
+ (f .upper (column ("c" )), pa .array (["HELLO " , " WORLD " , " !" ])),
631
+ (f .ends_with (column ("a" ), literal ("llo" )), pa .array ([True , False , False ])),
632
+ (
633
+ f .overlay (column ("a" ), literal ("--" ), literal (2 )),
634
+ pa .array (["H--lo" , "W--ld" , "--" ]),
635
+ ),
636
+ (
637
+ f .regexp_like (column ("a" ), literal ("(ell|orl)" )),
638
+ pa .array ([True , True , False ]),
639
+ ),
640
+ (
641
+ f .regexp_match (column ("a" ), literal ("(ell|orl)" )),
642
+ pa .array ([["ell" ], ["orl" ], None ]),
643
+ ),
644
+ (
645
+ f .regexp_replace (column ("a" ), literal ("(ell|orl)" ), literal ("-" )),
646
+ pa .array (["H-o" , "W-d" , "!" ]),
647
+ ),
648
+ ],
649
+ )
609
650
def test_string_functions (df , function , expected_result ):
610
651
df = df .select (function )
611
652
result = df .collect ()
@@ -849,27 +890,30 @@ def test_regr_funcs_sql_2():
849
890
assert result_sql [0 ].column (8 ) == pa .array ([4 ], type = pa .float64 ())
850
891
851
892
852
- @pytest .mark .parametrize ("func, expected" , [
853
- pytest .param (f .regr_slope , pa .array ([2 ], type = pa .float64 ()), id = "regr_slope" ),
854
- pytest .param (f .regr_intercept , pa .array ([0 ], type = pa .float64 ()), id = "regr_intercept" ),
855
- pytest .param (f .regr_count , pa .array ([3 ], type = pa .uint64 ()), id = "regr_count" ),
856
- pytest .param (f .regr_r2 , pa .array ([1 ], type = pa .float64 ()), id = "regr_r2" ),
857
- pytest .param (f .regr_avgx , pa .array ([2 ], type = pa .float64 ()), id = "regr_avgx" ),
858
- pytest .param (f .regr_avgy , pa .array ([4 ], type = pa .float64 ()), id = "regr_avgy" ),
859
- pytest .param (f .regr_sxx , pa .array ([2 ], type = pa .float64 ()), id = "regr_sxx" ),
860
- pytest .param (f .regr_syy , pa .array ([8 ], type = pa .float64 ()), id = "regr_syy" ),
861
- pytest .param (f .regr_sxy , pa .array ([4 ], type = pa .float64 ()), id = "regr_sxy" )
862
- ])
893
+ @pytest .mark .parametrize (
894
+ "func, expected" ,
895
+ [
896
+ pytest .param (f .regr_slope , pa .array ([2 ], type = pa .float64 ()), id = "regr_slope" ),
897
+ pytest .param (
898
+ f .regr_intercept , pa .array ([0 ], type = pa .float64 ()), id = "regr_intercept"
899
+ ),
900
+ pytest .param (f .regr_count , pa .array ([3 ], type = pa .uint64 ()), id = "regr_count" ),
901
+ pytest .param (f .regr_r2 , pa .array ([1 ], type = pa .float64 ()), id = "regr_r2" ),
902
+ pytest .param (f .regr_avgx , pa .array ([2 ], type = pa .float64 ()), id = "regr_avgx" ),
903
+ pytest .param (f .regr_avgy , pa .array ([4 ], type = pa .float64 ()), id = "regr_avgy" ),
904
+ pytest .param (f .regr_sxx , pa .array ([2 ], type = pa .float64 ()), id = "regr_sxx" ),
905
+ pytest .param (f .regr_syy , pa .array ([8 ], type = pa .float64 ()), id = "regr_syy" ),
906
+ pytest .param (f .regr_sxy , pa .array ([4 ], type = pa .float64 ()), id = "regr_sxy" ),
907
+ ],
908
+ )
863
909
def test_regr_funcs_df (func , expected ):
864
-
865
910
# test case based on `regr_*() basic tests
866
911
# https://github.com/apache/datafusion/blob/d1361d56b9a9e0c165d3d71a8df6795d2a5f51dd/datafusion/core/tests/sqllogictests/test_files/aggregate.slt#L2358C1-L2374C1
867
912
868
-
869
913
ctx = SessionContext ()
870
914
871
915
# Create a DataFrame
872
- data = {' column1' : [1 , 2 , 3 ], ' column2' : [2 , 4 , 6 ]}
916
+ data = {" column1" : [1 , 2 , 3 ], " column2" : [2 , 4 , 6 ]}
873
917
df = ctx .from_pydict (data , name = "test_table" )
874
918
875
919
# Perform the regression function using DataFrame API
@@ -900,6 +944,8 @@ def test_first_last_value(df):
900
944
assert result .column (3 ) == pa .array (["!" ])
901
945
assert result .column (4 ) == pa .array ([6 ])
902
946
assert result .column (5 ) == pa .array ([datetime (2020 , 7 , 2 )])
947
+ df .show ()
948
+ assert False
903
949
904
950
905
951
def test_binary_string_functions (df ):
0 commit comments