9
9
import numpy as np
10
10
import pytest
11
11
12
- from pandas ._config import using_string_dtype
13
-
14
- from pandas .compat import (
15
- HAS_PYARROW ,
16
- WASM ,
17
- )
12
+ from pandas .compat import WASM
18
13
from pandas .compat .numpy import np_version_gte1p24
19
14
from pandas .errors import IndexingError
20
15
32
27
NaT ,
33
28
Period ,
34
29
Series ,
30
+ StringDtype ,
35
31
Timedelta ,
36
32
Timestamp ,
37
33
array ,
@@ -535,14 +531,16 @@ def test_append_timedelta_does_not_cast(self, td, using_infer_string, request):
535
531
tm .assert_series_equal (ser , expected )
536
532
assert isinstance (ser ["td" ], Timedelta )
537
533
538
- @pytest .mark .xfail (using_string_dtype (), reason = "TODO(infer_string)" )
539
534
def test_setitem_with_expansion_type_promotion (self ):
540
535
# GH#12599
541
536
ser = Series (dtype = object )
542
537
ser ["a" ] = Timestamp ("2016-01-01" )
543
538
ser ["b" ] = 3.0
544
539
ser ["c" ] = "foo"
545
- expected = Series ([Timestamp ("2016-01-01" ), 3.0 , "foo" ], index = ["a" , "b" , "c" ])
540
+ expected = Series (
541
+ [Timestamp ("2016-01-01" ), 3.0 , "foo" ],
542
+ index = Index (["a" , "b" , "c" ], dtype = object ),
543
+ )
546
544
tm .assert_series_equal (ser , expected )
547
545
548
546
def test_setitem_not_contained (self , string_series ):
@@ -826,11 +824,6 @@ def test_mask_key(self, obj, key, expected, raises, val, indexer_sli):
826
824
else :
827
825
indexer_sli (obj )[mask ] = val
828
826
829
- @pytest .mark .xfail (
830
- using_string_dtype () and not HAS_PYARROW ,
831
- reason = "TODO(infer_string)" ,
832
- strict = False ,
833
- )
834
827
def test_series_where (self , obj , key , expected , raises , val , is_inplace ):
835
828
mask = np .zeros (obj .shape , dtype = bool )
836
829
mask [key ] = True
@@ -846,6 +839,11 @@ def test_series_where(self, obj, key, expected, raises, val, is_inplace):
846
839
obj = obj .copy ()
847
840
arr = obj ._values
848
841
842
+ if raises and obj .dtype == "string" :
843
+ with pytest .raises (TypeError , match = "Invalid value" ):
844
+ obj .where (~ mask , val )
845
+ return
846
+
849
847
res = obj .where (~ mask , val )
850
848
851
849
if val is NA and res .dtype == object :
@@ -858,25 +856,23 @@ def test_series_where(self, obj, key, expected, raises, val, is_inplace):
858
856
859
857
self ._check_inplace (is_inplace , orig , arr , obj )
860
858
861
- @pytest .mark .xfail (using_string_dtype (), reason = "TODO(infer_string)" , strict = False )
862
- def test_index_where (self , obj , key , expected , raises , val , using_infer_string ):
859
+ def test_index_where (self , obj , key , expected , raises , val ):
863
860
mask = np .zeros (obj .shape , dtype = bool )
864
861
mask [key ] = True
865
862
866
- if using_infer_string and obj .dtype == object :
863
+ if raises and obj .dtype == "string" :
867
864
with pytest .raises (TypeError , match = "Invalid value" ):
868
865
Index (obj ).where (~ mask , val )
869
866
else :
870
867
res = Index (obj ).where (~ mask , val )
871
868
expected_idx = Index (expected , dtype = expected .dtype )
872
869
tm .assert_index_equal (res , expected_idx )
873
870
874
- @pytest .mark .xfail (using_string_dtype (), reason = "TODO(infer_string)" , strict = False )
875
- def test_index_putmask (self , obj , key , expected , raises , val , using_infer_string ):
871
+ def test_index_putmask (self , obj , key , expected , raises , val ):
876
872
mask = np .zeros (obj .shape , dtype = bool )
877
873
mask [key ] = True
878
874
879
- if using_infer_string and obj .dtype == object :
875
+ if raises and obj .dtype == "string" :
880
876
with pytest .raises (TypeError , match = "Invalid value" ):
881
877
Index (obj ).putmask (mask , val )
882
878
else :
@@ -1372,6 +1368,19 @@ def raises(self):
1372
1368
return False
1373
1369
1374
1370
1371
+ @pytest .mark .parametrize (
1372
+ "val,exp_dtype,raises" ,
1373
+ [
1374
+ (1 , object , True ),
1375
+ ("e" , StringDtype (na_value = np .nan ), False ),
1376
+ ],
1377
+ )
1378
+ class TestCoercionString (CoercionTest ):
1379
+ @pytest .fixture
1380
+ def obj (self ):
1381
+ return Series (["a" , "b" , "c" , "d" ], dtype = StringDtype (na_value = np .nan ))
1382
+
1383
+
1375
1384
@pytest .mark .parametrize (
1376
1385
"val,exp_dtype,raises" ,
1377
1386
[
0 commit comments