From 0d3476a954d9201425f455c66803ed3080ccf71d Mon Sep 17 00:00:00 2001 From: Nader Al Awar Date: Fri, 24 Jan 2025 18:37:54 +0000 Subject: [PATCH 01/11] Add `dependent_launch` parameter to `TripleChevronFactory` --- cub/cub/detail/launcher/cuda_runtime.cuh | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cub/cub/detail/launcher/cuda_runtime.cuh b/cub/cub/detail/launcher/cuda_runtime.cuh index 81ef450f424..f59c26d7fbb 100644 --- a/cub/cub/detail/launcher/cuda_runtime.cuh +++ b/cub/cub/detail/launcher/cuda_runtime.cuh @@ -21,10 +21,10 @@ namespace detail struct TripleChevronFactory { - CUB_RUNTIME_FUNCTION THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron - operator()(dim3 grid, dim3 block, _CUDA_VSTD::size_t shared_mem, cudaStream_t stream) const + CUB_RUNTIME_FUNCTION THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron operator()( + dim3 grid, dim3 block, _CUDA_VSTD::size_t shared_mem, cudaStream_t stream, bool dependent_launch = false) const { - return THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(grid, block, shared_mem, stream); + return THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(grid, block, shared_mem, stream, dependent_launch); } CUB_RUNTIME_FUNCTION cudaError_t PtxVersion(int& version) From b8f18e6977a053518817311fc8f47ad10bf090a1 Mon Sep 17 00:00:00 2001 From: Nader Al Awar Date: Fri, 24 Jan 2025 18:39:06 +0000 Subject: [PATCH 02/11] Add `ItemsPerTile()` method to `PolicyWrapper` --- cub/cub/util_device.cuh | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/cub/cub/util_device.cuh b/cub/cub/util_device.cuh index 498f17c1259..0655ef62013 100644 --- a/cub/cub/util_device.cuh +++ b/cub/cub/util_device.cuh @@ -553,6 +553,11 @@ struct PolicyWrapper< { return StaticPolicyT::ITEMS_PER_THREAD; } + + CUB_RUNTIME_FUNCTION static constexpr int ItemsPerTile() + { + return StaticPolicyT::ITEMS_PER_TILE; + } }; template From 1b464261e9698290f12edb64e03d25d3fa751064 Mon Sep 17 00:00:00 2001 From: Nader Al Awar Date: Fri, 24 Jan 2025 18:41:40 +0000 Subject: [PATCH 03/11] Add `MergeSortPolicyWrapper` --- .../dispatch/tuning/tuning_merge_sort.cuh | 32 ++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/cub/cub/device/dispatch/tuning/tuning_merge_sort.cuh b/cub/cub/device/dispatch/tuning/tuning_merge_sort.cuh index 94d54b08509..34a7d525acd 100644 --- a/cub/cub/device/dispatch/tuning/tuning_merge_sort.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_merge_sort.cuh @@ -42,10 +42,34 @@ CUB_NAMESPACE_BEGIN -namespace detail +namespace detail::merge_sort { -namespace merge_sort + +template +struct MergeSortPolicyWrapper : PolicyT +{ + CUB_RUNTIME_FUNCTION MergeSortPolicyWrapper(PolicyT base) + : PolicyT(base) + {} +}; + +template +struct MergeSortPolicyWrapper> + : StaticPolicyT +{ + CUB_RUNTIME_FUNCTION MergeSortPolicyWrapper(StaticPolicyT base) + : StaticPolicyT(base) + {} + + CUB_DEFINE_SUB_POLICY_GETTER(MergeSort); +}; + +template +CUB_RUNTIME_FUNCTION MergeSortPolicyWrapper MakeMergeSortPolicyWrapper(PolicyT policy) { + return MergeSortPolicyWrapper{policy}; +} + template struct policy_hub { @@ -88,8 +112,8 @@ struct policy_hub using MaxPolicy = Policy600; }; -} // namespace merge_sort -} // namespace detail + +} // namespace detail::merge_sort template using DeviceMergeSortPolicy CCCL_DEPRECATED_BECAUSE("This class is considered an implementation detail and it will be " From c690075577d9bf418a700b2a25b4092f4409424b Mon Sep 17 00:00:00 2001 From: Nader Al Awar Date: Fri, 24 Jan 2025 18:43:42 +0000 Subject: [PATCH 04/11] Add `KernelSource` and use `launcher_factory` to launch `merge_sort` kernels --- .../device/dispatch/dispatch_merge_sort.cuh | 162 +++++++++++------- 1 file changed, 103 insertions(+), 59 deletions(-) diff --git a/cub/cub/device/dispatch/dispatch_merge_sort.cuh b/cub/cub/device/dispatch/dispatch_merge_sort.cuh index 056522e162d..95ec287c80a 100644 --- a/cub/cub/device/dispatch/dispatch_merge_sort.cuh +++ b/cub/cub/device/dispatch/dispatch_merge_sort.cuh @@ -46,7 +46,6 @@ #include #include -#include #include #include @@ -54,6 +53,51 @@ CUB_NAMESPACE_BEGIN +namespace detail::merge_sort +{ +template +struct DeviceMergeSortKernelSource +{ + using KeyT = cub::detail::value_t; + using ValueT = cub::detail::value_t; + + CUB_DEFINE_KERNEL_GETTER( + MergeSortBlockSortKernel, + DeviceMergeSortBlockSortKernel< + MaxPolicyT, + KeyInputIteratorT, + ValueInputIteratorT, + KeyIteratorT, + ValueIteratorT, + OffsetT, + CompareOpT, + KeyT, + ValueT>); + + CUB_DEFINE_KERNEL_GETTER(MergeSortPartitionKernel, + DeviceMergeSortPartitionKernel); + + CUB_DEFINE_KERNEL_GETTER( + MergeSortMergeKernel, + DeviceMergeSortMergeKernel); +}; + +} // namespace detail::merge_sort + /******************************************************************************* * Policy ******************************************************************************/ @@ -64,14 +108,23 @@ template > + typename PolicyHub = detail::merge_sort::policy_hub, + typename KernelSource = detail::merge_sort::DeviceMergeSortKernelSource< + typename PolicyHub::MaxPolicy, + KeyInputIteratorT, + ValueInputIteratorT, + KeyIteratorT, + ValueIteratorT, + OffsetT, + CompareOpT>, + typename KernelLauncherFactory = detail::TripleChevronFactory> struct DispatchMergeSort { using KeyT = cub::detail::value_t; using ValueT = cub::detail::value_t; /// Whether or not there are values to be trucked along with keys - static constexpr bool KEYS_ONLY = std::is_same::value; + static constexpr bool KEYS_ONLY = ::cuda::std::is_same_v; // Problem state @@ -106,6 +159,10 @@ struct DispatchMergeSort int ptx_version; + KernelSource kernel_source; + + KernelLauncherFactory launcher_factory; + // Constructor CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE DispatchMergeSort( void* d_temp_storage, @@ -117,7 +174,9 @@ struct DispatchMergeSort OffsetT num_items, CompareOpT compare_op, cudaStream_t stream, - int ptx_version) + int ptx_version, + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}) : d_temp_storage(d_temp_storage) , temp_storage_bytes(temp_storage_bytes) , d_input_keys(d_input_keys) @@ -128,11 +187,13 @@ struct DispatchMergeSort , compare_op(compare_op) , stream(stream) , ptx_version(ptx_version) + , kernel_source(kernel_source) + , launcher_factory(launcher_factory) {} // Invocation template - CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke() + CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke(ActivePolicyT policy = {}) { using MergePolicyT = typename ActivePolicyT::MergeSortPolicy; @@ -163,8 +224,9 @@ struct DispatchMergeSort do { - constexpr auto tile_size = merge_sort_helper_t::policy_t::ITEMS_PER_TILE; - const auto num_tiles = ::cuda::ceil_div(num_items, tile_size); + auto wrapped_policy = detail::merge_sort::MakeMergeSortPolicyWrapper(policy); + const auto tile_size = wrapped_policy.MergeSort().ItemsPerTile(); + const auto num_tiles = ::cuda::ceil_div(num_items, tile_size); const auto merge_partitions_size = static_cast(1 + num_tiles) * sizeof(OffsetT); const auto temporary_keys_storage_size = static_cast(num_items * sizeof(KeyT)); @@ -214,29 +276,18 @@ struct DispatchMergeSort auto items_buffer = static_cast(allocations[2]); // Invoke DeviceMergeSortBlockSortKernel - THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( - static_cast(num_tiles), merge_sort_helper_t::policy_t::BLOCK_THREADS, 0, stream, true) - .doit( - detail::merge_sort::DeviceMergeSortBlockSortKernel< - typename PolicyHub::MaxPolicy, - KeyInputIteratorT, - ValueInputIteratorT, - KeyIteratorT, - ValueIteratorT, - OffsetT, - CompareOpT, - KeyT, - ValueT>, - ping, - d_input_keys, - d_input_items, - d_output_keys, - d_output_items, - num_items, - keys_buffer, - items_buffer, - compare_op, - cub::detail::vsmem_t{allocations[3]}); + launcher_factory(static_cast(num_tiles), wrapped_policy.MergeSort().BlockThreads(), 0, stream, true) + .doit(kernel_source.MergeSortBlockSortKernel(), + ping, + d_input_keys, + d_input_items, + d_output_keys, + d_output_items, + num_items, + keys_buffer, + items_buffer, + compare_op, + cub::detail::vsmem_t{allocations[3]}); error = CubDebug(detail::DebugSyncStream(stream)); if (cudaSuccess != error) @@ -273,9 +324,8 @@ struct DispatchMergeSort const OffsetT target_merged_tiles_number = OffsetT(2) << pass; // Partition - THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( - partition_grid_size, threads_per_partition_block, 0, stream, true) - .doit(detail::merge_sort::DeviceMergeSortPartitionKernel, + launcher_factory(partition_grid_size, threads_per_partition_block, 0, stream, true) + .doit(kernel_source.MergeSortPartitionKernel(), ping, d_output_keys, keys_buffer, @@ -300,29 +350,19 @@ struct DispatchMergeSort } // Merge - THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( - static_cast(num_tiles), static_cast(merge_sort_helper_t::policy_t::BLOCK_THREADS), 0, stream, true) - .doit( - detail::merge_sort::DeviceMergeSortMergeKernel< - typename PolicyHub::MaxPolicy, - KeyInputIteratorT, - ValueInputIteratorT, - KeyIteratorT, - ValueIteratorT, - OffsetT, - CompareOpT, - KeyT, - ValueT>, - ping, - d_output_keys, - d_output_items, - num_items, - keys_buffer, - items_buffer, - compare_op, - merge_partitions, - target_merged_tiles_number, - cub::detail::vsmem_t{allocations[3]}); + launcher_factory( + static_cast(num_tiles), static_cast(wrapped_policy.MergeSort().BlockThreads()), 0, stream, true) + .doit(kernel_source.MergeSortMergeKernel(), + ping, + d_output_keys, + d_output_items, + num_items, + keys_buffer, + items_buffer, + compare_op, + merge_partitions, + target_merged_tiles_number, + cub::detail::vsmem_t{allocations[3]}); error = CubDebug(detail::DebugSyncStream(stream)); if (cudaSuccess != error) @@ -342,6 +382,7 @@ struct DispatchMergeSort return error; } + template CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Dispatch( void* d_temp_storage, std::size_t& temp_storage_bytes, @@ -351,7 +392,9 @@ struct DispatchMergeSort ValueIteratorT d_output_items, OffsetT num_items, CompareOpT compare_op, - cudaStream_t stream) + cudaStream_t stream, + KernelSource kernel_source = {}, + MaxPolicyT max_policy = {}) { cudaError error = cudaSuccess; do @@ -375,10 +418,11 @@ struct DispatchMergeSort num_items, compare_op, stream, - ptx_version); + ptx_version, + kernel_source); // Dispatch to chained policy - error = CubDebug(PolicyHub::MaxPolicy::Invoke(ptx_version, dispatch)); + error = CubDebug(max_policy.Invoke(ptx_version, dispatch)); if (cudaSuccess != error) { break; From 7cced5e0fc7b7f9a21f0b2118c135fb23859d2c6 Mon Sep 17 00:00:00 2001 From: Nader Al Awar Date: Tue, 28 Jan 2025 20:12:57 +0000 Subject: [PATCH 05/11] Add extra layer of indirection to `merge_sort_vsmem_helper_t` to select constexpr `BLOCK_THREADS` if available --- .../device/dispatch/dispatch_merge_sort.cuh | 9 ++++++-- .../device/dispatch/kernels/merge_sort.cuh | 21 ++++++++++++++++++- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/cub/cub/device/dispatch/dispatch_merge_sort.cuh b/cub/cub/device/dispatch/dispatch_merge_sort.cuh index 95ec287c80a..132f4a271ea 100644 --- a/cub/cub/device/dispatch/dispatch_merge_sort.cuh +++ b/cub/cub/device/dispatch/dispatch_merge_sort.cuh @@ -276,7 +276,8 @@ struct DispatchMergeSort auto items_buffer = static_cast(allocations[2]); // Invoke DeviceMergeSortBlockSortKernel - launcher_factory(static_cast(num_tiles), wrapped_policy.MergeSort().BlockThreads(), 0, stream, true) + launcher_factory( + static_cast(num_tiles), merge_sort_helper_t::block_threads_helper_t::BlockThreads(), 0, stream, true) .doit(kernel_source.MergeSortBlockSortKernel(), ping, d_input_keys, @@ -351,7 +352,11 @@ struct DispatchMergeSort // Merge launcher_factory( - static_cast(num_tiles), static_cast(wrapped_policy.MergeSort().BlockThreads()), 0, stream, true) + static_cast(num_tiles), + static_cast(merge_sort_helper_t::block_threads_helper_t::BlockThreads()), + 0, + stream, + true) .doit(kernel_source.MergeSortMergeKernel(), ping, d_output_keys, diff --git a/cub/cub/device/dispatch/kernels/merge_sort.cuh b/cub/cub/device/dispatch/kernels/merge_sort.cuh index 1065313c20d..29d93f9c1ed 100644 --- a/cub/cub/device/dispatch/kernels/merge_sort.cuh +++ b/cub/cub/device/dispatch/kernels/merge_sort.cuh @@ -48,6 +48,24 @@ struct dual_policy_agent_helper_t static constexpr auto fallback_size = sizeof(typename fallback_agent_t::TempStorage); }; +template +struct merge_sort_vsmem_block_threads_t +{ + CUB_RUNTIME_FUNCTION static int BlockThreads() + { + return PolicyT::BlockThreads(); + } +}; + +template +struct merge_sort_vsmem_block_threads_t> +{ + CUB_RUNTIME_FUNCTION static constexpr int BlockThreads() + { + return StaticPolicyT::BLOCK_THREADS; + } +}; + /** * @brief Helper class template for merge sort-specific virtual shared memory handling. The merge sort algorithm in its * current implementation relies on the fact that both the sorting as well as the merging kernels use the same tile @@ -115,7 +133,8 @@ public: using policy_t = ::cuda::std::_If; using block_sort_agent_t = ::cuda::std::_If; - using merge_agent_t = ::cuda::std::_If; + using merge_agent_t = ::cuda::std::_If; + using block_threads_helper_t = merge_sort_vsmem_block_threads_t; }; template Date: Tue, 28 Jan 2025 21:30:03 +0000 Subject: [PATCH 06/11] Revert "Add extra layer of indirection to `merge_sort_vsmem_helper_t` to select constexpr `BLOCK_THREADS` if available" This reverts commit 7cced5e0fc7b7f9a21f0b2118c135fb23859d2c6. --- .../device/dispatch/dispatch_merge_sort.cuh | 9 ++------ .../device/dispatch/kernels/merge_sort.cuh | 21 +------------------ 2 files changed, 3 insertions(+), 27 deletions(-) diff --git a/cub/cub/device/dispatch/dispatch_merge_sort.cuh b/cub/cub/device/dispatch/dispatch_merge_sort.cuh index 132f4a271ea..95ec287c80a 100644 --- a/cub/cub/device/dispatch/dispatch_merge_sort.cuh +++ b/cub/cub/device/dispatch/dispatch_merge_sort.cuh @@ -276,8 +276,7 @@ struct DispatchMergeSort auto items_buffer = static_cast(allocations[2]); // Invoke DeviceMergeSortBlockSortKernel - launcher_factory( - static_cast(num_tiles), merge_sort_helper_t::block_threads_helper_t::BlockThreads(), 0, stream, true) + launcher_factory(static_cast(num_tiles), wrapped_policy.MergeSort().BlockThreads(), 0, stream, true) .doit(kernel_source.MergeSortBlockSortKernel(), ping, d_input_keys, @@ -352,11 +351,7 @@ struct DispatchMergeSort // Merge launcher_factory( - static_cast(num_tiles), - static_cast(merge_sort_helper_t::block_threads_helper_t::BlockThreads()), - 0, - stream, - true) + static_cast(num_tiles), static_cast(wrapped_policy.MergeSort().BlockThreads()), 0, stream, true) .doit(kernel_source.MergeSortMergeKernel(), ping, d_output_keys, diff --git a/cub/cub/device/dispatch/kernels/merge_sort.cuh b/cub/cub/device/dispatch/kernels/merge_sort.cuh index 29d93f9c1ed..1065313c20d 100644 --- a/cub/cub/device/dispatch/kernels/merge_sort.cuh +++ b/cub/cub/device/dispatch/kernels/merge_sort.cuh @@ -48,24 +48,6 @@ struct dual_policy_agent_helper_t static constexpr auto fallback_size = sizeof(typename fallback_agent_t::TempStorage); }; -template -struct merge_sort_vsmem_block_threads_t -{ - CUB_RUNTIME_FUNCTION static int BlockThreads() - { - return PolicyT::BlockThreads(); - } -}; - -template -struct merge_sort_vsmem_block_threads_t> -{ - CUB_RUNTIME_FUNCTION static constexpr int BlockThreads() - { - return StaticPolicyT::BLOCK_THREADS; - } -}; - /** * @brief Helper class template for merge sort-specific virtual shared memory handling. The merge sort algorithm in its * current implementation relies on the fact that both the sorting as well as the merging kernels use the same tile @@ -133,8 +115,7 @@ public: using policy_t = ::cuda::std::_If; using block_sort_agent_t = ::cuda::std::_If; - using merge_agent_t = ::cuda::std::_If; - using block_threads_helper_t = merge_sort_vsmem_block_threads_t; + using merge_agent_t = ::cuda::std::_If; }; template Date: Tue, 28 Jan 2025 22:04:56 +0000 Subject: [PATCH 07/11] Move the vsmem_helper to kernel source and read `BlockThreads` from there instead of the policy directly --- .../device/dispatch/dispatch_merge_sort.cuh | 47 +++++++++++-------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/cub/cub/device/dispatch/dispatch_merge_sort.cuh b/cub/cub/device/dispatch/dispatch_merge_sort.cuh index 95ec287c80a..c0d0bf2b3ff 100644 --- a/cub/cub/device/dispatch/dispatch_merge_sort.cuh +++ b/cub/cub/device/dispatch/dispatch_merge_sort.cuh @@ -67,6 +67,28 @@ struct DeviceMergeSortKernelSource using KeyT = cub::detail::value_t; using ValueT = cub::detail::value_t; + using merge_sort_helper_t = detail::merge_sort::merge_sort_vsmem_helper_t< + typename MaxPolicyT::MergeSortPolicy, + KeyInputIteratorT, + ValueInputIteratorT, + KeyIteratorT, + ValueIteratorT, + OffsetT, + CompareOpT, + KeyT, + ValueT>; + + using BlockSortVSmemHelperT = detail::vsmem_helper_impl; + using MergeAgentVSmemHelperT = detail::vsmem_helper_impl; + + static constexpr std::size_t block_sort_vsmem_per_block = BlockSortVSmemHelperT::vsmem_per_block; + static constexpr std::size_t merge_vsmem_per_block = MergeAgentVSmemHelperT::vsmem_per_block; + + CUB_RUNTIME_FUNCTION static constexpr int BlockThreads(typename MaxPolicyT::MergeSortPolicy /*policy*/) + { + return merge_sort_helper_t::policy_t::BLOCK_THREADS; + } + CUB_DEFINE_KERNEL_GETTER( MergeSortBlockSortKernel, DeviceMergeSortBlockSortKernel< @@ -195,22 +217,6 @@ struct DispatchMergeSort template CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke(ActivePolicyT policy = {}) { - using MergePolicyT = typename ActivePolicyT::MergeSortPolicy; - - using merge_sort_helper_t = detail::merge_sort::merge_sort_vsmem_helper_t< - MergePolicyT, - KeyInputIteratorT, - ValueInputIteratorT, - KeyIteratorT, - ValueIteratorT, - OffsetT, - CompareOpT, - KeyT, - ValueT>; - - using BlockSortVSmemHelperT = detail::vsmem_helper_impl; - using MergeAgentVSmemHelperT = detail::vsmem_helper_impl; - cudaError error = cudaSuccess; if (num_items == 0) @@ -236,8 +242,8 @@ struct DispatchMergeSort * Merge sort supports large types, which can lead to excessive shared memory size requirements. In these cases, * merge sort allocates virtual shared memory that resides in global memory. */ - const std::size_t block_sort_smem_size = num_tiles * BlockSortVSmemHelperT::vsmem_per_block; - const std::size_t merge_smem_size = num_tiles * MergeAgentVSmemHelperT::vsmem_per_block; + const std::size_t block_sort_smem_size = num_tiles * KernelSource::block_sort_vsmem_per_block; + const std::size_t merge_smem_size = num_tiles * KernelSource::merge_vsmem_per_block; const std::size_t virtual_shared_memory_size = (::cuda::std::max)(block_sort_smem_size, merge_smem_size); void* allocations[4] = {nullptr, nullptr, nullptr, nullptr}; @@ -276,7 +282,8 @@ struct DispatchMergeSort auto items_buffer = static_cast(allocations[2]); // Invoke DeviceMergeSortBlockSortKernel - launcher_factory(static_cast(num_tiles), wrapped_policy.MergeSort().BlockThreads(), 0, stream, true) + launcher_factory( + static_cast(num_tiles), kernel_source.BlockThreads(wrapped_policy.MergeSort()), 0, stream, true) .doit(kernel_source.MergeSortBlockSortKernel(), ping, d_input_keys, @@ -351,7 +358,7 @@ struct DispatchMergeSort // Merge launcher_factory( - static_cast(num_tiles), static_cast(wrapped_policy.MergeSort().BlockThreads()), 0, stream, true) + static_cast(num_tiles), kernel_source.BlockThreads(wrapped_policy.MergeSort()), 0, stream, true) .doit(kernel_source.MergeSortMergeKernel(), ping, d_output_keys, From 39589d460225757498f0eb0e0a91824175568bf2 Mon Sep 17 00:00:00 2001 From: Nader Al Awar Date: Wed, 29 Jan 2025 00:19:09 +0000 Subject: [PATCH 08/11] Make `BlockThreads` templated on the policy type --- cub/cub/device/dispatch/dispatch_merge_sort.cuh | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cub/cub/device/dispatch/dispatch_merge_sort.cuh b/cub/cub/device/dispatch/dispatch_merge_sort.cuh index c0d0bf2b3ff..702752c55b2 100644 --- a/cub/cub/device/dispatch/dispatch_merge_sort.cuh +++ b/cub/cub/device/dispatch/dispatch_merge_sort.cuh @@ -84,7 +84,8 @@ struct DeviceMergeSortKernelSource static constexpr std::size_t block_sort_vsmem_per_block = BlockSortVSmemHelperT::vsmem_per_block; static constexpr std::size_t merge_vsmem_per_block = MergeAgentVSmemHelperT::vsmem_per_block; - CUB_RUNTIME_FUNCTION static constexpr int BlockThreads(typename MaxPolicyT::MergeSortPolicy /*policy*/) + template + CUB_RUNTIME_FUNCTION static constexpr int BlockThreads(PolicyT /*policy*/) { return merge_sort_helper_t::policy_t::BLOCK_THREADS; } From d702dbb7bf41033ded5eed97aef7c640e1ded7ee Mon Sep 17 00:00:00 2001 From: Nader Al Awar Date: Wed, 29 Jan 2025 16:35:40 +0000 Subject: [PATCH 09/11] Obtain `ItemsPerTile` from the kernel source through vsmem helper --- cub/cub/device/dispatch/dispatch_merge_sort.cuh | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/cub/cub/device/dispatch/dispatch_merge_sort.cuh b/cub/cub/device/dispatch/dispatch_merge_sort.cuh index 702752c55b2..9e8219f666f 100644 --- a/cub/cub/device/dispatch/dispatch_merge_sort.cuh +++ b/cub/cub/device/dispatch/dispatch_merge_sort.cuh @@ -90,6 +90,12 @@ struct DeviceMergeSortKernelSource return merge_sort_helper_t::policy_t::BLOCK_THREADS; } + template + CUB_RUNTIME_FUNCTION static constexpr int ItemsPerTile(PolicyT /*policy*/) + { + return merge_sort_helper_t::policy_t::ITEMS_PER_TILE; + } + CUB_DEFINE_KERNEL_GETTER( MergeSortBlockSortKernel, DeviceMergeSortBlockSortKernel< @@ -232,7 +238,7 @@ struct DispatchMergeSort do { auto wrapped_policy = detail::merge_sort::MakeMergeSortPolicyWrapper(policy); - const auto tile_size = wrapped_policy.MergeSort().ItemsPerTile(); + const auto tile_size = kernel_source.ItemsPerTile(wrapped_policy.MergeSort()); const auto num_tiles = ::cuda::ceil_div(num_items, tile_size); const auto merge_partitions_size = static_cast(1 + num_tiles) * sizeof(OffsetT); From 4cbee12e5bc0af8ee1037ac544c5070b504f9ac0 Mon Sep 17 00:00:00 2001 From: Nader Al Awar Date: Wed, 29 Jan 2025 20:30:04 +0000 Subject: [PATCH 10/11] Change vsmem indirection so that it is its own template parameter passed to `DispatchMergeSort` --- .../device/dispatch/dispatch_merge_sort.cuh | 98 +++++++++---------- .../device/dispatch/kernels/merge_sort.cuh | 22 +++++ 2 files changed, 66 insertions(+), 54 deletions(-) diff --git a/cub/cub/device/dispatch/dispatch_merge_sort.cuh b/cub/cub/device/dispatch/dispatch_merge_sort.cuh index 9e8219f666f..98a4b40e8f8 100644 --- a/cub/cub/device/dispatch/dispatch_merge_sort.cuh +++ b/cub/cub/device/dispatch/dispatch_merge_sort.cuh @@ -67,35 +67,6 @@ struct DeviceMergeSortKernelSource using KeyT = cub::detail::value_t; using ValueT = cub::detail::value_t; - using merge_sort_helper_t = detail::merge_sort::merge_sort_vsmem_helper_t< - typename MaxPolicyT::MergeSortPolicy, - KeyInputIteratorT, - ValueInputIteratorT, - KeyIteratorT, - ValueIteratorT, - OffsetT, - CompareOpT, - KeyT, - ValueT>; - - using BlockSortVSmemHelperT = detail::vsmem_helper_impl; - using MergeAgentVSmemHelperT = detail::vsmem_helper_impl; - - static constexpr std::size_t block_sort_vsmem_per_block = BlockSortVSmemHelperT::vsmem_per_block; - static constexpr std::size_t merge_vsmem_per_block = MergeAgentVSmemHelperT::vsmem_per_block; - - template - CUB_RUNTIME_FUNCTION static constexpr int BlockThreads(PolicyT /*policy*/) - { - return merge_sort_helper_t::policy_t::BLOCK_THREADS; - } - - template - CUB_RUNTIME_FUNCTION static constexpr int ItemsPerTile(PolicyT /*policy*/) - { - return merge_sort_helper_t::policy_t::ITEMS_PER_TILE; - } - CUB_DEFINE_KERNEL_GETTER( MergeSortBlockSortKernel, DeviceMergeSortBlockSortKernel< @@ -131,22 +102,33 @@ struct DeviceMergeSortKernelSource * Policy ******************************************************************************/ -template , - typename KernelSource = detail::merge_sort::DeviceMergeSortKernelSource< - typename PolicyHub::MaxPolicy, - KeyInputIteratorT, - ValueInputIteratorT, - KeyIteratorT, - ValueIteratorT, - OffsetT, - CompareOpT>, - typename KernelLauncherFactory = detail::TripleChevronFactory> +template < + typename KeyInputIteratorT, + typename ValueInputIteratorT, + typename KeyIteratorT, + typename ValueIteratorT, + typename OffsetT, + typename CompareOpT, + typename PolicyHub = detail::merge_sort::policy_hub, + typename KernelSource = detail::merge_sort::DeviceMergeSortKernelSource< + typename PolicyHub::MaxPolicy, + KeyInputIteratorT, + ValueInputIteratorT, + KeyIteratorT, + ValueIteratorT, + OffsetT, + CompareOpT>, + typename KernelLauncherFactory = detail::TripleChevronFactory, + typename VSMemHelperPolicyT = detail::merge_sort::merge_sort_vsmem_helper_t< + typename PolicyHub::MaxPolicy::MergeSortPolicy, + KeyInputIteratorT, + ValueInputIteratorT, + KeyIteratorT, + ValueIteratorT, + OffsetT, + CompareOpT, + cub::detail::value_t, + cub::detail::value_t>> struct DispatchMergeSort { using KeyT = cub::detail::value_t; @@ -192,6 +174,8 @@ struct DispatchMergeSort KernelLauncherFactory launcher_factory; + VSMemHelperPolicyT vsmem_helper; + // Constructor CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE DispatchMergeSort( void* d_temp_storage, @@ -205,7 +189,8 @@ struct DispatchMergeSort cudaStream_t stream, int ptx_version, KernelSource kernel_source = {}, - KernelLauncherFactory launcher_factory = {}) + KernelLauncherFactory launcher_factory = {}, + VSMemHelperPolicyT vsmem_helper = {}) : d_temp_storage(d_temp_storage) , temp_storage_bytes(temp_storage_bytes) , d_input_keys(d_input_keys) @@ -218,6 +203,7 @@ struct DispatchMergeSort , ptx_version(ptx_version) , kernel_source(kernel_source) , launcher_factory(launcher_factory) + , vsmem_helper(vsmem_helper) {} // Invocation @@ -238,7 +224,7 @@ struct DispatchMergeSort do { auto wrapped_policy = detail::merge_sort::MakeMergeSortPolicyWrapper(policy); - const auto tile_size = kernel_source.ItemsPerTile(wrapped_policy.MergeSort()); + const auto tile_size = vsmem_helper.ItemsPerTile(wrapped_policy.MergeSort()); const auto num_tiles = ::cuda::ceil_div(num_items, tile_size); const auto merge_partitions_size = static_cast(1 + num_tiles) * sizeof(OffsetT); @@ -249,8 +235,8 @@ struct DispatchMergeSort * Merge sort supports large types, which can lead to excessive shared memory size requirements. In these cases, * merge sort allocates virtual shared memory that resides in global memory. */ - const std::size_t block_sort_smem_size = num_tiles * KernelSource::block_sort_vsmem_per_block; - const std::size_t merge_smem_size = num_tiles * KernelSource::merge_vsmem_per_block; + const std::size_t block_sort_smem_size = num_tiles * vsmem_helper.block_sort_vsmem_per_block(); + const std::size_t merge_smem_size = num_tiles * vsmem_helper.merge_vsmem_per_block(); const std::size_t virtual_shared_memory_size = (::cuda::std::max)(block_sort_smem_size, merge_smem_size); void* allocations[4] = {nullptr, nullptr, nullptr, nullptr}; @@ -290,7 +276,7 @@ struct DispatchMergeSort // Invoke DeviceMergeSortBlockSortKernel launcher_factory( - static_cast(num_tiles), kernel_source.BlockThreads(wrapped_policy.MergeSort()), 0, stream, true) + static_cast(num_tiles), vsmem_helper.BlockThreads(wrapped_policy.MergeSort()), 0, stream, true) .doit(kernel_source.MergeSortBlockSortKernel(), ping, d_input_keys, @@ -365,7 +351,7 @@ struct DispatchMergeSort // Merge launcher_factory( - static_cast(num_tiles), kernel_source.BlockThreads(wrapped_policy.MergeSort()), 0, stream, true) + static_cast(num_tiles), vsmem_helper.BlockThreads(wrapped_policy.MergeSort()), 0, stream, true) .doit(kernel_source.MergeSortMergeKernel(), ping, d_output_keys, @@ -407,8 +393,10 @@ struct DispatchMergeSort OffsetT num_items, CompareOpT compare_op, cudaStream_t stream, - KernelSource kernel_source = {}, - MaxPolicyT max_policy = {}) + KernelSource kernel_source = {}, + KernelLauncherFactory launcher_factory = {}, + MaxPolicyT max_policy = {}, + VSMemHelperPolicyT vsmem_helper = {}) { cudaError error = cudaSuccess; do @@ -433,7 +421,9 @@ struct DispatchMergeSort compare_op, stream, ptx_version, - kernel_source); + kernel_source, + launcher_factory, + vsmem_helper); // Dispatch to chained policy error = CubDebug(max_policy.Invoke(ptx_version, dispatch)); diff --git a/cub/cub/device/dispatch/kernels/merge_sort.cuh b/cub/cub/device/dispatch/kernels/merge_sort.cuh index 1065313c20d..3e20b304d11 100644 --- a/cub/cub/device/dispatch/kernels/merge_sort.cuh +++ b/cub/cub/device/dispatch/kernels/merge_sort.cuh @@ -116,6 +116,28 @@ public: using block_sort_agent_t = ::cuda::std::_If; using merge_agent_t = ::cuda::std::_If; + + CUB_RUNTIME_FUNCTION static constexpr std::size_t block_sort_vsmem_per_block() + { + return detail::vsmem_helper_impl::vsmem_per_block; + } + + CUB_RUNTIME_FUNCTION static constexpr std::size_t merge_vsmem_per_block() + { + return detail::vsmem_helper_impl::vsmem_per_block; + } + + template + CUB_RUNTIME_FUNCTION static constexpr int BlockThreads(PolicyT /*policy*/) + { + return policy_t::BLOCK_THREADS; + } + + template + CUB_RUNTIME_FUNCTION static constexpr int ItemsPerTile(PolicyT /*policy*/) + { + return policy_t::ITEMS_PER_TILE; + } }; template Date: Thu, 30 Jan 2025 02:32:13 +0000 Subject: [PATCH 11/11] Use `_CCCL_HOST_DEVICE` for RTC --- cub/cub/device/dispatch/kernels/merge_sort.cuh | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/cub/cub/device/dispatch/kernels/merge_sort.cuh b/cub/cub/device/dispatch/kernels/merge_sort.cuh index 3e20b304d11..c9a8a61395a 100644 --- a/cub/cub/device/dispatch/kernels/merge_sort.cuh +++ b/cub/cub/device/dispatch/kernels/merge_sort.cuh @@ -117,24 +117,24 @@ public: ::cuda::std::_If; using merge_agent_t = ::cuda::std::_If; - CUB_RUNTIME_FUNCTION static constexpr std::size_t block_sort_vsmem_per_block() + _CCCL_HOST_DEVICE static constexpr ::cuda::std::size_t block_sort_vsmem_per_block() { return detail::vsmem_helper_impl::vsmem_per_block; } - CUB_RUNTIME_FUNCTION static constexpr std::size_t merge_vsmem_per_block() + _CCCL_HOST_DEVICE static constexpr ::cuda::std::size_t merge_vsmem_per_block() { return detail::vsmem_helper_impl::vsmem_per_block; } template - CUB_RUNTIME_FUNCTION static constexpr int BlockThreads(PolicyT /*policy*/) + _CCCL_HOST_DEVICE static constexpr int BlockThreads(PolicyT /*policy*/) { return policy_t::BLOCK_THREADS; } template - CUB_RUNTIME_FUNCTION static constexpr int ItemsPerTile(PolicyT /*policy*/) + _CCCL_HOST_DEVICE static constexpr int ItemsPerTile(PolicyT /*policy*/) { return policy_t::ITEMS_PER_TILE; }