Skip to content

Commit dc4399c

Browse files
String dtype: map builtin str alias to StringDtype (#59685)
* String dtype: map builtin str alias to StringDtype * fix tests * fix datetimelike astype and more tests * remove xfails * try fix typing * fix copy_view tests * fix remaining tests with infer_string enabled * ignore typing issue for now * move to common.py * simplify Categorical._str_get_dummies * small cleanup * fix ensure_string_array to not modify extension arrays inplace * fix ensure_string_array once more + fix is_extension_array_dtype for str * still xfail TestArrowArray::test_astype_str when not using infer_string * ensure maybe_convert_objects copies object dtype input array when inferring StringDtype * update test_1d_object_array_does_not_copy test * update constructor copy test + do not copy in maybe_convert_objects? * skip str.get_dummies test for now * use pandas_dtype() instead of registry.find * fix corner cases for calling pandas_dtype * add TODO comment in ensure_string_array
1 parent a790592 commit dc4399c

31 files changed

+183
-112
lines changed

pandas/_libs/lib.pyx

+8-1
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,14 @@ cpdef ndarray[object] ensure_string_array(
755755

756756
if hasattr(arr, "to_numpy"):
757757

758-
if hasattr(arr, "dtype") and arr.dtype.kind in "mM":
758+
if (
759+
hasattr(arr, "dtype")
760+
and arr.dtype.kind in "mM"
761+
# TODO: we should add a custom ArrowExtensionArray.astype implementation
762+
# that handles astype(str) specifically, avoiding ending up here and
763+
# then we can remove the below check for `_pa_array` (for ArrowEA)
764+
and not hasattr(arr, "_pa_array")
765+
):
759766
# dtype check to exclude DataFrame
760767
# GH#41409 TODO: not a great place for this
761768
out = arr.astype(str).astype(object)

pandas/_testing/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@
112112

113113
COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"]
114114
if using_string_dtype():
115-
STRING_DTYPES: list[Dtype] = [str, "U"]
115+
STRING_DTYPES: list[Dtype] = ["U"]
116116
else:
117117
STRING_DTYPES: list[Dtype] = [str, "str", "U"] # type: ignore[no-redef]
118118
COMPLEX_FLOAT_DTYPES: list[Dtype] = [*COMPLEX_DTYPES, *FLOAT_NUMPY_DTYPES]

pandas/core/arrays/categorical.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -2691,7 +2691,9 @@ def _str_get_dummies(self, sep: str = "|"):
26912691
# sep may not be in categories. Just bail on this.
26922692
from pandas.core.arrays import NumpyExtensionArray
26932693

2694-
return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep)
2694+
return NumpyExtensionArray(self.to_numpy(str, na_value="NaN"))._str_get_dummies(
2695+
sep
2696+
)
26952697

26962698
# ------------------------------------------------------------------------
26972699
# GroupBy Methods

pandas/core/arrays/datetimelike.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,16 @@ def astype(self, dtype, copy: bool = True):
472472

473473
return self._box_values(self.asi8.ravel()).reshape(self.shape)
474474

475+
elif is_string_dtype(dtype):
476+
if isinstance(dtype, ExtensionDtype):
477+
arr_object = self._format_native_types(na_rep=dtype.na_value) # type: ignore[arg-type]
478+
cls = dtype.construct_array_type()
479+
return cls._from_sequence(arr_object, dtype=dtype, copy=False)
480+
else:
481+
return self._format_native_types()
482+
475483
elif isinstance(dtype, ExtensionDtype):
476484
return super().astype(dtype, copy=copy)
477-
elif is_string_dtype(dtype):
478-
return self._format_native_types()
479485
elif dtype.kind in "iu":
480486
# we deliberately ignore int32 vs. int64 here.
481487
# See https://github.com/pandas-dev/pandas/issues/24381 for more.

pandas/core/dtypes/common.py

+17-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import numpy as np
1414

15+
from pandas._config import using_string_dtype
16+
1517
from pandas._libs import (
1618
Interval,
1719
Period,
@@ -1325,7 +1327,15 @@ def is_extension_array_dtype(arr_or_dtype) -> bool:
13251327
elif isinstance(dtype, np.dtype):
13261328
return False
13271329
else:
1328-
return registry.find(dtype) is not None
1330+
try:
1331+
with warnings.catch_warnings():
1332+
# pandas_dtype(..) can raise UserWarning for class input
1333+
warnings.simplefilter("ignore", UserWarning)
1334+
dtype = pandas_dtype(dtype)
1335+
except (TypeError, ValueError):
1336+
# np.dtype(..) can raise ValueError
1337+
return False
1338+
return isinstance(dtype, ExtensionDtype)
13291339

13301340

13311341
def is_ea_or_datetimelike_dtype(dtype: DtypeObj | None) -> bool:
@@ -1620,6 +1630,12 @@ def pandas_dtype(dtype) -> DtypeObj:
16201630
elif isinstance(dtype, (np.dtype, ExtensionDtype)):
16211631
return dtype
16221632

1633+
# builtin aliases
1634+
if dtype is str and using_string_dtype():
1635+
from pandas.core.arrays.string_ import StringDtype
1636+
1637+
return StringDtype(na_value=np.nan)
1638+
16231639
# registered extension types
16241640
result = registry.find(dtype)
16251641
if result is not None:

pandas/core/indexes/base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -6415,7 +6415,11 @@ def _should_compare(self, other: Index) -> bool:
64156415
return False
64166416

64176417
dtype = _unpack_nested_dtype(other)
6418-
return self._is_comparable_dtype(dtype) or is_object_dtype(dtype)
6418+
return (
6419+
self._is_comparable_dtype(dtype)
6420+
or is_object_dtype(dtype)
6421+
or is_string_dtype(dtype)
6422+
)
64196423

64206424
def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
64216425
"""

pandas/core/indexes/interval.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
is_number,
5151
is_object_dtype,
5252
is_scalar,
53+
is_string_dtype,
5354
pandas_dtype,
5455
)
5556
from pandas.core.dtypes.dtypes import (
@@ -699,7 +700,7 @@ def _get_indexer(
699700
# left/right get_indexer, compare elementwise, equality -> match
700701
indexer = self._get_indexer_unique_sides(target)
701702

702-
elif not is_object_dtype(target.dtype):
703+
elif not (is_object_dtype(target.dtype) or is_string_dtype(target.dtype)):
703704
# homogeneous scalar index: use IntervalTree
704705
# we should always have self._should_partial_index(target) here
705706
target = self._maybe_convert_i8(target)

pandas/tests/arrays/floating/test_astype.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,9 @@ def test_astype_str(using_infer_string):
6868

6969
if using_infer_string:
7070
expected = pd.array(["0.1", "0.2", None], dtype=pd.StringDtype(na_value=np.nan))
71-
tm.assert_extension_array_equal(a.astype("str"), expected)
7271

73-
# TODO(infer_string) this should also be a string array like above
74-
expected = np.array(["0.1", "0.2", "<NA>"], dtype="U32")
75-
tm.assert_numpy_array_equal(a.astype(str), expected)
72+
tm.assert_extension_array_equal(a.astype(str), expected)
73+
tm.assert_extension_array_equal(a.astype("str"), expected)
7674
else:
7775
expected = np.array(["0.1", "0.2", "<NA>"], dtype="U32")
7876

pandas/tests/arrays/integer/test_dtypes.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,9 @@ def test_astype_str(using_infer_string):
283283

284284
if using_infer_string:
285285
expected = pd.array(["1", "2", None], dtype=pd.StringDtype(na_value=np.nan))
286-
tm.assert_extension_array_equal(a.astype("str"), expected)
287286

288-
# TODO(infer_string) this should also be a string array like above
289-
expected = np.array(["1", "2", "<NA>"], dtype=f"{tm.ENDIAN}U21")
290-
tm.assert_numpy_array_equal(a.astype(str), expected)
287+
tm.assert_extension_array_equal(a.astype(str), expected)
288+
tm.assert_extension_array_equal(a.astype("str"), expected)
291289
else:
292290
expected = np.array(["1", "2", "<NA>"], dtype=f"{tm.ENDIAN}U21")
293291

pandas/tests/arrays/sparse/test_astype.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def test_astype_all(self, any_real_numpy_dtype):
8181
),
8282
(
8383
SparseArray([0, 1, 10]),
84-
str,
85-
SparseArray(["0", "1", "10"], dtype=SparseDtype(str, "0")),
84+
np.str_,
85+
SparseArray(["0", "1", "10"], dtype=SparseDtype(np.str_, "0")),
8686
),
8787
(SparseArray(["10", "20"]), float, SparseArray([10.0, 20.0])),
8888
(

pandas/tests/arrays/sparse/test_dtype.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def test_construct_from_string_fill_value_raises(string):
177177
[
178178
(SparseDtype(int, 0), float, SparseDtype(float, 0.0)),
179179
(SparseDtype(int, 1), float, SparseDtype(float, 1.0)),
180-
(SparseDtype(int, 1), str, SparseDtype(object, "1")),
180+
(SparseDtype(int, 1), np.str_, SparseDtype(object, "1")),
181181
(SparseDtype(float, 1.5), int, SparseDtype(int, 1)),
182182
],
183183
)

pandas/tests/dtypes/test_common.py

+12
Original file line numberDiff line numberDiff line change
@@ -810,11 +810,23 @@ def test_pandas_dtype_string_dtypes(string_storage):
810810
"pyarrow" if HAS_PYARROW else "python", na_value=np.nan
811811
)
812812

813+
with pd.option_context("future.infer_string", True):
814+
# with the default string_storage setting
815+
result = pandas_dtype(str)
816+
assert result == pd.StringDtype(
817+
"pyarrow" if HAS_PYARROW else "python", na_value=np.nan
818+
)
819+
813820
with pd.option_context("future.infer_string", True):
814821
with pd.option_context("string_storage", string_storage):
815822
result = pandas_dtype("str")
816823
assert result == pd.StringDtype(string_storage, na_value=np.nan)
817824

825+
with pd.option_context("future.infer_string", True):
826+
with pd.option_context("string_storage", string_storage):
827+
result = pandas_dtype(str)
828+
assert result == pd.StringDtype(string_storage, na_value=np.nan)
829+
818830
with pd.option_context("future.infer_string", False):
819831
with pd.option_context("string_storage", string_storage):
820832
result = pandas_dtype("str")

pandas/tests/extension/base/casting.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def test_tolist(self, data):
4343
assert result == expected
4444

4545
def test_astype_str(self, data):
46-
result = pd.Series(data[:5]).astype(str)
47-
expected = pd.Series([str(x) for x in data[:5]], dtype=str)
46+
result = pd.Series(data[:2]).astype(str)
47+
expected = pd.Series([str(x) for x in data[:2]], dtype=str)
4848
tm.assert_series_equal(result, expected)
4949

5050
@pytest.mark.parametrize(

pandas/tests/extension/json/array.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -207,9 +207,8 @@ def astype(self, dtype, copy=True):
207207
return self.copy()
208208
return self
209209
elif isinstance(dtype, StringDtype):
210-
value = self.astype(str) # numpy doesn't like nested dicts
211210
arr_cls = dtype.construct_array_type()
212-
return arr_cls._from_sequence(value, dtype=dtype, copy=False)
211+
return arr_cls._from_sequence(self, dtype=dtype, copy=False)
213212
elif not copy:
214213
return np.asarray([dict(x) for x in self], dtype=dtype)
215214
else:

pandas/tests/extension/test_arrow.py

+5-24
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
pa_version_under13p0,
4242
pa_version_under14p0,
4343
)
44-
import pandas.util._test_decorators as td
4544

4645
from pandas.core.dtypes.dtypes import (
4746
ArrowDtype,
@@ -286,43 +285,25 @@ def test_map(self, data_missing, na_action):
286285
expected = data_missing.to_numpy()
287286
tm.assert_numpy_array_equal(result, expected)
288287

289-
def test_astype_str(self, data, request):
288+
def test_astype_str(self, data, request, using_infer_string):
290289
pa_dtype = data.dtype.pyarrow_dtype
291290
if pa.types.is_binary(pa_dtype):
292291
request.applymarker(
293292
pytest.mark.xfail(
294293
reason=f"For {pa_dtype} .astype(str) decodes.",
295294
)
296295
)
297-
elif (
298-
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
299-
) or pa.types.is_duration(pa_dtype):
296+
elif not using_infer_string and (
297+
(pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None)
298+
or pa.types.is_duration(pa_dtype)
299+
):
300300
request.applymarker(
301301
pytest.mark.xfail(
302302
reason="pd.Timestamp/pd.Timedelta repr different from numpy repr",
303303
)
304304
)
305305
super().test_astype_str(data)
306306

307-
@pytest.mark.parametrize(
308-
"nullable_string_dtype",
309-
[
310-
"string[python]",
311-
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
312-
],
313-
)
314-
def test_astype_string(self, data, nullable_string_dtype, request):
315-
pa_dtype = data.dtype.pyarrow_dtype
316-
if (
317-
pa.types.is_timestamp(pa_dtype) and pa_dtype.tz is None
318-
) or pa.types.is_duration(pa_dtype):
319-
request.applymarker(
320-
pytest.mark.xfail(
321-
reason="pd.Timestamp/pd.Timedelta repr different from numpy repr",
322-
)
323-
)
324-
super().test_astype_string(data, nullable_string_dtype)
325-
326307
def test_from_dtype(self, data, request):
327308
pa_dtype = data.dtype.pyarrow_dtype
328309
if pa.types.is_string(pa_dtype) or pa.types.is_decimal(pa_dtype):

pandas/tests/frame/methods/test_astype.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -169,21 +169,21 @@ def test_astype_str(self):
169169
"d": list(map(str, d._values)),
170170
"e": list(map(str, e._values)),
171171
},
172-
dtype="object",
172+
dtype="str",
173173
)
174174

175175
tm.assert_frame_equal(result, expected)
176176

177-
def test_astype_str_float(self):
177+
def test_astype_str_float(self, using_infer_string):
178178
# see GH#11302
179179
result = DataFrame([np.nan]).astype(str)
180-
expected = DataFrame(["nan"], dtype="object")
180+
expected = DataFrame([np.nan if using_infer_string else "nan"], dtype="str")
181181

182182
tm.assert_frame_equal(result, expected)
183183
result = DataFrame([1.12345678901234567890]).astype(str)
184184

185185
val = "1.1234567890123457"
186-
expected = DataFrame([val], dtype="object")
186+
expected = DataFrame([val], dtype="str")
187187
tm.assert_frame_equal(result, expected)
188188

189189
@pytest.mark.parametrize("dtype_class", [dict, Series])
@@ -285,7 +285,7 @@ def test_astype_duplicate_col_series_arg(self):
285285
result = df.astype(dtypes)
286286
expected = DataFrame(
287287
{
288-
0: Series(vals[:, 0].astype(str), dtype=object),
288+
0: Series(vals[:, 0].astype(str), dtype="str"),
289289
1: vals[:, 1],
290290
2: pd.array(vals[:, 2], dtype="Float64"),
291291
3: vals[:, 3],
@@ -666,25 +666,26 @@ def test_astype_dt64tz(self, timezone_frame):
666666
# dt64tz->dt64 deprecated
667667
timezone_frame.astype("datetime64[ns]")
668668

669-
def test_astype_dt64tz_to_str(self, timezone_frame):
669+
def test_astype_dt64tz_to_str(self, timezone_frame, using_infer_string):
670670
# str formatting
671671
result = timezone_frame.astype(str)
672+
na_value = np.nan if using_infer_string else "NaT"
672673
expected = DataFrame(
673674
[
674675
[
675676
"2013-01-01",
676677
"2013-01-01 00:00:00-05:00",
677678
"2013-01-01 00:00:00+01:00",
678679
],
679-
["2013-01-02", "NaT", "NaT"],
680+
["2013-01-02", na_value, na_value],
680681
[
681682
"2013-01-03",
682683
"2013-01-03 00:00:00-05:00",
683684
"2013-01-03 00:00:00+01:00",
684685
],
685686
],
686687
columns=timezone_frame.columns,
687-
dtype="object",
688+
dtype="str",
688689
)
689690
tm.assert_frame_equal(result, expected)
690691

pandas/tests/frame/methods/test_select_dtypes.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,9 @@ def test_select_dtypes_include_using_list_like(self, using_infer_string):
9999
ei = df[["a"]]
100100
tm.assert_frame_equal(ri, ei)
101101

102+
ri = df.select_dtypes(include=[str])
103+
tm.assert_frame_equal(ri, ei)
104+
102105
def test_select_dtypes_exclude_using_list_like(self):
103106
df = DataFrame(
104107
{
@@ -358,7 +361,7 @@ def test_select_dtypes_datetime_with_tz(self):
358361
@pytest.mark.parametrize("dtype", [str, "str", np.bytes_, "S1", np.str_, "U1"])
359362
@pytest.mark.parametrize("arg", ["include", "exclude"])
360363
def test_select_dtypes_str_raises(self, dtype, arg, using_infer_string):
361-
if using_infer_string and dtype == "str":
364+
if using_infer_string and (dtype == "str" or dtype is str):
362365
# this is tested below
363366
pytest.skip("Selecting string columns works with future strings")
364367
df = DataFrame(

0 commit comments

Comments
 (0)