Skip to content

Commit

Permalink
merged with latest upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
kliuae committed Nov 29, 2023
2 parents e19a64c + 9a1781c commit 0f67117
Show file tree
Hide file tree
Showing 28 changed files with 867 additions and 152 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -177,3 +177,7 @@ _build/
# vim swap files
*.swo
*.swp

# hip files generated by PyTorch
*.hip
*_hip*
7 changes: 4 additions & 3 deletions csrc/activation_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>

#include "cuda_compat.h"
#include "dispatch_utils.h"

namespace vllm {
Expand All @@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel(
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
out[token_idx * d + idx] = silu(x) * y;
}
}
Expand Down Expand Up @@ -57,7 +58,7 @@ __global__ void activation_kernel(
const int d) {
const int64_t token_idx = blockIdx.x;
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
const scalar_t x = __ldg(&input[token_idx * d + idx]);
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
out[token_idx * d + idx] = ACT_FN(x);
}
}
Expand Down
48 changes: 37 additions & 11 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@

#include <algorithm>

#ifndef USE_ROCM
#define WARP_SIZE 32
#else
#define WARP_SIZE 64
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
Expand All @@ -40,7 +44,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Compute the sum per warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}

// Warp leaders store the data to shared memory.
Expand All @@ -59,11 +63,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
// Parallel reduction inside the warp.
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
}

// Broadcast to other threads.
return __shfl_sync(uint32_t(-1), sum, 0);
return VLLM_SHFL_SYNC(sum, 0);
}

// TODO(woosuk): Merge the last two dimensions of the grid.
Expand Down Expand Up @@ -223,7 +227,7 @@ __device__ void paged_attention_kernel(
// The 0-th thread of each thread group already has its max qk value.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
if (lane == 0) {
red_smem[warp_idx] = qk_max;
Expand All @@ -235,10 +239,10 @@ __device__ void paged_attention_kernel(
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
}
// Broadcast the max qk value to all threads.
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
qk_max = VLLM_SHFL_SYNC(qk_max, 0);

// Get the sum of the exp values.
float exp_sum = 0.f;
Expand Down Expand Up @@ -326,7 +330,7 @@ __device__ void paged_attention_kernel(
float acc = accs[i];
#pragma unroll
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
}
accs[i] = acc;
}
Expand Down Expand Up @@ -492,7 +496,7 @@ __global__ void paged_attention_v2_reduce_kernel(
// Reduce within the warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
if (lane == 0) {
red_smem[warp_idx] = max_logit;
Expand All @@ -502,10 +506,10 @@ __global__ void paged_attention_v2_reduce_kernel(
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
#pragma unroll
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
}
// Broadcast the max value to all threads.
max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
max_logit = VLLM_SHFL_SYNC(max_logit, 0);

// Load rescaled exp sums to shared memory.
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
Expand Down Expand Up @@ -538,9 +542,10 @@ __global__ void paged_attention_v2_reduce_kernel(

} // namespace vllm

#ifndef USE_ROCM
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
cudaFuncSetAttribute( \
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
(void*)vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
Expand All @@ -557,6 +562,27 @@ __global__ void paged_attention_v2_reduce_kernel(
q_stride, \
kv_block_stride, \
kv_head_stride);
#else
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
hipFuncSetAttribute( \
(void*)vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
hipFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
<<<grid, block, shared_mem_size, stream>>>( \
out_ptr, \
query_ptr, \
key_cache_ptr, \
value_cache_ptr, \
head_mapping_ptr, \
scale, \
block_tables_ptr, \
context_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride);
#endif

// TODO(woosuk): Tune NUM_THREADS.
template<
Expand Down
3 changes: 2 additions & 1 deletion csrc/attention/attention_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
*/
#pragma once

#include "../cuda_compat.h"
#include "attention_dtypes.h"

#include <float.h>
Expand All @@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
float qk = sum(qk_vec);
#pragma unroll
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
}
return qk;
}
Expand Down
19 changes: 16 additions & 3 deletions csrc/attention/dtype_bfloat16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,17 @@
#include "attention_generic.cuh"
#include "dtype_float32.cuh"

#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifndef USE_ROCM
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#else
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>

typedef __hip_bfloat162 __nv_bfloat162;
typedef __hip_bfloat16 __nv_bfloat16;
#endif

#include <stdint.h>

namespace vllm {
Expand Down Expand Up @@ -98,7 +107,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
assert(false);
#else
return a + b;
#ifndef USE_ROCM
return a + b;
#else
return __hadd(a, b);
#endif
#endif
}

Expand Down
90 changes: 84 additions & 6 deletions csrc/attention/dtype_float16.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include "attention_generic.cuh"
#include "dtype_float32.cuh"

#ifdef USE_ROCM
#include <hip/hip_fp16.h>
#endif

#include <stdint.h>

namespace vllm {
Expand Down Expand Up @@ -63,58 +67,114 @@ struct FloatVec<uint4> {

// Utility functions for type conversions.
inline __device__ uint32_t h0_h0(uint16_t a) {
#ifndef USE_ROCM
uint32_t b;
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
return b;
#else
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u16[0] = a;
tmp.u16[1] = a;
return tmp.u32;
#endif
}

inline __device__ float half_to_float(uint16_t h) {
#ifndef USE_ROCM
float f;
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
return f;
#else
float f;
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
return f;
#endif
}

inline __device__ float2 half2_to_float2(uint32_t v) {
#ifndef USE_ROCM
uint16_t lo, hi;
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
return make_float2(half_to_float(lo), half_to_float(hi));
#else
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u32 = v;
float2 ret;
ret.x = half_to_float(tmp.u16[0]);
ret.y = half_to_float(tmp.u16[1]);
return ret;
#endif
}

inline __device__ uint16_t float_to_half(float f) {
#ifndef USE_ROCM
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
return tmp.u16[0];
#else
uint16_t ret;
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(ret) : "v"(f));
return ret;
#endif
}

inline __device__ uint32_t float2_to_half2(float2 f) {
#ifndef USE_ROCM
union {
uint32_t u32;
uint16_t u16[2];
} tmp;

#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
return tmp.u32;
#else
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
#endif
union {
uint32_t u32;
uint16_t u16[2];
} tmp;
tmp.u16[0] = float_to_half(f.x);
tmp.u16[1] = float_to_half(f.y);
return tmp.u32;
#endif
}

// Vector addition.
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
#ifndef USE_ROCM
uint16_t c;
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
#else
uint16_t c;
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
return c;
#endif
}

inline __device__ uint32_t add(uint32_t a, uint32_t b) {
#ifndef USE_ROCM
uint32_t c;
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
#else
uint32_t c;
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
return c;
#endif
}

inline __device__ uint2 add(uint2 a, uint2 b) {
Expand Down Expand Up @@ -157,16 +217,28 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
// Vector multiplication.
template<>
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
#ifndef USE_ROCM
uint16_t c;
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
return c;
#else
uint16_t c;
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
return c;
#endif
}

template<>
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
#ifndef USE_ROCM
uint32_t c;
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
return c;
#else
uint32_t c;
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
return c;
#endif
}

template<>
Expand Down Expand Up @@ -271,9 +343,15 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {

// Vector fused multiply-add.
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
#ifndef USE_ROCM
uint32_t d;
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
return d;
#else
uint32_t d;
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
return d;
#endif
}

inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) {
Expand Down Expand Up @@ -441,4 +519,4 @@ inline __device__ void zero(uint16_t& dst) {
dst = uint16_t(0);
}

} // namespace vllm
} // namespace vllm
Loading

0 comments on commit 0f67117

Please sign in to comment.