diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 1ed49b8aa9cae..632b579c55afa 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -23,17 +24,23 @@ namespace vllm { -constexpr int kMaxBlocks = 64; -// note: we don't want to use atomics for signals because peer atomics are no -// supported on PCIe links +constexpr int kMaxBlocks = 36; +// Counter may overflow, but it's fine since unsigned int overflow is +// well-defined behavior. +using FlagType = uint32_t; struct Signal { - alignas(128) uint32_t start[kMaxBlocks][8]; - alignas(128) uint32_t end[kMaxBlocks][8]; + alignas(128) FlagType self_counter[kMaxBlocks][8]; + // Two sets of peer counters are needed for two syncs. The reason is that + // it's possible for peer GPU block to arrive at the second sync point while + // the current GPU block haven't passed the first sync point. Thus, peer GPU + // may write counter+1 while current GPU is busy waiting for counter. We use + // alternating counter array to avoid this possibility. + alignas(128) FlagType peer_counter[2][kMaxBlocks][8]; }; struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; -struct __align__(16) RankSignals { volatile Signal* signals[8]; }; +struct __align__(16) RankSignals { Signal* signals[8]; }; // like std::array, but aligned template @@ -123,47 +130,60 @@ DINLINE O downcast(array_t val) { } } -// This function is meant to be used as the first synchronization in the all -// reduce kernel. Thus, it doesn't need to make any visibility guarantees for -// prior memory accesses. Note: volatile writes will not be reordered against -// other volatile writes. -template -DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, - int rank) { - if (threadIdx.x < ngpus) { - // reset flag for next time - self_sg->end[blockIdx.x][threadIdx.x] = 0; - // simultaneously write to the corresponding flag of all ranks. - // Latency = 1 p2p write - sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; - // wait until we got true from all ranks - while (!self_sg->start[blockIdx.x][threadIdx.x]); - } - __syncthreads(); +static DINLINE void st_flag_release(FlagType* flag_addr, FlagType flag) { + asm volatile("st.release.sys.global.u32 [%1], %0;" ::"r"(flag), + "l"(flag_addr)); +} + +static DINLINE FlagType ld_flag_acquire(FlagType* flag_addr) { + FlagType flag; + asm volatile("ld.acquire.sys.global.u32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); + return flag; +} + +static DINLINE void st_flag_volatile(FlagType* flag_addr, FlagType flag) { + asm volatile("st.volatile.global.u32 [%1], %0;" ::"r"(flag), "l"(flag_addr)); +} + +static DINLINE FlagType ld_flag_volatile(FlagType* flag_addr) { + FlagType flag; + asm volatile("ld.volatile.global.u32 %0, [%1];" + : "=r"(flag) + : "l"(flag_addr)); + return flag; } -// This function is meant to be used as the second or the final synchronization -// barrier in the all reduce kernel. If it's the final synchronization barrier, -// we don't need to make any visibility guarantees for prior memory accesses. -template -DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, - int rank) { - __syncthreads(); - // eliminate the case that prior writes are not visible after signals become - // visible. Note that I did not managed to make this happen through a lot of - // testing. Might be the case that hardware provides stronger guarantee than - // the memory model. - if constexpr (!final_sync) __threadfence_system(); +// is_start: whether this is the very first synchronization barrier. +// need_fence: whether a memory fence is needed. If true, a release-acquire +// semantic is used to enforce memory access order before and after this +// barrier. +template +DINLINE void multi_gpu_barrier(const RankSignals& sg, Signal* self_sg, + int rank) { + if constexpr (!is_start) __syncthreads(); + static_assert( + !(is_start && need_fence)); // Start barrier shouldn't need fence. if (threadIdx.x < ngpus) { - // reset flag for next time - self_sg->start[blockIdx.x][threadIdx.x] = 0; - // simultaneously write to the corresponding flag of all ranks. - // Latency = 1 p2p write - sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; - // wait until we got true from all ranks - while (!self_sg->end[blockIdx.x][threadIdx.x]); + // Increment the counter. Technically we only need one counter, but we use + // multiple per block to eliminate the need to share the counter via smem. + auto val = self_sg->self_counter[blockIdx.x][threadIdx.x] += 1; + // Write the expected counter value to peer and wait for correct value from + // peer. + auto peer_counter_ptr = + &sg.signals[threadIdx.x]->peer_counter[val % 2][blockIdx.x][rank]; + auto self_counter_ptr = + &self_sg->peer_counter[val % 2][blockIdx.x][threadIdx.x]; + if constexpr (need_fence) { + st_flag_release(peer_counter_ptr, val); + while (ld_flag_acquire(self_counter_ptr) != val); + } else { + st_flag_volatile(peer_counter_ptr, val); + while (ld_flag_volatile(self_counter_ptr) != val); + } } - if constexpr (!final_sync) __syncthreads(); + if constexpr (is_start || need_fence) __syncthreads(); } template @@ -178,33 +198,31 @@ DINLINE P packed_reduce(const P* ptrs[], int idx) { template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_1stage(RankData* _dp, RankSignals sg, - volatile Signal* self_sg, T* __restrict__ result, - int rank, int size) { + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { using P = typename packed_t::P; using A = typename packed_t::A; // note: we don't reorder the address so the accumulation order is the same // for all ranks, ensuring bitwise identical results auto dp = *_dp; - start_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // do the actual reduction for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += gridDim.x * blockDim.x) { ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); } - end_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); } template -DINLINE P* get_tmp_buf(volatile Signal* sg) { +DINLINE P* get_tmp_buf(Signal* sg) { return (P*)(((Signal*)sg) + 1); } template __global__ void __launch_bounds__(512, 1) - cross_device_reduce_2stage(RankData* _dp, RankSignals sg, - volatile Signal* self_sg, T* __restrict__ result, - int rank, int size) { + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, Signal* self_sg, + T* __restrict__ result, int rank, int size) { int tid = blockIdx.x * blockDim.x + threadIdx.x; int stride = gridDim.x * blockDim.x; using P = typename packed_t::P; @@ -222,12 +240,12 @@ __global__ void __launch_bounds__(512, 1) tmps[i] = get_tmp_buf

(sg.signals[target]); } auto tmp_out = tmps[0]; - start_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // stage 1: reduce scatter for (int idx = start + tid; idx < end; idx += stride) { tmp_out[idx - start] = packed_reduce(ptrs, idx); } - end_sync(sg, self_sg, rank); + multi_gpu_barrier(sg, self_sg, rank); // stage 2: allgather. Note: it's important to match the tid between // the two stages, because visibility across devices is only guaranteed @@ -437,6 +455,8 @@ class CustomAllreduce { #define KL(ngpus, name) \ name<<>>(ptrs, sg_, self_sg_, output, \ rank_, size); + // TODO(hanzhi713): Threshold is different for A100 and H100. + // Add per device threshold. #define REDUCE_CASE(ngpus) \ case ngpus: { \ if (world_size_ == 2) { \ diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index f7868233076cd..c8b5d0a013f63 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -1,15 +1,15 @@ /** * This is a standalone test for custom allreduce. * To compile, make sure you have MPI and NCCL installed in your system. - * export MPI_HOME=XXX + * export MPI_HOME=xxx * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o - * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi + * custom_all_reduce_test -lnccl -I${MPI_HOME} -lmpi * * Warning: this C++ test is not designed to be very readable and was used * during the rapid prototyping process. * * To run: - * mpirun -np 8 ./custom_all_reduce_test + * mpirun --allow-run-as-root -np 8 ./custom_all_reduce_test */ #include #include @@ -302,15 +302,19 @@ int main(int argc, char** argv) { bool performance_test = true; cudaProfilerStart(); - // for (int threads : {256, 512}) { + // Uncomment to scan through different block size configs. + // for (int threads : {256, 512, 1024}) { // for (int block_limit = 16; block_limit < 112; block_limit += 4) { - // run(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); + // run(myRank, nRanks, comm, threads, block_limit, 1024 * 1024, + // performance_test); // } // } + // Scan through different sizes to test performance. for (int sz = 512; sz <= (8 << 20); sz *= 2) { run(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test); } cudaProfilerStop(); + MPICHECK(MPI_Finalize()); return EXIT_SUCCESS; }