Skip to content

Commit

Permalink
Add dynamic CUB dispatch for merge_sort (#3525)
Browse files Browse the repository at this point in the history
* Add `dependent_launch` parameter to `TripleChevronFactory`
* Add `ItemsPerTile()` method to `PolicyWrapper`
* Add `MergeSortPolicyWrapper`
* Add `KernelSource` and use `launcher_factory` to launch `merge_sort` kernels
* Move the vsmem_helper to kernel source and read `BlockThreads` from there instead of the policy directly
* Make `BlockThreads` templated on the policy type
* Obtain `ItemsPerTile` from the kernel source through vsmem helper
* Change vsmem indirection so that it is its own template parameter passed to `DispatchMergeSort`
* Use `_CCCL_HOST_DEVICE` for RTC
  • Loading branch information
NaderAlAwar authored Jan 30, 2025
1 parent 3e888d8 commit 15a0116
Show file tree
Hide file tree
Showing 5 changed files with 189 additions and 90 deletions.
6 changes: 3 additions & 3 deletions cub/cub/detail/launcher/cuda_runtime.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@ namespace detail

struct TripleChevronFactory
{
CUB_RUNTIME_FUNCTION THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron
operator()(dim3 grid, dim3 block, _CUDA_VSTD::size_t shared_mem, cudaStream_t stream) const
CUB_RUNTIME_FUNCTION THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron operator()(
dim3 grid, dim3 block, _CUDA_VSTD::size_t shared_mem, cudaStream_t stream, bool dependent_launch = false) const
{
return THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(grid, block, shared_mem, stream);
return THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(grid, block, shared_mem, stream, dependent_launch);
}

CUB_RUNTIME_FUNCTION cudaError_t PtxVersion(int& version)
Expand Down
214 changes: 131 additions & 83 deletions cub/cub/device/dispatch/dispatch_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -46,32 +46,96 @@
#include <cub/util_vsmem.cuh>

#include <thrust/detail/integer_math.h>
#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>

#include <cuda/std/__algorithm/max.h>
#include <cuda/std/__algorithm/min.h>
#include <cuda/std/type_traits>

CUB_NAMESPACE_BEGIN

/*******************************************************************************
* Policy
******************************************************************************/

template <typename KeyInputIteratorT,
namespace detail::merge_sort
{
template <typename MaxPolicyT,
typename KeyInputIteratorT,
typename ValueInputIteratorT,
typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename CompareOpT,
typename PolicyHub = detail::merge_sort::policy_hub<KeyIteratorT>>
typename CompareOpT>
struct DeviceMergeSortKernelSource
{
using KeyT = cub::detail::value_t<KeyIteratorT>;
using ValueT = cub::detail::value_t<ValueIteratorT>;

CUB_DEFINE_KERNEL_GETTER(
MergeSortBlockSortKernel,
DeviceMergeSortBlockSortKernel<
MaxPolicyT,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>);

CUB_DEFINE_KERNEL_GETTER(MergeSortPartitionKernel,
DeviceMergeSortPartitionKernel<KeyIteratorT, OffsetT, CompareOpT, KeyT>);

CUB_DEFINE_KERNEL_GETTER(
MergeSortMergeKernel,
DeviceMergeSortMergeKernel<MaxPolicyT,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>);
};

} // namespace detail::merge_sort

/*******************************************************************************
* Policy
******************************************************************************/

template <
typename KeyInputIteratorT,
typename ValueInputIteratorT,
typename KeyIteratorT,
typename ValueIteratorT,
typename OffsetT,
typename CompareOpT,
typename PolicyHub = detail::merge_sort::policy_hub<KeyIteratorT>,
typename KernelSource = detail::merge_sort::DeviceMergeSortKernelSource<
typename PolicyHub::MaxPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT>,
typename KernelLauncherFactory = detail::TripleChevronFactory,
typename VSMemHelperPolicyT = detail::merge_sort::merge_sort_vsmem_helper_t<
typename PolicyHub::MaxPolicy::MergeSortPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
cub::detail::value_t<KeyIteratorT>,
cub::detail::value_t<ValueIteratorT>>>
struct DispatchMergeSort
{
using KeyT = cub::detail::value_t<KeyIteratorT>;
using ValueT = cub::detail::value_t<ValueIteratorT>;

/// Whether or not there are values to be trucked along with keys
static constexpr bool KEYS_ONLY = std::is_same<ValueT, NullType>::value;
static constexpr bool KEYS_ONLY = ::cuda::std::is_same_v<ValueT, NullType>;

// Problem state

Expand Down Expand Up @@ -106,6 +170,12 @@ struct DispatchMergeSort

int ptx_version;

KernelSource kernel_source;

KernelLauncherFactory launcher_factory;

VSMemHelperPolicyT vsmem_helper;

// Constructor
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE DispatchMergeSort(
void* d_temp_storage,
Expand All @@ -117,7 +187,10 @@ struct DispatchMergeSort
OffsetT num_items,
CompareOpT compare_op,
cudaStream_t stream,
int ptx_version)
int ptx_version,
KernelSource kernel_source = {},
KernelLauncherFactory launcher_factory = {},
VSMemHelperPolicyT vsmem_helper = {})
: d_temp_storage(d_temp_storage)
, temp_storage_bytes(temp_storage_bytes)
, d_input_keys(d_input_keys)
Expand All @@ -128,28 +201,15 @@ struct DispatchMergeSort
, compare_op(compare_op)
, stream(stream)
, ptx_version(ptx_version)
, kernel_source(kernel_source)
, launcher_factory(launcher_factory)
, vsmem_helper(vsmem_helper)
{}

// Invocation
template <typename ActivePolicyT>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke()
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE cudaError_t Invoke(ActivePolicyT policy = {})
{
using MergePolicyT = typename ActivePolicyT::MergeSortPolicy;

using merge_sort_helper_t = detail::merge_sort::merge_sort_vsmem_helper_t<
MergePolicyT,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>;

using BlockSortVSmemHelperT = detail::vsmem_helper_impl<typename merge_sort_helper_t::block_sort_agent_t>;
using MergeAgentVSmemHelperT = detail::vsmem_helper_impl<typename merge_sort_helper_t::merge_agent_t>;

cudaError error = cudaSuccess;

if (num_items == 0)
Expand All @@ -163,8 +223,9 @@ struct DispatchMergeSort

do
{
constexpr auto tile_size = merge_sort_helper_t::policy_t::ITEMS_PER_TILE;
const auto num_tiles = ::cuda::ceil_div(num_items, tile_size);
auto wrapped_policy = detail::merge_sort::MakeMergeSortPolicyWrapper(policy);
const auto tile_size = vsmem_helper.ItemsPerTile(wrapped_policy.MergeSort());
const auto num_tiles = ::cuda::ceil_div(num_items, tile_size);

const auto merge_partitions_size = static_cast<std::size_t>(1 + num_tiles) * sizeof(OffsetT);
const auto temporary_keys_storage_size = static_cast<std::size_t>(num_items * sizeof(KeyT));
Expand All @@ -174,8 +235,8 @@ struct DispatchMergeSort
* Merge sort supports large types, which can lead to excessive shared memory size requirements. In these cases,
* merge sort allocates virtual shared memory that resides in global memory.
*/
const std::size_t block_sort_smem_size = num_tiles * BlockSortVSmemHelperT::vsmem_per_block;
const std::size_t merge_smem_size = num_tiles * MergeAgentVSmemHelperT::vsmem_per_block;
const std::size_t block_sort_smem_size = num_tiles * vsmem_helper.block_sort_vsmem_per_block();
const std::size_t merge_smem_size = num_tiles * vsmem_helper.merge_vsmem_per_block();
const std::size_t virtual_shared_memory_size = (::cuda::std::max)(block_sort_smem_size, merge_smem_size);

void* allocations[4] = {nullptr, nullptr, nullptr, nullptr};
Expand Down Expand Up @@ -214,29 +275,19 @@ struct DispatchMergeSort
auto items_buffer = static_cast<ValueT*>(allocations[2]);

// Invoke DeviceMergeSortBlockSortKernel
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
static_cast<int>(num_tiles), merge_sort_helper_t::policy_t::BLOCK_THREADS, 0, stream, true)
.doit(
detail::merge_sort::DeviceMergeSortBlockSortKernel<
typename PolicyHub::MaxPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>,
ping,
d_input_keys,
d_input_items,
d_output_keys,
d_output_items,
num_items,
keys_buffer,
items_buffer,
compare_op,
cub::detail::vsmem_t{allocations[3]});
launcher_factory(
static_cast<int>(num_tiles), vsmem_helper.BlockThreads(wrapped_policy.MergeSort()), 0, stream, true)
.doit(kernel_source.MergeSortBlockSortKernel(),
ping,
d_input_keys,
d_input_items,
d_output_keys,
d_output_items,
num_items,
keys_buffer,
items_buffer,
compare_op,
cub::detail::vsmem_t{allocations[3]});

error = CubDebug(detail::DebugSyncStream(stream));
if (cudaSuccess != error)
Expand Down Expand Up @@ -273,9 +324,8 @@ struct DispatchMergeSort
const OffsetT target_merged_tiles_number = OffsetT(2) << pass;

// Partition
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
partition_grid_size, threads_per_partition_block, 0, stream, true)
.doit(detail::merge_sort::DeviceMergeSortPartitionKernel<KeyIteratorT, OffsetT, CompareOpT, KeyT>,
launcher_factory(partition_grid_size, threads_per_partition_block, 0, stream, true)
.doit(kernel_source.MergeSortPartitionKernel(),
ping,
d_output_keys,
keys_buffer,
Expand All @@ -300,29 +350,19 @@ struct DispatchMergeSort
}

// Merge
THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
static_cast<int>(num_tiles), static_cast<int>(merge_sort_helper_t::policy_t::BLOCK_THREADS), 0, stream, true)
.doit(
detail::merge_sort::DeviceMergeSortMergeKernel<
typename PolicyHub::MaxPolicy,
KeyInputIteratorT,
ValueInputIteratorT,
KeyIteratorT,
ValueIteratorT,
OffsetT,
CompareOpT,
KeyT,
ValueT>,
ping,
d_output_keys,
d_output_items,
num_items,
keys_buffer,
items_buffer,
compare_op,
merge_partitions,
target_merged_tiles_number,
cub::detail::vsmem_t{allocations[3]});
launcher_factory(
static_cast<int>(num_tiles), vsmem_helper.BlockThreads(wrapped_policy.MergeSort()), 0, stream, true)
.doit(kernel_source.MergeSortMergeKernel(),
ping,
d_output_keys,
d_output_items,
num_items,
keys_buffer,
items_buffer,
compare_op,
merge_partitions,
target_merged_tiles_number,
cub::detail::vsmem_t{allocations[3]});

error = CubDebug(detail::DebugSyncStream(stream));
if (cudaSuccess != error)
Expand All @@ -342,6 +382,7 @@ struct DispatchMergeSort
return error;
}

template <typename MaxPolicyT = typename PolicyHub::MaxPolicy>
CUB_RUNTIME_FUNCTION _CCCL_FORCEINLINE static cudaError_t Dispatch(
void* d_temp_storage,
std::size_t& temp_storage_bytes,
Expand All @@ -351,7 +392,11 @@ struct DispatchMergeSort
ValueIteratorT d_output_items,
OffsetT num_items,
CompareOpT compare_op,
cudaStream_t stream)
cudaStream_t stream,
KernelSource kernel_source = {},
KernelLauncherFactory launcher_factory = {},
MaxPolicyT max_policy = {},
VSMemHelperPolicyT vsmem_helper = {})
{
cudaError error = cudaSuccess;
do
Expand All @@ -375,10 +420,13 @@ struct DispatchMergeSort
num_items,
compare_op,
stream,
ptx_version);
ptx_version,
kernel_source,
launcher_factory,
vsmem_helper);

// Dispatch to chained policy
error = CubDebug(PolicyHub::MaxPolicy::Invoke(ptx_version, dispatch));
error = CubDebug(max_policy.Invoke(ptx_version, dispatch));
if (cudaSuccess != error)
{
break;
Expand Down
22 changes: 22 additions & 0 deletions cub/cub/device/dispatch/kernels/merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,28 @@ public:
using block_sort_agent_t =
::cuda::std::_If<uses_fallback_policy, fallback_block_sort_agent_t, default_block_sort_agent_t>;
using merge_agent_t = ::cuda::std::_If<uses_fallback_policy, fallback_merge_agent_t, default_merge_agent_t>;

_CCCL_HOST_DEVICE static constexpr ::cuda::std::size_t block_sort_vsmem_per_block()
{
return detail::vsmem_helper_impl<block_sort_agent_t>::vsmem_per_block;
}

_CCCL_HOST_DEVICE static constexpr ::cuda::std::size_t merge_vsmem_per_block()
{
return detail::vsmem_helper_impl<merge_agent_t>::vsmem_per_block;
}

template <typename PolicyT>
_CCCL_HOST_DEVICE static constexpr int BlockThreads(PolicyT /*policy*/)
{
return policy_t::BLOCK_THREADS;
}

template <typename PolicyT>
_CCCL_HOST_DEVICE static constexpr int ItemsPerTile(PolicyT /*policy*/)
{
return policy_t::ITEMS_PER_TILE;
}
};
template <typename ChainedPolicyT,
typename KeyInputIteratorT,
Expand Down
Loading

0 comments on commit 15a0116

Please sign in to comment.