Skip to content

Commit c6815aa

Browse files
authored
gh-372 : Fixing Series.astype() (#519)
* adding overloads to astype * Created the table for astype * Update table.rst * Updated the table and added numpy dtypes * Update table.rst * updated np.datetime64 * Update table.rst * added types in Timedelta * removed not required args in Dtype * removed np.timedelta64 in Timedelta * Removed timedelta64 * expanding series astype * Added type in args in dtype * corrected the args * adding a overload for 'category' and normal changes * added tests * removed unused args * corrected tests * Delete table.rst * added the bool overload to top and done the required test changes * added type_checker * added types for check and did requested changes * updated the check types * added astype in dataframe and other changes * Update test_series.py * Update test_series.py * added dict test for astype in datatest_frame and tests for ExtensionDtype in test_series * commented out the decimal tests * Update test_series.py * updated dtype args in astype * added any to list of args for astype * changed dtype args
1 parent 1556bdf commit c6815aa

File tree

6 files changed

+268
-14
lines changed

6 files changed

+268
-14
lines changed

pandas-stubs/_typing.pyi

+55-2
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ from typing import (
1919

2020
import numpy as np
2121
from numpy import typing as npt
22+
import pandas as pd
2223
from pandas.core.arrays import ExtensionArray
2324
from pandas.core.frame import DataFrame
2425
from pandas.core.generic import NDFrame
@@ -75,9 +76,61 @@ class FulldatetimeDict(YearMonthDayDict, total=False):
7576
# dtypes
7677
NpDtype: TypeAlias = str | np.dtype[np.generic] | type[str | complex | bool | object]
7778
Dtype: TypeAlias = ExtensionDtype | NpDtype
78-
AstypeArg: TypeAlias = ExtensionDtype | npt.DTypeLike
79-
# DtypeArg specifies all allowable dtypes in a functions its dtype argument
8079
DtypeArg: TypeAlias = Dtype | dict[Any, Dtype]
80+
BooleanDtypeArg: TypeAlias = (
81+
type[bool] | type[np.bool_] | pd.BooleanDtype | Literal["bool"]
82+
)
83+
IntDtypeArg: TypeAlias = (
84+
Literal["int", "int32"]
85+
| type[int]
86+
| pd.Int8Dtype
87+
| pd.Int16Dtype
88+
| pd.Int32Dtype
89+
| pd.Int64Dtype
90+
| type[np.int8]
91+
| type[np.int16]
92+
| type[np.int32]
93+
| type[np.int64]
94+
| type[np.uint8]
95+
| type[np.uint16]
96+
| type[np.uint32]
97+
| type[np.uint64]
98+
| type[np.intp]
99+
| type[np.uintp]
100+
| type[np.byte]
101+
| type[np.ubyte]
102+
)
103+
StrDtypeArg: TypeAlias = type[str] | pd.StringDtype | Literal["str"]
104+
BytesDtypeArg: TypeAlias = type[bytes]
105+
FloatDtypeArg: TypeAlias = (
106+
pd.Float32Dtype
107+
| pd.Float64Dtype
108+
| type[np.float16]
109+
| type[np.float32]
110+
| type[np.float64]
111+
| type[float]
112+
| Literal["float"]
113+
)
114+
ComplexDtypeArg: TypeAlias = (
115+
type[np.complex64] | type[np.complex128] | type[complex] | Literal["complex"]
116+
)
117+
TimedeltaDtypeArg: TypeAlias = Literal["timedelta64[ns]"]
118+
TimestampDtypeArg: TypeAlias = Literal["datetime64[ns]"]
119+
CategoryDtypeArg: TypeAlias = Literal["category"]
120+
121+
AstypeArg: TypeAlias = (
122+
BooleanDtypeArg
123+
| IntDtypeArg
124+
| StrDtypeArg
125+
| BytesDtypeArg
126+
| FloatDtypeArg
127+
| ComplexDtypeArg
128+
| TimedeltaDtypeArg
129+
| TimestampDtypeArg
130+
| CategoricalDtype
131+
| ExtensionDtype
132+
)
133+
# DtypeArg specifies all allowable dtypes in a functions its dtype argument
81134
DtypeObj: TypeAlias = np.dtype[np.generic] | ExtensionDtype
82135

83136
# filenames and file-like-objects

pandas-stubs/core/frame.pyi

+3-2
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ from pandas._typing import (
5959
AggFuncTypeFrame,
6060
AnyArrayLike,
6161
ArrayLike,
62+
AstypeArg,
6263
Axes,
6364
Axis,
6465
AxisType,
@@ -1440,9 +1441,9 @@ class DataFrame(NDFrame, OpsMixin):
14401441
) -> DataFrame: ...
14411442
def astype(
14421443
self,
1443-
dtype: _str | Dtype | Mapping[HashableT, _str | Dtype] | Series,
1444+
dtype: AstypeArg | Mapping[Any, Dtype] | Series,
14441445
copy: _bool = ...,
1445-
errors: _str = ...,
1446+
errors: IgnoreRaise = ...,
14461447
) -> DataFrame: ...
14471448
def at_time(
14481449
self,

pandas-stubs/core/generic.pyi

-6
Original file line numberDiff line numberDiff line change
@@ -374,12 +374,6 @@ class NDFrame(PandasObject, indexing.IndexingMixin):
374374
def values(self) -> ArrayLike: ...
375375
@property
376376
def dtypes(self): ...
377-
def astype(
378-
self: NDFrameT,
379-
dtype,
380-
copy: _bool = ...,
381-
errors: IgnoreRaise = ...,
382-
) -> NDFrameT: ...
383377
def copy(self: NDFrameT, deep: _bool = ...) -> NDFrameT: ...
384378
def __copy__(self, deep: _bool = ...) -> NDFrame: ...
385379
def __deepcopy__(self, memo=...) -> NDFrame: ...

pandas-stubs/core/series.pyi

+82-1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ from typing import (
2020
overload,
2121
)
2222

23+
from core.api import (
24+
Int8Dtype as Int8Dtype,
25+
Int16Dtype as Int16Dtype,
26+
Int32Dtype as Int32Dtype,
27+
Int64Dtype as Int64Dtype,
28+
)
2329
from matplotlib.axes import (
2430
Axes as PlotAxes,
2531
SubplotBase,
@@ -80,17 +86,23 @@ from pandas._typing import (
8086
Axes,
8187
Axis,
8288
AxisType,
89+
BooleanDtypeArg,
90+
BytesDtypeArg,
8391
CalculationMethod,
92+
CategoryDtypeArg,
93+
ComplexDtypeArg,
8494
CompressionOptions,
8595
DtypeObj,
8696
FilePath,
8797
FillnaOptions,
98+
FloatDtypeArg,
8899
GroupByObjectNonScalar,
89100
HashableT1,
90101
HashableT2,
91102
HashableT3,
92103
IgnoreRaise,
93104
IndexingInt,
105+
IntDtypeArg,
94106
IntervalClosedType,
95107
JoinHow,
96108
JsonSeriesOrient,
@@ -106,14 +118,19 @@ from pandas._typing import (
106118
Scalar,
107119
SeriesAxisType,
108120
SortKind,
121+
StrDtypeArg,
122+
TimedeltaDtypeArg,
109123
TimestampConvention,
124+
TimestampDtypeArg,
110125
WriteBuffer,
111126
np_ndarray_anyint,
112127
np_ndarray_bool,
113128
npt,
114129
num,
115130
)
116131

132+
from pandas.core.dtypes.base import ExtensionDtype
133+
117134
from pandas.plotting import PlotAccessor
118135

119136
from .base import IndexOpsMixin
@@ -1035,9 +1052,73 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]):
10351052
axis: SeriesAxisType | None = ...,
10361053
ignore_index: _bool = ...,
10371054
) -> Series[S1]: ...
1055+
@overload
1056+
def astype( # type: ignore[misc]
1057+
self,
1058+
dtype: BooleanDtypeArg,
1059+
copy: _bool = ...,
1060+
errors: IgnoreRaise = ...,
1061+
) -> Series[bool]: ...
1062+
@overload
1063+
def astype(
1064+
self,
1065+
dtype: IntDtypeArg,
1066+
copy: _bool = ...,
1067+
errors: IgnoreRaise = ...,
1068+
) -> Series[int]: ...
1069+
@overload
1070+
def astype(
1071+
self,
1072+
dtype: StrDtypeArg,
1073+
copy: _bool = ...,
1074+
errors: IgnoreRaise = ...,
1075+
) -> Series[_str]: ...
1076+
@overload
1077+
def astype(
1078+
self,
1079+
dtype: BytesDtypeArg,
1080+
copy: _bool = ...,
1081+
errors: IgnoreRaise = ...,
1082+
) -> Series[bytes]: ...
1083+
@overload
1084+
def astype(
1085+
self,
1086+
dtype: FloatDtypeArg,
1087+
copy: _bool = ...,
1088+
errors: IgnoreRaise = ...,
1089+
) -> Series[float]: ...
1090+
@overload
1091+
def astype(
1092+
self,
1093+
dtype: ComplexDtypeArg,
1094+
copy: _bool = ...,
1095+
errors: IgnoreRaise = ...,
1096+
) -> Series[complex]: ...
1097+
@overload
1098+
def astype(
1099+
self,
1100+
dtype: TimedeltaDtypeArg,
1101+
copy: _bool = ...,
1102+
errors: IgnoreRaise = ...,
1103+
) -> TimedeltaSeries: ...
1104+
@overload
1105+
def astype(
1106+
self,
1107+
dtype: TimestampDtypeArg,
1108+
copy: _bool = ...,
1109+
errors: IgnoreRaise = ...,
1110+
) -> TimestampSeries: ...
1111+
@overload
1112+
def astype(
1113+
self,
1114+
dtype: CategoryDtypeArg,
1115+
copy: _bool = ...,
1116+
errors: IgnoreRaise = ...,
1117+
) -> Series: ...
1118+
@overload
10381119
def astype(
10391120
self,
1040-
dtype: S1 | _str | type[Scalar],
1121+
dtype: ExtensionDtype,
10411122
copy: _bool = ...,
10421123
errors: IgnoreRaise = ...,
10431124
) -> Series: ...

tests/test_frame.py

+19-2
Original file line numberDiff line numberDiff line change
@@ -2299,8 +2299,9 @@ def test_astype_dict() -> None:
22992299
# GH 447
23002300
df = pd.DataFrame({"a": [1, 2, 3], 43: [4, 5, 6]})
23012301
columns_types = {"a": "int", 43: "float"}
2302-
df = df.astype(columns_types)
2303-
check(assert_type(df, pd.DataFrame), pd.DataFrame)
2302+
de = df.astype(columns_types)
2303+
check(assert_type(de, pd.DataFrame), pd.DataFrame)
2304+
check(assert_type(df.astype({"a": "int", 43: "float"}), pd.DataFrame), pd.DataFrame)
23042305

23052306

23062307
def test_setitem_none() -> None:
@@ -2425,3 +2426,19 @@ def test_insert_newvalues() -> None:
24252426
assert assert_type(df.insert(loc=0, column="b", value=None), None) is None
24262427
assert assert_type(ab.insert(loc=0, column="newcol", value=[99, 99]), None) is None
24272428
assert assert_type(ef.insert(loc=0, column="g", value=4), None) is None
2429+
2430+
2431+
def test_astype() -> None:
2432+
s = pd.DataFrame({"d": [1, 2]})
2433+
ab = pd.DataFrame({"col1": [1, 2], "col2": ["a", "b"]})
2434+
2435+
check(assert_type(s.astype(int), "pd.DataFrame"), pd.DataFrame)
2436+
check(assert_type(s.astype(pd.Int64Dtype()), "pd.DataFrame"), pd.DataFrame)
2437+
check(assert_type(s.astype(str), "pd.DataFrame"), pd.DataFrame)
2438+
check(assert_type(s.astype(bytes), "pd.DataFrame"), pd.DataFrame)
2439+
check(assert_type(s.astype(pd.Float64Dtype()), "pd.DataFrame"), pd.DataFrame)
2440+
check(assert_type(s.astype(complex), "pd.DataFrame"), pd.DataFrame)
2441+
check(
2442+
assert_type(ab.astype({"col1": "int32", "col2": str}), "pd.DataFrame"),
2443+
pd.DataFrame,
2444+
)

0 commit comments

Comments
 (0)