From 0b81bae650c277655be9553cb87a1b4949a2942a Mon Sep 17 00:00:00 2001 From: fbusato Date: Tue, 28 Jan 2025 17:29:39 -0800 Subject: [PATCH] fix cases with bad SASS --- cub/cub/thread/thread_operators.cuh | 23 +++++++++ cub/cub/thread/thread_reduce.cuh | 48 +++++++++---------- .../catch2_test_thread_reduce.cu | 25 ++++++---- 3 files changed, 62 insertions(+), 34 deletions(-) diff --git a/cub/cub/thread/thread_operators.cuh b/cub/cub/thread/thread_operators.cuh index c61903b74ed..aa1b116e722 100644 --- a/cub/cub/thread/thread_operators.cuh +++ b/cub/cub/thread/thread_operators.cuh @@ -630,6 +630,29 @@ inline constexpr bool is_predefined_operator_v = is_predefined_sum_mul_v || // is_predefined_bitwise_v; +template +struct is_simd_operator : ::cuda::std::false_type +{}; + +template +struct is_simd_operator> : ::cuda::std::true_type +{}; + +template +struct is_simd_operator> : ::cuda::std::true_type +{}; + +template +struct is_simd_operator> : ::cuda::std::true_type +{}; + +template +struct is_simd_operator> : ::cuda::std::true_type +{}; + +template +inline constexpr bool is_simd_operator_v = is_simd_operator::value; + //---------------------------------------------------------------------------------------------------------------------- // Predefined CUDA operators to SIMD diff --git a/cub/cub/thread/thread_reduce.cuh b/cub/cub/thread/thread_reduce.cuh index 8282fb48870..d05f2f1e90c 100644 --- a/cub/cub/thread/thread_reduce.cuh +++ b/cub/cub/thread/thread_reduce.cuh @@ -295,9 +295,7 @@ _CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE constexpr bool enable_simd_reduct template inline constexpr bool enable_ternary_reduction_sm90_v = - is_one_of_v - && (is_predefined_min_max_v || is_predefined_bitwise_v - || cub::detail::is_one_of_v, ::cuda::std::plus>); + is_one_of_v && is_predefined_min_max_v; # if defined(_CCCL_HAS_NVFP16) @@ -318,10 +316,10 @@ inline constexpr bool enable_ternary_reduction_sm90_v<__nv_bfloat162, ReductionO template inline constexpr bool enable_ternary_reduction_sm50_v = - is_one_of_v + ::cuda::std::is_integral_v && sizeof(T) <= 4 && (is_one_of_v, ::cuda::std::plus> || is_predefined_bitwise_v); -template +template _CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE constexpr bool enable_ternary_reduction() { constexpr auto length = cub::detail::static_size_v; @@ -331,13 +329,14 @@ _CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE constexpr bool enable_ternary_red } else { + // apply SM90 min/max ternary reduction only if the input is natively int32/uint32 using T = detail::random_access_range_elem_t; // clang-format off NV_DISPATCH_TARGET( NV_PROVIDES_SM_90, - (return enable_ternary_reduction_sm90_v || enable_ternary_reduction_sm50_v;), + (return enable_ternary_reduction_sm90_v || enable_ternary_reduction_sm50_v;), NV_PROVIDES_SM_50, - (return enable_ternary_reduction_sm50_v;), + (return enable_ternary_reduction_sm50_v;), NV_ANY_TARGET, (return false;) ); @@ -345,14 +344,6 @@ _CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE constexpr bool enable_ternary_red } } -/*********************************************************************************************************************** - * Enable Promotion (Trait) - **********************************************************************************************************************/ - -template > -inline constexpr bool enable_promotion_v = - ::cuda::std::is_integral_v && sizeof(T) <= 2 && is_predefined_operator_v; - /*********************************************************************************************************************** * Internal Reduction Algorithms: Sequential, Binary, Ternary **********************************************************************************************************************/ @@ -453,6 +444,10 @@ _CCCL_DEVICE _CCCL_FORCEINLINE auto ThreadReduceSimd(const Input& input, Reducti return unsafe_bitcast(result)[0]; } +template +inline constexpr bool enable_promotion_v = + is_predefined_min_max_v && ::cuda::std::is_integral_v && sizeof(T) <= 2; + } // namespace internal /*********************************************************************************************************************** @@ -472,27 +467,28 @@ _CCCL_NODISCARD _CCCL_DEVICE _CCCL_FORCEINLINE AccumT ThreadReduce(const Input& } using cub::detail::is_one_of_v; using namespace cub::internal; - using PromT = ::cuda::std::_If, int, AccumT>; - // sizeof(ValueT) >= 8 requires too many registers - if constexpr ((!is_predefined_operator_v - && !is_one_of_v, SimdMax>) + using PromT = ::cuda::std::_If, int, AccumT>; + // TODO: should be part of the tuning policy + if constexpr ((!is_predefined_operator_v && !is_simd_operator_v) || sizeof(ValueT) >= 8) { return cub::internal::ThreadReduceSequential(input, reduction_op); } - else if constexpr (is_one_of_v, ::cuda::std::plus> - && is_one_of_v) - { - NV_IF_TARGET(NV_PROVIDES_SM_90, // - (return cub::internal::ThreadReduceSequential(input, reduction_op);), - (return cub::internal::ThreadReduceTernaryTree(input, reduction_op);)); - } else if constexpr (enable_simd_reduction()) { return cub::internal::ThreadReduceSimd(input, reduction_op); } else if constexpr (enable_ternary_reduction()) { + // with the current tuning policies, SM90/int32/+ uses too many registers (TODO: fix tuning policy) + if constexpr ((is_one_of_v, ::cuda::std::plus> + && is_one_of_v) + // the compiler generates bad code for int8/uint8 and min/max for SM90 + || (is_predefined_min_max_v && is_one_of_v) ) + { + NV_IF_TARGET(NV_PROVIDES_SM_90, // + (return cub::internal::ThreadReduceSequential(input, reduction_op);)); + } return cub::internal::ThreadReduceTernaryTree(input, reduction_op); } else diff --git a/cub/test/thread_reduce/catch2_test_thread_reduce.cu b/cub/test/thread_reduce/catch2_test_thread_reduce.cu index ba7342db9a5..a88db400874 100644 --- a/cub/test/thread_reduce/catch2_test_thread_reduce.cu +++ b/cub/test/thread_reduce/catch2_test_thread_reduce.cu @@ -54,7 +54,7 @@ * Thread Reduce Wrapper Kernels **********************************************************************************************************************/ -// #define CCCL_CHECK_SASS // instantiate only the kernels useful for SASS inspection +#define CCCL_CHECK_SASS // instantiate only the kernels useful for SASS inspection template __global__ void thread_reduce_kernel(const T* __restrict__ d_in, T* __restrict__ d_out, ReduceOperator reduce_operator) @@ -255,17 +255,24 @@ using narrow_precision_type_list = c2h::type_list< #endif >; -using fp_type_list = - c2h::type_list; +#if defined(CCCL_CHECK_SASS) + +using fp_type_list = c2h::type_list; + +using integral_type_list = c2h::type_list<::cuda::std::int8_t, ::cuda::std::int16_t, ::cuda::std::int32_t>; + +using cub_operator_integral_list = + c2h::type_list, cuda::std::multiplies<>, cuda::std::bit_xor<>, cuda::minimum<>>; + +using cub_operator_fp_list = c2h::type_list, cuda::minimum<>>; + +#else // !defined(CCCL_CHECK_SASS) using integral_type_list = c2h:: type_list<::cuda::std::int8_t, ::cuda::std::int16_t, ::cuda::std::uint16_t, ::cuda::std::int32_t, ::cuda::std::int64_t>; +using fp_type_list = c2h::type_list; + using cub_operator_integral_list = c2h::type_list, cuda::std::multiplies<>, @@ -278,6 +285,8 @@ using cub_operator_integral_list = using cub_operator_fp_list = c2h::type_list, cuda::std::multiplies<>, cuda::minimum<>, cuda::maximum<>>; +#endif // defined(CCCL_CHECK_SASS) + /*********************************************************************************************************************** * Verify results and kernel launch **********************************************************************************************************************/