Skip to content

Commit dc4399c

Browse files
String dtype: map builtin str alias to StringDtype (#59685)
* String dtype: map builtin str alias to StringDtype * fix tests * fix datetimelike astype and more tests * remove xfails * try fix typing * fix copy_view tests * fix remaining tests with infer_string enabled * ignore typing issue for now * move to common.py * simplify Categorical._str_get_dummies * small cleanup * fix ensure_string_array to not modify extension arrays inplace * fix ensure_string_array once more + fix is_extension_array_dtype for str * still xfail TestArrowArray::test_astype_str when not using infer_string * ensure maybe_convert_objects copies object dtype input array when inferring StringDtype * update test_1d_object_array_does_not_copy test * update constructor copy test + do not copy in maybe_convert_objects? * skip str.get_dummies test for now * use pandas_dtype() instead of registry.find * fix corner cases for calling pandas_dtype * add TODO comment in ensure_string_array
1 parent a790592 commit dc4399c

File tree

31 files changed

+183
-112
lines changed

31 files changed

+183
-112
lines changed

pandas/_libs/lib.pyx

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,14 @@ cpdef ndarray[object] ensure_string_array(
755755

756756
if hasattr(arr, "to_numpy"):
757757

758-
if hasattr(arr, "dtype") and arr.dtype.kind in "mM":
758+
if (
759+
hasattr(arr, "dtype")
760+
and arr.dtype.kind in "mM"
761+
# TODO: we should add a custom ArrowExtensionArray.astype implementation
762+
# that handles astype(str) specifically, avoiding ending up here and
763+
# then we can remove the below check for `_pa_array` (for ArrowEA)
764+
and not hasattr(arr, "_pa_array")
765+
):
759766
# dtype check to exclude DataFrame
760767
# GH#41409 TODO: not a great place for this
761768
out = arr.astype(str).astype(object)

pandas/_testing/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@
112112

113113
COMPLEX_DTYPES: list[Dtype] = [complex, "complex64", "complex128"]
114114
if using_string_dtype():
115-
STRING_DTYPES: list[Dtype] = [str, "U"]
115+
STRING_DTYPES: list[Dtype] = ["U"]
116116
else:
117117
STRING_DTYPES: list[Dtype] = [str, "str", "U"] # type: ignore[no-redef]
118118
COMPLEX_FLOAT_DTYPES: list[Dtype] = [*COMPLEX_DTYPES, *FLOAT_NUMPY_DTYPES]

pandas/core/arrays/categorical.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2691,7 +2691,9 @@ def _str_get_dummies(self, sep: str = "|"):
26912691
# sep may not be in categories. Just bail on this.
26922692
from pandas.core.arrays import NumpyExtensionArray
26932693

2694-
return NumpyExtensionArray(self.astype(str))._str_get_dummies(sep)
2694+
return NumpyExtensionArray(self.to_numpy(str, na_value="NaN"))._str_get_dummies(
2695+
sep
2696+
)
26952697

26962698
# ------------------------------------------------------------------------
26972699
# GroupBy Methods

pandas/core/arrays/datetimelike.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -472,10 +472,16 @@ def astype(self, dtype, copy: bool = True):
472472

473473
return self._box_values(self.asi8.ravel()).reshape(self.shape)
474474

475+
elif is_string_dtype(dtype):
476+
if isinstance(dtype, ExtensionDtype):
477+
arr_object = self._format_native_types(na_rep=dtype.na_value) # type: ignore[arg-type]
478+
cls = dtype.construct_array_type()
479+
return cls._from_sequence(arr_object, dtype=dtype, copy=False)
480+
else:
481+
return self._format_native_types()
482+
475483
elif isinstance(dtype, ExtensionDtype):
476484
return super().astype(dtype, copy=copy)
477-
elif is_string_dtype(dtype):
478-
return self._format_native_types()
479485
elif dtype.kind in "iu":
480486
# we deliberately ignore int32 vs. int64 here.
481487
# See https://github.com/pandas-dev/pandas/issues/24381 for more.

pandas/core/dtypes/common.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
import numpy as np
1414

15+
from pandas._config import using_string_dtype
16+
1517
from pandas._libs import (
1618
Interval,
1719
Period,
@@ -1325,7 +1327,15 @@ def is_extension_array_dtype(arr_or_dtype) -> bool:
13251327
elif isinstance(dtype, np.dtype):
13261328
return False
13271329
else:
1328-
return registry.find(dtype) is not None
1330+
try:
1331+
with warnings.catch_warnings():
1332+
# pandas_dtype(..) can raise UserWarning for class input
1333+
warnings.simplefilter("ignore", UserWarning)
1334+
dtype = pandas_dtype(dtype)
1335+
except (TypeError, ValueError):
1336+
# np.dtype(..) can raise ValueError
1337+
return False
1338+
return isinstance(dtype, ExtensionDtype)
13291339

13301340

13311341
def is_ea_or_datetimelike_dtype(dtype: DtypeObj | None) -> bool:
@@ -1620,6 +1630,12 @@ def pandas_dtype(dtype) -> DtypeObj:
16201630
elif isinstance(dtype, (np.dtype, ExtensionDtype)):
16211631
return dtype
16221632

1633+
# builtin aliases
1634+
if dtype is str and using_string_dtype():
1635+
from pandas.core.arrays.string_ import StringDtype
1636+
1637+
return StringDtype(na_value=np.nan)
1638+
16231639
# registered extension types
16241640
result = registry.find(dtype)
16251641
if result is not None:

pandas/core/indexes/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6415,7 +6415,11 @@ def _should_compare(self, other: Index) -> bool:
64156415
return False
64166416

64176417
dtype = _unpack_nested_dtype(other)
6418-
return self._is_comparable_dtype(dtype) or is_object_dtype(dtype)
6418+
return (
6419+
self._is_comparable_dtype(dtype)
6420+
or is_object_dtype(dtype)
6421+
or is_string_dtype(dtype)
6422+
)
64196423

64206424
def _is_comparable_dtype(self, dtype: DtypeObj) -> bool:
64216425
"""

pandas/core/indexes/interval.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
is_number,
5151
is_object_dtype,
5252
is_scalar,
53+
is_string_dtype,
5354
pandas_dtype,
5455
)
5556
from pandas.core.dtypes.dtypes import (
@@ -699,7 +700,7 @@ def _get_indexer(
699700
# left/right get_indexer, compare elementwise, equality -> match
700701
indexer = self._get_indexer_unique_sides(target)
701702

702-
elif not is_object_dtype(target.dtype):
703+
elif not (is_object_dtype(target.dtype) or is_string_dtype(target.dtype)):
703704
# homogeneous scalar index: use IntervalTree
704705
# we should always have self._should_partial_index(target) here
705706
target = self._maybe_convert_i8(target)

pandas/tests/arrays/floating/test_astype.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,11 +68,9 @@ def test_astype_str(using_infer_string):
6868

6969
if using_infer_string:
7070
expected = pd.array(["0.1", "0.2", None], dtype=pd.StringDtype(na_value=np.nan))
71-
tm.assert_extension_array_equal(a.astype("str"), expected)
7271

73-
# TODO(infer_string) this should also be a string array like above
74-
expected = np.array(["0.1", "0.2", "<NA>"], dtype="U32")
75-
tm.assert_numpy_array_equal(a.astype(str), expected)
72+
tm.assert_extension_array_equal(a.astype(str), expected)
73+
tm.assert_extension_array_equal(a.astype("str"), expected)
7674
else:
7775
expected = np.array(["0.1", "0.2", "<NA>"], dtype="U32")
7876

pandas/tests/arrays/integer/test_dtypes.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,11 +283,9 @@ def test_astype_str(using_infer_string):
283283

284284
if using_infer_string:
285285
expected = pd.array(["1", "2", None], dtype=pd.StringDtype(na_value=np.nan))
286-
tm.assert_extension_array_equal(a.astype("str"), expected)
287286

288-
# TODO(infer_string) this should also be a string array like above
289-
expected = np.array(["1", "2", "<NA>"], dtype=f"{tm.ENDIAN}U21")
290-
tm.assert_numpy_array_equal(a.astype(str), expected)
287+
tm.assert_extension_array_equal(a.astype(str), expected)
288+
tm.assert_extension_array_equal(a.astype("str"), expected)
291289
else:
292290
expected = np.array(["1", "2", "<NA>"], dtype=f"{tm.ENDIAN}U21")
293291

pandas/tests/arrays/sparse/test_astype.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ def test_astype_all(self, any_real_numpy_dtype):
8181
),
8282
(
8383
SparseArray([0, 1, 10]),
84-
str,
85-
SparseArray(["0", "1", "10"], dtype=SparseDtype(str, "0")),
84+
np.str_,
85+
SparseArray(["0", "1", "10"], dtype=SparseDtype(np.str_, "0")),
8686
),
8787
(SparseArray(["10", "20"]), float, SparseArray([10.0, 20.0])),
8888
(

0 commit comments

Comments
 (0)