Skip to content

Commit

Permalink
More avoid cudf.dtype internally in favor of pre-defined, supported t…
Browse files Browse the repository at this point in the history
…ypes (#17918)

Continuation of #17839

Authors:
  - Matthew Roeschke (https://github.com/mroeschke)

Approvers:
  - Vyas Ramasubramani (https://github.com/vyasr)

URL: #17918
  • Loading branch information
mroeschke authored Feb 8, 2025
1 parent 61e47bb commit 428dc18
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 68 deletions.
5 changes: 2 additions & 3 deletions python/cudf/cudf/core/buffer/buffer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
# Copyright (c) 2020-2025, NVIDIA CORPORATION.

from __future__ import annotations

Expand All @@ -13,7 +13,6 @@
import pylibcudf
import rmm

import cudf
from cudf.core.abc import Serializable
from cudf.utils.string import format_bytes

Expand Down Expand Up @@ -504,7 +503,7 @@ def get_ptr_and_size(array_interface: Mapping) -> tuple[int, int]:

shape = array_interface["shape"] or (1,)
strides = array_interface["strides"]
itemsize = cudf.dtype(array_interface["typestr"]).itemsize
itemsize = numpy.dtype(array_interface["typestr"]).itemsize
if strides is None or pylibcudf.column.is_c_contiguous(
shape, strides, itemsize
):
Expand Down
5 changes: 0 additions & 5 deletions python/cudf/cudf/core/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,6 @@ def dtype(arbitrary):
raise TypeError(f"Unsupported type {np_dtype}")
return np_dtype

if isinstance(arbitrary, str) and arbitrary in {"hex", "hex32", "hex64"}:
# read_csv only accepts "hex"
# e.g. test_csv_reader_hexadecimals, test_csv_reader_hexadecimal_overflow
return arbitrary

# use `pandas_dtype` to try and interpret
# `arbitrary` as a Pandas extension type.
# Return the corresponding NumPy/cuDF type.
Expand Down
8 changes: 4 additions & 4 deletions python/cudf/cudf/core/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,16 +476,16 @@ def __repr__(self) -> str:
# https://github.com/numpy/numpy/issues/17552
return f"{self.__class__.__name__}({self.value!s}, dtype={self.dtype})"

def _binop_result_dtype_or_error(self, other, op):
def _binop_result_dtype_or_error(self, other, op) -> np.dtype:
if op in {"__eq__", "__ne__", "__lt__", "__gt__", "__le__", "__ge__"}:
return np.bool_
return np.dtype(np.bool_)

out_dtype = get_allowed_combinations_for_operator(
self.dtype, other.dtype, op
)

# datetime handling
if out_dtype in {"M", "m"}:
if out_dtype.kind in {"M", "m"}:
if self.dtype.char in {"M", "m"} and other.dtype.char not in {
"M",
"m",
Expand All @@ -505,7 +505,7 @@ def _binop_result_dtype_or_error(self, other, op):
return np.dtype(f"m8[{res}]")
return np.result_type(self.dtype, other.dtype)

return cudf.dtype(out_dtype)
return out_dtype

def _binaryop(self, other, op: str):
if is_scalar(other):
Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/tools/numeric.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def to_numeric(
type_set = list(np.typecodes["UnsignedInteger"])

for t in type_set:
downcast_dtype = cudf.dtype(t)
downcast_dtype = np.dtype(t)
if downcast_dtype.itemsize <= col.dtype.itemsize:
if col.can_cast_safely(downcast_dtype):
col = col.cast(downcast_dtype)
Expand Down
99 changes: 53 additions & 46 deletions python/cudf/cudf/io/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import warnings
from collections import abc
from io import BytesIO, StringIO
from typing import cast
from typing import TYPE_CHECKING, cast

import numpy as np
import pandas as pd
Expand All @@ -16,7 +16,7 @@

import cudf
from cudf._lib.column import Column
from cudf.api.types import is_hashable, is_scalar
from cudf.api.types import is_scalar
from cudf.core.buffer import acquire_spill_lock
from cudf.core.column_accessor import ColumnAccessor
from cudf.utils import ioutils
Expand All @@ -26,6 +26,10 @@
)
from cudf.utils.performance_tracking import _performance_tracking

if TYPE_CHECKING:
from cudf._typing import DtypeObj


_CSV_HEX_TYPE_MAP = {
"hex": np.dtype("int64"),
"hex64": np.dtype("int64"),
Expand Down Expand Up @@ -158,33 +162,49 @@ def read_csv(
header = 0

hex_cols: list[abc.Hashable] = []
new_dtypes: list[plc.DataType] | dict[abc.Hashable, plc.DataType] = []
cudf_dtypes: list[DtypeObj] | dict[abc.Hashable, DtypeObj] | DtypeObj = []
plc_dtypes: list[plc.DataType] | dict[abc.Hashable, plc.DataType] = []
if dtype is not None:
if isinstance(dtype, abc.Mapping):
new_dtypes = {}
plc_dtypes = {}
cudf_dtypes = {}
for k, col_type in dtype.items():
if is_hashable(col_type) and col_type in _CSV_HEX_TYPE_MAP:
if isinstance(col_type, str) and col_type in _CSV_HEX_TYPE_MAP:
col_type = _CSV_HEX_TYPE_MAP[col_type]
hex_cols.append(str(k))

new_dtypes[k] = _get_plc_data_type_from_dtype(
cudf.dtype(col_type)
)
elif cudf.api.types.is_scalar(dtype) or isinstance(
dtype, (np.dtype, pd.api.extensions.ExtensionDtype, type)
cudf_dtype = cudf.dtype(col_type)
cudf_dtypes[k] = cudf_dtype
plc_dtypes[k] = _get_plc_data_type_from_dtype(cudf_dtype)
elif isinstance(
dtype,
(
str,
np.dtype,
pd.api.extensions.ExtensionDtype,
cudf.core.dtypes._BaseDtype,
type,
),
):
if is_hashable(dtype) and dtype in _CSV_HEX_TYPE_MAP:
if isinstance(dtype, str) and dtype in _CSV_HEX_TYPE_MAP:
dtype = _CSV_HEX_TYPE_MAP[dtype]
hex_cols.append(0)

cast(list, new_dtypes).append(_get_plc_data_type_from_dtype(dtype))
else:
dtype = cudf.dtype(dtype)
cudf_dtypes = dtype
cast(list, plc_dtypes).append(_get_plc_data_type_from_dtype(dtype))
elif isinstance(dtype, abc.Collection):
for index, col_dtype in enumerate(dtype):
if is_hashable(col_dtype) and col_dtype in _CSV_HEX_TYPE_MAP:
if (
isinstance(col_dtype, str)
and col_dtype in _CSV_HEX_TYPE_MAP
):
col_dtype = _CSV_HEX_TYPE_MAP[col_dtype]
hex_cols.append(index)

new_dtypes.append(_get_plc_data_type_from_dtype(col_dtype))
else:
col_dtype = cudf.dtype(col_dtype)
cudf_dtypes.append(col_dtype)
plc_dtypes.append(_get_plc_data_type_from_dtype(col_dtype))
else:
raise ValueError(
"dtype should be a scalar/str/list-like/dict-like"
Expand Down Expand Up @@ -243,7 +263,7 @@ def read_csv(
if hex_cols is not None:
options.set_parse_hex(list(hex_cols))

options.set_dtypes(new_dtypes)
options.set_dtypes(plc_dtypes)

if true_values is not None:
options.set_true_values([str(val) for val in true_values])
Expand All @@ -266,15 +286,21 @@ def read_csv(
ca = ColumnAccessor(data, rangeindex=len(data) == 0)
df = cudf.DataFrame._from_data(ca)

if isinstance(dtype, abc.Mapping):
for k, v in dtype.items():
if isinstance(cudf.dtype(v), cudf.CategoricalDtype):
df._data[str(k)] = df._data[str(k)].astype(v)
elif dtype == "category" or isinstance(dtype, cudf.CategoricalDtype):
# Cast result to categorical if specified in dtype=
# since categorical is not handled in pylibcudf
if isinstance(cudf_dtypes, dict):
to_category = {
k: v
for k, v in cudf_dtypes.items()
if isinstance(v, cudf.CategoricalDtype)
}
if to_category:
df = df.astype(to_category)
elif isinstance(cudf_dtypes, cudf.CategoricalDtype):
df = df.astype(dtype)
elif isinstance(dtype, abc.Collection) and not is_scalar(dtype):
for index, col_dtype in enumerate(dtype):
if isinstance(cudf.dtype(col_dtype), cudf.CategoricalDtype):
elif isinstance(cudf_dtypes, list):
for index, col_dtype in enumerate(cudf_dtypes):
if isinstance(col_dtype, cudf.CategoricalDtype):
col_name = df._column_names[index]
df._data[col_name] = df._data[col_name].astype(col_dtype)

Expand Down Expand Up @@ -527,30 +553,11 @@ def _validate_args(
)


def _get_plc_data_type_from_dtype(dtype) -> plc.DataType:
def _get_plc_data_type_from_dtype(dtype: DtypeObj) -> plc.DataType:
# TODO: Remove this work-around Dictionary types
# in libcudf are fully mapped to categorical columns:
# https://github.com/rapidsai/cudf/issues/3960
if isinstance(dtype, cudf.CategoricalDtype):
# TODO: should we do this generally in dtype_to_pylibcudf_type?
dtype = dtype.categories.dtype
elif dtype == "category":
dtype = "str"

if isinstance(dtype, str):
if dtype == "date32":
return plc.DataType(plc.types.TypeId.TIMESTAMP_DAYS)
elif dtype in ("date", "date64"):
return plc.DataType(plc.types.TypeId.TIMESTAMP_MILLISECONDS)
elif dtype == "timestamp":
return plc.DataType(plc.types.TypeId.TIMESTAMP_MILLISECONDS)
elif dtype == "timestamp[us]":
return plc.DataType(plc.types.TypeId.TIMESTAMP_MICROSECONDS)
elif dtype == "timestamp[s]":
return plc.DataType(plc.types.TypeId.TIMESTAMP_SECONDS)
elif dtype == "timestamp[ms]":
return plc.DataType(plc.types.TypeId.TIMESTAMP_MILLISECONDS)
elif dtype == "timestamp[ns]":
return plc.DataType(plc.types.TypeId.TIMESTAMP_NANOSECONDS)

dtype = cudf.dtype(dtype)
return dtype_to_pylibcudf_type(dtype)
6 changes: 3 additions & 3 deletions python/cudf/cudf/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,7 @@ def write_to_dataset(
return metadata


def _parse_metadata(meta) -> tuple[bool, Any, Any]:
def _parse_metadata(meta) -> tuple[bool, Any, None | np.dtype]:
file_is_range_index = False
file_index_cols = None
file_column_dtype = None
Expand All @@ -541,7 +541,7 @@ def _parse_metadata(meta) -> tuple[bool, Any, Any]:
):
file_is_range_index = True
if "column_indexes" in meta and len(meta["column_indexes"]) == 1:
file_column_dtype = meta["column_indexes"][0]["numpy_type"]
file_column_dtype = np.dtype(meta["column_indexes"][0]["numpy_type"])
return file_is_range_index, file_index_cols, file_column_dtype


Expand Down Expand Up @@ -2368,6 +2368,6 @@ def _process_metadata(
df.index.names = index_col

if df._num_columns == 0 and column_index_type is not None:
df._data.label_dtype = cudf.dtype(column_index_type)
df._data.label_dtype = column_index_type

return df
15 changes: 9 additions & 6 deletions python/cudf/cudf/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,9 @@ def _get_nan_for_dtype(dtype: DtypeObj) -> DtypeObj:
return np.float64("nan")


def get_allowed_combinations_for_operator(dtype_l, dtype_r, op):
def get_allowed_combinations_for_operator(
dtype_l: np.dtype, dtype_r: np.dtype, op: str
) -> np.dtype:
error = TypeError(
f"{op} not supported between {dtype_l} and {dtype_r} scalars"
)
Expand All @@ -456,18 +458,19 @@ def get_allowed_combinations_for_operator(dtype_l, dtype_r, op):
# special rules for string
if dtype_l == "object" or dtype_r == "object":
if (dtype_l == dtype_r == "object") and op == "__add__":
return "str"
return CUDF_STRING_DTYPE
else:
raise error

# Check if we can directly operate

for valid_combo in allowed:
ltype, rtype, outtype = valid_combo
if np.can_cast(dtype_l.char, ltype) and np.can_cast(
dtype_r.char, rtype
ltype, rtype, outtype = valid_combo # type: ignore[misc]
if np.can_cast(dtype_l.char, ltype) and np.can_cast( # type: ignore[has-type]
dtype_r.char,
rtype, # type: ignore[has-type]
):
return outtype
return np.dtype(outtype) # type: ignore[has-type]

raise error

Expand Down

0 comments on commit 428dc18

Please sign in to comment.