Skip to content

Commit

Permalink
Prune more seldom used dtype utils
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke committed Mar 3, 2025
1 parent b6a6d39 commit 0537bbd
Show file tree
Hide file tree
Showing 7 changed files with 206 additions and 220 deletions.
3 changes: 1 addition & 2 deletions python/cudf/cudf/core/column/column.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
dtype_from_pylibcudf_column,
dtype_to_pylibcudf_type,
find_common_type,
get_time_unit,
is_column_like,
is_mixed_with_object_dtype,
min_signed_type,
Expand Down Expand Up @@ -2725,7 +2724,7 @@ def as_column(
nan_as_null=nan_as_null,
)
elif arbitrary.dtype.kind in "mM":
time_unit = get_time_unit(arbitrary)
time_unit = np.datetime_data(arbitrary.dtype)[0]
if time_unit in ("D", "W", "M", "Y"):
# TODO: Raise in these cases instead of downcasting to s?
new_type = f"{arbitrary.dtype.type.__name__}[s]"
Expand Down
23 changes: 18 additions & 5 deletions python/cudf/cudf/core/column/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,13 +178,27 @@ def infer_format(element: str, **kwargs) -> str:
return fmt


def _get_time_unit(obj: ColumnBinaryOperand) -> str:
if isinstance(
obj,
(
cudf.core.column.datetime.DatetimeColumn,
cudf.core.column.timedelta.TimeDeltaColumn,
),
):
return obj.time_unit

time_unit, _ = np.datetime_data(obj.dtype)
return time_unit


def _resolve_mixed_dtypes(
lhs: ColumnBinaryOperand, rhs: ColumnBinaryOperand, base_type: str
) -> Dtype:
units = ["s", "ms", "us", "ns"]
lhs_time_unit = cudf.utils.dtypes.get_time_unit(lhs)
lhs_time_unit = _get_time_unit(lhs)
lhs_unit = units.index(lhs_time_unit)
rhs_time_unit = cudf.utils.dtypes.get_time_unit(rhs)
rhs_time_unit = _get_time_unit(rhs)
rhs_unit = units.index(rhs_time_unit)
return np.dtype(f"{base_type}[{units[max(lhs_unit, rhs_unit)]}]")

Expand Down Expand Up @@ -537,7 +551,7 @@ def normalize_binop_value( # type: ignore[override]

if isinstance(other, np.datetime64):
if np.isnat(other):
other_time_unit = cudf.utils.dtypes.get_time_unit(other)
other_time_unit = np.datetime_data(other.dtype)[0]
if other_time_unit not in {"s", "ms", "ns", "us"}:
other_time_unit = "ns"

Expand All @@ -548,8 +562,7 @@ def normalize_binop_value( # type: ignore[override]
other = other.astype(self.dtype)
return cudf.Scalar(other)
elif isinstance(other, np.timedelta64):
other_time_unit = cudf.utils.dtypes.get_time_unit(other)

other_time_unit = np.datetime_data(other.dtype)[0]
if np.isnat(other):
return cudf.Scalar(
None,
Expand Down
34 changes: 31 additions & 3 deletions python/cudf/cudf/core/column/numerical.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@
from cudf.utils.dtypes import (
CUDF_STRING_DTYPE,
find_common_type,
min_column_type,
min_signed_type,
min_unsigned_type,
np_dtypes_to_pandas_dtypes,
)

Expand Down Expand Up @@ -469,6 +469,34 @@ def _process_values_for_isin(
def _can_return_nan(self, skipna: bool | None = None) -> bool:
return not skipna and self.has_nulls(include_nan=True)

def _min_column_type(self, expected_type: np.dtype) -> np.dtype:
"""
Return the smallest dtype which can represent all elements of self.
"""
if self.null_count == len(self):
return self.dtype

min_value, max_value = self.min(), self.max()
either_is_inf = np.isinf(min_value) or np.isinf(max_value)
if not either_is_inf and expected_type.kind == "i":
max_bound_dtype = min_signed_type(max_value)
min_bound_dtype = min_signed_type(min_value)
return np.promote_types(max_bound_dtype, min_bound_dtype)
elif not either_is_inf and expected_type.kind == "u":
max_bound_dtype = min_unsigned_type(max_value)
min_bound_dtype = min_unsigned_type(min_value)
return np.promote_types(max_bound_dtype, min_bound_dtype)
elif self.dtype.kind == "f" or expected_type.kind == "f":
return np.promote_types(
expected_type,
np.promote_types(
np.min_scalar_type(float(max_value)),
np.min_scalar_type(float(min_value)),
),
)
else:
return self.dtype

def find_and_replace(
self,
to_replace: ColumnLike,
Expand Down Expand Up @@ -762,8 +790,8 @@ def _normalize_find_and_replace_input(
normalized_column = normalized_column.astype(input_column_dtype)
if normalized_column.can_cast_safely(input_column_dtype):
return normalized_column.astype(input_column_dtype)
col_to_normalize_dtype = min_column_type(
normalized_column, input_column_dtype
col_to_normalize_dtype = normalized_column._min_column_type( # type: ignore[attr-defined]
input_column_dtype
)
# Scalar case
if len(col_to_normalize) == 1:
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/column/timedelta.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def normalize_binop_value(self, other) -> ColumnBinaryOperand:
other = pd.Timedelta(other).to_timedelta64()

if isinstance(other, np.timedelta64):
other_time_unit = cudf.utils.dtypes.get_time_unit(other)
other_time_unit = np.datetime_data(other.dtype)[0]
if np.isnat(other):
return cudf.Scalar(
None,
Expand Down
4 changes: 2 additions & 2 deletions python/cudf/cudf/core/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6715,8 +6715,8 @@ def _apply_cupy_method_axis_1(self, method, *args, **kwargs):
prepared._data[col] = (
prepared._data[col]
.astype(
cudf.utils.dtypes.get_min_float_dtype(
prepared._data[col]
prepared._data[col]._min_column_type(
np.dtype(np.float32)
)
if common_dtype.kind != "M"
else np.dtype(np.float64)
Expand Down
154 changes: 153 additions & 1 deletion python/cudf/cudf/core/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
from cudf.core.missing import NA, NaT
from cudf.core.mixins import BinaryOperand
from cudf.utils.dtypes import (
CUDF_STRING_DTYPE,
cudf_dtype_from_pa_type,
get_allowed_combinations_for_operator,
to_cudf_compatible_scalar,
)

Expand All @@ -36,6 +36,158 @@
from cudf._typing import Dtype, ScalarLike


# Type dispatch loops similar to what are found in `np.add.types`
# In NumPy, whether or not an op can be performed between two
# operands is determined by checking to see if NumPy has a c/c++
# loop specifically for adding those two operands built in. If
# not it will search lists like these for a loop for types that
# the operands can be safely cast to. These are those lookups,
# modified slightly for cuDF's rules
_ADD_TYPES = [
"???",
"BBB",
"HHH",
"III",
"LLL",
"bbb",
"hhh",
"iii",
"lll",
"fff",
"ddd",
"mMM",
"MmM",
"mmm",
"LMM",
"MLM",
"Lmm",
"mLm",
]
_SUB_TYPES = [
"BBB",
"HHH",
"III",
"LLL",
"bbb",
"hhh",
"iii",
"lll",
"fff",
"ddd",
"???",
"MMm",
"mmm",
"MmM",
"MLM",
"mLm",
"Lmm",
]
_MUL_TYPES = [
"???",
"BBB",
"HHH",
"III",
"LLL",
"bbb",
"hhh",
"iii",
"lll",
"fff",
"ddd",
"mLm",
"Lmm",
"mlm",
"lmm",
]
_FLOORDIV_TYPES = [
"bbb",
"BBB",
"HHH",
"III",
"LLL",
"hhh",
"iii",
"lll",
"fff",
"ddd",
"???",
"mqm",
"mdm",
"mmq",
]
_TRUEDIV_TYPES = ["fff", "ddd", "mqm", "mmd", "mLm"]
_MOD_TYPES = [
"bbb",
"BBB",
"hhh",
"HHH",
"iii",
"III",
"lll",
"LLL",
"fff",
"ddd",
"mmm",
]
_POW_TYPES = [
"bbb",
"BBB",
"hhh",
"HHH",
"iii",
"III",
"lll",
"LLL",
"fff",
"ddd",
]


def get_allowed_combinations_for_operator(
dtype_l: np.dtype, dtype_r: np.dtype, op: str
) -> np.dtype:
error = TypeError(
f"{op} not supported between {dtype_l} and {dtype_r} scalars"
)

to_numpy_ops = {
"__add__": _ADD_TYPES,
"__radd__": _ADD_TYPES,
"__sub__": _SUB_TYPES,
"__rsub__": _SUB_TYPES,
"__mul__": _MUL_TYPES,
"__rmul__": _MUL_TYPES,
"__floordiv__": _FLOORDIV_TYPES,
"__rfloordiv__": _FLOORDIV_TYPES,
"__truediv__": _TRUEDIV_TYPES,
"__rtruediv__": _TRUEDIV_TYPES,
"__mod__": _MOD_TYPES,
"__rmod__": _MOD_TYPES,
"__pow__": _POW_TYPES,
"__rpow__": _POW_TYPES,
}
allowed = to_numpy_ops.get(op, op)

# special rules for string
if dtype_l == "object" or dtype_r == "object":
if (dtype_l == dtype_r == "object") and op == "__add__":
return CUDF_STRING_DTYPE
else:
raise error

# Check if we can directly operate

for valid_combo in allowed:
ltype, rtype, outtype = valid_combo # type: ignore[misc]
if np.can_cast(dtype_l.char, ltype) and np.can_cast( # type: ignore[has-type]
dtype_r.char,
rtype, # type: ignore[has-type]
):
return np.dtype(outtype) # type: ignore[has-type]

raise error


def _preprocess_host_value(value, dtype) -> tuple[ScalarLike, Dtype]:
"""
Preprocess a value and dtype for host-side cudf.Scalar
Expand Down
Loading

0 comments on commit 0537bbd

Please sign in to comment.