From 45d80669367c6bf3b9dc0cd122f0ea36072cb7ea Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Mon, 3 Mar 2025 21:25:11 -0800 Subject: [PATCH] Remove cudf.Scalar from shift/fillna (#17922) Toward https://github.com/rapidsai/cudf/issues/17843 Authors: - Matthew Roeschke (https://github.com/mroeschke) - Vyas Ramasubramani (https://github.com/vyasr) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/17922 --- python/cudf/cudf/core/column/categorical.py | 9 +++++-- python/cudf/cudf/core/column/column.py | 23 +++++++++++------ python/cudf/cudf/core/column/datetime.py | 14 +++++++++++ python/cudf/cudf/core/column/decimal.py | 28 ++++++++++++++++++--- python/cudf/cudf/core/column/numerical.py | 13 +++++++--- python/cudf/cudf/core/column/timedelta.py | 15 +++++++++++ 6 files changed, 84 insertions(+), 18 deletions(-) diff --git a/python/cudf/cudf/core/column/categorical.py b/python/cudf/cudf/core/column/categorical.py index d41e448254c..c75d285e7de 100644 --- a/python/cudf/cudf/core/column/categorical.py +++ b/python/cudf/cudf/core/column/categorical.py @@ -20,6 +20,7 @@ from cudf.core.scalar import pa_scalar_to_plc_scalar from cudf.utils.dtypes import ( SIZE_TYPE_DTYPE, + cudf_dtype_to_pa_type, find_common_type, is_mixed_with_object_dtype, min_signed_type, @@ -1042,7 +1043,7 @@ def notnull(self) -> ColumnBase: def _validate_fillna_value( self, fill_value: ScalarLike | ColumnLike - ) -> cudf.Scalar | ColumnBase: + ) -> plc.Scalar | ColumnBase: """Align fill_value for .fillna based on column type.""" if cudf.api.types.is_scalar(fill_value): if fill_value != _DEFAULT_CATEGORICAL_VALUE: @@ -1052,7 +1053,11 @@ def _validate_fillna_value( raise ValueError( f"{fill_value=} must be in categories" ) from err - return cudf.Scalar(fill_value, dtype=self.codes.dtype) + return pa_scalar_to_plc_scalar( + pa.scalar( + fill_value, type=cudf_dtype_to_pa_type(self.codes.dtype) + ) + ) else: fill_value = column.as_column(fill_value, nan_as_null=False) if isinstance(fill_value.dtype, CategoricalDtype): diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index 61f4f7d52fb..0d36fd3855b 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -891,12 +891,11 @@ def _fill( @acquire_spill_lock() def shift(self, offset: int, fill_value: ScalarLike) -> Self: - if not isinstance(fill_value, cudf.Scalar): - fill_value = cudf.Scalar(fill_value, dtype=self.dtype) + plc_fill_value = self._scalar_to_plc_scalar(fill_value) plc_col = plc.copying.shift( self.to_pylibcudf(mode="read"), offset, - fill_value.device_value, + plc_fill_value, ) return type(self).from_pylibcudf(plc_col) # type: ignore[return-value] @@ -1188,13 +1187,21 @@ def _check_scatter_key_length( f"{num_keys}" ) + def _scalar_to_plc_scalar(self, scalar: ScalarLike) -> plc.Scalar: + """Return a pylibcudf.Scalar that matches the type of self.dtype""" + if not isinstance(scalar, pa.Scalar): + scalar = pa.scalar(scalar) + return pa_scalar_to_plc_scalar( + scalar.cast(cudf_dtype_to_pa_type(self.dtype)) + ) + def _validate_fillna_value( self, fill_value: ScalarLike | ColumnLike - ) -> cudf.Scalar | ColumnBase: + ) -> plc.Scalar | ColumnBase: """Align fill_value for .fillna based on column type.""" if is_scalar(fill_value): - return cudf.Scalar(fill_value, dtype=self.dtype) - return as_column(fill_value) + return self._scalar_to_plc_scalar(fill_value) + return as_column(fill_value).astype(self.dtype) @acquire_spill_lock() def replace( @@ -1240,8 +1247,8 @@ def fillna( if method == "ffill" else plc.replace.ReplacePolicy.FOLLOWING ) - elif is_scalar(fill_value): - plc_replace = cudf.Scalar(fill_value).device_value + elif isinstance(fill_value, plc.Scalar): + plc_replace = fill_value else: plc_replace = fill_value.to_pylibcudf(mode="read") plc_column = plc.replace.replace_nulls( diff --git a/python/cudf/cudf/core/column/datetime.py b/python/cudf/cudf/core/column/datetime.py index 213e91d7b3f..64ddcae72a7 100644 --- a/python/cudf/cudf/core/column/datetime.py +++ b/python/cudf/cudf/core/column/datetime.py @@ -45,6 +45,7 @@ from cudf._typing import ( ColumnBinaryOperand, + ColumnLike, DatetimeLikeScalar, Dtype, DtypeObj, @@ -269,6 +270,19 @@ def __contains__(self, item: ScalarLike) -> bool: "cudf.core.column.NumericalColumn", self.astype(np.dtype(np.int64)) ) + def _validate_fillna_value( + self, fill_value: ScalarLike | ColumnLike + ) -> plc.Scalar | ColumnBase: + """Align fill_value for .fillna based on column type.""" + if ( + isinstance(fill_value, np.datetime64) + and self.time_unit != np.datetime_data(fill_value)[0] + ): + fill_value = fill_value.astype(self.dtype) + elif isinstance(fill_value, str) and fill_value.lower() == "nat": + fill_value = np.datetime64(fill_value, self.time_unit) + return super()._validate_fillna_value(fill_value) + @functools.cached_property def time_unit(self) -> str: return np.datetime_data(self.dtype)[0] diff --git a/python/cudf/cudf/core/column/decimal.py b/python/cudf/cudf/core/column/decimal.py index 8db6f805bce..848faf6a9ee 100644 --- a/python/cudf/cudf/core/column/decimal.py +++ b/python/cudf/cudf/core/column/decimal.py @@ -24,7 +24,8 @@ DecimalDtype, ) from cudf.core.mixins import BinaryOperand -from cudf.utils.dtypes import CUDF_STRING_DTYPE +from cudf.core.scalar import pa_scalar_to_plc_scalar +from cudf.utils.dtypes import CUDF_STRING_DTYPE, cudf_dtype_to_pa_type from cudf.utils.utils import pa_mask_buffer_to_mask if TYPE_CHECKING: @@ -165,16 +166,35 @@ def _binaryop(self, other: ColumnBinaryOperand, op: str): return result + def _scalar_to_plc_scalar(self, scalar: ScalarLike) -> plc.Scalar: + """Return a pylibcudf.Scalar that matches the type of self.dtype""" + if not isinstance(scalar, pa.Scalar): + # e.g casting int to decimal type isn't allow, but OK in the constructor? + pa_scalar = pa.scalar( + scalar, type=cudf_dtype_to_pa_type(self.dtype) + ) + else: + pa_scalar = scalar.cast(cudf_dtype_to_pa_type(self.dtype)) + plc_scalar = pa_scalar_to_plc_scalar(pa_scalar) + if isinstance(self.dtype, (Decimal32Dtype, Decimal64Dtype)): + # pyarrow.Scalar only supports Decimal128 so conversion + # from pyarrow would only return a pylibcudf.Scalar with Decimal128 + col = ColumnBase.from_pylibcudf( + plc.Column.from_scalar(plc_scalar, 1) + ).astype(self.dtype) + return plc.copying.get_element(col.to_pylibcudf(mode="read"), 0) + return plc_scalar + def _validate_fillna_value( self, fill_value: ScalarLike | ColumnLike - ) -> cudf.Scalar | ColumnBase: + ) -> plc.Scalar | ColumnBase: """Align fill_value for .fillna based on column type.""" if isinstance(fill_value, (int, Decimal)): - return cudf.Scalar(fill_value, dtype=self.dtype) + return super()._validate_fillna_value(fill_value) elif isinstance(fill_value, ColumnBase) and ( isinstance(self.dtype, DecimalDtype) or self.dtype.kind in "iu" ): - return fill_value.astype(self.dtype) + return super()._validate_fillna_value(fill_value) raise TypeError( "Decimal columns only support using fillna with decimal and " "integer values" diff --git a/python/cudf/cudf/core/column/numerical.py b/python/cudf/cudf/core/column/numerical.py index eecb294acee..77c5a6b6caf 100644 --- a/python/cudf/cudf/core/column/numerical.py +++ b/python/cudf/cudf/core/column/numerical.py @@ -559,15 +559,20 @@ def find_and_replace( def _validate_fillna_value( self, fill_value: ScalarLike | ColumnLike - ) -> cudf.Scalar | ColumnBase: + ) -> plc.Scalar | ColumnBase: """Align fill_value for .fillna based on column type.""" if is_scalar(fill_value): - cudf_obj: cudf.Scalar | ColumnBase = cudf.Scalar(fill_value) - if not as_column(cudf_obj).can_cast_safely(self.dtype): + cudf_obj = ColumnBase.from_pylibcudf( + plc.Column.from_scalar( + pa_scalar_to_plc_scalar(pa.scalar(fill_value)), 1 + ) + ) + if not cudf_obj.can_cast_safely(self.dtype): raise TypeError( f"Cannot safely cast non-equivalent " f"{type(fill_value).__name__} to {self.dtype.name}" ) + return super()._validate_fillna_value(fill_value) else: cudf_obj = as_column(fill_value, nan_as_null=False) if not cudf_obj.can_cast_safely(self.dtype): # type: ignore[attr-defined] @@ -576,7 +581,7 @@ def _validate_fillna_value( f"{cudf_obj.dtype.type.__name__} to " f"{self.dtype.type.__name__}" ) - return cudf_obj.astype(self.dtype) + return cudf_obj.astype(self.dtype) def can_cast_safely(self, to_dtype: DtypeObj) -> bool: """ diff --git a/python/cudf/cudf/core/column/timedelta.py b/python/cudf/cudf/core/column/timedelta.py index e4d47f492c2..654d2c2b800 100644 --- a/python/cudf/cudf/core/column/timedelta.py +++ b/python/cudf/cudf/core/column/timedelta.py @@ -30,9 +30,11 @@ from cudf._typing import ( ColumnBinaryOperand, + ColumnLike, DatetimeLikeScalar, Dtype, DtypeObj, + ScalarLike, ) _unit_to_nanoseconds_conversion = { @@ -142,6 +144,19 @@ def __contains__(self, item: DatetimeLikeScalar) -> bool: "cudf.core.column.NumericalColumn", self.astype(np.dtype(np.int64)) ) + def _validate_fillna_value( + self, fill_value: ScalarLike | ColumnLike + ) -> plc.Scalar | ColumnBase: + """Align fill_value for .fillna based on column type.""" + if ( + isinstance(fill_value, np.timedelta64) + and self.time_unit != np.datetime_data(fill_value)[0] + ): + fill_value = fill_value.astype(self.dtype) + elif isinstance(fill_value, str) and fill_value.lower() == "nat": + fill_value = np.timedelta64(fill_value, self.time_unit) + return super()._validate_fillna_value(fill_value) + @property def values(self): """