Skip to content

Commit

Permalink
fix cases with bad SASS
Browse files Browse the repository at this point in the history
  • Loading branch information
fbusato committed Jan 29, 2025
1 parent f2956dc commit 0b81bae
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 34 deletions.
23 changes: 23 additions & 0 deletions cub/cub/thread/thread_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,29 @@ inline constexpr bool is_predefined_operator_v =
is_predefined_sum_mul_v<ReductionOp, T> || //
is_predefined_bitwise_v<ReductionOp, T>;

template <typename ReductionOp>
struct is_simd_operator : ::cuda::std::false_type
{};

template <typename T>
struct is_simd_operator<SimdSum<T>> : ::cuda::std::true_type
{};

template <typename T>
struct is_simd_operator<SimdMul<T>> : ::cuda::std::true_type
{};

template <typename T>
struct is_simd_operator<SimdMin<T>> : ::cuda::std::true_type
{};

template <typename T>
struct is_simd_operator<SimdMax<T>> : ::cuda::std::true_type
{};

template <typename ReductionOp>
inline constexpr bool is_simd_operator_v = is_simd_operator<ReductionOp>::value;

//----------------------------------------------------------------------------------------------------------------------
// Predefined CUDA operators to SIMD

Expand Down
48 changes: 22 additions & 26 deletions cub/cub/thread/thread_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -295,9 +295,7 @@ _CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE constexpr bool enable_simd_reduct

template <typename T, typename ReductionOp>
inline constexpr bool enable_ternary_reduction_sm90_v =
is_one_of_v<T, int32_t, uint32_t>
&& (is_predefined_min_max_v<ReductionOp, T> || is_predefined_bitwise_v<ReductionOp, T>
|| cub::detail::is_one_of_v<ReductionOp, ::cuda::std::plus<>, ::cuda::std::plus<T>>);
is_one_of_v<T, int32_t, uint32_t> && is_predefined_min_max_v<ReductionOp, T>;

# if defined(_CCCL_HAS_NVFP16)

Expand All @@ -318,10 +316,10 @@ inline constexpr bool enable_ternary_reduction_sm90_v<__nv_bfloat162, ReductionO

template <typename T, typename ReductionOp>
inline constexpr bool enable_ternary_reduction_sm50_v =
is_one_of_v<T, int32_t, uint32_t>
::cuda::std::is_integral_v<T> && sizeof(T) <= 4
&& (is_one_of_v<ReductionOp, ::cuda::std::plus<>, ::cuda::std::plus<T>> || is_predefined_bitwise_v<ReductionOp, T>);

template <typename Input, typename ReductionOp, typename AccumT>
template <typename Input, typename ReductionOp, typename Prom>
_CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE constexpr bool enable_ternary_reduction()
{
constexpr auto length = cub::detail::static_size_v<Input>;
Expand All @@ -331,28 +329,21 @@ _CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE constexpr bool enable_ternary_red
}
else
{
// apply SM90 min/max ternary reduction only if the input is natively int32/uint32
using T = detail::random_access_range_elem_t<Input>;
// clang-format off
NV_DISPATCH_TARGET(
NV_PROVIDES_SM_90,
(return enable_ternary_reduction_sm90_v<T, ReductionOp> || enable_ternary_reduction_sm50_v<T, ReductionOp>;),
(return enable_ternary_reduction_sm90_v<T, ReductionOp> || enable_ternary_reduction_sm50_v<Prom, ReductionOp>;),
NV_PROVIDES_SM_50,
(return enable_ternary_reduction_sm50_v<T, ReductionOp>;),
(return enable_ternary_reduction_sm50_v<Prom, ReductionOp>;),
NV_ANY_TARGET,
(return false;)
);
// clang-format on
}
}

/***********************************************************************************************************************
* Enable Promotion (Trait)
**********************************************************************************************************************/

template <typename Input, typename ReductionOp, typename AccumT, typename T = detail::random_access_range_elem_t<Input>>
inline constexpr bool enable_promotion_v =
::cuda::std::is_integral_v<T> && sizeof(T) <= 2 && is_predefined_operator_v<ReductionOp, T>;

/***********************************************************************************************************************
* Internal Reduction Algorithms: Sequential, Binary, Ternary
**********************************************************************************************************************/
Expand Down Expand Up @@ -453,6 +444,10 @@ _CCCL_DEVICE _CCCL_FORCEINLINE auto ThreadReduceSimd(const Input& input, Reducti
return unsafe_bitcast<UnpackedType>(result)[0];
}

template <typename ReductionOp, typename T>
inline constexpr bool enable_promotion_v =
is_predefined_min_max_v<ReductionOp, T> && ::cuda::std::is_integral_v<T> && sizeof(T) <= 2;

} // namespace internal

/***********************************************************************************************************************
Expand All @@ -472,27 +467,28 @@ _CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE AccumT ThreadReduce(const Input&
}
using cub::detail::is_one_of_v;
using namespace cub::internal;
using PromT = ::cuda::std::_If<enable_promotion_v<Input, ReductionOp, AccumT>, int, AccumT>;
// sizeof(ValueT) >= 8 requires too many registers
if constexpr ((!is_predefined_operator_v<ReductionOp, ValueT>
&& !is_one_of_v<ReductionOp, SimdMin<ValueT>, SimdMax<ValueT>>)
using PromT = ::cuda::std::_If<enable_promotion_v<ReductionOp, ValueT>, int, AccumT>;
// TODO: should be part of the tuning policy
if constexpr ((!is_predefined_operator_v<ReductionOp, ValueT> && !is_simd_operator_v<ReductionOp>)
|| sizeof(ValueT) >= 8)
{
return cub::internal::ThreadReduceSequential<AccumT>(input, reduction_op);
}
else if constexpr (is_one_of_v<ReductionOp, ::cuda::std::plus<>, ::cuda::std::plus<ValueT>>
&& is_one_of_v<ValueT, int32_t, uint32_t>)
{
NV_IF_TARGET(NV_PROVIDES_SM_90, //
(return cub::internal::ThreadReduceSequential<AccumT>(input, reduction_op);),
(return cub::internal::ThreadReduceTernaryTree<PromT>(input, reduction_op);));
}
else if constexpr (enable_simd_reduction<Input, ReductionOp, AccumT>())
{
return cub::internal::ThreadReduceSimd(input, reduction_op);
}
else if constexpr (enable_ternary_reduction<Input, ReductionOp, PromT>())
{
// with the current tuning policies, SM90/int32/+ uses too many registers (TODO: fix tuning policy)
if constexpr ((is_one_of_v<ReductionOp, ::cuda::std::plus<>, ::cuda::std::plus<PromT>>
&& is_one_of_v<PromT, int32_t, uint32_t>)
// the compiler generates bad code for int8/uint8 and min/max for SM90
|| (is_predefined_min_max_v<ReductionOp, ValueT> && is_one_of_v<PromT, int8_t, uint8_t>) )
{
NV_IF_TARGET(NV_PROVIDES_SM_90, //
(return cub::internal::ThreadReduceSequential<PromT>(input, reduction_op);));
}
return cub::internal::ThreadReduceTernaryTree<PromT>(input, reduction_op);
}
else
Expand Down
25 changes: 17 additions & 8 deletions cub/test/thread_reduce/catch2_test_thread_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
* Thread Reduce Wrapper Kernels
**********************************************************************************************************************/

// #define CCCL_CHECK_SASS // instantiate only the kernels useful for SASS inspection
#define CCCL_CHECK_SASS // instantiate only the kernels useful for SASS inspection

template <int NUM_ITEMS, typename T, typename ReduceOperator>
__global__ void thread_reduce_kernel(const T* __restrict__ d_in, T* __restrict__ d_out, ReduceOperator reduce_operator)
Expand Down Expand Up @@ -255,17 +255,24 @@ using narrow_precision_type_list = c2h::type_list<
#endif
>;

using fp_type_list =
c2h::type_list<float
#if !defined(CCCL_CHECK_SASS)
,
double
#endif
>;
#if defined(CCCL_CHECK_SASS)

using fp_type_list = c2h::type_list<float>;

using integral_type_list = c2h::type_list<::cuda::std::int8_t, ::cuda::std::int16_t, ::cuda::std::int32_t>;

using cub_operator_integral_list =
c2h::type_list<cuda::std::plus<>, cuda::std::multiplies<>, cuda::std::bit_xor<>, cuda::minimum<>>;

using cub_operator_fp_list = c2h::type_list<cuda::std::plus<>, cuda::minimum<>>;

#else // !defined(CCCL_CHECK_SASS)

using integral_type_list = c2h::
type_list<::cuda::std::int8_t, ::cuda::std::int16_t, ::cuda::std::uint16_t, ::cuda::std::int32_t, ::cuda::std::int64_t>;

using fp_type_list = c2h::type_list<float, double>;

using cub_operator_integral_list =
c2h::type_list<cuda::std::plus<>,
cuda::std::multiplies<>,
Expand All @@ -278,6 +285,8 @@ using cub_operator_integral_list =
using cub_operator_fp_list =
c2h::type_list<cuda::std::plus<>, cuda::std::multiplies<>, cuda::minimum<>, cuda::maximum<>>;

#endif // defined(CCCL_CHECK_SASS)

/***********************************************************************************************************************
* Verify results and kernel launch
**********************************************************************************************************************/
Expand Down

0 comments on commit 0b81bae

Please sign in to comment.