Skip to content

Commit d862eca

Browse files
committed
Na return value
1 parent 6b26309 commit d862eca

File tree

4 files changed

+38
-34
lines changed

4 files changed

+38
-34
lines changed

pandas/core/arrays/string_.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ class StringDtype(StorageExtensionDtype):
101101
#: StringDtype().na_value uses pandas.NA
102102
@property
103103
def na_value(self) -> libmissing.NAType:
104-
return libmissing.NA
104+
if self.storage == "pyarrow_numpy":
105+
return np.nan
106+
else:
107+
return libmissing.NA
105108

106109
_metadata = ("storage",)
107110

pandas/tests/strings/__init__.py

+14
Original file line numberDiff line numberDiff line change
@@ -1 +1,15 @@
1+
import numpy as np
2+
3+
import pandas as pd
4+
15
object_pyarrow_numpy = ("object", "string[pyarrow_numpy]")
6+
7+
8+
def _convert_na_value(ser, expected):
9+
if ser.dtype != object:
10+
if ser.dtype.storage == "pyarrow_numpy":
11+
expected = expected.fillna(np.nan)
12+
else:
13+
# GH#18463
14+
expected = expected.fillna(pd.NA)
15+
return expected

pandas/tests/strings/test_find_replace.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
Series,
1212
_testing as tm,
1313
)
14-
from pandas.tests.strings import object_pyarrow_numpy
14+
from pandas.tests.strings import (
15+
_convert_na_value,
16+
object_pyarrow_numpy,
17+
)
1518

1619
# --------------------------------------------------------------------------------------
1720
# str.contains
@@ -780,9 +783,7 @@ def test_findall(any_string_dtype):
780783
ser = Series(["fooBAD__barBAD", np.nan, "foo", "BAD"], dtype=any_string_dtype)
781784
result = ser.str.findall("BAD[_]*")
782785
expected = Series([["BAD__", "BAD"], np.nan, [], ["BAD"]])
783-
if ser.dtype != object:
784-
# GH#18463
785-
expected = expected.fillna(pd.NA)
786+
expected = _convert_na_value(ser, expected)
786787
tm.assert_series_equal(result, expected)
787788

788789

pandas/tests/strings/test_split_partition.py

+15-29
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
Series,
1313
_testing as tm,
1414
)
15+
from pandas.tests.strings import (
16+
_convert_na_value,
17+
object_pyarrow_numpy,
18+
)
1519

1620

1721
@pytest.mark.parametrize("method", ["split", "rsplit"])
@@ -20,9 +24,7 @@ def test_split(any_string_dtype, method):
2024

2125
result = getattr(values.str, method)("_")
2226
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
23-
if values.dtype != object:
24-
# GH#18463
25-
exp = exp.fillna(pd.NA)
27+
exp = _convert_na_value(values, exp)
2628
tm.assert_series_equal(result, exp)
2729

2830

@@ -32,9 +34,7 @@ def test_split_more_than_one_char(any_string_dtype, method):
3234
values = Series(["a__b__c", "c__d__e", np.nan, "f__g__h"], dtype=any_string_dtype)
3335
result = getattr(values.str, method)("__")
3436
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
35-
if values.dtype != object:
36-
# GH#18463
37-
exp = exp.fillna(pd.NA)
37+
exp = _convert_na_value(values, exp)
3838
tm.assert_series_equal(result, exp)
3939

4040
result = getattr(values.str, method)("__", expand=False)
@@ -46,9 +46,7 @@ def test_split_more_regex_split(any_string_dtype):
4646
values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"], dtype=any_string_dtype)
4747
result = values.str.split("[,_]")
4848
exp = Series([["a", "b", "c"], ["c", "d", "e"], np.nan, ["f", "g", "h"]])
49-
if values.dtype != object:
50-
# GH#18463
51-
exp = exp.fillna(pd.NA)
49+
exp = _convert_na_value(values, exp)
5250
tm.assert_series_equal(result, exp)
5351

5452

@@ -118,8 +116,8 @@ def test_split_object_mixed(expand, method):
118116
def test_split_n(any_string_dtype, method, n):
119117
s = Series(["a b", pd.NA, "b c"], dtype=any_string_dtype)
120118
expected = Series([["a", "b"], pd.NA, ["b", "c"]])
121-
122119
result = getattr(s.str, method)(" ", n=n)
120+
expected = _convert_na_value(s, expected)
123121
tm.assert_series_equal(result, expected)
124122

125123

@@ -128,9 +126,7 @@ def test_rsplit(any_string_dtype):
128126
values = Series(["a,b_c", "c_d,e", np.nan, "f,g,h"], dtype=any_string_dtype)
129127
result = values.str.rsplit("[,_]")
130128
exp = Series([["a,b_c"], ["c_d,e"], np.nan, ["f,g,h"]])
131-
if values.dtype != object:
132-
# GH#18463
133-
exp = exp.fillna(pd.NA)
129+
exp = _convert_na_value(values, exp)
134130
tm.assert_series_equal(result, exp)
135131

136132

@@ -139,9 +135,7 @@ def test_rsplit_max_number(any_string_dtype):
139135
values = Series(["a_b_c", "c_d_e", np.nan, "f_g_h"], dtype=any_string_dtype)
140136
result = values.str.rsplit("_", n=1)
141137
exp = Series([["a_b", "c"], ["c_d", "e"], np.nan, ["f_g", "h"]])
142-
if values.dtype != object:
143-
# GH#18463
144-
exp = exp.fillna(pd.NA)
138+
exp = _convert_na_value(values, exp)
145139
tm.assert_series_equal(result, exp)
146140

147141

@@ -390,7 +384,7 @@ def test_split_nan_expand(any_string_dtype):
390384
# check that these are actually np.nan/pd.NA and not None
391385
# TODO see GH 18463
392386
# tm.assert_frame_equal does not differentiate
393-
if any_string_dtype == "object":
387+
if any_string_dtype in object_pyarrow_numpy:
394388
assert all(np.isnan(x) for x in result.iloc[1])
395389
else:
396390
assert all(x is pd.NA for x in result.iloc[1])
@@ -455,9 +449,7 @@ def test_partition_series_more_than_one_char(method, exp, any_string_dtype):
455449
s = Series(["a__b__c", "c__d__e", np.nan, "f__g__h", None], dtype=any_string_dtype)
456450
result = getattr(s.str, method)("__", expand=False)
457451
expected = Series(exp)
458-
if s.dtype != object:
459-
# GH#18463
460-
expected = expected.fillna(pd.NA)
452+
expected = _convert_na_value(s, expected)
461453
tm.assert_series_equal(result, expected)
462454

463455

@@ -480,9 +472,7 @@ def test_partition_series_none(any_string_dtype, method, exp):
480472
s = Series(["a b c", "c d e", np.nan, "f g h", None], dtype=any_string_dtype)
481473
result = getattr(s.str, method)(expand=False)
482474
expected = Series(exp)
483-
if s.dtype != object:
484-
# GH#18463
485-
expected = expected.fillna(pd.NA)
475+
expected = _convert_na_value(s, expected)
486476
tm.assert_series_equal(result, expected)
487477

488478

@@ -505,9 +495,7 @@ def test_partition_series_not_split(any_string_dtype, method, exp):
505495
s = Series(["abc", "cde", np.nan, "fgh", None], dtype=any_string_dtype)
506496
result = getattr(s.str, method)("_", expand=False)
507497
expected = Series(exp)
508-
if s.dtype != object:
509-
# GH#18463
510-
expected = expected.fillna(pd.NA)
498+
expected = _convert_na_value(s, expected)
511499
tm.assert_series_equal(result, expected)
512500

513501

@@ -531,9 +519,7 @@ def test_partition_series_unicode(any_string_dtype, method, exp):
531519

532520
result = getattr(s.str, method)("_", expand=False)
533521
expected = Series(exp)
534-
if s.dtype != object:
535-
# GH#18463
536-
expected = expected.fillna(pd.NA)
522+
expected = _convert_na_value(s, expected)
537523
tm.assert_series_equal(result, expected)
538524

539525

0 commit comments

Comments
 (0)