Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor math_ops.cu dispatcher logic #18006

Open
wants to merge 4 commits into
base: branch-25.04
Choose a base branch
from
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 90 additions & 189 deletions cpp/src/unary/math_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -199,19 +199,14 @@ struct DeviceAbs {
}
};

// round float to int

struct DeviceRInt {
template <typename T>
std::enable_if_t<std::is_floating_point_v<T>, T> __device__ operator()(T data)
__device__ T operator()(T data)
{
return std::rint(data);
}

// Dummy to handle other types, will never be executed
template <typename T>
std::enable_if_t<!std::is_floating_point_v<T>, T> __device__ operator()(T data)
{
return data;
}
};

// bitwise op
Expand Down Expand Up @@ -350,14 +345,26 @@ std::unique_ptr<cudf::column> transform_fn(InputIterator begin,
null_count,
stream,
mr);
if (size == 0) return output;

auto output_view = output->mutable_view();
thrust::transform(rmm::exec_policy(stream), begin, end, output_view.begin<OutputType>(), UFN{});
output->set_null_count(null_count);
return output;
}

template <typename T, typename UFN>
std::unique_ptr<cudf::column> transform_fn(cudf::column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
return transform_fn<T, UFN>(input.begin<T>(),
input.end<T>(),
detail::copy_bitmask(input, stream, mr),
input.null_count(),
stream,
mr);
}

template <typename T, typename UFN>
std::unique_ptr<cudf::column> transform_fn(cudf::dictionary_column_view const& input,
rmm::cuda_stream_view stream,
Expand All @@ -377,136 +384,52 @@ std::unique_ptr<cudf::column> transform_fn(cudf::dictionary_column_view const& i
output->view(), dictionary::detail::get_indices_type_for_size(output->size()), stream, mr);
}

template <typename UFN>
struct MathOpDispatcher {
template <typename T, std::enable_if_t<std::is_arithmetic_v<T>>* = nullptr>
std::unique_ptr<cudf::column> operator()(cudf::column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
return transform_fn<T, UFN>(input.begin<T>(),
input.end<T>(),
cudf::detail::copy_bitmask(input, stream, mr),
input.null_count(),
stream,
mr);
}

struct dictionary_dispatch {
template <typename T, std::enable_if_t<std::is_arithmetic_v<T>>* = nullptr>
std::unique_ptr<cudf::column> operator()(cudf::dictionary_column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
return transform_fn<T, UFN>(input, stream, mr);
}

template <typename T, typename... Args>
std::enable_if_t<!std::is_arithmetic_v<T>, std::unique_ptr<cudf::column>> operator()(Args&&...)
{
CUDF_FAIL("dictionary keys must be numeric for this operation");
}
};

template <
typename T,
std::enable_if_t<!std::is_arithmetic_v<T> and std::is_same_v<T, dictionary32>>* = nullptr>
std::unique_ptr<cudf::column> operator()(cudf::column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
if (input.is_empty()) return empty_like(input);
auto dictionary_col = dictionary_column_view(input);
return type_dispatcher(
dictionary_col.keys().type(), dictionary_dispatch{}, dictionary_col, stream, mr);
}

template <typename T, typename... Args>
std::enable_if_t<!std::is_arithmetic_v<T> and !std::is_same_v<T, dictionary32>,
std::unique_ptr<cudf::column>>
operator()(Args&&...)
{
CUDF_FAIL("Unsupported data type for operation");
}
template <typename T>
struct ArithmeticOps {
static constexpr bool is_supported() { return std::is_arithmetic_v<T>; }
};

template <typename UFN>
struct NegateOpDispatcher {
template <typename T>
static constexpr bool is_supported()
{
return std::is_signed_v<T> || cudf::is_duration<T>();
}

template <typename T, std::enable_if_t<is_supported<T>()>* = nullptr>
std::unique_ptr<cudf::column> operator()(cudf::column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
return transform_fn<T, UFN>(input.begin<T>(),
input.end<T>(),
cudf::detail::copy_bitmask(input, stream, mr),
input.null_count(),
stream,
mr);
}

template <typename T, typename... Args>
std::enable_if_t<!is_supported<T>(), std::unique_ptr<cudf::column>> operator()(Args&&...)
{
CUDF_FAIL("Unsupported data type for negate operation");
}
template <typename T>
struct NegateOps {
static constexpr bool is_supported() { return std::is_signed_v<T> || cudf::is_duration<T>(); }
};

template <typename UFN>
struct BitwiseOpDispatcher {
template <typename T, std::enable_if_t<std::is_integral_v<T>>* = nullptr>
std::unique_ptr<cudf::column> operator()(cudf::column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
return transform_fn<T, UFN>(input.begin<T>(),
input.end<T>(),
cudf::detail::copy_bitmask(input, stream, mr),
input.null_count(),
stream,
mr);
}

struct dictionary_dispatch {
template <typename T, std::enable_if_t<std::is_integral_v<T>>* = nullptr>
std::unique_ptr<cudf::column> operator()(cudf::dictionary_column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
return transform_fn<T, UFN>(input, stream, mr);
}
template <typename T>
struct BitWiseOps {
static constexpr bool is_supported() { return std::is_integral_v<T>; }
};

template <typename T, typename... Args>
std::enable_if_t<!std::is_integral_v<T>, std::unique_ptr<cudf::column>> operator()(Args&&...)
{
CUDF_FAIL("dictionary keys type not supported for this operation");
}
};
template <typename T>
struct FloatOnlyOps {
static constexpr bool is_supported() { return std::is_floating_point_v<T>; }
};

template <typename T,
std::enable_if_t<!std::is_integral_v<T> and std::is_same_v<T, dictionary32>>* = nullptr>
/**
* @brief Generic math-ops dispatcher
*
* Performs a transform on the input data using the operator defined by UFN.
* The Supported type determines which types are allowed by the operator.
*
* @tparam UFN The actual operator to perform on the input data
* @tparam Supported Contains the 'is_supported()' function
*/
template <typename UFN, template <typename> typename Supported>
struct MathOpDispatcher {
template <typename T, std::enable_if_t<Supported<T>::is_supported()>* = nullptr>
std::unique_ptr<cudf::column> operator()(cudf::column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
if (input.is_empty()) return empty_like(input);
auto dictionary_col = dictionary_column_view(input);
return type_dispatcher(
dictionary_col.keys().type(), dictionary_dispatch{}, dictionary_col, stream, mr);
return (input.type().id() == type_id::DICTIONARY32)
? transform_fn<T, UFN>(cudf::dictionary_column_view(input), stream, mr)
: transform_fn<T, UFN>(input, stream, mr);
}

template <typename T, typename... Args>
std::enable_if_t<!std::is_integral_v<T> and !std::is_same_v<T, dictionary32>,
std::unique_ptr<cudf::column>>
operator()(Args&&...)
std::enable_if_t<!Supported<T>::is_supported(), std::unique_ptr<cudf::column>> operator()(
Args&&...)
{
CUDF_FAIL("Unsupported datatype for operation");
CUDF_FAIL("Unsupported data type for this operation");
}
};

Expand All @@ -525,54 +448,26 @@ struct LogicalOpDispatcher {
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
return transform_fn<bool, UFN>(input.begin<T>(),
input.end<T>(),
cudf::detail::copy_bitmask(input, stream, mr),
input.null_count(),

stream,
mr);
}

struct dictionary_dispatch {
template <typename T, std::enable_if_t<is_supported<T>()>* = nullptr>
std::unique_ptr<cudf::column> operator()(cudf::dictionary_column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
auto dictionary_view = cudf::column_device_view::create(input.parent(), stream);
if (input.type().id() == type_id::DICTIONARY32) {
auto dictionary_view = cudf::column_device_view::create(input, stream);
auto dictionary_itr = dictionary::detail::make_dictionary_iterator<T>(*dictionary_view);
return transform_fn<bool, UFN>(dictionary_itr,
dictionary_itr + input.size(),
cudf::detail::copy_bitmask(input.parent(), stream, mr),
cudf::detail::copy_bitmask(input, stream, mr),
input.null_count(),
stream,
mr);
}

template <typename T, typename... Args>
std::enable_if_t<!is_supported<T>(), std::unique_ptr<cudf::column>> operator()(Args&&...)
{
CUDF_FAIL("dictionary keys type not supported for this operation");
}
};

template <typename T,
std::enable_if_t<!is_supported<T>() and std::is_same_v<T, dictionary32>>* = nullptr>
std::unique_ptr<cudf::column> operator()(cudf::column_view const& input,
rmm::cuda_stream_view stream,
rmm::device_async_resource_ref mr)
{
if (input.is_empty()) return make_empty_column(cudf::data_type{cudf::type_id::BOOL8});
auto dictionary_col = dictionary_column_view(input);
return type_dispatcher(
dictionary_col.keys().type(), dictionary_dispatch{}, dictionary_col, stream, mr);
return transform_fn<bool, UFN>(input.begin<T>(),
input.end<T>(),
cudf::detail::copy_bitmask(input, stream, mr),
input.null_count(),
stream,
mr);
}

template <typename T, typename... Args>
std::enable_if_t<!is_supported<T>() and !std::is_same_v<T, dictionary32>,
std::unique_ptr<cudf::column>>
operator()(Args&&...)
std::enable_if_t<!is_supported<T>(), std::unique_ptr<cudf::column>> operator()(Args&&...)
{
CUDF_FAIL("Unsupported datatype for operation");
}
Expand Down Expand Up @@ -614,79 +509,85 @@ std::unique_ptr<cudf::column> unary_operation(cudf::column_view const& input,
if (cudf::is_fixed_point(input.type()))
return type_dispatcher(input.type(), detail::FixedPointOpDispatcher{}, input, op, stream, mr);

if (input.is_empty()) {
return op == cudf::unary_operator::NOT ? make_empty_column(type_id::BOOL8) : empty_like(input);
}

// dispatch on the keys if dictionary saves a 2nd dispatch later
auto dispatch_type = input.type().id() == type_id::DICTIONARY32
? dictionary_column_view(input).keys().type()
: input.type();

switch (op) {
case cudf::unary_operator::SIN:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceSin>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceSin, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::COS:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceCos>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceCos, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::TAN:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceTan>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceTan, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::ARCSIN:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceArcSin>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceArcSin, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::ARCCOS:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceArcCos>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceArcCos, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::ARCTAN:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceArcTan>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceArcTan, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::SINH:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceSinH>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceSinH, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::COSH:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceCosH>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceCosH, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::TANH:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceTanH>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceTanH, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::ARCSINH:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceArcSinH>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceArcSinH, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::ARCCOSH:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceArcCosH>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceArcCosH, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::ARCTANH:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceArcTanH>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceArcTanH, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::EXP:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceExp>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceExp, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::LOG:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceLog>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceLog, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::SQRT:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceSqrt>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceSqrt, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::CBRT:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceCbrt>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceCbrt, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::CEIL:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceCeil>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceCeil, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::FLOOR:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceFloor>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceFloor, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::ABS:
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceAbs>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceAbs, ArithmeticOps>{}, input, stream, mr);
case cudf::unary_operator::RINT:
CUDF_EXPECTS(
(input.type().id() == type_id::FLOAT32) or (input.type().id() == type_id::FLOAT64),
"rint expects floating point values");
return cudf::type_dispatcher(
input.type(), detail::MathOpDispatcher<detail::DeviceRInt>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceRInt, FloatOnlyOps>{}, input, stream, mr);
case cudf::unary_operator::BIT_INVERT:
return cudf::type_dispatcher(
input.type(), detail::BitwiseOpDispatcher<detail::DeviceInvert>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceInvert, BitWiseOps>{}, input, stream, mr);
case cudf::unary_operator::NOT:
return cudf::type_dispatcher(
input.type(), detail::LogicalOpDispatcher<detail::DeviceNot>{}, input, stream, mr);
dispatch_type, detail::LogicalOpDispatcher<DeviceNot>{}, input, stream, mr);
case cudf::unary_operator::NEGATE:
return cudf::type_dispatcher(
input.type(), detail::NegateOpDispatcher<detail::DeviceNegate>{}, input, stream, mr);
dispatch_type, MathOpDispatcher<DeviceNegate, NegateOps>{}, input, stream, mr);
default: CUDF_FAIL("Undefined unary operation");
}
}
Expand Down
Loading