Skip to content

Commit

Permalink
feat(python): Add DataType.is_integer and other dtype groups (#12200)
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego authored Nov 14, 2023
1 parent 13f2a7b commit 8af2048
Show file tree
Hide file tree
Showing 15 changed files with 251 additions and 186 deletions.
7 changes: 3 additions & 4 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
from polars.dataframe._html import NotebookFormatter
from polars.dataframe.group_by import DynamicGroupBy, GroupBy, RollingGroupBy
from polars.datatypes import (
FLOAT_DTYPES,
INTEGER_DTYPES,
N_INFER_DEFAULT,
NUMERIC_DTYPES,
Expand Down Expand Up @@ -1419,13 +1418,13 @@ def _div(self, other: Any, *, floordiv: bool) -> DataFrame:
df = (
df
if not floordiv
else df.with_columns([s.floor() for s in df if s.dtype in FLOAT_DTYPES])
else df.with_columns([s.floor() for s in df if s.dtype.is_float()])
)
if floordiv:
int_casts = [
col(column).cast(tp)
for i, (column, tp) in enumerate(self.schema.items())
if tp in INTEGER_DTYPES and orig_dtypes[i] in INTEGER_DTYPES
if tp.is_integer() and orig_dtypes[i].is_integer()
]
if int_casts:
return df.with_columns(int_casts)
Expand Down Expand Up @@ -1711,7 +1710,7 @@ def __getitem__(
dtype = item.dtype
if dtype == Utf8:
return self._from_pydf(self._df.select(item))
elif dtype in INTEGER_DTYPES:
elif dtype.is_integer():
return self._take_with_series(item._pos_idxs(self.shape[0]))

# if no data has been returned, the operation is not supported
Expand Down
4 changes: 0 additions & 4 deletions py-polars/polars/datatypes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
Field,
Float32,
Float64,
FractionalType,
Int8,
Int16,
Int32,
Int64,
IntegerType,
List,
Null,
NumericType,
Object,
Struct,
TemporalType,
Expand Down Expand Up @@ -89,15 +87,13 @@
"Field",
"Float32",
"Float64",
"FractionalType",
"Int16",
"Int32",
"Int64",
"Int8",
"IntegerType",
"List",
"Null",
"NumericType",
"Object",
"Struct",
"TemporalType",
Expand Down
84 changes: 71 additions & 13 deletions py-polars/polars/datatypes/classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,30 @@ def is_not(cls, other: PolarsDataType) -> bool: # noqa: D102
def is_nested(self) -> bool: # noqa: D102
...

@classmethod
def is_numeric(cls) -> bool: # noqa: D102
...

@classmethod
def is_integer(cls) -> bool: # noqa: D102
...

@classmethod
def is_signed_integer(cls) -> bool: # noqa: D102
...

@classmethod
def is_unsigned_integer(cls) -> bool: # noqa: D102
...

@classmethod
def is_float(cls) -> bool: # noqa: D102
...

@classmethod
def is_temporal(cls) -> bool: # noqa: D102
...


class DataType(metaclass=DataTypeClass):
"""Base class for all Polars data types."""
Expand Down Expand Up @@ -161,6 +185,36 @@ def is_nested(self) -> bool:
issue_deprecation_warning(message, version="0.19.10")
return False

@classmethod
def is_numeric(cls) -> bool:
"""Check whether the data type is a numeric type."""
return issubclass(cls, NumericType)

@classmethod
def is_integer(cls) -> bool:
"""Check whether the data type is an integer type."""
return issubclass(cls, IntegerType)

@classmethod
def is_signed_integer(cls) -> bool:
"""Check whether the data type is a signed integer type."""
return issubclass(cls, SignedIntegerType)

@classmethod
def is_unsigned_integer(cls) -> bool:
"""Check whether the data type is an unsigned integer type."""
return issubclass(cls, UnsignedIntegerType)

@classmethod
def is_float(cls) -> bool:
"""Check whether the data type is a temporal type."""
return issubclass(cls, FloatType)

@classmethod
def is_temporal(cls) -> bool:
"""Check whether the data type is a temporal type."""
return issubclass(cls, TemporalType)


def _custom_reconstruct(
cls: type[Any], base: type[Any], state: Any
Expand Down Expand Up @@ -214,14 +268,18 @@ class NumericType(DataType):


class IntegerType(NumericType):
"""Base class for integral data types."""
"""Base class for integer data types."""


class SignedIntegerType(IntegerType):
"""Base class for signed integer data types."""


class FractionalType(NumericType):
"""Base class for fractional data types."""
class UnsignedIntegerType(IntegerType):
"""Base class for unsigned integer data types."""


class FloatType(FractionalType):
class FloatType(NumericType):
"""Base class for float data types."""


Expand Down Expand Up @@ -252,35 +310,35 @@ def is_nested(self) -> bool:
return True


class Int8(IntegerType):
class Int8(SignedIntegerType):
"""8-bit signed integer type."""


class Int16(IntegerType):
class Int16(SignedIntegerType):
"""16-bit signed integer type."""


class Int32(IntegerType):
class Int32(SignedIntegerType):
"""32-bit signed integer type."""


class Int64(IntegerType):
class Int64(SignedIntegerType):
"""64-bit signed integer type."""


class UInt8(IntegerType):
class UInt8(UnsignedIntegerType):
"""8-bit unsigned integer type."""


class UInt16(IntegerType):
class UInt16(UnsignedIntegerType):
"""16-bit unsigned integer type."""


class UInt32(IntegerType):
class UInt32(UnsignedIntegerType):
"""32-bit unsigned integer type."""


class UInt64(IntegerType):
class UInt64(UnsignedIntegerType):
"""64-bit unsigned integer type."""


Expand All @@ -292,7 +350,7 @@ class Float64(FloatType):
"""64-bit floating point type."""


class Decimal(FractionalType):
class Decimal(NumericType):
"""
Decimal 128-bit type with an optional precision and non-negative scale.
Expand Down
6 changes: 2 additions & 4 deletions py-polars/polars/expr/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@
import polars._reexport as pl
from polars import functions as F
from polars.datatypes import (
FLOAT_DTYPES,
INTEGER_DTYPES,
Categorical,
Null,
Struct,
Expand Down Expand Up @@ -9196,8 +9194,8 @@ def _remap_key_or_value_series(
# Values Series has same dtype as keys Series.
dtype = s.dtype
elif (
(s.dtype in INTEGER_DTYPES and dtype_keys in INTEGER_DTYPES)
or (s.dtype in FLOAT_DTYPES and dtype_keys in FLOAT_DTYPES)
(s.dtype.is_integer() and dtype_keys.is_integer())
or (s.dtype.is_float() and dtype_keys.is_float())
or (s.dtype == Utf8 and dtype_keys == Categorical)
):
# Values Series and keys Series are of similar dtypes,
Expand Down
7 changes: 2 additions & 5 deletions py-polars/polars/io/spreadsheet/_write_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from polars.datatypes import (
FLOAT_DTYPES,
INTEGER_DTYPES,
NUMERIC_DTYPES,
Date,
Datetime,
Float64,
Expand Down Expand Up @@ -371,9 +370,7 @@ def _map_str(s: Series) -> Series:
if not row_totals:
row_total_funcs = {}
else:
numeric_cols = {
col for col, tp in df.schema.items() if tp.base_type() in NUMERIC_DTYPES
}
numeric_cols = {col for col, tp in df.schema.items() if tp.is_numeric()}
if not isinstance(row_totals, dict):
sum_cols = (
numeric_cols
Expand Down Expand Up @@ -450,7 +447,7 @@ def _map_str(s: Series) -> Series:
if base_type in dtype_formats:
fmt = dtype_formats.get(tp, dtype_formats[base_type])
column_formats.setdefault(col, fmt)
if base_type in NUMERIC_DTYPES:
if base_type.is_numeric():
if column_totals is True:
column_total_funcs.setdefault(col, "sum")
elif isinstance(column_totals, str):
Expand Down
Loading

0 comments on commit 8af2048

Please sign in to comment.