From 612f47572125ced9f134be764b3f722f47c1e254 Mon Sep 17 00:00:00 2001 From: Bernhard Manfred Gruber Date: Tue, 4 Feb 2025 19:02:07 +0100 Subject: [PATCH] Add b200 tunings for radix_sort.keys (#3611) (#3655) --- .../dispatch/tuning/tuning_radix_sort.cuh | 127 +++++++++++++++++- 1 file changed, 125 insertions(+), 2 deletions(-) diff --git a/cub/cub/device/dispatch/tuning/tuning_radix_sort.cuh b/cub/cub/device/dispatch/tuning/tuning_radix_sort.cuh index 6080622274d..debf41db4a1 100644 --- a/cub/cub/device/dispatch/tuning/tuning_radix_sort.cuh +++ b/cub/cub/device/dispatch/tuning/tuning_radix_sort.cuh @@ -50,7 +50,7 @@ namespace detail { namespace radix { -// default +// sm90 default template struct sm90_small_key_tuning { @@ -1069,7 +1069,130 @@ struct policy_hub SEGMENTED_RADIX_BITS - 1>; }; - using MaxPolicy = Policy900; + // todo(@gonidelis): refactor this as to not duplicate SM90. + struct Policy1000 : ChainedPolicy<1000, Policy1000, Policy900> + { + static constexpr bool ONESWEEP = true; + static constexpr int ONESWEEP_RADIX_BITS = 8; + + using HistogramPolicy = AgentRadixSortHistogramPolicy<128, 16, 1, KeyT, ONESWEEP_RADIX_BITS>; + using ExclusiveSumPolicy = AgentRadixSortExclusiveSumPolicy<256, ONESWEEP_RADIX_BITS>; + + private: + static constexpr int PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 7 : 5; + static constexpr int SINGLE_TILE_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5; + static constexpr int SEGMENTED_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5; + static constexpr int OFFSET_64BIT = sizeof(OffsetT) == 8 ? 1 : 0; + static constexpr int FLOAT_KEYS = ::cuda::std::is_same::value ? 1 : 0; + + using OnesweepPolicyKey32 = AgentRadixSortOnesweepPolicy< + 384, + KEYS_ONLY ? 20 - OFFSET_64BIT - FLOAT_KEYS + : (sizeof(ValueT) < 8 ? (OFFSET_64BIT ? 17 : 23) : (OFFSET_64BIT ? 29 : 30)), + DominantT, + 1, + RADIX_RANK_MATCH_EARLY_COUNTS_ANY, + BLOCK_SCAN_RAKING_MEMOIZE, + RADIX_SORT_STORE_DIRECT, + ONESWEEP_RADIX_BITS>; + + using OnesweepPolicyKey64 = AgentRadixSortOnesweepPolicy< + 384, + sizeof(ValueT) < 8 ? 30 : 24, + DominantT, + 1, + RADIX_RANK_MATCH_EARLY_COUNTS_ANY, + BLOCK_SCAN_RAKING_MEMOIZE, + RADIX_SORT_STORE_DIRECT, + ONESWEEP_RADIX_BITS>; + + using OnesweepLargeKeyPolicy = ::cuda::std::_If; + + using OnesweepSmallKeyPolicySizes = + sm100_small_key_tuning; + + using OnesweepSmallKeyPolicy = AgentRadixSortOnesweepPolicy< + OnesweepSmallKeyPolicySizes::threads, + OnesweepSmallKeyPolicySizes::items, + DominantT, + 1, + RADIX_RANK_MATCH_EARLY_COUNTS_ANY, + BLOCK_SCAN_RAKING_MEMOIZE, + RADIX_SORT_STORE_DIRECT, + 8>; + + public: + using OnesweepPolicy = ::cuda::std::_If; + + // The Scan, Downsweep and Upsweep policies are never run on SM90, but we have to include them to prevent a + // compilation error: When we compile e.g. for SM70 **and** SM90, the host compiler will reach calls to those + // kernels, and instantiate them for MaxPolicy (which is Policy900) on the host, which will reach into the policies + // below to set the launch bounds. The device compiler pass will also compile all kernels for SM70 **and** SM90, + // even though only the Onesweep kernel is used on SM90. + using ScanPolicy = + AgentScanPolicy<512, + 23, + OffsetT, + BLOCK_LOAD_WARP_TRANSPOSE, + LOAD_DEFAULT, + BLOCK_STORE_WARP_TRANSPOSE, + BLOCK_SCAN_RAKING_MEMOIZE>; + + using DownsweepPolicy = AgentRadixSortDownsweepPolicy< + 512, + 23, + DominantT, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MATCH, + BLOCK_SCAN_WARP_SCANS, + PRIMARY_RADIX_BITS>; + + using AltDownsweepPolicy = AgentRadixSortDownsweepPolicy< + (sizeof(KeyT) > 1) ? 256 : 128, + 47, + DominantT, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS, + PRIMARY_RADIX_BITS - 1>; + + using UpsweepPolicy = AgentRadixSortUpsweepPolicy<256, 23, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS>; + using AltUpsweepPolicy = AgentRadixSortUpsweepPolicy<256, 47, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS - 1>; + + using SingleTilePolicy = AgentRadixSortDownsweepPolicy< + 256, + 19, + DominantT, + BLOCK_LOAD_DIRECT, + LOAD_LDG, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS, + SINGLE_TILE_RADIX_BITS>; + + using SegmentedPolicy = AgentRadixSortDownsweepPolicy< + 192, + 39, + DominantT, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS, + SEGMENTED_RADIX_BITS>; + + using AltSegmentedPolicy = AgentRadixSortDownsweepPolicy< + 384, + 11, + DominantT, + BLOCK_LOAD_TRANSPOSE, + LOAD_DEFAULT, + RADIX_RANK_MEMOIZE, + BLOCK_SCAN_WARP_SCANS, + SEGMENTED_RADIX_BITS - 1>; + }; + + using MaxPolicy = Policy1000; }; } // namespace radix