Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable custom all-reduce for ROCm. #34

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/custom_all_reduce.cu")
endif()

if(VLLM_GPU_LANG STREQUAL "HIP")
list(APPEND VLLM_EXT_SRC
"csrc/custom_all_reduce.cu")
endif()

define_gpu_extension_target(
_C
DESTINATION vllm
Expand Down
9 changes: 8 additions & 1 deletion csrc/custom_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,20 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) && !defined(USE_ROCM)
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
break;
}
#elif defined(USE_ROCM)
case at::ScalarType::BFloat16: {
fa->allreduce<__hip_bfloat16>(
stream, reinterpret_cast<__hip_bfloat16 *>(inp.data_ptr()),
reinterpret_cast<__hip_bfloat16 *>(out.data_ptr()), out.numel());
break;
}
#endif
default:
throw std::runtime_error(
Expand Down
84 changes: 82 additions & 2 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#pragma once

#include <cuda.h>
#ifndef USE_ROCM
#include <cuda_bf16.h>
#else
#include <hip/amd_detail/amd_hip_bf16.h>
#endif
#include <cuda_fp16.h>
#include <cuda_runtime.h>

Expand Down Expand Up @@ -31,9 +35,17 @@ struct Signal {
alignas(128) uint32_t end[kMaxBlocks][8];
};

#ifndef USE_ROCM
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
#else
struct __align__(16) RankData { const void * ptrs[8]; };
#endif

#ifndef USE_ROCM
struct __align__(16) RankSignals { volatile Signal *signals[8]; };
#else
struct __align__(16) RankSignals { Signal *signals[8]; };
#endif

// like std::array, but aligned
template <typename T, int sz>
Expand Down Expand Up @@ -74,6 +86,7 @@ DINLINE half &assign_add(half &a, half b) {
}
DINLINE float &assign_add(float &a, float b) { return a += b; }

#ifndef USE_ROCM
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
template <>
Expand All @@ -85,6 +98,17 @@ DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
return a;
}
#endif
#else
DINLINE float upcast_s(__hip_bfloat16 val) { return __bfloat162float(val); }
template <>
DINLINE __hip_bfloat16 downcast_s(float val) {
return __float2bfloat16(val);
}
DINLINE __hip_bfloat16 &assign_add(__hip_bfloat16 &a, __hip_bfloat16 b) {
a = __hadd(a, b);
return a;
}
#endif

template <typename T, int N>
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
Expand Down Expand Up @@ -128,16 +152,30 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
#ifndef USE_ROCM
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
int rank) {
int rank) {
#else
DINLINE void start_sync(const RankSignals &sg, Signal *self_sg, int rank) {
#endif
if (threadIdx.x < ngpus) {
// reset flag for next time
#ifndef USE_ROCM
self_sg->end[blockIdx.x][threadIdx.x] = 0;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x])
#else
__atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0, __ATOMIC_RELAXED);
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], 1, __ATOMIC_RELAXED);
__atomic_thread_fence(__ATOMIC_ACQ_REL);
// wait until we got true from all ranks
while (!__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED))
#endif
;
}
__syncthreads();
Expand All @@ -147,13 +185,18 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false>
#ifndef USE_ROCM
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
int rank) {
#else
DINLINE void end_sync(const RankSignals &sg, Signal *self_sg, int rank) {
#endif
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
#ifndef USE_ROCM
if constexpr (!final_sync) __threadfence_system();
if (threadIdx.x < ngpus) {
// reset flag for next time
Expand All @@ -164,6 +207,18 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
// wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x])
;
#else
if (threadIdx.x < ngpus) {
// reset flag for next time
__atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0, __ATOMIC_RELAXED);
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1, __ATOMIC_RELAXED);
__atomic_thread_fence(__ATOMIC_ACQ_REL);
// wait until we got true from all ranks
while (!__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED))
;
#endif
}
if constexpr (!final_sync) __syncthreads();
}
Expand All @@ -179,10 +234,16 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
}

template <typename T, int ngpus>
#ifndef USE_ROCM
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result,
int rank, int size) {
#else
__global__ void __launch_bounds__(1024, 1)
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
Signal *self_sg, T *__restrict__ result, int rank, int size) {
#endif
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
Expand All @@ -199,15 +260,26 @@ __global__ void __launch_bounds__(512, 1)
}

template <typename P>
#ifndef USE_ROCM
DINLINE P *get_tmp_buf(volatile Signal *sg) {
#else
DINLINE P *get_tmp_buf(Signal *sg) {
#endif
return (P *)(((Signal *)sg) + 1);
}

template <typename T, int ngpus>
#ifndef USE_ROCM
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result,
int rank, int size) {
#else
__global__ void __launch_bounds__(1024, 1)
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
Signal *self_sg, T *__restrict__ result,
int rank, int size) {
#endif
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
Expand Down Expand Up @@ -327,8 +399,12 @@ class CustomAllreduce {
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr,
#ifndef USE_ROCM
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
(CUdeviceptr)ptr) != CUDA_SUCCESS)
#else
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#endif
(CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr");
CUDACHECK(cudaIpcGetMemHandle(
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
Expand Down Expand Up @@ -406,7 +482,11 @@ class CustomAllreduce {
*/
template <typename T>
void allreduce(cudaStream_t stream, T *input, T *output, int size,
#ifndef USE_ROCM
int threads = 512, int block_limit = 36) {
#else
int threads = 1024, int block_limit = 36) {
#endif
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
Expand Down
30 changes: 30 additions & 0 deletions csrc/custom_all_reduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
* export MPI_HOME=XXX
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
* to hipify and compile
* export MPI_HOME=XXX
* hipify-perl custom_all_reduce_test.cu > custom_all_reduce_test.hip
* hipcc -O2 -std=c++17 custom_all_reduce_test.hip -o custom_all_reduce_test -lrccl -I${MPI_HOME}/include -L${MPI_HOME}/lib -lmpi -DUSE_ROCM=1
*
* Warning: this C++ test is not designed to be very readable and was used
* during the rapid prototyping process.
Expand All @@ -12,17 +16,29 @@
* mpirun -np 8 ./custom_all_reduce_test
*/
#include <cuda.h>
#ifndef USE_ROCM
#include <curand_kernel.h>
#else
#include <hiprand/hiprand_kernel.h>
#endif
#include <stdio.h>
#include <stdlib.h>

#include <limits>
#include <vector>

#include "cuda_profiler_api.h"
#ifndef USE_ROCM
#include "custom_all_reduce.cuh"
#else
#include "custom_all_reduce_hip.cuh"
#endif
#include "mpi.h"
#ifndef USE_ROCM
#include "nccl.h"
#else
#include <rccl/rccl.h>
#endif

#define MPICHECK(cmd) \
do { \
Expand All @@ -44,7 +60,12 @@
} while (0)

__global__ void dummy_kernel() {
#ifndef USE_ROCM
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
#else
#pragma unroll
for (int i = 0; i < 100; i++) __builtin_amdgcn_s_sleep(127);
#endif
}

template <typename T>
Expand Down Expand Up @@ -164,7 +185,11 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
ncclDataType_t ncclDtype;
if (std::is_same<T, half>::value) {
ncclDtype = ncclFloat16;
#ifndef USE_ROCM
} else if (std::is_same<T, nv_bfloat16>::value) {
#else
} else if (std::is_same<T, __hip_bfloat16>::value) {
#endif
ncclDtype = ncclBfloat16;
} else {
ncclDtype = ncclFloat;
Expand Down Expand Up @@ -308,9 +333,14 @@ int main(int argc, char **argv) {
// }
// }
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
#ifndef USE_ROCM
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
#else
run<half>(myRank, nRanks, comm, 1024, 16, sz + 8 * 47, performance_test);
#endif
}

cudaProfilerStop();
MPICHECK(MPI_Finalize());
return EXIT_SUCCESS;
}
4 changes: 1 addition & 3 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ void moe_align_block_size(
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);

#ifndef USE_ROCM
using fptr_t = uint64_t;
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
const std::vector<std::string> &handles,
Expand All @@ -158,7 +157,6 @@ void register_buffer(fptr_t _fa, torch::Tensor &t,
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
const std::vector<std::vector<int64_t>> &offsets);
#endif

void convert_fp8(
torch::Tensor& src_cache,
Expand All @@ -180,4 +178,4 @@ torch::Tensor fp8_gemm_16(
torch::Tensor& scaleA,
torch::Tensor& scaleB,
int algo_idx
);
);
2 changes: 0 additions & 2 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute.");

#ifndef USE_ROCM
// Custom all-reduce kernels
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
Expand All @@ -126,6 +125,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"get_graph_buffer_ipc_meta");
custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers");
#endif

}
5 changes: 2 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,9 @@ def _verify_args(self) -> None:
"Pipeline parallelism is not supported yet.")
if not self.disable_custom_all_reduce and self.world_size > 1:
if is_hip():
self.disable_custom_all_reduce = True
self.disable_custom_all_reduce = False
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
"Enable the custom all-reduce kernel on AMD GPUs.")
elif self.pipeline_parallel_size > 1:
self.disable_custom_all_reduce = True
logger.info(
Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/parallel_utils/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)

try:
import pynvml

from vllm._C import custom_ar
except ImportError:
# For AMD GPUs
custom_ar = None

try:
import pynvml
except ImportError:
# For AMD GPUs
pynvml = None

logger = init_logger(__name__)
Expand Down
Loading