diff --git a/pandas-stubs/_typing.pyi b/pandas-stubs/_typing.pyi index 1d482b508..476a76db1 100644 --- a/pandas-stubs/_typing.pyi +++ b/pandas-stubs/_typing.pyi @@ -19,6 +19,7 @@ from typing import ( import numpy as np from numpy import typing as npt +import pandas as pd from pandas.core.arrays import ExtensionArray from pandas.core.frame import DataFrame from pandas.core.generic import NDFrame @@ -75,9 +76,61 @@ class FulldatetimeDict(YearMonthDayDict, total=False): # dtypes NpDtype: TypeAlias = str | np.dtype[np.generic] | type[str | complex | bool | object] Dtype: TypeAlias = ExtensionDtype | NpDtype -AstypeArg: TypeAlias = ExtensionDtype | npt.DTypeLike -# DtypeArg specifies all allowable dtypes in a functions its dtype argument DtypeArg: TypeAlias = Dtype | dict[Any, Dtype] +BooleanDtypeArg: TypeAlias = ( + type[bool] | type[np.bool_] | pd.BooleanDtype | Literal["bool"] +) +IntDtypeArg: TypeAlias = ( + Literal["int", "int32"] + | type[int] + | pd.Int8Dtype + | pd.Int16Dtype + | pd.Int32Dtype + | pd.Int64Dtype + | type[np.int8] + | type[np.int16] + | type[np.int32] + | type[np.int64] + | type[np.uint8] + | type[np.uint16] + | type[np.uint32] + | type[np.uint64] + | type[np.intp] + | type[np.uintp] + | type[np.byte] + | type[np.ubyte] +) +StrDtypeArg: TypeAlias = type[str] | pd.StringDtype | Literal["str"] +BytesDtypeArg: TypeAlias = type[bytes] +FloatDtypeArg: TypeAlias = ( + pd.Float32Dtype + | pd.Float64Dtype + | type[np.float16] + | type[np.float32] + | type[np.float64] + | type[float] + | Literal["float"] +) +ComplexDtypeArg: TypeAlias = ( + type[np.complex64] | type[np.complex128] | type[complex] | Literal["complex"] +) +TimedeltaDtypeArg: TypeAlias = Literal["timedelta64[ns]"] +TimestampDtypeArg: TypeAlias = Literal["datetime64[ns]"] +CategoryDtypeArg: TypeAlias = Literal["category"] + +AstypeArg: TypeAlias = ( + BooleanDtypeArg + | IntDtypeArg + | StrDtypeArg + | BytesDtypeArg + | FloatDtypeArg + | ComplexDtypeArg + | TimedeltaDtypeArg + | TimestampDtypeArg + | CategoricalDtype + | ExtensionDtype +) +# DtypeArg specifies all allowable dtypes in a functions its dtype argument DtypeObj: TypeAlias = np.dtype[np.generic] | ExtensionDtype # filenames and file-like-objects diff --git a/pandas-stubs/core/frame.pyi b/pandas-stubs/core/frame.pyi index 664803b81..5cf47ede9 100644 --- a/pandas-stubs/core/frame.pyi +++ b/pandas-stubs/core/frame.pyi @@ -59,6 +59,7 @@ from pandas._typing import ( AggFuncTypeFrame, AnyArrayLike, ArrayLike, + AstypeArg, Axes, Axis, AxisType, @@ -1440,9 +1441,9 @@ class DataFrame(NDFrame, OpsMixin): ) -> DataFrame: ... def astype( self, - dtype: _str | Dtype | Mapping[HashableT, _str | Dtype] | Series, + dtype: AstypeArg | Mapping[Any, Dtype] | Series, copy: _bool = ..., - errors: _str = ..., + errors: IgnoreRaise = ..., ) -> DataFrame: ... def at_time( self, diff --git a/pandas-stubs/core/generic.pyi b/pandas-stubs/core/generic.pyi index fa7589270..557607d9f 100644 --- a/pandas-stubs/core/generic.pyi +++ b/pandas-stubs/core/generic.pyi @@ -374,12 +374,6 @@ class NDFrame(PandasObject, indexing.IndexingMixin): def values(self) -> ArrayLike: ... @property def dtypes(self): ... - def astype( - self: NDFrameT, - dtype, - copy: _bool = ..., - errors: IgnoreRaise = ..., - ) -> NDFrameT: ... def copy(self: NDFrameT, deep: _bool = ...) -> NDFrameT: ... def __copy__(self, deep: _bool = ...) -> NDFrame: ... def __deepcopy__(self, memo=...) -> NDFrame: ... diff --git a/pandas-stubs/core/series.pyi b/pandas-stubs/core/series.pyi index 3e9e46611..79a6f2a18 100644 --- a/pandas-stubs/core/series.pyi +++ b/pandas-stubs/core/series.pyi @@ -20,6 +20,12 @@ from typing import ( overload, ) +from core.api import ( + Int8Dtype as Int8Dtype, + Int16Dtype as Int16Dtype, + Int32Dtype as Int32Dtype, + Int64Dtype as Int64Dtype, +) from matplotlib.axes import ( Axes as PlotAxes, SubplotBase, @@ -80,17 +86,23 @@ from pandas._typing import ( Axes, Axis, AxisType, + BooleanDtypeArg, + BytesDtypeArg, CalculationMethod, + CategoryDtypeArg, + ComplexDtypeArg, CompressionOptions, DtypeObj, FilePath, FillnaOptions, + FloatDtypeArg, GroupByObjectNonScalar, HashableT1, HashableT2, HashableT3, IgnoreRaise, IndexingInt, + IntDtypeArg, IntervalClosedType, JoinHow, JsonSeriesOrient, @@ -106,7 +118,10 @@ from pandas._typing import ( Scalar, SeriesAxisType, SortKind, + StrDtypeArg, + TimedeltaDtypeArg, TimestampConvention, + TimestampDtypeArg, WriteBuffer, np_ndarray_anyint, np_ndarray_bool, @@ -114,6 +129,8 @@ from pandas._typing import ( num, ) +from pandas.core.dtypes.base import ExtensionDtype + from pandas.plotting import PlotAccessor from .base import IndexOpsMixin @@ -1035,9 +1052,73 @@ class Series(IndexOpsMixin, NDFrame, Generic[S1]): axis: SeriesAxisType | None = ..., ignore_index: _bool = ..., ) -> Series[S1]: ... + @overload + def astype( # type: ignore[misc] + self, + dtype: BooleanDtypeArg, + copy: _bool = ..., + errors: IgnoreRaise = ..., + ) -> Series[bool]: ... + @overload + def astype( + self, + dtype: IntDtypeArg, + copy: _bool = ..., + errors: IgnoreRaise = ..., + ) -> Series[int]: ... + @overload + def astype( + self, + dtype: StrDtypeArg, + copy: _bool = ..., + errors: IgnoreRaise = ..., + ) -> Series[_str]: ... + @overload + def astype( + self, + dtype: BytesDtypeArg, + copy: _bool = ..., + errors: IgnoreRaise = ..., + ) -> Series[bytes]: ... + @overload + def astype( + self, + dtype: FloatDtypeArg, + copy: _bool = ..., + errors: IgnoreRaise = ..., + ) -> Series[float]: ... + @overload + def astype( + self, + dtype: ComplexDtypeArg, + copy: _bool = ..., + errors: IgnoreRaise = ..., + ) -> Series[complex]: ... + @overload + def astype( + self, + dtype: TimedeltaDtypeArg, + copy: _bool = ..., + errors: IgnoreRaise = ..., + ) -> TimedeltaSeries: ... + @overload + def astype( + self, + dtype: TimestampDtypeArg, + copy: _bool = ..., + errors: IgnoreRaise = ..., + ) -> TimestampSeries: ... + @overload + def astype( + self, + dtype: CategoryDtypeArg, + copy: _bool = ..., + errors: IgnoreRaise = ..., + ) -> Series: ... + @overload def astype( self, - dtype: S1 | _str | type[Scalar], + dtype: ExtensionDtype, copy: _bool = ..., errors: IgnoreRaise = ..., ) -> Series: ... diff --git a/tests/test_frame.py b/tests/test_frame.py index 0e55e3f1b..3b907569b 100644 --- a/tests/test_frame.py +++ b/tests/test_frame.py @@ -2297,8 +2297,9 @@ def test_astype_dict() -> None: # GH 447 df = pd.DataFrame({"a": [1, 2, 3], 43: [4, 5, 6]}) columns_types = {"a": "int", 43: "float"} - df = df.astype(columns_types) - check(assert_type(df, pd.DataFrame), pd.DataFrame) + de = df.astype(columns_types) + check(assert_type(de, pd.DataFrame), pd.DataFrame) + check(assert_type(df.astype({"a": "int", 43: "float"}), pd.DataFrame), pd.DataFrame) def test_setitem_none() -> None: @@ -2423,3 +2424,19 @@ def test_insert_newvalues() -> None: assert assert_type(df.insert(loc=0, column="b", value=None), None) is None assert assert_type(ab.insert(loc=0, column="newcol", value=[99, 99]), None) is None assert assert_type(ef.insert(loc=0, column="g", value=4), None) is None + + +def test_astype() -> None: + s = pd.DataFrame({"d": [1, 2]}) + ab = pd.DataFrame({"col1": [1, 2], "col2": ["a", "b"]}) + + check(assert_type(s.astype(int), "pd.DataFrame"), pd.DataFrame) + check(assert_type(s.astype(pd.Int64Dtype()), "pd.DataFrame"), pd.DataFrame) + check(assert_type(s.astype(str), "pd.DataFrame"), pd.DataFrame) + check(assert_type(s.astype(bytes), "pd.DataFrame"), pd.DataFrame) + check(assert_type(s.astype(pd.Float64Dtype()), "pd.DataFrame"), pd.DataFrame) + check(assert_type(s.astype(complex), "pd.DataFrame"), pd.DataFrame) + check( + assert_type(ab.astype({"col1": "int32", "col2": str}), "pd.DataFrame"), + pd.DataFrame, + ) diff --git a/tests/test_series.py b/tests/test_series.py index 05fe07776..c1e5465c3 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -1,6 +1,8 @@ from __future__ import annotations import datetime + +# from decimal import Decimal from pathlib import Path import re from typing import ( @@ -25,10 +27,17 @@ ExtensionDtype, ) from pandas.core.window import ExponentialMovingWindow + +# from pandas.tests.extension.decimal import DecimalDtype import pytest -from typing_extensions import assert_type +from typing_extensions import ( + TypeAlias, + assert_type, +) import xarray as xr +from pandas._libs.tslibs.timedeltas import Timedelta +from pandas._libs.tslibs.timestamps import Timestamp from pandas._typing import ( DtypeObj, Scalar, @@ -41,6 +50,15 @@ pytest_warns_bounded, ) +if TYPE_CHECKING: + from pandas.core.series import ( + TimedeltaSeries, + TimestampSeries, + ) +else: + TimedeltaSeries: TypeAlias = pd.Series + TimestampSeries: TypeAlias = pd.Series + if TYPE_CHECKING: from pandas._typing import np_ndarray_int # noqa: F401 @@ -1425,3 +1443,93 @@ def test_change_to_dict_return_type() -> None: df = pd.DataFrame(zip(id, value), columns=["id", "value"]) fd = df.set_index("id")["value"].to_dict() check(assert_type(fd, Dict[Any, Any]), dict) + + +def test_updated_astype() -> None: + s = pd.Series([3, 4, 5]) + s1 = pd.Series(True) + + check(assert_type(s.astype(int), "pd.Series[int]"), pd.Series, np.integer) + check(assert_type(s.astype("int"), "pd.Series[int]"), pd.Series, np.integer) + check(assert_type(s.astype("int32"), "pd.Series[int]"), pd.Series, np.int32) + check(assert_type(s.astype(pd.Int8Dtype()), "pd.Series[int]"), pd.Series, np.int8) + check(assert_type(s.astype(pd.Int16Dtype()), "pd.Series[int]"), pd.Series, np.int16) + check(assert_type(s.astype(pd.Int32Dtype()), "pd.Series[int]"), pd.Series, np.int32) + check(assert_type(s.astype(pd.Int64Dtype()), "pd.Series[int]"), pd.Series, np.int64) + check(assert_type(s.astype(np.int8), "pd.Series[int]"), pd.Series, np.int8) + check(assert_type(s.astype(np.int16), "pd.Series[int]"), pd.Series, np.int16) + check(assert_type(s.astype(np.int32), "pd.Series[int]"), pd.Series, np.int32) + check(assert_type(s.astype(np.int64), "pd.Series[int]"), pd.Series, np.int64) + check(assert_type(s.astype(np.uint8), "pd.Series[int]"), pd.Series, np.uint8) + check(assert_type(s.astype(np.uint16), "pd.Series[int]"), pd.Series, np.uint16) + check(assert_type(s.astype(np.uint32), "pd.Series[int]"), pd.Series, np.uint32) + check(assert_type(s.astype(np.uint64), "pd.Series[int]"), pd.Series, np.uint64) + check(assert_type(s.astype(np.intp), "pd.Series[int]"), pd.Series, np.intp) + check(assert_type(s.astype(np.uintp), "pd.Series[int]"), pd.Series, np.uintp) + check(assert_type(s.astype(np.byte), "pd.Series[int]"), pd.Series, np.byte) + check(assert_type(s.astype(np.ubyte), "pd.Series[int]"), pd.Series, np.ubyte) + + check(assert_type(s.astype(str), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.astype(pd.StringDtype()), "pd.Series[str]"), pd.Series, str) + check(assert_type(s.astype("str"), "pd.Series[str]"), pd.Series, str) + + check(assert_type(s.astype(bytes), "pd.Series[bytes]"), pd.Series, bytes) + + check( + assert_type(s.astype(pd.Float32Dtype()), "pd.Series[float]"), + pd.Series, + np.float32, + ) + check( + assert_type(s.astype(pd.Float64Dtype()), "pd.Series[float]"), + pd.Series, + np.float64, + ) + check(assert_type(s.astype(np.float16), "pd.Series[float]"), pd.Series, np.float16) + check(assert_type(s.astype(np.float32), "pd.Series[float]"), pd.Series, np.float32) + check(assert_type(s.astype(np.float64), "pd.Series[float]"), pd.Series, np.float64) + check(assert_type(s.astype(float), "pd.Series[float]"), pd.Series, float) + check(assert_type(s.astype("float"), "pd.Series[float]"), pd.Series, float) + + check( + assert_type(s.astype(np.complex64), "pd.Series[complex]"), + pd.Series, + np.complex64, + ) + check( + assert_type(s.astype(np.complex128), "pd.Series[complex]"), + pd.Series, + np.complex128, + ) + check(assert_type(s.astype(complex), "pd.Series[complex]"), pd.Series, complex) + check(assert_type(s.astype("complex"), "pd.Series[complex]"), pd.Series, complex) + + check( + assert_type(s1.astype(pd.BooleanDtype()), "pd.Series[bool]"), + pd.Series, + np.bool_, + ) + check(assert_type(s.astype("bool"), "pd.Series[bool]"), pd.Series, np.bool_) + check(assert_type(s.astype(bool), "pd.Series[bool]"), pd.Series, np.bool_) + check(assert_type(s.astype(np.bool_), "pd.Series[bool]"), pd.Series, np.bool_) + + check( + assert_type(s.astype("timedelta64[ns]"), TimedeltaSeries), + pd.Series, + Timedelta, + ) + + check( + assert_type(s.astype("datetime64[ns]"), TimestampSeries), + pd.Series, + Timestamp, + ) + + # orseries = pd.Series([Decimal(x) for x in [1, 2, 3]]) + # newtype: ExtensionDtype = DecimalDtype() + # decseries = orseries.astype(newtype) + # check( + # assert_type(decseries, pd.Series), + # pd.Series, + # Decimal, + # )