diff --git a/.gitignore b/.gitignore index b5195629e5cf3..b1513ef0ddb0c 100644 --- a/.gitignore +++ b/.gitignore @@ -181,6 +181,7 @@ _build/ # hip files generated by PyTorch *.hip *_hip* +hip_compat.h # Benchmark dataset *.json diff --git a/Dockerfile.rocm b/Dockerfile.rocm index a45265d79a6ac..32ab7cf96fc59 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -3,8 +3,6 @@ ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" FROM $BASE_IMAGE -ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" - RUN echo "Base image is $BASE_IMAGE" # BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" @@ -26,22 +24,12 @@ ARG BUILD_FA="1" # whether to build cupy on rocm ARG BUILD_CUPY="1" -# Install some basic utilities -RUN apt-get update && apt-get install python3 python3-pip -y +# whether to build triton on rocm +ARG BUILD_TRITON="1" # Install some basic utilities RUN apt-get update && apt-get install -y \ - curl \ - ca-certificates \ - sudo \ - git \ - bzip2 \ - libx11-6 \ - build-essential \ - wget \ - unzip \ - nvidia-cuda-toolkit \ - tmux \ + sqlite3 libsqlite3-dev libfmt-dev \ && rm -rf /var/lib/apt/lists/* ### Mount Point ### @@ -95,6 +83,17 @@ RUN if [ "$BUILD_CUPY" = "1" ]; then \ && cd ..; \ fi +# build triton +RUN if [ "$BUILD_TRITON" = "1" ]; then \ + mkdir -p libs \ + && cd libs \ + && pip uninstall -y triton \ + && git clone https://github.com/ROCm/triton.git \ + && cd triton/python \ + && pip3 install . \ + && cd ../..; \ + fi + COPY ./ /app/vllm RUN python3 -m pip install --upgrade pip @@ -104,12 +103,13 @@ RUN cd /app \ && cd vllm \ && pip install -U -r requirements-rocm.txt \ && if [ "$BUILD_FA" = "1" ]; then \ - bash patch_xformers.rocm.sh; fi \ - && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \ + bash patch_xformers.rocm.sh; fi \ + && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ + patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch; fi \ && python3 setup.py install \ && cd .. RUN python3 -m pip install --upgrade pip -RUN python3 -m pip install --no-cache-dir ray[all] +RUN python3 -m pip install --no-cache-dir ray[all]==2.9.3 CMD ["/bin/bash"] diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 2fdc08c5c26df..4563d4c45def2 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -16,18 +16,17 @@ def main(args: argparse.Namespace): # NOTE(woosuk): If the request cannot be processed in a single batch, # the engine will automatically process the request in multiple batches. - llm = LLM( - model=args.model, - tokenizer=args.tokenizer, - quantization=args.quantization, - tensor_parallel_size=args.tensor_parallel_size, - trust_remote_code=args.trust_remote_code, - dtype=args.dtype, - enforce_eager=args.enforce_eager, - kv_cache_dtype=args.kv_cache_dtype, - device=args.device, - ray_workers_use_nsight=args.ray_workers_use_nsight, - ) + llm = LLM(model=args.model, + tokenizer=args.tokenizer, + quantization=args.quantization, + tensor_parallel_size=args.tensor_parallel_size, + trust_remote_code=args.trust_remote_code, + dtype=args.dtype, + enforce_eager=args.enforce_eager, + kv_cache_dtype=args.kv_cache_dtype, + device=args.device, + ray_workers_use_nsight=args.ray_workers_use_nsight, + worker_use_ray=args.worker_use_ray) sampling_params = SamplingParams( n=args.n, @@ -151,5 +150,10 @@ def run_to_completion(profile_dir: Optional[str] = None): action='store_true', help="If specified, use nsight to profile ray workers", ) + parser.add_argument('--worker-use-ray', + action='store_true', + help='use Ray for distributed serving, will be ' + 'automatically set when using more than 1 GPU ' + 'unless on ROCm where the default is torchrun') args = parser.parse_args() main(args) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 72bdc4b3b4540..8cd05c4863a28 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -75,21 +75,25 @@ def run_vllm( device: str, enable_prefix_caching: bool, gpu_memory_utilization: float = 0.9, + worker_use_ray: bool = False, ) -> float: from vllm import LLM, SamplingParams - llm = LLM(model=model, - tokenizer=tokenizer, - quantization=quantization, - tensor_parallel_size=tensor_parallel_size, - seed=seed, - trust_remote_code=trust_remote_code, - dtype=dtype, - max_model_len=max_model_len, - gpu_memory_utilization=gpu_memory_utilization, - enforce_eager=enforce_eager, - kv_cache_dtype=kv_cache_dtype, - device=device, - enable_prefix_caching=enable_prefix_caching) + llm = LLM( + model=model, + tokenizer=tokenizer, + quantization=quantization, + tensor_parallel_size=tensor_parallel_size, + seed=seed, + trust_remote_code=trust_remote_code, + dtype=dtype, + max_model_len=max_model_len, + gpu_memory_utilization=gpu_memory_utilization, + enforce_eager=enforce_eager, + kv_cache_dtype=kv_cache_dtype, + device=device, + enable_prefix_caching=enable_prefix_caching, + worker_use_ray=worker_use_ray, + ) # Add the requests to the engine. for prompt, _, output_len in requests: @@ -213,7 +217,8 @@ def main(args: argparse.Namespace): args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, args.trust_remote_code, args.dtype, args.max_model_len, args.enforce_eager, args.kv_cache_dtype, args.device, - args.enable_prefix_caching, args.gpu_memory_utilization) + args.enable_prefix_caching, args.gpu_memory_utilization, + args.worker_use_ray) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -314,6 +319,11 @@ def main(args: argparse.Namespace): "--enable-prefix-caching", action='store_true', help="enable automatic prefix caching for vLLM backend.") + parser.add_argument('--worker-use-ray', + action='store_true', + help='use Ray for distributed serving, will be ' + 'automatically set when using more than 1 GPU ' + 'unless on ROCm where the default is torchrun') args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/csrc/attention/attention_kernels.cu b/csrc/attention/attention_kernels.cu index 5e61668d5cc1a..1b92bf6791d9f 100644 --- a/csrc/attention/attention_kernels.cu +++ b/csrc/attention/attention_kernels.cu @@ -602,7 +602,11 @@ template< typename CACHE_T, int BLOCK_SIZE, bool IS_FP8_E5M2_KV_CACHE, +#ifdef USE_ROCM + int NUM_THREADS = 1024> +#else int NUM_THREADS = 128> +#endif void paged_attention_v1_launcher( torch::Tensor& out, torch::Tensor& query, @@ -779,8 +783,13 @@ template< typename CACHE_T, int BLOCK_SIZE, bool IS_FP8_E5M2_KV_CACHE, +#ifdef USE_ROCM + int NUM_THREADS = 1024, + int PARTITION_SIZE = 1024> +#else int NUM_THREADS = 128, int PARTITION_SIZE = 512> +#endif void paged_attention_v2_launcher( torch::Tensor& out, torch::Tensor& exp_sums, diff --git a/csrc/custom/custom.cpp b/csrc/custom/custom.cpp new file mode 100644 index 0000000000000..aeff9cc5e6ae7 --- /dev/null +++ b/csrc/custom/custom.cpp @@ -0,0 +1,74 @@ +#include +#include +#include +#include + +namespace py = pybind11; + +// declare templates for front (cpp) and back (cuda) sides of function: +//template + +void LLGemm_Silu(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block); +void LLMM_Silu(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int rows_per_block) { + int M = in_a.size(0); + int K = in_a.size(1); + LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(), + out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),rows_per_block); +} + +void LLGemm1(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream,const int rows_per_block); + +//template +void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int rows_per_block=4) { + int M = in_a.size(0); + int K = in_a.size(1); + //if (N != in_b.numel()) + // throw std::invalid_argument("Size mismatch A.numel(): " + std::to_string(in_a.numel()) + // + ", B.numel(): " + std::to_string(in_b.numel())); + + //out_c.resize_({N}); + + // call the kernel function... + LLGemm1(in_a.data_ptr(), in_b.data_ptr(), + out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),rows_per_block); +} + +void LLGemmZZ(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int solidx); + +void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int solidx=0) { + int M = in_a.size(0); + int K = in_a.size(1); + + LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(), + out_c.data_ptr(), M, K, at::cuda::getCurrentCUDAStream(),solidx); +} +// instantiate the CPP template for T=float: +//template void AddGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c); + + +void MMGPUKernel(float *in_a, float *in_b, float *out_c, + int numARows, int numAColumns, + int numBRows, int numBColumns, + int numCRows, int numCColumns, + cudaStream_t stream); + + +void MMCustomGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c) { + auto matA_sizes { in_a.sizes() }; + auto matB_sizes { in_b.sizes() }; + auto matO_sizes { out_c.sizes() }; + MMGPUKernel(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), + matA_sizes[0], matA_sizes[1], + matB_sizes[0], matB_sizes[1], + matO_sizes[0], matO_sizes[1], + at::cuda::getCurrentCUDAStream()); +} + +// declare the extension module with the AddGPU function: +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){ + m.doc() = "pybind11 example plugin"; + m.def("LLMM1", &LLMM1); + m.def("LLMM_Silu", &LLMM_Silu); + m.def("LLZZ", &LLZZ); +//m.def("MMCustomGPU", &MMCustomGPU); +} diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu new file mode 100644 index 0000000000000..b5ab0dbe8317c --- /dev/null +++ b/csrc/custom/custom_kernels.cu @@ -0,0 +1,367 @@ +#include +#include +#include +#include + +constexpr int WARP_SIZE = 64; + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + //auto dat0 = *(addr_alias); + //auto dat1 = *(addr_alias+1); + //auto dat2 = *(addr_alias+2); + //auto dat3 = *(addr_alias+3); + return make_float4(dat0,dat1,dat2,dat3); +} + +//TBlock fetches entire rows of A, and entire col of B (K dimension); assume N=1 for time being +//grid is M/A_NUM_ROWS blocks +template +__global__ void LLGemm1_kernel(float4 *af4, __half2 *bf4, __half2 *c) { + __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; + const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * blockDim.x; + //int row_addr_1 = row_addr + CUDA_NUM_THREADS; + //int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS; + //int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS; + const int threadid = threadIdx.x; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int num_warps = blockDim.x / WARP_SIZE; + const int qwarpid = threadid/16; + const int qthreadid = threadid%16; + float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; + //float4 colB_elem4; + __half2 colB_elem4x,colB_elem4y,colB_elem4z,colB_elem4w; + float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; + float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0; + __half2 acch2; + __half2 oval; + + //rowA_elem4 = af4[row_addr + threadid]; + //__syncthreads(); + //rowA_elem4_1 = af4[row_addr_1 + threadid]; + //rowA_elem4_2 = af4[row_addr_2 + threadid]; + //rowA_elem4_3 = af4[row_addr_3 + threadid]; + #pragma unroll + for (int i=0; i(&colB_elem4); + //auto Bf2x = *Bh2ptr; + //auto Bf2y = *(Bh2ptr+1); + //auto Bf2z = *(Bh2ptr+2); + //auto Bf2w = *(Bh2ptr+3); + auto Ah2ptr = reinterpret_cast<__half2 *>(&rowA_elem4); + __half2 *ah2lptr; + #pragma unroll + for (int i=0; i= 1; mask /= 2) { + #pragma unroll + for (int i=0; i= 1; mask /= 2) { + //#pragma unroll + //for (int i=0; i8) { + // #pragma unroll + // for (int j=0; j<8; j++) { + // acc[2*threadid] += red_smem[2*threadid][j]; + // acc[2*threadid+1] += red_smem[2*threadid+1][j]; + // } + // } + // #pragma unroll + // for (int j=0; j +void LLGemm1(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block=4) { + float4 *af4 = reinterpret_cast(in_a); + auto *bf4 = reinterpret_cast<__half2*>(in_b); + auto *c = reinterpret_cast<__half2*>(out_c); + //constexpr int A_ROWS_PER_BLOCK = 8; + const int NUM_THREADS = K*2/16; + int NUM_BLOCKS = M/rows_per_block; + if (rows_per_block==2) { + LLGemm1_kernel<2><<>>(af4, bf4, c); + } + else if (rows_per_block==4) { + LLGemm1_kernel<4><<>>(af4, bf4, c); + } + else if (rows_per_block==8) { + LLGemm1_kernel<8><<>>(af4, bf4, c); + } + else if (rows_per_block==16) { + LLGemm1_kernel<16><<>>(af4, bf4, c); + } + else { + NUM_BLOCKS = M/4; + LLGemm1_kernel<4><<>>(af4, bf4, c); + } + + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} + +// instantiate the kernel template for T=float: +//template void AddGPUKernel(float *in_a, float *in_b, float *out_c, const int M, const int K, cudaStream_t stream); + +const unsigned int TILE_WIDTH = 32; + +// Compute C = A * B +__global__ void matrixMultiplyShared(float *A, float *B, float *C, + int numARows, int numAColumns, + int numBRows, int numBColumns, + int numCRows, int numCColumns) { + __shared__ float sA[TILE_WIDTH][TILE_WIDTH]; // Tile size of 32x32 + __shared__ float sB[TILE_WIDTH][TILE_WIDTH]; + + int Row = blockDim.y * blockIdx.y + threadIdx.y; + int Col = blockDim.x * blockIdx.x + threadIdx.x; + float Cvalue = 0.0; + sA[threadIdx.y][threadIdx.x] = 0.0; + sB[threadIdx.y][threadIdx.x] = 0.0; + + for (int ph = 0; ph < (((numAColumns - 1) / TILE_WIDTH) + 1); ph++) { + if ((Row < numARows) && (threadIdx.x + (ph * TILE_WIDTH)) < numAColumns) { + sA[threadIdx.y][threadIdx.x] = A[(Row * numAColumns) + threadIdx.x + (ph * TILE_WIDTH)]; + } else { + sA[threadIdx.y][threadIdx.x] = 0.0; + } + if (Col < numBColumns && (threadIdx.y + ph * TILE_WIDTH) < numBRows) { + sB[threadIdx.y][threadIdx.x] = B[(threadIdx.y + ph * TILE_WIDTH) * numBColumns + Col]; + } else { + sB[threadIdx.y][threadIdx.x] = 0.0; + } + __syncthreads(); + for (int j = 0; j < TILE_WIDTH; ++j) { + Cvalue += sA[threadIdx.y][j] * sB[j][threadIdx.x]; + } + } + if (Row < numCRows && Col < numCColumns) { + C[Row * numCColumns + Col] = Cvalue; + } +} + + +void MMGPUKernel(float *in_a, float *in_b, float *out_c, + int numARows, int numAColumns, + int numBRows, int numBColumns, + int numCRows, int numCColumns, + cudaStream_t stream) { + + // Initialize the grid and block dimensions + dim3 dimBlock(TILE_WIDTH, TILE_WIDTH, 1); + dim3 dimGrid((numCColumns / TILE_WIDTH) + 1, (numCRows / TILE_WIDTH) + 1, 1); + //@@ Launch the GPU Kernel here + matrixMultiplyShared <<>> + (in_a, in_b, out_c, numARows, numAColumns, numBRows, numBColumns, numCRows, numCColumns); + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} + + + +template +__global__ +__launch_bounds__(512) +void HGEMV_WFPerRow(int m, int n, const _Float16 *A, int lda, const _Float16 *x, _Float16 *y) +{ + int num_row_per_block = CTA / nThreads_per_row; + int row_id = (blockIdx.x*num_row_per_block+threadIdx.y)*MT0; + int inc = (gridDim.x * num_row_per_block)*MT0; + + while (row_id < m) { + float2 sum2[MT0]; + +#pragma unroll + for (int i = 0; i < MT0; ++i) + { + sum2[i] = {0.0,0.0}; + } + + for (int j = threadIdx.x; j < n; j += (nThreads_per_row*MT1)){ + bool is_active = j < n; + if (is_active) { + float2 x2[MT1>>1]; +#pragma unroll + for(int offset = 0; offset < MT1; offset += 2) + { + x2[offset>>1] = {x[j+nThreads_per_row*offset], x[j+nThreads_per_row*(offset+1)]}; + } + float2 a2[MT0][MT1>>1]; +#pragma unroll + for (int i = 0; i < MT0; i++) + { +#pragma unroll + for (int offset = 0; offset < MT1; offset += 2) + { + a2[i][offset>>1] = {A[(row_id+i)*n+j+nThreads_per_row*offset], A[(row_id+i)*n+j+nThreads_per_row*(offset+1)]}; + } + } + +#pragma unroll + for (int i = 0; i < MT0; i++) + { +#pragma unroll + for (int offset = 0; offset < (MT1>>1); offset++) + { + sum2[i] += a2[i][offset]*x2[offset]; + } + } + + } + } + float sum[MT0]; +#pragma unroll + for (int i = 0; i < MT0; i++) + { + sum[i] = sum2[i].x+sum2[i].y; + } + +#pragma unroll + for (int i = 0; i < MT0; i++) + { +#pragma unroll + for (int offset = nThreads_per_row >> 1; offset >= 1; offset = offset >> 1) { + sum[i] += __shfl_down(sum[i], offset, nThreads_per_row); + } + } + if (threadIdx.x == 0) + { +#pragma unroll + for (int i = 0; i < MT0; i++) + { + y[row_id+i] = sum[i]; + } + } + row_id += inc; + } +} + +void LLGemmZZ(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int solidx=0) { + //m -> M, n-> K + dim3 grid(1024); + dim3 block(64, 8); + if (solidx==0) { + HGEMV_WFPerRow<64, 512, 4, 8><<>>(M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b),reinterpret_cast<_Float16*>(out_c)); + } + else if (solidx==1) { + HGEMV_WFPerRow<64, 512, 2, 8><<>>(M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b),reinterpret_cast<_Float16*>(out_c)); + } + else if (solidx==2) { + HGEMV_WFPerRow<64, 512, 1, 8><<>>(M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b),reinterpret_cast<_Float16*>(out_c)); + } + else { + HGEMV_WFPerRow<64, 512, 4, 8><<>>(M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b),reinterpret_cast<_Float16*>(out_c)); + } + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} diff --git a/csrc/custom/fused_kernels.cu b/csrc/custom/fused_kernels.cu new file mode 100644 index 0000000000000..5a4a11f914eb9 --- /dev/null +++ b/csrc/custom/fused_kernels.cu @@ -0,0 +1,192 @@ +#include +#include +#include +#include + +constexpr int WARP_SIZE = 64; + +template +__device__ __forceinline__ T silu(const T& x) { + // x * sigmoid(x) + return (T) (((float) x) / (1.0f + expf((float) -x))); +} + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + //auto dat0 = *(addr_alias); + //auto dat1 = *(addr_alias+1); + //auto dat2 = *(addr_alias+2); + //auto dat3 = *(addr_alias+3); + return make_float4(dat0,dat1,dat2,dat3); +} + +//TBlock fetches entire rows of A, and entire col of B (K dimension); assume N=1 for time being +//grid is M/A_NUM_ROWS blocks +template +__global__ void LLGemm_Silu_kernel(float4 *af4, __half2 *bf4, _Float16 *c, const int d) { + __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; + const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK/2 * blockDim.x; + const int row_addr_d = row_addr + d * blockDim.x; + //int row_addr_1 = row_addr + CUDA_NUM_THREADS; + //int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS; + //int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS; + const int threadid = threadIdx.x; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int num_warps = blockDim.x / WARP_SIZE; + const int qwarpid = threadid/16; + const int qthreadid = threadid%16; + float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; + //float4 colB_elem4; + __half2 colB_elem4x,colB_elem4y,colB_elem4z,colB_elem4w; + float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; + float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0; + __half2 acch2; + __half2 oval; + + //rowA_elem4 = af4[row_addr + threadid]; + //__syncthreads(); + //rowA_elem4_1 = af4[row_addr_1 + threadid]; + //rowA_elem4_2 = af4[row_addr_2 + threadid]; + //rowA_elem4_3 = af4[row_addr_3 + threadid]; + #pragma unroll + for (int i=0; i(&colB_elem4); + //auto Bf2x = *Bh2ptr; + //auto Bf2y = *(Bh2ptr+1); + //auto Bf2z = *(Bh2ptr+2); + //auto Bf2w = *(Bh2ptr+3); + auto Ah2ptr = reinterpret_cast<__half2 *>(&rowA_elem4); + __half2 *ah2lptr; + #pragma unroll + for (int i=0; i= 1; mask /= 2) { + #pragma unroll + for (int i=0; i= 1; mask /= 2) { + //#pragma unroll + //for (int i=0; i +void LLGemm_Silu(void *in_a, void *in_b, void *out_c, const int M, const int K, cudaStream_t stream, const int rows_per_block=4) { + float4 *af4 = reinterpret_cast(in_a); + auto *bf4 = reinterpret_cast<__half2*>(in_b); + auto *c = reinterpret_cast<_Float16*>(out_c); + const int d = M/2; + const int NUM_THREADS = K*2/16; + int NUM_BLOCKS = M/rows_per_block; + if (rows_per_block==2) { + LLGemm_Silu_kernel<2><<>>(af4, bf4, c, d); + } + else if (rows_per_block==4) { + LLGemm_Silu_kernel<4><<>>(af4, bf4, c, d); + } + else if (rows_per_block==8) { + LLGemm_Silu_kernel<8><<>>(af4, bf4, c, d); + } + else if (rows_per_block==16) { + LLGemm_Silu_kernel<16><<>>(af4, bf4, c, d); + } + else { + NUM_BLOCKS = M/4; + LLGemm_Silu_kernel<4><<>>(af4, bf4, c, d); + } + + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} + diff --git a/requirements-rocm.txt b/requirements-rocm.txt index 53bd11de7c9de..c7ef913460816 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -2,7 +2,7 @@ ninja # For faster builds. typing-extensions>=4.8.0 starlette psutil -ray >= 2.9 +ray == 2.9.3 sentencepiece # Required for LLaMA tokenizer. numpy tokenizers>=0.15.0 diff --git a/setup.py b/setup.py index accf6bb400310..7f5e82f85b35c 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,9 @@ # Supported NVIDIA GPU architectures. NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"} -ROCM_SUPPORTED_ARCHS = {"gfx908", "gfx90a", "gfx942", "gfx1100"} +ROCM_SUPPORTED_ARCHS = { + "gfx90a", "gfx908", "gfx906", "gfx942", "gfx1030", "gfx1100" +} # SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS) @@ -369,6 +371,19 @@ def get_torch_arch_list() -> Set[str]: ) ext_modules.append(vllm_extension) +custom_extension = CUDAExtension( + name="vllm.custom_ops", + sources=[ + "csrc/custom/custom.cpp", "csrc/custom/custom_kernels.cu", + "csrc/custom/fused_kernels.cu" + ], + extra_compile_args={ + "cxx": CXX_FLAGS, + "nvcc": NVCC_FLAGS + }, +) +ext_modules.append(custom_extension) + def get_path(*filepath) -> str: return os.path.join(ROOT_DIR, *filepath) diff --git a/vllm/config.py b/vllm/config.py index de687395a0001..444a979d20305 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -428,6 +428,7 @@ def __init__( else: self.tensor_parallel_size = tensor_parallel_size self.worker_use_ray = worker_use_ray + self.worker_use_torchrun = False self.max_parallel_loading_workers = max_parallel_loading_workers self.disable_custom_all_reduce = disable_custom_all_reduce self.ray_workers_use_nsight = ray_workers_use_nsight @@ -436,7 +437,16 @@ def __init__( self.world_size = pipeline_parallel_size * self.tensor_parallel_size # Ray worker is not supported for Neuron backend. if self.world_size > 1 and not is_neuron(): - self.worker_use_ray = True + if is_hip() and not self.worker_use_ray: + logger.info("Using torchrun for multi-GPU on " + "ROCM platform. Use --worker-use-ray " + "to override") + if not os.environ.get("RANK"): + raise RuntimeError("Needs to be run in torchrun: " + "torchrun --standalone --nproc_per_node= ...") + self.worker_use_torchrun = True + else: + self.worker_use_ray = True self._verify_args() def _verify_args(self) -> None: diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index c3dccdd5bb50b..d9c56e4daaeee 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -152,7 +152,8 @@ def add_cli_args( parser.add_argument('--worker-use-ray', action='store_true', help='use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU') + 'automatically set when using more than 1 GPU ' + 'unless on ROCm where the default is torchrun') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, @@ -314,12 +315,10 @@ def create_engine_configs( self.swap_space, self.kv_cache_dtype, model_config.get_sliding_window(), self.enable_prefix_caching) - parallel_config = ParallelConfig(self.pipeline_parallel_size, - self.tensor_parallel_size, - self.worker_use_ray, - self.max_parallel_loading_workers, - self.disable_custom_all_reduce, - self.ray_workers_use_nsight) + parallel_config = ParallelConfig( + self.pipeline_parallel_size, self.tensor_parallel_size, + self.worker_use_ray, self.max_parallel_loading_workers, + self.disable_custom_all_reduce, self.ray_workers_use_nsight) scheduler_config = SchedulerConfig(self.max_num_batched_tokens, self.max_num_seqs, model_config.max_model_len, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 4cdad4180aa14..e8272968d190d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -126,6 +126,9 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": initialize_ray_cluster(parallel_config) from vllm.executor.ray_gpu_executor import RayGPUExecutor executor_class = RayGPUExecutor + elif parallel_config.worker_use_torchrun: + from vllm.executor.torchrun_gpu_executor import TorchrunGPUExecutor + executor_class = TorchrunGPUExecutor else: assert parallel_config.world_size == 1, ( "Ray is required if parallel_config.world_size > 1.") diff --git a/vllm/executor/torchrun_gpu_executor.py b/vllm/executor/torchrun_gpu_executor.py new file mode 100644 index 0000000000000..b30d0cb0b9246 --- /dev/null +++ b/vllm/executor/torchrun_gpu_executor.py @@ -0,0 +1,104 @@ +import os +from typing import Dict, List, Optional + +from vllm.executor.gpu_executor import GPUExecutor +from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, + ParallelConfig, SchedulerConfig, LoRAConfig) +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.logger import init_logger +from vllm.model_executor.parallel_utils.communication_op import ( + broadcast_object_list) +from vllm.sequence import SamplerOutput, SequenceGroupMetadata +from vllm.utils import (get_ip, get_open_port, get_distributed_init_method, + make_async) + +logger = init_logger(__name__) + +# A map between the device type (in device config) to its worker module. +DEVICE_TO_WORKER_MODULE_MAP = { + "cuda": "vllm.worker.worker", + "neuron": "vllm.worker.neuron_worker", +} + + +class TorchrunGPUExecutor(GPUExecutor): + + def __init__( + self, + model_config: ModelConfig, + cache_config: CacheConfig, + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + ) -> None: + self.local_rank = int(os.getenv("LOCAL_RANK", "0")) + self.is_driver_worker = self.local_rank == 0 + super().__init__(model_config, cache_config, parallel_config, + scheduler_config, device_config, lora_config) + + def _init_worker(self): + # Lazy import the Worker to avoid importing torch.cuda/xformers + # before CUDA_VISIBLE_DEVICES is set in the Worker + Worker = self._dispatch_worker() + + assert self.parallel_config.world_size > 1, ( + "TorchrunGPUExecutor only supports multiple GPUs.") + + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + self.driver_worker = Worker( + self.model_config, + self.parallel_config, + self.scheduler_config, + self.device_config, + local_rank=self.local_rank, + rank=self.local_rank, + distributed_init_method=distributed_init_method, + lora_config=self.lora_config, + kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=self.is_driver_worker, + ) + self.driver_worker.init_model() + self.driver_worker.load_model() + + def execute_model(self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]]) -> SamplerOutput: + output = self.driver_worker.execute_model( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy, + ) + if self.is_driver_worker: + broadcast_object_list([output], src=0) + else: + res = [None] + broadcast_object_list(res, src=0) + output = res[0] + return output + + +class TorchrunGPUExecutorAsync(TorchrunGPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + seq_group_metadata_list: List[SequenceGroupMetadata], + blocks_to_swap_in: Dict[int, int], + blocks_to_swap_out: Dict[int, int], + blocks_to_copy: Dict[int, List[int]], + ) -> SamplerOutput: + output = await make_async(self.driver_worker.execute_model)( + seq_group_metadata_list=seq_group_metadata_list, + blocks_to_swap_in=blocks_to_swap_in, + blocks_to_swap_out=blocks_to_swap_out, + blocks_to_copy=blocks_to_copy) + return output + + async def check_health_async(self) -> None: + # TorchrunGPUExecutor will always be healthy as long as + # it's running. + return diff --git a/vllm/model_executor/layers/attention/attention.py b/vllm/model_executor/layers/attention/attention.py index 4b63b9eaf59a7..60f456d451f19 100644 --- a/vllm/model_executor/layers/attention/attention.py +++ b/vllm/model_executor/layers/attention/attention.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.model_executor.input_metadata import InputMetadata from vllm.utils import is_hip +import os logger = init_logger(__name__) @@ -34,11 +35,12 @@ def __init__( sliding_window: Optional[int] = None, ) -> None: super().__init__() - if _use_flash_attn(): + if use_triton := _use_flash_attn(): from vllm.model_executor.layers.attention.backends.flash_attn import FlashAttentionBackend # noqa: E501 self.backend = FlashAttentionBackend(num_heads, head_size, scale, num_kv_heads, alibi_slopes, - sliding_window) + sliding_window, + use_triton == 2) else: from vllm.model_executor.layers.attention.backends.xformers import XFormersBackend # noqa: E501 self.backend = XFormersBackend(num_heads, head_size, scale, @@ -59,26 +61,42 @@ def forward( @lru_cache(maxsize=1) -def _use_flash_attn() -> bool: - try: - import flash_attn # noqa: F401 - except ImportError: - logger.info("flash_attn is not found. Using xformers backend.") - return False - - if is_hip(): - # AMD GPUs. - return False - if torch.cuda.get_device_capability()[0] < 8: +def _use_flash_attn() -> int: + """Returns if and which flash attention to use. + + Returns: + int: 0 for none, + 1 for default implementation, + 2 for triton implementation. + """ + use_flash_attn_triton = os.environ.get('VLLM_USE_FLASH_ATTN_TRITON', "True").lower() in ("true", "1") + if not ( use_flash_attn_triton and is_hip()): + # AMD GPUs can use flash_attn package or triton impl. + try: + import flash_attn # noqa: F401 + except ImportError: + logger.info("flash_attn is not found. Using xformers backend.") + return 0 + + if (not is_hip()) and torch.cuda.get_device_capability()[0] < 8: # Volta and Turing NVIDIA GPUs. logger.info("flash_attn is not supported on Turing or older GPUs. " "Using xformers backend.") - return False + return 0 + + if is_hip() and torch.cuda.get_device_capability()[0] != 9: + # not Instinct series GPUs. + logger.info("flash_atten is not supported on NAVI GPUs. " + "Using xformers backend.") + return 0 + if torch.get_default_dtype() not in (torch.float16, torch.bfloat16): logger.info( "flash_attn only supports torch.float16 or torch.bfloat16. " "Using xformers backend.") - return False + return 0 - logger.info("Using flash_attn backend.") - return True + logger.info(f"""Using {'Triton' + if use_flash_attn_triton + else ''} flash_attn backend.""") + return 2 if use_flash_attn_triton else 1 diff --git a/vllm/model_executor/layers/attention/backends/flash_attn.py b/vllm/model_executor/layers/attention/backends/flash_attn.py index 58ccd461b993e..15fcc56922586 100644 --- a/vllm/model_executor/layers/attention/backends/flash_attn.py +++ b/vllm/model_executor/layers/attention/backends/flash_attn.py @@ -1,12 +1,15 @@ """Attention layer with Flash and PagedAttention.""" from typing import List, Optional +from vllm.utils import is_hip from flash_attn import flash_attn_func import torch from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.attention.ops.paged_attn import ( PagedAttentionImpl) +from vllm.model_executor.layers.attention.ops.flash_attention_triton import ( + triton_attention) class FlashAttentionBackend: @@ -19,6 +22,7 @@ def __init__( num_kv_heads: Optional[int] = None, alibi_slopes: Optional[List[float]] = None, sliding_window: Optional[int] = None, + use_triton: Optional[bool] = False, ) -> None: self.num_heads = num_heads self.head_size = head_size @@ -28,6 +32,7 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes + self.use_triton = use_triton assert self.num_heads % self.num_kv_heads == 0 self.num_queries_per_kv = self.num_heads // self.num_kv_heads @@ -40,6 +45,15 @@ def __init__( self.sliding_window = ((self.sliding_window, self.sliding_window) if self.sliding_window is not None else (-1, -1)) + def repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + tokens, n_kv_heads, head_dim = x.shape + return ( + x[:, :, None, :] + .expand(tokens, n_kv_heads, n_rep, head_dim) + .reshape(tokens, n_kv_heads * n_rep, head_dim) + ) + def forward( self, query: torch.Tensor, @@ -81,19 +95,45 @@ def forward( # Prompt run. if (key_cache is None or value_cache is None or input_metadata.block_tables.numel() == 0): + if self.use_triton and (self.num_kv_heads != self.num_heads): + # Interleave for MQA workaround. + key = self.repeat_kv(key, self.num_queries_per_kv) + value = self.repeat_kv(value, self.num_queries_per_kv) + # normal attention query = query.unflatten(0, (batch_size, seq_len)) key = key.unflatten(0, (batch_size, seq_len)) value = value.unflatten(0, (batch_size, seq_len)) - output = flash_attn_func( - query, - key, - value, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - ) + if self.use_triton: + output, _ = triton_attention( + query, + key, + value, + None, + input_metadata, + True, + self.scale, + ) + else: + if is_hip(): + #XXX: window_size and alibi_slopes not supported + output = flash_attn_func( + query, + key, + value, + softmax_scale=self.scale, + causal=True, + ) + else: + output = flash_attn_func( + query, + key, + value, + softmax_scale=self.scale, + causal=True, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + ) else: # prefix-enabled attention output = PagedAttentionImpl.forward_prefix( diff --git a/vllm/model_executor/layers/attention/ops/flash_attention_triton.py b/vllm/model_executor/layers/attention/ops/flash_attention_triton.py new file mode 100644 index 0000000000000..80962e4cf9d9a --- /dev/null +++ b/vllm/model_executor/layers/attention/ops/flash_attention_triton.py @@ -0,0 +1,538 @@ +#!/usr/bin/env python +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) +Credits: OpenAI kernel team, AMD ML Frameworks Triton team + +Features supported: + +1) Fwd with causal masking +2) Any sequence lengths without padding (currently fwd kernel only) +3) Support for different sequence lengths for q and k +4) Nested tensor API currently does not support dropout or bias. + +Not currently supported: + +1) Non power of two head dims + +""" + +import torch +import triton +import triton.language as tl + +torch_dtype:tl.constexpr = torch.float16 + +TORCH_HAS_FP8E5 = hasattr(torch, 'float8_e5m2fnuz') +if TORCH_HAS_FP8E5: + torch_dtype:tl.constexpr = torch.float8_e5m2fnuz + +@triton.jit +def cdiv_fn(x,y): + return (x + y - 1) // y + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + +@triton.jit +def dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride): + ms = tl.arange(0, m) + ns = tl.arange(0, n) + return philox_offset + ms[:, None] * stride + ns[None, :] + +@triton.jit +def dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_offsets = dropout_offsets(philox_seed, philox_offset, dropout_p, m, n, stride).to(tl.uint32) + # TODO: use tl.randint for better performance + return tl.rand(philox_seed, rng_offsets) + +@triton.jit +def dropout_mask(philox_seed, philox_offset, dropout_p, m, n, stride): + rng_output = dropout_rng(philox_seed, philox_offset, dropout_p, m, n, stride) + rng_keep = rng_output > dropout_p + return rng_keep + +@triton.jit +def load_fn(block_ptr, first, second, pad): + if first and second: + tensor = tl.load(block_ptr, boundary_check=(0,1), padding_option=pad) + elif first: + tensor = tl.load(block_ptr, boundary_check=(0,), padding_option=pad) + elif second: + tensor = tl.load(block_ptr, boundary_check=(1,), padding_option=pad) + else: + tensor = tl.load(block_ptr) + return tensor + +@triton.jit +def _attn_fwd_inner( + acc, l_i, m_i, q, + K_block_ptr, V_block_ptr, + start_m, + actual_seqlen_k, + dropout_p, + philox_seed, + batch_philox_offset, + encoded_softmax_block_ptr, + block_min, block_max, + offs_n_causal, + masked_blocks, + n_extra_tokens, + bias_ptr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + OFFS_M: tl.constexpr, + OFFS_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + MASK_STEPS: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, + RETURN_ENCODED_SOFTMAX: tl.constexpr, + PADDED_HEAD: tl.constexpr +): + # loop over k, v, and update accumulator + for start_n in range (block_min, block_max, BLOCK_N): + # For padded blocks, we will overrun the tensor size if + # we load all BLOCK_N. For others, the blocks are all within range. + k = load_fn(K_block_ptr, PADDED_HEAD, MASK_STEPS and (n_extra_tokens != 0), "zero") + if PRE_LOAD_V: + v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + # We start from end of seqlen_k so only the first iteration would need + # to be checked for padding if it is not a multiple of block_n + # TODO: This can be optimized to only be true for the padded block. + if MASK_STEPS: + # If this is the last block / iteration, we want to + # mask if the sequence length is not a multiple of block size + # a solution is to always do BLOCK_M // BLOCK_N + 1 steps if not is_modulo_mn. + # last step might get wasted but that is okay. check if this masking works For + # that case. + if (start_n + BLOCK_N == block_max) and (n_extra_tokens != 0): + boundary_m = tl.full([BLOCK_M], actual_seqlen_k, dtype=tl.int32) + size_n = start_n + OFFS_N[None,:] + mask = size_n < boundary_m[:,None] + qk = tl.where(mask, qk, float("-inf")) + if IS_CAUSAL: + causal_boundary = start_n + offs_n_causal + causal_mask = OFFS_M[:, None] >= causal_boundary[None, :] + qk = tl.where(causal_mask, qk, float("-inf")) + # -- compute qk ---- + qk += tl.dot(q, k) + if bias_ptr is not None: + bias = load_fn(bias_ptr, False, MASK_STEPS and (n_extra_tokens != 0), "zero") + # While bias is added after multiplying qk with sm_scale, + # our optimization to use 2^x instead of e^x results in an additional + # scale factor of log2(e) which we must also multiply the bias with. + qk += (bias * 1.44269504089) + m_ij = tl.maximum(m_i, tl.max(qk,1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + + # CAVEAT: Must update l_ij before applying dropout + l_ij = tl.sum(p, 1) + if ENABLE_DROPOUT: + philox_offset = batch_philox_offset + start_m * BLOCK_M * actual_seqlen_k + start_n - BLOCK_N + keep = dropout_mask(philox_seed, philox_offset, dropout_p, BLOCK_M, BLOCK_N, actual_seqlen_k) + if RETURN_ENCODED_SOFTMAX: + tl.store(encoded_softmax_block_ptr, tl.where(keep, p, -p).to(encoded_softmax_block_ptr.type.element_ty)) + p = tl.where(keep, p, 0.0) + elif RETURN_ENCODED_SOFTMAX: + tl.store(encoded_softmax_block_ptr, p.to(encoded_softmax_block_ptr.type.element_ty)) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not PRE_LOAD_V: + v = load_fn(V_block_ptr, MASK_STEPS and (n_extra_tokens != 0), PADDED_HEAD, "zero") + # -- update m_i and l_i + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + acc += tl.dot(p.to(V_block_ptr.type.element_ty), v) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, (0, BLOCK_N)) + return acc, l_i, m_i + +@triton.autotune( + configs=[ + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 64, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 256, 'BLOCK_N': 128, 'waves_per_eu': 2, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': True}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'waves_per_eu': 3, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 64, 'BLOCK_N': 64, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + triton.Config({'BLOCK_M': 32, 'BLOCK_N': 32, 'waves_per_eu': 4, 'PRE_LOAD_V': False}, num_stages=1, num_warps=8), + # TODO: This config fails with head_size not pow2 with data mismatches. Check why. + # triton.Config({'BLOCK_M': 32, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + triton.Config({'BLOCK_M': 16, 'BLOCK_N': 16, 'waves_per_eu': 1, 'PRE_LOAD_V': False}, num_stages=1, num_warps=4), + ], + key=['hq', 'hk', 'IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], +) +@triton.jit +def attn_fwd( + Q, K, V, bias, sm_scale, L, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + stride_bz, stride_bh, stride_bm, stride_bn, + cu_seqlens_q, cu_seqlens_k, + dropout_p, philox_seed, philox_offset_base, encoded_softmax, + hq, hk, + ACTUAL_BLOCK_DMODEL:tl.constexpr, + MAX_SEQLENS_Q:tl.constexpr, MAX_SEQLENS_K:tl.constexpr, + VARLEN: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + PRE_LOAD_V: tl.constexpr, + BIAS_TYPE: tl.constexpr, + ENABLE_DROPOUT: tl.constexpr, RETURN_ENCODED_SOFTMAX: tl.constexpr +): + start_m = tl.program_id(0) + off_h_q = tl.program_id(1) + off_z = tl.program_id(2) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + if VARLEN: + cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z) + cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1) + seqlen_q = cu_seqlens_q_end - cu_seqlens_q_start + # We have a one-size-fits-all grid in id(0). Some seqlens might be too + # small for all start_m so for those we return early. + if start_m * BLOCK_M > seqlen_q: + return + cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z) + cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1) + seqlen_k = cu_seqlens_k_end - cu_seqlens_k_start + else: + cu_seqlens_q_start = 0 + cu_seqlens_k_start = 0 + seqlen_q = MAX_SEQLENS_Q + seqlen_k = MAX_SEQLENS_K + + # Now we compute whether we need to exit early due to causal masking. + # This is because for seqlen_q > seqlen_k, M rows of the attn scores + # are completely masked, resulting in 0s written to the output, and + # inf written to LSE. We don't need to do any GEMMs in this case. + # This block of code determines what N is, and if this WG is operating + # on those M rows. + n_blocks = cdiv_fn(seqlen_k, BLOCK_N) + if (IS_CAUSAL): + # If seqlen_q == seqlen_k, the attn scores are a square matrix. + # If seqlen_q != seqlen_k, attn scores are rectangular which means + # the causal mask boundary is bottom right aligned, and ends at either + # the top edge (seqlen_q < seqlen_k) or left edge. + # This captures the decrease in n_blocks if we have a rectangular attn matrix + n_blocks_seqlen = cdiv_fn( + (start_m + 1) * BLOCK_M + seqlen_k - seqlen_q, + BLOCK_N + ) + # This is what adjusts the block_max for the current WG, only + # if IS_CAUSAL. Otherwise we want to always iterate through all n_blocks + n_blocks = min(n_blocks, n_blocks_seqlen) + # If we have no blocks after adjusting for seqlen deltas, this WG is part of + # the blocks that are all 0. We exit early. + if n_blocks <= 0: + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=Out.type.element_ty) + # We still need to write 0s to the result + #tl.store(O_block_ptr, acc.to(Out.type.element_ty), boundary_check=(0,1)) + #l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # We store inf to LSE, not -inf because in the bwd pass, we subtract this + # from qk which makes it -inf, such that exp(qk - inf) = 0 for these masked blocks. + #l = tl.full([BLOCK_M], value=float("inf"), dtype=tl.float32) + #tl.store(l_ptrs, l) + # TODO: Should dropout and return encoded softmax be handled here too? + return + + is_mqa = hq != hk + off_h_k = off_h_q % hk if is_mqa else off_h_q + need_padding = False + n_extra_tokens = 0 + if seqlen_k < BLOCK_N: + need_padding = True + n_extra_tokens = BLOCK_N - seqlen_k + elif seqlen_k % BLOCK_N: + need_padding = True + n_extra_tokens = seqlen_k % BLOCK_N + padded_head = (ACTUAL_BLOCK_DMODEL != BLOCK_DMODEL) + + # Compute pointers for all the tensors used in this kernel. + q_offset = off_z * stride_qz + off_h_q * stride_qh + cu_seqlens_q_start * stride_qm + Q_block_ptr = tl.make_block_ptr( + base=Q + q_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + k_offset = off_z * stride_kz + off_h_k * stride_kh + cu_seqlens_k_start * stride_kn + K_block_ptr = tl.make_block_ptr( + base=K + k_offset, + shape=(ACTUAL_BLOCK_DMODEL, seqlen_k), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + v_offset = off_z * stride_vz + off_h_k * stride_vh + cu_seqlens_k_start * stride_vk + V_block_ptr = tl.make_block_ptr( + base=V + v_offset, + shape=(seqlen_k, ACTUAL_BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0) + ) + if BIAS_TYPE != 0: + bias_ptr = tl.make_block_ptr( + base=bias + off_h_q * stride_bh, + shape=(seqlen_q, seqlen_k), + strides=(stride_bm, stride_bn), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0), + ) + else: + bias_ptr = None + if ENABLE_DROPOUT: + batch_philox_offset = philox_offset_base + off_hz * seqlen_q * seqlen_k + else: + batch_philox_offset = 0 + # We can ask to return the dropout mask without actually doing any dropout. In + # this case, we return an invalid pointer so indicate the mask is not valid. + # TODO: Fix encoded softmax. It currently uses just h_q in the base offset. + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.make_block_ptr( + base=encoded_softmax + off_h_q * seqlen_q * seqlen_k, + shape=(seqlen_q, seqlen_k), + strides=(seqlen_k, 1), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_N), + order=(1, 0) + ) + else: + encoded_softmax_block_ptr = 0 + # initialize pointer to m and l + m_i = tl.full([BLOCK_M], float("-inf"), dtype=tl.float32) + l_i = tl.full([BLOCK_M], 1.0, dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use 2^x in the loop as we do not + # have native e^x support in HW. + qk_scale = sm_scale * 1.44269504089 + # Q is loaded once at the beginning and shared by all N blocks. + q = load_fn(Q_block_ptr, True, padded_head, "zero") + q = (q * qk_scale).to(Q_block_ptr.type.element_ty) + + # Here we compute how many full and masked blocks we have. + padded_block_k = n_extra_tokens != 0 + is_modulo_mn = not padded_block_k and (seqlen_q % BLOCK_M == 0) + if IS_CAUSAL: + # There are always at least BLOCK_M // BLOCK_N masked blocks. + # Additionally there might be one more due to dissimilar seqlens. + masked_blocks = BLOCK_M // BLOCK_N + (not is_modulo_mn) + else: + # Padding on Q does not need to be masked in the FA loop. + masked_blocks = padded_block_k + # if IS_CAUSAL, not is_modulo_mn does not always result in an additional block. + # In this case we might exceed n_blocks so pick the min. + masked_blocks = min(masked_blocks, n_blocks) + n_full_blocks = n_blocks - masked_blocks + block_min = 0 + block_max = n_blocks * BLOCK_N + # Compute for full blocks. Here we set causal to false regardless of its actual + # value because there is no masking. Similarly we do not need padding. + if n_full_blocks > 0: + block_max = (n_blocks - masked_blocks) * BLOCK_N + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, seqlen_k, + dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + # _, _, offs_n_causal, masked_blocks, n_extra_tokens, _ + block_min, block_max, 0, 0, 0, bias_ptr, + # IS_CAUSAL, .... + False, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, False, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head + ) + block_min = block_max + block_max = n_blocks * BLOCK_N + + tl.debug_barrier() + # Remaining blocks, if any, are full / not masked. + if (masked_blocks > 0): + if IS_CAUSAL: + offs_n_causal = offs_n + (seqlen_q - seqlen_k) + else: + offs_n_causal = 0 + K_block_ptr = tl.advance(K_block_ptr, (0, n_full_blocks*BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (n_full_blocks*BLOCK_N, 0)) + if bias_ptr is not None: + bias_ptr = tl.advance(bias_ptr, (0, n_full_blocks*BLOCK_N)) + if RETURN_ENCODED_SOFTMAX: + encoded_softmax_block_ptr = tl.advance(encoded_softmax_block_ptr, + (0, n_full_blocks)) + acc, l_i, m_i = _attn_fwd_inner( + acc, l_i, m_i, q, K_block_ptr, V_block_ptr, + start_m, seqlen_k, + dropout_p, philox_seed, batch_philox_offset, encoded_softmax_block_ptr, + block_min, block_max, offs_n_causal, masked_blocks, n_extra_tokens, bias_ptr, + IS_CAUSAL, BLOCK_M, BLOCK_DMODEL, BLOCK_N, offs_m, offs_n, + # _, MASK_STEPS, ... + PRE_LOAD_V, True, ENABLE_DROPOUT, RETURN_ENCODED_SOFTMAX, padded_head + ) + # epilogue + acc = acc / l_i[:, None] + if ENABLE_DROPOUT: + acc = acc / (1 - dropout_p) + # If seqlen_q > seqlen_k but the delta is not a multiple of BLOCK_M, + # then we have one block with a row of all NaNs which come from computing + # softmax over a row of all -infs (-inf - inf = NaN). We check for that here + # and store 0s where there are NaNs as these rows should've been zeroed out. + end_m_idx = (start_m + 1) * BLOCK_M + start_m_idx = start_m * BLOCK_M + causal_start_idx = seqlen_q - seqlen_k + acc = acc.to(Out.type.element_ty) + if IS_CAUSAL: + if causal_start_idx > start_m_idx and causal_start_idx < end_m_idx: + out_mask_boundary = tl.full((BLOCK_DMODEL,), causal_start_idx, dtype=tl.int32) + mask_m_offsets = start_m_idx + tl.arange(0, BLOCK_M) + out_ptrs_mask = mask_m_offsets[:, None] >= out_mask_boundary[None, :] + z = 0.0 + acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) + # write back LSE + #l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # If seqlen_q not multiple of BLOCK_M, we need to mask out the last few rows. + # This is only true for the last M block. For others, overflow_size will be -ve + #overflow_size = end_m_idx - seqlen_q + #if overflow_size > 0: + # boundary = tl.full((BLOCK_M,), BLOCK_M - overflow_size, dtype=tl.int32) + # # This is a > check because mask being 0 blocks the store. + # l_ptrs_mask = boundary > tl.arange(0, BLOCK_M) + # tl.store(l_ptrs, m_i + tl.math.log2(l_i), mask=l_ptrs_mask) + #else: + # tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + + # write back O + o_offset = off_z * stride_oz + cu_seqlens_q_start * stride_om + off_h_q * stride_oh + O_block_ptr = tl.make_block_ptr( + base=Out + o_offset, + shape=(seqlen_q, ACTUAL_BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + # Need boundary check on this to make sure the padding from the + # Q and KV tensors in both dims are not part of what we store back. + # TODO: Do the boundary check optionally. + tl.store(O_block_ptr, acc, boundary_check=(0,1)) + +def check_args(q, k, v, o, max_seqlens): + assert q.dim() == k.dim() and q.dim() == v.dim() + assert q.dim() == 4 + batch, nheads_q, seqlen_q, head_size = q.shape + _, nheads_k, seqlen_k, _ = k.shape + assert max_seqlens > 0 + assert k.shape == v.shape + assert q.shape[-1] == k.shape[-1] and q.shape[-1] == v.shape[-1] + # TODO: Change assert if we support qkl f8 and v f16 + assert q.dtype == k.dtype and q.dtype == v.dtype + # TODO: Fix assert to check head size <=256 once supported + assert head_size <= 128 + assert o.shape == q.shape + assert (nheads_q % nheads_k) == 0 + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, o, metadata, causal=False, sm_scale=1.0, bias=None): + if o is None: + o = torch.empty_like(q, dtype=v.dtype) + check_args(q, k, v, o, metadata.max_seq_len) + + batch, seqlen_q, nheads_q, head_size = q.shape + _, seqlen_k, nheads_k, _ = k.shape + q_strides = (q.stride(0), q.stride(2), q.stride(1), q.stride(3)) + k_strides = (k.stride(0), k.stride(2), k.stride(1), k.stride(3)) + v_strides = (v.stride(0), v.stride(2), v.stride(1), v.stride(3)) + o_strides = (o.stride(0), o.stride(2), o.stride(1), o.stride(3)) + + # Get closest power of 2 over or equal to 32. + unpadded_head_dims = {32, 64, 128} + if head_size not in unpadded_head_dims: + padded_d_model = None + for i in unpadded_head_dims: + if i > head_size: + padded_d_model = i + break + assert padded_d_model is not None + else: + padded_d_model = head_size + + + grid = lambda META: ( + triton.cdiv(metadata.max_seq_len, META['BLOCK_M']), + nheads_q, + batch + ) + + encoded_softmax = None + + # Seed the RNG so we get reproducible results for testing. + philox_seed = 0x1BF52 + philox_offset = 0x1D4B42 + + if bias is not None: + bias_strides = (bias.stride(0), bias.stride(1), + bias.stride(2), bias.stride(3)) + else: + bias_strides = (0,0,0,0) + + attn_fwd[grid]( + q, k, v, bias, sm_scale, None, o, + *q_strides, *k_strides, *v_strides, *o_strides, *bias_strides, + None, None, + dropout_p=0.0, + philox_seed=philox_seed, + philox_offset_base=philox_offset, + encoded_softmax=encoded_softmax, + hq=nheads_q, hk=nheads_k, + ACTUAL_BLOCK_DMODEL=head_size, + MAX_SEQLENS_Q=metadata.max_seq_len, + MAX_SEQLENS_K=metadata.max_seq_len, + IS_CAUSAL=causal, + VARLEN=False, + BLOCK_DMODEL=padded_d_model, + BIAS_TYPE=0 if bias is None else 1, + ENABLE_DROPOUT=False, + RETURN_ENCODED_SOFTMAX=False + ) + + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = head_size + ctx.causal = causal + ctx.dropout_p = 0.0 + ctx.philox_seed = philox_seed + ctx.philox_offset = philox_offset + ctx.encoded_softmax = encoded_softmax + ctx.return_encoded_softmax = False + return o, encoded_softmax + +triton_attention = _attention.apply diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 40e681df48f86..6e66072627391 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -13,6 +13,8 @@ divide, split_tensor_along_last_dim) from vllm.model_executor.utils import set_weight_attrs from vllm.logger import init_logger +from vllm.utils import is_hip +from vllm import custom_ops logger = init_logger(__name__) @@ -72,6 +74,30 @@ def apply_weights(self, x: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: weight = weights["weight"] + if is_hip() and x.view(-1, x.size(-1)).shape[0] == 1: + batched = False + if x.dim() == 3: + inp = x.view(-1, x.size(-1)) + batched = True + else: + inp = x + m, k = weight.shape[0], inp.shape[1] + out = torch.empty(inp.shape[0], + weight.shape[0], + dtype=inp.dtype, + device='cuda') + if (k == 8192 and + (m == 1280 or m == 7168)) or (k == 3584 and m == 8192): + custom_ops.LLMM1(weight, inp, out, 8) + elif k <= 8192 and k % 8 == 0 and m % 4 == 0: + custom_ops.LLMM1(weight, inp, out, 4) + else: + out = F.linear(inp, weight) + if batched: + out = out.view(x.shape[0], x.shape[1], weight.shape[0]) + if bias is not None: + out = out + bias + return out if self.separate_bias_add: if bias is not None: return F.linear(x, weight) + bias diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 4c163dfdab537..30bfc11f22fe7 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -46,6 +46,7 @@ from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput +from vllm import custom_ops KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -74,8 +75,17 @@ def __init__( self.act_fn = SiluAndMul() def forward(self, x): - gate_up, _ = self.gate_up_proj(x) - x = self.act_fn(gate_up) + if x.shape[0] == 1 and x.shape[1] == 1: + out = torch.empty(x.shape[0], + self.gate_up_proj.weight.shape[0] // 2, + dtype=x.dtype, + device=x.device) + custom_ops.LLMM_Silu(self.gate_up_proj.weight, + x.view(-1, x.size(-1)), out, 8) + x = out.view(x.shape[0], x.shape[1], out.shape[1]) + else: + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index 0dcd4018afa5f..6f446db6c1cda 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -268,12 +268,19 @@ def init_distributed_environment( "distributed_init_method must be set if torch.distributed " "is not already initialized") else: - torch.distributed.init_process_group( - backend="nccl", - world_size=parallel_config.world_size, - rank=rank, - init_method=distributed_init_method, - ) + if parallel_config.worker_use_torchrun: + torch.distributed.init_process_group( + backend="nccl", + world_size=parallel_config.world_size, + init_method="env://", + ) + else: + torch.distributed.init_process_group( + backend="nccl", + world_size=parallel_config.world_size, + rank=rank, + init_method=distributed_init_method, + ) if cupy_utils.is_initialized(): cupy_world_size = cupy_utils.get_world_size()