Skip to content

Commit

Permalink
Remove cudf.Scalar from shift/fillna (#17922)
Browse files Browse the repository at this point in the history
Toward #17843

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)
  - Vyas Ramasubramani (https://github.com/vyasr)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #17922
  • Loading branch information
mroeschke authored Mar 4, 2025
1 parent 45bd05d commit 45d8066
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 18 deletions.
9 changes: 7 additions & 2 deletions python/cudf/cudf/core/column/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from cudf.core.scalar import pa_scalar_to_plc_scalar
from cudf.utils.dtypes import (
SIZE_TYPE_DTYPE,
cudf_dtype_to_pa_type,
find_common_type,
is_mixed_with_object_dtype,
min_signed_type,
Expand Down Expand Up @@ -1042,7 +1043,7 @@ def notnull(self) -> ColumnBase:

def _validate_fillna_value(
self, fill_value: ScalarLike | ColumnLike
) -> cudf.Scalar | ColumnBase:
) -> plc.Scalar | ColumnBase:
"""Align fill_value for .fillna based on column type."""
if cudf.api.types.is_scalar(fill_value):
if fill_value != _DEFAULT_CATEGORICAL_VALUE:
Expand All @@ -1052,7 +1053,11 @@ def _validate_fillna_value(
raise ValueError(
f"{fill_value=} must be in categories"
) from err
return cudf.Scalar(fill_value, dtype=self.codes.dtype)
return pa_scalar_to_plc_scalar(
pa.scalar(
fill_value, type=cudf_dtype_to_pa_type(self.codes.dtype)
)
)
else:
fill_value = column.as_column(fill_value, nan_as_null=False)
if isinstance(fill_value.dtype, CategoricalDtype):
Expand Down
23 changes: 15 additions & 8 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -891,12 +891,11 @@ def _fill(

@acquire_spill_lock()
def shift(self, offset: int, fill_value: ScalarLike) -> Self:
if not isinstance(fill_value, cudf.Scalar):
fill_value = cudf.Scalar(fill_value, dtype=self.dtype)
plc_fill_value = self._scalar_to_plc_scalar(fill_value)
plc_col = plc.copying.shift(
self.to_pylibcudf(mode="read"),
offset,
fill_value.device_value,
plc_fill_value,
)
return type(self).from_pylibcudf(plc_col) # type: ignore[return-value]

Expand Down Expand Up @@ -1188,13 +1187,21 @@ def _check_scatter_key_length(
f"{num_keys}"
)

def _scalar_to_plc_scalar(self, scalar: ScalarLike) -> plc.Scalar:
"""Return a pylibcudf.Scalar that matches the type of self.dtype"""
if not isinstance(scalar, pa.Scalar):
scalar = pa.scalar(scalar)
return pa_scalar_to_plc_scalar(
scalar.cast(cudf_dtype_to_pa_type(self.dtype))
)

def _validate_fillna_value(
self, fill_value: ScalarLike | ColumnLike
) -> cudf.Scalar | ColumnBase:
) -> plc.Scalar | ColumnBase:
"""Align fill_value for .fillna based on column type."""
if is_scalar(fill_value):
return cudf.Scalar(fill_value, dtype=self.dtype)
return as_column(fill_value)
return self._scalar_to_plc_scalar(fill_value)
return as_column(fill_value).astype(self.dtype)

@acquire_spill_lock()
def replace(
Expand Down Expand Up @@ -1240,8 +1247,8 @@ def fillna(
if method == "ffill"
else plc.replace.ReplacePolicy.FOLLOWING
)
elif is_scalar(fill_value):
plc_replace = cudf.Scalar(fill_value).device_value
elif isinstance(fill_value, plc.Scalar):
plc_replace = fill_value
else:
plc_replace = fill_value.to_pylibcudf(mode="read")
plc_column = plc.replace.replace_nulls(
Expand Down
14 changes: 14 additions & 0 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

from cudf._typing import (
ColumnBinaryOperand,
ColumnLike,
DatetimeLikeScalar,
Dtype,
DtypeObj,
Expand Down Expand Up @@ -269,6 +270,19 @@ def __contains__(self, item: ScalarLike) -> bool:
"cudf.core.column.NumericalColumn", self.astype(np.dtype(np.int64))
)

def _validate_fillna_value(
self, fill_value: ScalarLike | ColumnLike
) -> plc.Scalar | ColumnBase:
"""Align fill_value for .fillna based on column type."""
if (
isinstance(fill_value, np.datetime64)
and self.time_unit != np.datetime_data(fill_value)[0]
):
fill_value = fill_value.astype(self.dtype)
elif isinstance(fill_value, str) and fill_value.lower() == "nat":
fill_value = np.datetime64(fill_value, self.time_unit)
return super()._validate_fillna_value(fill_value)

@functools.cached_property
def time_unit(self) -> str:
return np.datetime_data(self.dtype)[0]
Expand Down
28 changes: 24 additions & 4 deletions python/cudf/cudf/core/column/decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@
DecimalDtype,
)
from cudf.core.mixins import BinaryOperand
from cudf.utils.dtypes import CUDF_STRING_DTYPE
from cudf.core.scalar import pa_scalar_to_plc_scalar
from cudf.utils.dtypes import CUDF_STRING_DTYPE, cudf_dtype_to_pa_type
from cudf.utils.utils import pa_mask_buffer_to_mask

if TYPE_CHECKING:
Expand Down Expand Up @@ -165,16 +166,35 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str):

return result

def _scalar_to_plc_scalar(self, scalar: ScalarLike) -> plc.Scalar:
"""Return a pylibcudf.Scalar that matches the type of self.dtype"""
if not isinstance(scalar, pa.Scalar):
# e.g casting int to decimal type isn't allow, but OK in the constructor?
pa_scalar = pa.scalar(
scalar, type=cudf_dtype_to_pa_type(self.dtype)
)
else:
pa_scalar = scalar.cast(cudf_dtype_to_pa_type(self.dtype))
plc_scalar = pa_scalar_to_plc_scalar(pa_scalar)
if isinstance(self.dtype, (Decimal32Dtype, Decimal64Dtype)):
# pyarrow.Scalar only supports Decimal128 so conversion
# from pyarrow would only return a pylibcudf.Scalar with Decimal128
col = ColumnBase.from_pylibcudf(
plc.Column.from_scalar(plc_scalar, 1)
).astype(self.dtype)
return plc.copying.get_element(col.to_pylibcudf(mode="read"), 0)
return plc_scalar

def _validate_fillna_value(
self, fill_value: ScalarLike | ColumnLike
) -> cudf.Scalar | ColumnBase:
) -> plc.Scalar | ColumnBase:
"""Align fill_value for .fillna based on column type."""
if isinstance(fill_value, (int, Decimal)):
return cudf.Scalar(fill_value, dtype=self.dtype)
return super()._validate_fillna_value(fill_value)
elif isinstance(fill_value, ColumnBase) and (
isinstance(self.dtype, DecimalDtype) or self.dtype.kind in "iu"
):
return fill_value.astype(self.dtype)
return super()._validate_fillna_value(fill_value)
raise TypeError(
"Decimal columns only support using fillna with decimal and "
"integer values"
Expand Down
13 changes: 9 additions & 4 deletions python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,15 +559,20 @@ def find_and_replace(

def _validate_fillna_value(
self, fill_value: ScalarLike | ColumnLike
) -> cudf.Scalar | ColumnBase:
) -> plc.Scalar | ColumnBase:
"""Align fill_value for .fillna based on column type."""
if is_scalar(fill_value):
cudf_obj: cudf.Scalar | ColumnBase = cudf.Scalar(fill_value)
if not as_column(cudf_obj).can_cast_safely(self.dtype):
cudf_obj = ColumnBase.from_pylibcudf(
plc.Column.from_scalar(
pa_scalar_to_plc_scalar(pa.scalar(fill_value)), 1
)
)
if not cudf_obj.can_cast_safely(self.dtype):
raise TypeError(
f"Cannot safely cast non-equivalent "
f"{type(fill_value).__name__} to {self.dtype.name}"
)
return super()._validate_fillna_value(fill_value)
else:
cudf_obj = as_column(fill_value, nan_as_null=False)
if not cudf_obj.can_cast_safely(self.dtype): # type: ignore[attr-defined]
Expand All @@ -576,7 +581,7 @@ def _validate_fillna_value(
f"{cudf_obj.dtype.type.__name__} to "
f"{self.dtype.type.__name__}"
)
return cudf_obj.astype(self.dtype)
return cudf_obj.astype(self.dtype)

def can_cast_safely(self, to_dtype: DtypeObj) -> bool:
"""
Expand Down
15 changes: 15 additions & 0 deletions python/cudf/cudf/core/column/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@

from cudf._typing import (
ColumnBinaryOperand,
ColumnLike,
DatetimeLikeScalar,
Dtype,
DtypeObj,
ScalarLike,
)

_unit_to_nanoseconds_conversion = {
Expand Down Expand Up @@ -142,6 +144,19 @@ def __contains__(self, item: DatetimeLikeScalar) -> bool:
"cudf.core.column.NumericalColumn", self.astype(np.dtype(np.int64))
)

def _validate_fillna_value(
self, fill_value: ScalarLike | ColumnLike
) -> plc.Scalar | ColumnBase:
"""Align fill_value for .fillna based on column type."""
if (
isinstance(fill_value, np.timedelta64)
and self.time_unit != np.datetime_data(fill_value)[0]
):
fill_value = fill_value.astype(self.dtype)
elif isinstance(fill_value, str) and fill_value.lower() == "nat":
fill_value = np.timedelta64(fill_value, self.time_unit)
return super()._validate_fillna_value(fill_value)

@property
def values(self):
"""
Expand Down

0 comments on commit 45d8066

Please sign in to comment.