Skip to content

Commit

Permalink
Simplify DecimalDtype and DecimalColumn operations (#18111)
Browse files Browse the repository at this point in the history
Broken off (the non-breaking parts) from #18035 as that PR will probably not move forward since it would require a pyarrow minimum version bump to 19

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #18111
  • Loading branch information
mroeschke authored Feb 27, 2025
1 parent aa7f436 commit 7713bc1
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 47 deletions.
1 change: 1 addition & 0 deletions docs/cudf/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,7 @@ def on_missing_reference(app, env, node, contnode):
("py:class", "pd.DataFrame"),
("py:class", "pandas.core.indexes.frozen.FrozenList"),
("py:class", "pa.Array"),
("py:class", "pa.Decimal128Type"),
("py:class", "ScalarLike"),
("py:class", "ParentType"),
("py:class", "pyarrow.lib.DataType"),
Expand Down
30 changes: 8 additions & 22 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import pylibcudf as plc

import cudf
from cudf.api.types import is_scalar
from cudf.core._internals import binaryop
from cudf.core.buffer import acquire_spill_lock, as_buffer
from cudf.core.column.column import ColumnBase
Expand Down Expand Up @@ -73,11 +72,8 @@ def __cuda_array_interface__(self):
def as_decimal_column(
self,
dtype: Dtype,
) -> "DecimalBaseColumn":
if (
isinstance(dtype, cudf.core.dtypes.DecimalDtype)
and dtype.scale < self.dtype.scale
):
) -> DecimalBaseColumn:
if isinstance(dtype, DecimalDtype) and dtype.scale < self.dtype.scale:
warnings.warn(
"cuDF truncates when downcasting decimals to a lower scale. "
"To round, use Series.round() or DataFrame.round()."
Expand Down Expand Up @@ -204,22 +200,17 @@ def normalize_binop_value(self, other) -> Self | cudf.Scalar:
other = other.astype(self.dtype)
return other
if isinstance(other, cudf.Scalar) and isinstance(
# TODO: Should it be possible to cast scalars of other numerical
# types to decimal?
other.dtype,
cudf.core.dtypes.DecimalDtype,
DecimalDtype,
):
# TODO: Should it be possible to cast scalars of other numerical
# types to decimal?
if _same_precision_and_scale(self.dtype, other.dtype):
other = other.astype(self.dtype)
return other
elif is_scalar(other) and isinstance(other, (int, Decimal)):
other = Decimal(other)
metadata = other.as_tuple()
precision = max(len(metadata.digits), metadata.exponent)
scale = -cast(int, metadata.exponent)
return cudf.Scalar(
other, dtype=self.dtype.__class__(precision, scale)
)
elif isinstance(other, (int, Decimal)):
dtype = self.dtype._from_decimal(Decimal(other))
return cudf.Scalar(other, dtype=dtype)
return NotImplemented

def as_numerical_column(
Expand Down Expand Up @@ -373,11 +364,6 @@ def __init__(
children=children,
)

def __setitem__(self, key, value):
if isinstance(value, np.integer):
value = int(value)
super().__setitem__(key, value)

@classmethod
def from_arrow(cls, data: pa.Array):
dtype = Decimal64Dtype.from_arrow(data.type)
Expand Down
4 changes: 3 additions & 1 deletion python/cudf/cudf/core/column/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,9 @@ def total_seconds(self) -> ColumnBase:
# https://github.com/rapidsai/cudf/issues/17664
return (
(self.astype(np.dtype(np.int64)) * conversion)
.astype(cudf.Decimal128Dtype(38, 9))
.astype(
cudf.Decimal128Dtype(cudf.Decimal128Dtype.MAX_PRECISION, 9)
)
.round(decimals=abs(int(math.log10(conversion))))
.astype(np.dtype(np.float64))
)
Expand Down
43 changes: 22 additions & 21 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -776,35 +776,36 @@ def _recursively_replace_fields(self, result: dict) -> dict:
class DecimalDtype(_BaseDtype):
_metadata = ("precision", "scale")

def __init__(self, precision, scale=0):
def __init__(self, precision: int, scale: int = 0) -> None:
self._validate(precision, scale)
self._typ = pa.decimal128(precision, scale)
self._precision = precision
self._scale = scale

@property
def str(self):
def str(self) -> str:
return f"{self.name!s}({self.precision}, {self.scale})"

@property
def precision(self):
def precision(self) -> int:
"""
The decimal precision, in number of decimal digits (an integer).
"""
return self._typ.precision
return self._precision

@precision.setter
def precision(self, value):
def precision(self, value: int) -> None:
self._validate(value, self.scale)
self._typ = pa.decimal128(precision=value, scale=self.scale)
self._precision = value

@property
def scale(self):
def scale(self) -> int:
"""
The decimal scale (an integer).
"""
return self._typ.scale
return self._scale

@property
def itemsize(self):
def itemsize(self) -> int:
"""
Length of one column element in bytes.
"""
Expand All @@ -815,14 +816,14 @@ def type(self):
# might need to account for precision and scale here
return decimal.Decimal

def to_arrow(self):
def to_arrow(self) -> pa.Decimal128Type:
"""
Return the equivalent ``pyarrow`` dtype.
"""
return self._typ
return pa.decimal128(self.precision, self.scale)

@classmethod
def from_arrow(cls, typ):
def from_arrow(cls, typ: pa.Decimal128Type) -> Self:
"""
Construct a cudf decimal dtype from a ``pyarrow`` dtype
Expand Down Expand Up @@ -856,23 +857,23 @@ def __repr__(self):
)

@classmethod
def _validate(cls, precision, scale=0):
def _validate(cls, precision: int, scale: int) -> None:
if precision > cls.MAX_PRECISION:
raise ValueError(
f"Cannot construct a {cls.__name__}"
f" with precision > {cls.MAX_PRECISION}"
)
if abs(scale) > precision:
raise ValueError(f"scale={scale} exceeds precision={precision}")
raise ValueError(f"{scale=} cannot exceed {precision=}")

@classmethod
def _from_decimal(cls, decimal):
def _from_decimal(cls, decimal: decimal.Decimal) -> Self:
"""
Create a cudf.DecimalDtype from a decimal.Decimal object
"""
metadata = decimal.as_tuple()
precision = max(len(metadata.digits), -metadata.exponent)
return cls(precision, -metadata.exponent)
precision = max(len(metadata.digits), -metadata.exponent) # type: ignore[operator]
return cls(precision, -metadata.exponent) # type: ignore[operator]

def serialize(self) -> tuple[dict, list]:
return (
Expand All @@ -885,7 +886,7 @@ def serialize(self) -> tuple[dict, list]:
)

@classmethod
def deserialize(cls, header: dict, frames: list):
def deserialize(cls, header: dict, frames: list) -> Self:
_check_type(cls, header, frames, is_valid_class=issubclass)
return cls(header["precision"], header["scale"])

Expand All @@ -896,8 +897,8 @@ def __eq__(self, other: Dtype) -> bool:
return False
return self.precision == other.precision and self.scale == other.scale

def __hash__(self):
return hash(self._typ)
def __hash__(self) -> int:
return hash(self.to_arrow())


@doc_apply(
Expand Down
6 changes: 3 additions & 3 deletions python/cudf/cudf/core/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ def _preprocess_host_value(value, dtype) -> tuple[ScalarLike, Dtype]:
return value.as_py(), dtype

if isinstance(dtype, cudf.core.dtypes.DecimalDtype):
value = pa.scalar(
value, type=pa.decimal128(dtype.precision, dtype.scale)
).as_py()
if isinstance(value, np.integer):
value = int(value)
value = pa.scalar(value, type=dtype.to_arrow()).as_py()
if isinstance(value, decimal.Decimal) and dtype is None:
dtype = cudf.Decimal128Dtype._from_decimal(value)

Expand Down

0 comments on commit 7713bc1

Please sign in to comment.