Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Add DataType.is_integer and other dtype groups #12200

Merged
merged 11 commits into from
Nov 14, 2023
Merged
7 changes: 3 additions & 4 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from polars.dataframe._html import NotebookFormatter
from polars.dataframe.group_by import DynamicGroupBy, GroupBy, RollingGroupBy
from polars.datatypes import (
FLOAT_DTYPES,
INTEGER_DTYPES,
N_INFER_DEFAULT,
NUMERIC_DTYPES,
Expand Down Expand Up @@ -1419,13 +1418,13 @@ def _div(self, other: Any, *, floordiv: bool) -> DataFrame:
df = (
df
if not floordiv
else df.with_columns([s.floor() for s in df if s.dtype in FLOAT_DTYPES])
else df.with_columns([s.floor() for s in df if s.dtype.is_float()])
)
if floordiv:
int_casts = [
col(column).cast(tp)
for i, (column, tp) in enumerate(self.schema.items())
if tp in INTEGER_DTYPES and orig_dtypes[i] in INTEGER_DTYPES
if tp.is_integer() and orig_dtypes[i].is_integer()
]
if int_casts:
return df.with_columns(int_casts)
Expand Down Expand Up @@ -1711,7 +1710,7 @@ def __getitem__(
dtype = item.dtype
if dtype == Utf8:
return self._from_pydf(self._df.select(item))
elif dtype in INTEGER_DTYPES:
elif dtype.is_integer():
return self._take_with_series(item._pos_idxs(self.shape[0]))

# if no data has been returned, the operation is not supported
Expand Down
4 changes: 0 additions & 4 deletions py-polars/polars/datatypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
Field,
Float32,
Float64,
FractionalType,
Int8,
Int16,
Int32,
Int64,
IntegerType,
List,
Null,
NumericType,
Object,
Struct,
TemporalType,
Expand Down Expand Up @@ -89,15 +87,13 @@
"Field",
"Float32",
"Float64",
"FractionalType",
"Int16",
"Int32",
"Int64",
"Int8",
"IntegerType",
"List",
"Null",
"NumericType",
"Object",
"Struct",
"TemporalType",
Expand Down
84 changes: 71 additions & 13 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,30 @@ def is_not(cls, other: PolarsDataType) -> bool: # noqa: D102
def is_nested(self) -> bool: # noqa: D102
...

@classmethod
def is_numeric(cls) -> bool: # noqa: D102
...

@classmethod
def is_integer(cls) -> bool: # noqa: D102
...

@classmethod
def is_signed_integer(cls) -> bool: # noqa: D102
...

@classmethod
def is_unsigned_integer(cls) -> bool: # noqa: D102
...

@classmethod
def is_float(cls) -> bool: # noqa: D102
...

@classmethod
def is_temporal(cls) -> bool: # noqa: D102
...


class DataType(metaclass=DataTypeClass):
"""Base class for all Polars data types."""
Expand Down Expand Up @@ -161,6 +185,36 @@ def is_nested(self) -> bool:
issue_deprecation_warning(message, version="0.19.10")
return False

@classmethod
def is_numeric(cls) -> bool:
"""Check whether the data type is a numeric type."""
return issubclass(cls, NumericType)

@classmethod
def is_integer(cls) -> bool:
"""Check whether the data type is an integer type."""
return issubclass(cls, IntegerType)

@classmethod
def is_signed_integer(cls) -> bool:
"""Check whether the data type is a signed integer type."""
return issubclass(cls, SignedIntegerType)

@classmethod
def is_unsigned_integer(cls) -> bool:
"""Check whether the data type is an unsigned integer type."""
return issubclass(cls, UnsignedIntegerType)

@classmethod
def is_float(cls) -> bool:
"""Check whether the data type is a temporal type."""
return issubclass(cls, FloatType)

@classmethod
def is_temporal(cls) -> bool:
"""Check whether the data type is a temporal type."""
return issubclass(cls, TemporalType)


def _custom_reconstruct(
cls: type[Any], base: type[Any], state: Any
Expand Down Expand Up @@ -214,14 +268,18 @@ class NumericType(DataType):


class IntegerType(NumericType):
"""Base class for integral data types."""
"""Base class for integer data types."""


class SignedIntegerType(IntegerType):
"""Base class for signed integer data types."""


class FractionalType(NumericType):
"""Base class for fractional data types."""
class UnsignedIntegerType(IntegerType):
"""Base class for unsigned integer data types."""


class FloatType(FractionalType):
class FloatType(NumericType):
"""Base class for float data types."""


Expand Down Expand Up @@ -252,35 +310,35 @@ def is_nested(self) -> bool:
return True


class Int8(IntegerType):
class Int8(SignedIntegerType):
"""8-bit signed integer type."""


class Int16(IntegerType):
class Int16(SignedIntegerType):
"""16-bit signed integer type."""


class Int32(IntegerType):
class Int32(SignedIntegerType):
"""32-bit signed integer type."""


class Int64(IntegerType):
class Int64(SignedIntegerType):
"""64-bit signed integer type."""


class UInt8(IntegerType):
class UInt8(UnsignedIntegerType):
"""8-bit unsigned integer type."""


class UInt16(IntegerType):
class UInt16(UnsignedIntegerType):
"""16-bit unsigned integer type."""


class UInt32(IntegerType):
class UInt32(UnsignedIntegerType):
"""32-bit unsigned integer type."""


class UInt64(IntegerType):
class UInt64(UnsignedIntegerType):
"""64-bit unsigned integer type."""


Expand All @@ -292,7 +350,7 @@ class Float64(FloatType):
"""64-bit floating point type."""


class Decimal(FractionalType):
class Decimal(NumericType):
"""
Decimal 128-bit type with an optional precision and non-negative scale.

Expand Down
6 changes: 2 additions & 4 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import polars._reexport as pl
from polars import functions as F
from polars.datatypes import (
FLOAT_DTYPES,
INTEGER_DTYPES,
Categorical,
Null,
Struct,
Expand Down Expand Up @@ -9196,8 +9194,8 @@ def _remap_key_or_value_series(
# Values Series has same dtype as keys Series.
dtype = s.dtype
elif (
(s.dtype in INTEGER_DTYPES and dtype_keys in INTEGER_DTYPES)
or (s.dtype in FLOAT_DTYPES and dtype_keys in FLOAT_DTYPES)
(s.dtype.is_integer() and dtype_keys.is_integer())
or (s.dtype.is_float() and dtype_keys.is_float())
or (s.dtype == Utf8 and dtype_keys == Categorical)
):
# Values Series and keys Series are of similar dtypes,
Expand Down
7 changes: 2 additions & 5 deletions py-polars/polars/io/spreadsheet/_write_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from polars.datatypes import (
FLOAT_DTYPES,
INTEGER_DTYPES,
NUMERIC_DTYPES,
Date,
Datetime,
Float64,
Expand Down Expand Up @@ -371,9 +370,7 @@ def _map_str(s: Series) -> Series:
if not row_totals:
row_total_funcs = {}
else:
numeric_cols = {
col for col, tp in df.schema.items() if tp.base_type() in NUMERIC_DTYPES
}
numeric_cols = {col for col, tp in df.schema.items() if tp.is_numeric()}
if not isinstance(row_totals, dict):
sum_cols = (
numeric_cols
Expand Down Expand Up @@ -450,7 +447,7 @@ def _map_str(s: Series) -> Series:
if base_type in dtype_formats:
fmt = dtype_formats.get(tp, dtype_formats[base_type])
column_formats.setdefault(col, fmt)
if base_type in NUMERIC_DTYPES:
if base_type.is_numeric():
if column_totals is True:
column_total_funcs.setdefault(col, "sum")
elif isinstance(column_totals, str):
Expand Down
Loading