diff --git a/src/kernels/attention/flash_infer/attention_kernel.h b/src/kernels/attention/flash_infer/attention_kernel.h index be2f07d8..bc9fdc4b 100644 --- a/src/kernels/attention/flash_infer/attention_kernel.h +++ b/src/kernels/attention/flash_infer/attention_kernel.h @@ -9,7 +9,6 @@ #include #include -#include #include #include #include @@ -24,6 +23,7 @@ #include #include "kv_cache.h" +#include "state_merge_kernel.h" namespace flashinfer { @@ -39,50 +39,50 @@ namespace { template -constexpr bool is_invalid_configuration(uint32_t num_frags_x, - uint32_t num_frags_y, - uint32_t num_frags_z, - uint32_t num_warps_x, - uint32_t num_warps_z) { - return ((num_frags_y < 4) || (num_frags_y == 4 && num_frags_z % 2 == 1) || - (num_frags_y > 4 && num_frags_y % (2 * num_warps_x) != 0) || - (num_frags_x * - (8 * num_frags_y + 2 * sizeof(DTypeQKAccum) * num_frags_z) >= +constexpr bool is_invalid_configuration(uint32_t num_iters_m, + uint32_t num_iters_k, + uint32_t num_iters_n, + uint32_t num_warps_m, + uint32_t num_warps_n) { + return ((num_iters_k < 4) || (num_iters_k == 4 && num_iters_n % 2 == 1) || + (num_iters_k > 4 && num_iters_k % (2 * num_warps_m) != 0) || + (num_iters_m * + (8 * num_iters_k + 2 * sizeof(DTypeQKAccum) * num_iters_n) >= 256) || - (sizeof(DTypeKV) == 1 && num_frags_z * 2 % num_warps_x != 0) || + (sizeof(DTypeKV) == 1 && num_iters_n * 2 % num_warps_m != 0) || (sizeof(DTypeKV) == 1 && pos_encoding_mode == PosEncodingMode::kRoPELlama)); } -template +template __device__ __forceinline__ uint32_t get_warp_idx_x() { - if constexpr (num_warps_x == 1) { + if constexpr (num_warps_m == 1) { return 0; } else { return threadIdx.y; } } -template +template __device__ __forceinline__ uint32_t get_warp_idx_z() { - if constexpr (num_warps_z == 1) { + if constexpr (num_warps_n == 1) { return 0; } else { return threadIdx.z; } } -template +template __device__ __forceinline__ uint32_t get_warp_idx() { - return get_warp_idx_z() * num_warps_x + - get_warp_idx_x(); + return get_warp_idx_z() * num_warps_m + + get_warp_idx_x(); } template @@ -97,23 +97,23 @@ __device__ __forceinline__ void page_produce_kv( // moment constexpr SharedMemFillMode fill_mode = produce_v ? SharedMemFillMode::kFillZero : SharedMemFillMode::kNoFill; - constexpr uint32_t head_dim = num_frags_y * 16; - constexpr uint32_t num_warps = num_warps_x * num_warps_z; + constexpr uint32_t head_dim = num_iters_k * 16; + constexpr uint32_t num_warps = num_warps_m * num_warps_n; constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); - const uint32_t warp_idx = get_warp_idx(), + const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; if constexpr (swizzle_mode == SwizzleMode::k128B) { uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; - // NOTE(Zihao): num_frags_z * 4 / num_warps_x = num_warps_z * num_frags_z * + // NOTE(Zihao): num_iters_n * 4 / num_warps_m = num_warps_n * num_iters_n * // 4 / num_warps - static_assert(num_frags_z * 4 % num_warps_x == 0); + static_assert(num_iters_n * 4 % num_warps_m == 0); #pragma unroll - for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) { + for (uint32_t i = 0; i < num_iters_n * 4 / num_warps_m; ++i) { DType* gptr = produce_v ? paged_kv.v_data(kv_offset[i]) : paged_kv.k_data(kv_offset[i]); #pragma unroll - for (uint32_t j = 0; j < num_frags_y / (8 / sizeof(DType)); ++j) { + for (uint32_t j = 0; j < num_iters_k / (8 / sizeof(DType)); ++j) { smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); @@ -123,16 +123,16 @@ __device__ __forceinline__ void page_produce_kv( *smem_offset = smem.template advance_offset_by_row( *smem_offset) - - sizeof(DType) * num_frags_y; + sizeof(DType) * num_iters_k; } - *smem_offset -= num_warps_z * num_frags_z * 16 * channel_size_128b_kv; + *smem_offset -= num_warps_n * num_iters_n * 16 * channel_size_128b_kv; } else { uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; - // NOTE(Zihao): num_frags_z * 2 / num_warps_x = num_warps_z * num_frags_z * + // NOTE(Zihao): num_iters_n * 2 / num_warps_m = num_warps_n * num_iters_n * // 2 / num_warps - static_assert(num_frags_z * 2 % num_warps_x == 0); + static_assert(num_iters_n * 2 % num_warps_m == 0); #pragma unroll - for (uint32_t i = 0; i < num_frags_z * 2 / num_warps_x; ++i) { + for (uint32_t i = 0; i < num_iters_n * 2 / num_warps_m; ++i) { DType* gptr = produce_v ? paged_kv.v_data(kv_offset[i]) : paged_kv.k_data(kv_offset[i]); smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); @@ -141,38 +141,40 @@ __device__ __forceinline__ void page_produce_kv( channel_size_128b_kv>( *smem_offset); } - *smem_offset -= num_warps_z * num_frags_z * 16 * channel_size_128b_kv; + *smem_offset -= num_warps_n * num_iters_n * 16 * channel_size_128b_kv; } } -template -__device__ __forceinline__ void init_states(float (*o_frag)[num_frags_y][8], +template +__device__ __forceinline__ void init_states(float (*o_frag)[num_iters_k][8], DTypeQKAccum (*m)[2], float (*d)[2]) { #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + for (uint32_t fy = 0; fy < num_iters_k; ++fy) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { + // o_frag: [num_iters_m, num_iters_k, 8] o_frag[fx][fy][reg_id] = 0.f; } } } #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { + // m/d: [num_iters_m, 2] m[fx][j] = DTypeQKAccum(-5e4); d[fx][j] = 1.f; } } } -template __device__ __forceinline__ void load_q_global_smem( @@ -183,59 +185,79 @@ __device__ __forceinline__ void load_q_global_smem( const uint32_t q_stride_h, const uint_fastdiv group_size, smem_t* q_smem) { - constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t head_dim = num_iters_k * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); - const uint32_t lane_idx = threadIdx.x, - warp_idx_x = get_warp_idx_x(); - - if (get_warp_idx_z() == 0) { - uint32_t q_smem_offset_w = q_smem->get_permuted_offset( - warp_idx_x * num_frags_x * 16 + lane_idx / 8, lane_idx % 8); - -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 4; ++j) { - uint32_t q, r; - group_size.divmod(packed_offset + lane_idx / 8 + fx * 16 + j * 4, q, r); - const uint32_t q_idx = q; - DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h; -#pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { - // load q fragment from gmem to smem - q_smem->load_128b_async( - q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); - q_smem_offset_w = q_smem->template advance_offset_by_column<8>( - q_smem_offset_w, fyo); - q_ptr += 8 * num_elems_per_128b(); - } - q_smem_offset_w = - q_smem->template advance_offset_by_row<4, channel_size_128b_q>( - q_smem_offset_w) - - 2 * num_frags_y; + const uint32_t lane_idx = threadIdx.x; + const uint32_t warp_idx_x = get_warp_idx_x(); + const uint32_t warp_idx_z = get_warp_idx_z(); + // let all warps load q + // rows to load: num_iters_m * 16 + // threads layout in warp: 4 x 8 + // | t0 | t1 | t2 | t3 | t4 | t5 | t6 | t7 | + // | t8 | t9 | t10 | t11 | t12 | t13 | t14 | t15 | + // | t16 | t17 | t18 | t19 | t20 | t21 | t22 | t23 | + // | t24 | t25 | t26 | t27 | t28 | t29 | t30 | t31 | + // + + // q_smem: [num_warps_m, num_iters_m, 16, head_dim] + uint32_t q_smem_x = warp_idx_x * num_iters_m * 16 + lane_idx / 8; + +#pragma unroll + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { + // each wrap loads 4 rows, loading 16 rows needs 4 iters +#pragma unroll + for (uint32_t j = 0; j < 4; ++j) { + const uint32_t packed_q_idx = + packed_offset + fx * 16 + lane_idx / 8 + j * 4; + uint32_t q, r; + group_size.divmod(packed_q_idx, q, r); + if (q >= qo_upper_bound) { + continue; + } + // q_ptr_base: [n_tokens, n_heads, head_dim] + // q_ptr for given header: [head_dim] + DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h + + warp_idx_z * 8 * num_elems_per_128b(); + + // load head_dim from global memory to shared memory using num_warps_n + // warps echa warp loads 8 columns once + uint32_t q_smem_y = warp_idx_z * 8 + lane_idx % 8; + while (q_smem_y * num_elems_per_128b() < head_dim) { + const uint32_t q_smem_offset_w = + q_smem->template get_permuted_offset(q_smem_x, + q_smem_y); + + // load q fragment from gmem to smem + q_smem->load_128b_async(q_smem_offset_w, q_ptr); + // move ahead by 8*int128_t for each warp + q_smem_y += (8 * num_warps_n); + q_ptr += (8 * num_elems_per_128b() * num_warps_n); } + + // move ahead by 4 rows + q_smem_x += 4; } } } -template __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale( smem_t* q_smem, const float sm_scale) { - const uint32_t warp_idx = get_warp_idx(), + const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; - constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t head_dim = num_iters_k * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); - constexpr uint32_t num_warps = num_warps_x * num_warps_z; + constexpr uint32_t num_warps = num_warps_m * num_warps_n; #pragma unroll - for (uint32_t i = 0; i < num_frags_x * head_dim / (num_warps_z * 16); ++i) { + for (uint32_t i = 0; i < num_iters_m * head_dim / (num_warps_n * 16); ++i) { vec_t tmp; tmp.load((DTypeQ*)(q_smem->base) + (i * num_warps + warp_idx) * 256 + lane_idx * 8); @@ -249,9 +271,9 @@ __device__ __forceinline__ void q_smem_inplace_multiply_sm_scale( } template * k_smem, uint32_t* k_smem_offset_r, - DTypeQKAccum (*s_frag)[num_frags_z][8], + DTypeQKAccum (*s_frag)[num_iters_n][8], const float soft_cap) { - constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t head_dim = num_iters_k * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); - uint32_t a_frag[num_frags_x][4], b_frag[4]; + uint32_t a_frag[num_iters_m][4], b_frag[4]; // compute q*k^T #pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + for (uint32_t fy = 0; fy < num_iters_k; ++fy) { #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx]); *q_smem_offset_r = q_smem->template advance_offset_by_row<16, channel_size_128b_q>( @@ -283,10 +305,10 @@ __device__ __forceinline__ void compute_qk( *q_smem_offset_r = q_smem->template advance_offset_by_column<2>(*q_smem_offset_r, fy) - - num_frags_x * 16 * channel_size_128b_q; + num_iters_m * 16 * channel_size_128b_q; #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (uint32_t fz = 0; fz < num_iters_n; ++fz) { if constexpr (sizeof(DTypeKV) == 1) { uint32_t b_frag_f8[2]; if (fy % 2 == 0) { @@ -306,7 +328,7 @@ __device__ __forceinline__ void compute_qk( *k_smem_offset_r); #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { if constexpr (std::is_same_v) { if (fy == 0) { mma::mma_sync_m16n16k16_row_col_f16f16f32( @@ -331,21 +353,21 @@ __device__ __forceinline__ void compute_qk( *k_smem_offset_r = k_smem->template advance_offset_by_column<2>( *k_smem_offset_r, fy / 2); } - *k_smem_offset_r -= num_frags_z * 16 * channel_size_128b_kv; + *k_smem_offset_r -= num_iters_n * 16 * channel_size_128b_kv; } else { *k_smem_offset_r = k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, fy) - - num_frags_z * 16 * channel_size_128b_kv; + num_iters_n * 16 * channel_size_128b_kv; } } - *q_smem_offset_r -= num_frags_y * 2; - *k_smem_offset_r -= num_frags_y * sizeof(DTypeKV); + *q_smem_offset_r -= num_iters_k * 2; + *k_smem_offset_r -= num_iters_k * sizeof(DTypeKV); if constexpr (std::is_same::value) { #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (uint32_t fz = 0; fz < num_iters_n; ++fz) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { s_frag[fx][fz][reg_id] = apply_logits_post_hook( @@ -356,9 +378,9 @@ __device__ __forceinline__ void compute_qk( } else { static_assert(std::is_same::value); #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (uint32_t fz = 0; fz < num_iters_n; ++fz) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { *(half2*)(&s_frag[fx][fz][reg_id * 2]) = @@ -371,19 +393,19 @@ __device__ __forceinline__ void compute_qk( } // TODO: move it to a separate file -template +template __device__ __forceinline__ void apply_alibi_bias( const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, const int32_t q_offset, const uint_fastdiv group_size, float (*alibi_slope)[2], - T (*s_frag)[num_frags_z][8]) { + T (*s_frag)[num_iters_n][8]) { const int32_t lane_idx = threadIdx.x; #pragma unroll - for (int32_t fx = 0; fx < num_frags_x; ++fx) { + for (int32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll - for (int32_t fz = 0; fz < num_frags_z; ++fz) { + for (int32_t fz = 0; fz < num_iters_n; ++fz) { #pragma unroll for (int32_t reg_id = 0; reg_id < 8; ++reg_id) { const int32_t q_idx = (qo_packed_idx_base + fx * 16 + lane_idx / 4 + @@ -399,9 +421,9 @@ __device__ __forceinline__ void apply_alibi_bias( } template __device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base, const uint32_t kv_idx_base, @@ -411,12 +433,12 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base, const uint32_t chunk_end, const uint_fastdiv group_size, uint8_t* custom_mask, - DTypeQKAccum (*s_frag)[num_frags_z][8]) { + DTypeQKAccum (*s_frag)[num_iters_n][8]) { const uint32_t lane_idx = threadIdx.x; #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (uint32_t fz = 0; fz < num_iters_n; ++fz) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { const uint32_t q_idx = (qo_packed_idx_base + fx * 16 + lane_idx / 4 + @@ -443,23 +465,23 @@ __device__ __forceinline__ void mask_s(const uint32_t qo_packed_idx_base, } } -template __device__ __forceinline__ void update_mdo_states( - DTypeQKAccum (*s_frag)[num_frags_z][8], - float (*o_frag)[num_frags_y][8], + DTypeQKAccum (*s_frag)[num_iters_n][8], + float (*o_frag)[num_iters_k][8], DTypeQKAccum (*m)[2], float (*d)[2]) { if constexpr (std::is_same_v) { #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { float m_prev = m[fx][j]; #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (uint32_t fz = 0; fz < num_iters_n; ++fz) { float m_local = max(max(s_frag[fx][fz][j * 2 + 0], s_frag[fx][fz][j * 2 + 1]), max(s_frag[fx][fz][j * 2 + 4], s_frag[fx][fz][j * 2 + 5])); @@ -471,14 +493,14 @@ __device__ __forceinline__ void update_mdo_states( float o_scale = math::ptx_exp2(m_prev - m[fx][j]); d[fx][j] *= o_scale; #pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + for (uint32_t fy = 0; fy < num_iters_k; ++fy) { o_frag[fx][fy][j * 2 + 0] *= o_scale; o_frag[fx][fy][j * 2 + 1] *= o_scale; o_frag[fx][fy][j * 2 + 4] *= o_scale; o_frag[fx][fy][j * 2 + 5] *= o_scale; } #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (uint32_t fz = 0; fz < num_iters_n; ++fz) { s_frag[fx][fz][j * 2 + 0] = math::ptx_exp2(s_frag[fx][fz][j * 2 + 0] - m[fx][j]); s_frag[fx][fz][j * 2 + 1] = @@ -492,13 +514,13 @@ __device__ __forceinline__ void update_mdo_states( } } else if constexpr (std::is_same_v) { #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { half m_prev[2]; #pragma unroll for (uint32_t j = 0; j < 2; ++j) { m_prev[j] = m[fx][j]; #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (uint32_t fz = 0; fz < num_iters_n; ++fz) { half2 m_local = __hmax2(*(half2*)&s_frag[fx][fz][j * 2], *(half2*)&s_frag[fx][fz][j * 2 + 4]); m[fx][j] = __hmax(m[fx][j], __hmax(m_local.x, m_local.y)); @@ -513,7 +535,7 @@ __device__ __forceinline__ void update_mdo_states( float o_scale = math::ptx_exp2(float(m_prev[j] - m[fx][j])); d[fx][j] *= o_scale; #pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + for (uint32_t fy = 0; fy < num_iters_k; ++fy) { o_frag[fx][fy][j * 2 + 0] *= o_scale; o_frag[fx][fy][j * 2 + 1] *= o_scale; o_frag[fx][fy][j * 2 + 4] *= o_scale; @@ -521,7 +543,7 @@ __device__ __forceinline__ void update_mdo_states( } half2 m2 = make_half2(m[fx][j], m[fx][j]); #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (uint32_t fz = 0; fz < num_iters_n; ++fz) { *(half2*)&s_frag[fx][fz][j * 2] = math::ptx_exp2(*(half2*)&s_frag[fx][fz][j * 2] - m2); *(half2*)&s_frag[fx][fz][j * 2 + 4] = @@ -532,9 +554,9 @@ __device__ __forceinline__ void update_mdo_states( } } -template * v_smem, uint32_t* v_smem_offset_r, - DTypeQKAccum (*s_frag)[num_frags_z][8], - float (*o_frag)[num_frags_y][8], + DTypeQKAccum (*s_frag)[num_iters_n][8], + float (*o_frag)[num_iters_k][8], float (*d)[2]) { - constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t head_dim = num_iters_k * 16; constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); - DTypeQ s_frag_f16[num_frags_x][num_frags_z][8]; + DTypeQ s_frag_f16[num_iters_m][num_iters_n][8]; if constexpr (std::is_same_v) { #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (uint32_t fz = 0; fz < num_iters_n; ++fz) { vec_cast::cast<8>(s_frag_f16[fx][fz], s_frag[fx][fz]); } } } #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (uint32_t fz = 0; fz < num_iters_n; ++fz) { if constexpr (std::is_same::value) { mma::rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); } else { @@ -573,9 +595,9 @@ __device__ __forceinline__ void compute_sfm_v( } #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (uint32_t fz = 0; fz < num_iters_n; ++fz) { #pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + for (uint32_t fy = 0; fy < num_iters_k; ++fy) { uint32_t b_frag[4]; if constexpr (sizeof(DTypeKV) == 1) { uint32_t b_frag_f8[2]; @@ -593,7 +615,7 @@ __device__ __forceinline__ void compute_sfm_v( v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); } #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { if constexpr (std::is_same::value) { mma::mma_sync_m16n16k16_row_col_f16f16f32( o_frag[fx][fy], (uint32_t*)(s_frag_f16[fx][fz]), b_frag); @@ -615,19 +637,19 @@ __device__ __forceinline__ void compute_sfm_v( *v_smem_offset_r = v_smem->template advance_offset_by_row<16, channel_size_128b_kv>( *v_smem_offset_r) - - sizeof(DTypeKV) * num_frags_y; + sizeof(DTypeKV) * num_iters_k; } - *v_smem_offset_r -= 16 * num_frags_z * channel_size_128b_kv; + *v_smem_offset_r -= 16 * num_iters_n * channel_size_128b_kv; } -template -__device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], +template +__device__ __forceinline__ void normalize_d(float (*o_frag)[num_iters_k][8], DTypeQKAccum (*m)[2], float (*d)[2]) { - float d_rcp[num_frags_x][2]; + float d_rcp[num_iters_m][2]; // compute reciprocal of d #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { d_rcp[fx][j] = @@ -636,9 +658,9 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], } #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + for (uint32_t fy = 0; fy < num_iters_k; ++fy) { #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { o_frag[fx][fy][reg_id] = @@ -652,32 +674,32 @@ __device__ __forceinline__ void normalize_d(float (*o_frag)[num_frags_y][8], * \brief Synchronize the states of the MDO kernel across the threadblock along * threadIdx.z. */ -template __device__ __forceinline__ void threadblock_sync_mdo_states( - float (*o_frag)[num_frags_y][8], + float (*o_frag)[num_iters_k][8], float* smem_workspace, DTypeQKAccum (*m)[2], float (*d)[2], const uint32_t warp_idx, const uint32_t lane_idx) { // only necessary when blockDim.z > 1 - if constexpr (num_warps_z > 1) { + if constexpr (num_warps_n > 1) { float2* smem_md = - (float2*)(smem_workspace + num_frags_x * num_frags_y * num_warps_x * - num_warps_z * warp_size * 8); - // o: [num_warps, num_frags_x, num_frags_y, warp_size(32), 8] - // md: [num_warps, num_frags_x, 2, warp_size(32), 2 (m/d)] + (float2*)(smem_workspace + num_iters_m * num_iters_k * num_warps_m * + num_warps_n * warp_size * 8); + // o: [num_warps, num_iters_m, num_iters_k, warp_size(32), 8] + // md: [num_warps, num_iters_m, 2, warp_size(32), 2 (m/d)] #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + for (uint32_t fy = 0; fy < num_iters_k; ++fy) { vec_t::memcpy( smem_workspace + - (((warp_idx * num_frags_x + fx) * num_frags_y + fy) * + (((warp_idx * num_iters_m + fx) * num_iters_k + fy) * warp_size + lane_idx) * 8, @@ -686,10 +708,10 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( } #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { - smem_md[((warp_idx * num_frags_x + fx) * 2 + j) * warp_size + + smem_md[((warp_idx * num_iters_m + fx) * 2 + j) * warp_size + lane_idx] = make_float2(float(m[fx][j]), d[fx][j]); } } @@ -697,16 +719,16 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( // synchronize m,d first __syncthreads(); #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - float o_scale[2][num_warps_z]; + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { + float o_scale[2][num_warps_n]; #pragma unroll for (uint32_t j = 0; j < 2; ++j) { float m_new = -5e4, d_new = 1.f; #pragma unroll - for (uint32_t i = 0; i < num_warps_z; ++i) { - float2 md = smem_md[(((i * num_warps_x + - get_warp_idx_x()) * - num_frags_x + + for (uint32_t i = 0; i < num_warps_n; ++i) { + float2 md = smem_md[(((i * num_warps_m + + get_warp_idx_x()) * + num_iters_m + fx) * 2 + j) * @@ -719,10 +741,10 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( } #pragma unroll - for (uint32_t i = 0; i < num_warps_z; ++i) { - float2 md = smem_md[(((i * num_warps_x + - get_warp_idx_x()) * - num_frags_x + + for (uint32_t i = 0; i < num_warps_n; ++i) { + float2 md = smem_md[(((i * num_warps_m + + get_warp_idx_x()) * + num_iters_m + fx) * 2 + j) * @@ -736,18 +758,18 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( } #pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + for (uint32_t fy = 0; fy < num_iters_k; ++fy) { vec_t o_new; o_new.fill(0.f); #pragma unroll - for (uint32_t i = 0; i < num_warps_z; ++i) { + for (uint32_t i = 0; i < num_warps_n; ++i) { vec_t oi; oi.load(smem_workspace + - ((((i * num_warps_x + - get_warp_idx_x()) * - num_frags_x + + ((((i * num_warps_m + + get_warp_idx_x()) * + num_iters_m + fx) * - num_frags_y + + num_iters_k + fy) * warp_size + lane_idx) * @@ -763,47 +785,46 @@ __device__ __forceinline__ void threadblock_sync_mdo_states( } } -template -__device__ __forceinline__ void write_o_reg_gmem( - float (*o_frag)[num_frags_y][8], - smem_t* o_smem, - DTypeOut* o_ptr_base, - const uint32_t o_packed_idx_base, - const uint32_t qo_upper_bound, - const uint32_t o_stride_n, - const uint32_t o_stride_h, - const uint_fastdiv group_size) { - constexpr uint32_t head_dim = num_frags_y * 16; +template +__device__ __forceinline__ void write_o_reg_smem( + float (*o_frag)[num_iters_k][8], + smem_t* o_smem) { + constexpr uint32_t head_dim = num_iters_k * 16; constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); - const uint32_t warp_idx_x = get_warp_idx_x(); + const uint32_t warp_idx_x = get_warp_idx_x(); + const uint32_t warp_idx_z = get_warp_idx_z(); const uint32_t lane_idx = threadIdx.x; - if (get_warp_idx_z() == 0) { + // o_frag: [num_iters_m, num_iters_k, 8] + if (warp_idx_z == 0) { + // write o from register to shared memory + // why not every thread writes to shared memory? #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + for (uint32_t fy = 0; fy < num_iters_k; ++fy) { uint32_t o_frag_f16[4]; vec_cast::cast<8>((DTypeOut*)o_frag_f16, o_frag[fx][fy]); #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED uint32_t o_smem_offset_w = o_smem->get_permuted_offset( - (warp_idx_x * num_frags_x + fx) * 16 + lane_idx % 16, + (warp_idx_x * num_iters_m + fx) * 16 + lane_idx % 16, fy * 2 + lane_idx / 16); o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); #else uint32_t o_smem_offset_w = o_smem->get_permuted_offset( - (warp_idx_x * num_frags_x + fx) * 16 + lane_idx / 4, fy * 2); + (warp_idx_x * num_iters_m + fx) * 16 + lane_idx / 4, fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; + // TODO: avoid manipulating permuted offset directly ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * channel_size_128b_out))[lane_idx % 4] = o_frag_f16[1]; ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = @@ -813,55 +834,87 @@ __device__ __forceinline__ void write_o_reg_gmem( #endif } } + } +} + +template +__device__ __forceinline__ void write_o_smem_gmem( + smem_t* o_smem, + DTypeOut* o_ptr_base, + const uint32_t o_packed_idx_base, + const uint32_t qo_upper_bound, + const uint32_t o_stride_n, + const uint32_t o_stride_h, + const uint_fastdiv group_size) { + constexpr uint32_t head_dim = num_iters_k * 16; + constexpr uint32_t channel_size_128b_out = + head_dim / num_elems_per_128b(); + const uint32_t warp_idx_x = get_warp_idx_x(); + const uint32_t warp_idx_z = get_warp_idx_z(); + const uint32_t lane_idx = threadIdx.x; - uint32_t o_smem_offset_w = - o_smem->get_permuted_offset( - warp_idx_x * num_frags_x * 16 + lane_idx / 8, lane_idx % 8); + // write o from shared memory to global memory + // o_smem: [num_warps_m, num_iters_m, 16, head_dim] + uint32_t o_smem_x = warp_idx_x * num_iters_m * 16 + lane_idx / 8; #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 4; ++j) { - uint32_t q, r; - group_size.divmod( - o_packed_idx_base + lane_idx / 8 + fx * 16 + j * 4, q, r); - const uint32_t o_idx = q; - DTypeOut* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h; + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { + // each wrap writes 4 rows, 16 rows needs 4(16/4) iters #pragma unroll - for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { - if (o_idx < qo_upper_bound) { - o_smem->store_128b(o_smem_offset_w, o_ptr); - } - o_ptr += 8 * num_elems_per_128b(); - o_smem_offset_w = o_smem->template advance_offset_by_column<8>( - o_smem_offset_w, fyo); - } - o_smem_offset_w = - o_smem->template advance_offset_by_row<4, channel_size_128b_out>( - o_smem_offset_w) - - 2 * num_frags_y; + for (uint32_t j = 0; j < 4; ++j) { + const uint32_t packed_o_idx = + o_packed_idx_base + lane_idx / 8 + fx * 16 + j * 4; + uint32_t q, r; + group_size.divmod(packed_o_idx, q, r); + // skip if out of boundary + if (q >= qo_upper_bound) { + continue; } + + DTypeOut* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h + + warp_idx_z * 8 * num_elems_per_128b(); + uint32_t o_smem_y = warp_idx_z * 8 + lane_idx % 8; + // write head_dim from shared memory to global memory + while (o_smem_y * num_elems_per_128b() < head_dim) { + const uint32_t o_smem_offset_w = + o_smem->template get_permuted_offset( + o_smem_x, o_smem_y); + o_smem->store_128b(o_smem_offset_w, o_ptr); + + // move ahead by 8 * int128_t for each warp + o_smem_y += (8 * num_warps_n); + o_ptr += 8 * num_elems_per_128b() * num_warps_n; + } + // move row by 4 + o_smem_x += 4; } } } } // namespace +// dim3 nblks(n_splits, num_kv_heads); +// dim3 nthrs(32, num_warps_m, num_warps_n); template __global__ -__launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( +__launch_bounds__(num_warps_m* num_warps_n* warp_size) void attention_kernel( IdType* __restrict__ request_indices, IdType* __restrict__ q_tile_indices, IdType* __restrict__ kv_tile_indices, @@ -884,41 +937,54 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( float* __restrict__ alibi_slopes) { static_assert(sizeof(DTypeQ) == 2); static_assert(sizeof(DTypeOut) == 2); + + // instead of using loge for softmax, we use log2 for better performance + // exp(x - max) == exp2(x * log_2(e) - max * log_2(e)) sm_scale *= (logits_post_hook == LogitsPostHook::kNone ? math::log2e : math::ptx_rcp(logits_soft_cap)); auto block = cg::this_thread_block(); const uint32_t kv_chunk_size = *kv_chunk_size_ptr; - const uint32_t bx = blockIdx.x, lane_idx = threadIdx.x, - warp_idx = get_warp_idx(), - kv_head_idx = blockIdx.z; + const uint32_t bx = blockIdx.x; + if (block_valid_mask && !block_valid_mask[bx]) { return; } - const uint32_t num_kv_heads = gridDim.z, - num_qo_heads = num_kv_heads * group_size; - float alibi_slopes_frag[num_frags_x][2]; - - const uint32_t request_idx = request_indices[bx], - qo_tile_idx = q_tile_indices[bx], - kv_tile_idx = kv_tile_indices[bx]; - constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; + + const uint32_t kv_head_idx = blockIdx.y; + const uint32_t q_head_idx_base = kv_head_idx * group_size; + const uint32_t lane_idx = threadIdx.x; + const uint32_t warp_idx = get_warp_idx(); + + const uint32_t num_kv_heads = gridDim.y; + const uint32_t num_qo_heads = num_kv_heads * group_size; + + const uint32_t request_idx = request_indices[bx]; + const uint32_t qo_tile_idx = q_tile_indices[bx]; + const uint32_t kv_tile_idx = kv_tile_indices[bx]; + + constexpr uint32_t num_rows_per_cta = num_iters_m * num_warps_m * 16; const uint32_t qo_len = q_indptr[request_idx + 1] - q_indptr[request_idx]; const uint32_t kv_len = kv_indptr[request_idx + 1] - kv_indptr[request_idx]; + // could kv_len be 0? const uint32_t kv_len_safe = kv_len > 0 ? kv_len : 1; const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len; - const uint32_t max_chunk_size = partition_kv ? kv_chunk_size : kv_len; - const uint32_t chunk_start = partition_kv ? kv_tile_idx * max_chunk_size : 0; + + // kv idx range for this chunk + const uint32_t chunk_start = partition_kv ? kv_tile_idx * kv_chunk_size : 0; const uint32_t chunk_end = - partition_kv ? min((kv_tile_idx + 1) * max_chunk_size, kv_len) : kv_len; + partition_kv ? min((kv_tile_idx + 1) * kv_chunk_size, kv_len) : kv_len; const uint32_t chunk_size = chunk_end - chunk_start; + // heads in query are flattened to first dimension, so we need to div + // group_size to get original request_idx const uint32_t qo_upper_bound = min(qo_len, ceil_div((qo_tile_idx + 1) * num_rows_per_cta, group_size)); - constexpr uint32_t head_dim = num_frags_y * 16; + constexpr uint32_t head_dim = num_iters_k * 16; + // TODO: static_assert even division constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_kv = @@ -926,47 +992,48 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); + // define shared memory extern __shared__ uint8_t smem[]; - - DTypeQKAccum s_frag[num_frags_x][num_frags_z][8]; - float o_frag[num_frags_x][num_frags_y][8]; - DTypeQKAccum m[num_frags_x][2]; - float d[num_frags_x][2]; - - init_states(o_frag, m, d); - - const uint32_t qo_packed_idx_base = - (qo_tile_idx * num_warps_x + get_warp_idx_x()) * - num_frags_x * 16; - const uint32_t q_stride_n = num_qo_heads * head_dim, q_stride_h = head_dim; + // swizzle_mode_q: 128B for 16-bit Q, + // 64B for 8-bit Q is not supported yet constexpr SwizzleMode swizzle_mode_q = SwizzleMode::k128B; + // [num_iters_m, num_warps_m, MMA_M, head_dim] smem_t qo_smem(smem); + + constexpr SwizzleMode swizzle_mode_kv = + (sizeof(DTypeKV) == 1 && head_dim == 64) ? SwizzleMode::k64B + : SwizzleMode::k128B; + // [num_iters_n, num_warps_n, MMA_M, head_dim] + smem_t k_smem( + smem + (num_warps_m * num_iters_m * sizeof(DTypeQ)) * 16 * head_dim); + // [num_iters_n, num_warps_n, MMA_M, head_dim] + smem_t v_smem(smem + + (num_warps_m * num_iters_m * sizeof(DTypeQ) + + num_warps_n * num_iters_n * sizeof(DTypeKV)) * + 16 * head_dim); + + const uint32_t q_stride_n = num_qo_heads * head_dim; + const uint32_t q_stride_h = head_dim; + + // TODO: move this into load_q_global_smem + // [n_tokens, n_heads, head_dim] DTypeQ* q_ptr_base = - q + get_elem_offset_impl(q_indptr[request_idx], - kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b(), - q_stride_n, - q_stride_h); - DTypeOut* o_ptr_base = - partition_kv ? o + kv_tile_idx * num_qo_heads * head_dim + - get_elem_offset_impl( - o_indptr[request_idx], - kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b(), - num_qo_heads * head_dim, - head_dim) - : o + get_elem_offset_impl( - o_indptr[request_idx], - kv_head_idx * group_size, - (lane_idx % 8) * num_elems_per_128b(), - num_qo_heads * head_dim, - head_dim); - uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( - get_warp_idx_x() * num_frags_x * 16 + - lane_idx % 16, - lane_idx / 16); + q + get_elem_offset_impl( + /*elem_idx=*/q_indptr[request_idx], + /*head_idx=*/q_head_idx_base, + /*feat_idx=*/(lane_idx % 8) * num_elems_per_128b(), + q_stride_n, + q_stride_h); + + // base index of flattened qo for cur warp, + // qo_packed: [n_tokens*n_heads, head_dim] + // per warp rows: num_iters_m * 16 + const uint32_t qo_packed_idx_base = + (qo_tile_idx * num_warps_m + get_warp_idx_x()) * + num_iters_m * 16; - load_q_global_smem( + // load q to shared memory once for current wrap + load_q_global_smem( qo_packed_idx_base, qo_upper_bound, q_ptr_base, @@ -974,70 +1041,47 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( q_stride_h, group_size, &qo_smem); - + // [] => [q] cp_async::commit_group(); - cp_async::wait_group<0>(); - block.sync(); - - q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); - if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { -#pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 2; ++j) { - const uint32_t qo_head_idx = - kv_head_idx * group_size + - (qo_packed_idx_base + lane_idx / 4 + j * 8 + fx * 16) % group_size; - alibi_slopes_frag[fx][j] = alibi_slopes[qo_head_idx] * math::log2e; - } - } - } + // load first k/v chunk to shared memory + // calculate q/k/v offsets to reand and write + uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( + get_warp_idx_x() * num_iters_m * 16 + + lane_idx % 16, + lane_idx / 16); - constexpr SwizzleMode swizzle_mode_kv = - (sizeof(DTypeKV) == 1 && head_dim == 64) ? SwizzleMode::k64B - : SwizzleMode::k128B; constexpr uint32_t kv_frag_rows = swizzle_mode_kv == SwizzleMode::k128B ? 4 : 8; constexpr uint32_t kv_frag_cols = swizzle_mode_kv == SwizzleMode::k128B ? 8 : 4; - smem_t k_smem( - smem + (num_warps_x * num_frags_x * sizeof(DTypeQ)) * 16 * head_dim), - v_smem(smem + (num_warps_x * num_frags_x * sizeof(DTypeQ) + - num_warps_z * num_frags_z * sizeof(DTypeKV)) * - 16 * head_dim); - size_t kv_offset[num_frags_z * + size_t kv_offset[num_iters_n * (swizzle_mode_kv == SwizzleMode::k128B ? 4 : 2) / - num_warps_x]; + num_warps_m]; uint32_t k_smem_offset_r = k_smem.get_permuted_offset( - get_warp_idx_z() * num_frags_z * 16 + - 8 * (lane_idx / 16) + lane_idx % 8, - (lane_idx % 16) / 8), - v_smem_offset_r = v_smem.get_permuted_offset( - get_warp_idx_z() * num_frags_z * 16 + - lane_idx % 16, - lane_idx / 16), - kv_smem_offset_w = k_smem.get_permuted_offset( - warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, - lane_idx % kv_frag_cols); + get_warp_idx_z() * num_iters_n * 16 + + 8 * (lane_idx / 16) + lane_idx % 8, + (lane_idx % 16) / 8); + uint32_t v_smem_offset_r = v_smem.get_permuted_offset( + get_warp_idx_z() * num_iters_n * 16 + + lane_idx % 16, + lane_idx / 16); + uint32_t kv_smem_offset_w = k_smem.get_permuted_offset( + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, + lane_idx % kv_frag_cols); // kv_idx of current sequence uint32_t kv_idx_base = chunk_start; #pragma unroll for (uint32_t i = 0; - i < num_frags_z * (swizzle_mode_kv == SwizzleMode::k128B ? 4 : 2) / - num_warps_x; + i < num_iters_n * (swizzle_mode_kv == SwizzleMode::k128B ? 4 : 2) / + num_warps_m; ++i) { const uint32_t kv_idx = kv_idx_base + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols + - kv_frag_rows * num_warps_x * num_warps_z * i; + kv_frag_rows * num_warps_m * num_warps_n * i; const uint32_t feat_idx = (lane_idx % kv_frag_cols) * num_elems_per_128b(); kv_offset[i] = @@ -1046,13 +1090,26 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( : 0; } - page_produce_kv( + page_produce_kv( k_smem, &kv_smem_offset_w, paged_kv, 0, kv_offset, chunk_size); cp_async::commit_group(); - page_produce_kv( + page_produce_kv( v_smem, &kv_smem_offset_w, paged_kv, 0, kv_offset, chunk_size); + // [q] => [q, k, v] cp_async::commit_group(); + // wait for q to be loaded: [q, k, v] => [k, v] + cp_async::wait_group<2>(); + block.sync(); + + // TODO: can we do this in register? + q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); + const uint32_t num_iterations = ceil_div( (mask_mode == MaskMode::kCausal ? min(chunk_size, @@ -1061,12 +1118,12 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( ((qo_tile_idx + 1) * num_rows_per_cta) / group_size, chunk_start)) : chunk_size), - 16 * num_warps_z * num_frags_z); + 16 * num_warps_n * num_iters_n); const uint32_t window_iteration = ceil_div(sub_if_greater_or_zero(kv_len + (bx + 1) * num_rows_per_cta, qo_len + window_left + chunk_start), - (16 * num_warps_z * num_frags_z)); + (16 * num_warps_n * num_iters_n)); const uint32_t mask_iteration = (mask_mode == MaskMode::kCausal @@ -1076,19 +1133,48 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( qo_len, chunk_start)) : chunk_size) / - (16 * num_warps_z * num_frags_z); + (16 * num_warps_n * num_iters_n); + + // alibi slopes for 2 rows 0, 8 + float alibi_slopes_frag[num_iters_m][2]; + if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { +#pragma unroll + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { +#pragma unroll + for (uint32_t j = 0; j < 2; ++j) { + const uint32_t qo_head_idx = + q_head_idx_base + + (qo_packed_idx_base + lane_idx / 4 + j * 8 + fx * 16) % group_size; + alibi_slopes_frag[fx][j] = alibi_slopes[qo_head_idx] * math::log2e; + } + } + } + // define fragments for each thread + DTypeQKAccum s_frag[num_iters_m][num_iters_n][8]; + // regesters hold a whole data in register? necessary? + float o_frag[num_iters_m][num_iters_k][8]; + + // max and sum for 2 rows 0, 8 + DTypeQKAccum m[num_iters_m][2]; + float d[num_iters_m][2]; + // initialize o = 0, m = -5e4, d = 1 + init_states(o_frag, m, d); + + // iterate over kv chunks #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { - kv_idx_base += 16 * num_warps_z * num_frags_z; + kv_idx_base += 16 * num_warps_n * num_iters_n; + + // calculate kv offsets to read before waiting for k ready #pragma unroll for (uint32_t i = 0; - i < num_frags_z * (swizzle_mode_kv == SwizzleMode::k128B ? 4 : 2) / - num_warps_x; + i < num_iters_n * (swizzle_mode_kv == SwizzleMode::k128B ? 4 : 2) / + num_warps_m; ++i) { const uint32_t kv_idx = kv_idx_base + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols + - kv_frag_rows * num_warps_x * num_warps_z * i; + kv_frag_rows * num_warps_m * num_warps_n * i; const uint32_t feat_idx = (lane_idx % kv_frag_cols) * num_elems_per_128b(); kv_offset[i] = kv_idx < kv_len @@ -1096,14 +1182,16 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( request_idx, kv_idx, kv_head_idx, feat_idx) : 0; } + + // wait for k ready: [k, v] => [v] cp_async::wait_group<1>(); block.sync(); // compute attention score compute_qk( + apply_alibi_bias( qo_packed_idx_base, - chunk_start + (iter * num_warps_z + - get_warp_idx_z()) * - num_frags_z * 16, + chunk_start + (iter * num_warps_n + + get_warp_idx_z()) * + num_iters_n * 16, int(kv_len) - int(qo_len), group_size, alibi_slopes_frag, @@ -1127,11 +1215,11 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( } // apply mask if constexpr (mask_mode == MaskMode::kCustom) { - mask_s( + mask_s( qo_packed_idx_base, - chunk_start + (iter * num_warps_z + - get_warp_idx_z()) * - num_frags_z * 16, + chunk_start + (iter * num_warps_n + + get_warp_idx_z()) * + num_iters_n * 16, qo_len, kv_len, window_left, @@ -1141,11 +1229,11 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( s_frag); } else { if (iter >= mask_iteration || iter < window_iteration) { - mask_s( + mask_s( qo_packed_idx_base, - chunk_start + (iter * num_warps_z + - get_warp_idx_z()) * - num_frags_z * 16, + chunk_start + (iter * num_warps_n + + get_warp_idx_z()) * + num_iters_n * 16, qo_len, kv_len, window_left, @@ -1157,59 +1245,94 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( } // compute m,d states in online softmax - update_mdo_states( + update_mdo_states( s_frag, o_frag, m, d); - block.sync(); - page_produce_kv( + + // produce k for next iteration + page_produce_kv( k_smem, &kv_smem_offset_w, paged_kv, - (iter + 1) * 16 * num_warps_z * num_frags_z, + (iter + 1) * 16 * num_warps_n * num_iters_n, kv_offset, chunk_size); + // [v] => [v, k] cp_async::commit_group(); + + // wait for v ready, [v, k] => [k] cp_async::wait_group<1>(); block.sync(); // compute sfm*v - compute_sfm_v(&v_smem, &v_smem_offset_r, s_frag, o_frag, d); - block.sync(); - page_produce_kv( + + // produce v for next iteration + page_produce_kv( v_smem, &kv_smem_offset_w, paged_kv, - (iter + 1) * 16 * num_warps_z * num_frags_z, + (iter + 1) * 16 * num_warps_n * num_iters_n, kv_offset, chunk_size); + + // [k] => [k, v] cp_async::commit_group(); } + + // wait for all data ready cp_async::wait_group<0>(); block.sync(); // threadblock synchronization - threadblock_sync_mdo_states( o_frag, (float*)smem, m, d, warp_idx, lane_idx); // normalize d - normalize_d(o_frag, m, d); + normalize_d(o_frag, m, d); + + // write o from register to shared memory + write_o_reg_smem(o_frag, &qo_smem); + + block.sync(); + // write o from shared memory to global memory const uint32_t num_kv_chunks = (kv_len_safe + kv_chunk_size - 1) / kv_chunk_size; - // write_back - write_o_reg_gmem( - o_frag, + // partition_kv: [n_tokens*n_kv_tiles, n_heads, head_dim] + // !partition_kv: [n_tokens*1, n_heads, head_dim] + DTypeOut* o_ptr_base = partition_kv + ? o + kv_tile_idx * num_qo_heads * head_dim + + get_elem_offset_impl( + /*elem_idx=*/o_indptr[request_idx], + /*head_idx=*/q_head_idx_base, + /*feat_idx=*/(lane_idx % 8) * + num_elems_per_128b(), + num_qo_heads * head_dim, + head_dim) + : o + get_elem_offset_impl( + /*elem_idx=*/o_indptr[request_idx], + /*head_idx=*/q_head_idx_base, + /*feat_idx=*/(lane_idx % 8) * + num_elems_per_128b(), + num_qo_heads * head_dim, + head_dim); + write_o_smem_gmem( &qo_smem, o_ptr_base, qo_packed_idx_base, @@ -1222,23 +1345,27 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void attention_kernel( // write lse if (lse != nullptr) { - if (get_warp_idx_z() == 0) { + if (get_warp_idx_z() == 0) { #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll for (uint32_t j = 0; j < 2; ++j) { uint32_t q, r; group_size.divmod( qo_packed_idx_base + lane_idx / 4 + j * 8 + fx * 16, q, r); - const uint32_t qo_head_idx = kv_head_idx * group_size + r; + const uint32_t qo_head_idx = q_head_idx_base + r; const uint32_t qo_idx = q; if (qo_idx < qo_upper_bound) { if (partition_kv) { + // [n_tokens, n_kv_tiles, n_heads] + // lse = ((o_indptr[request_idx] + qo_idx * num_kv_chunks + + // kv_tile_idx) * num_qo_heads + qo_head_idx) lse[(o_indptr[request_idx] + qo_idx * num_kv_chunks + kv_tile_idx) * num_qo_heads + qo_head_idx] = math::ptx_log2(d[fx][j]) + float(m[fx][j]); } else { + // [n_tokens, n_heads] lse[(o_indptr[request_idx] + qo_idx) * num_qo_heads + qo_head_idx] = math::ptx_log2(d[fx][j]) + float(m[fx][j]); } @@ -1291,9 +1418,9 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q, } #endif - constexpr uint32_t num_frags_x = get_num_frags_x(); - constexpr uint32_t num_warps_x = get_num_warps_x(); - constexpr uint32_t num_warps_z = get_num_warps_z(); + constexpr uint32_t num_iters_m = get_num_frags_x(); + constexpr uint32_t num_warps_m = get_num_warps_x(); + constexpr uint32_t num_warps_n = get_num_warps_z(); const uint32_t group_size = num_qo_heads / num_kv_heads; const uint_fastdiv group_size_fastdiv(group_size); @@ -1304,10 +1431,10 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q, return cudaSuccess; } - dim3 nblks(padded_batch_size, 1, num_kv_heads); - dim3 nthrs(32, num_warps_x, num_warps_z); + dim3 nblks(padded_batch_size, num_kv_heads); + dim3 nthrs(32, num_warps_m, num_warps_n); - constexpr uint32_t num_frags_y = HEAD_DIM / 16; + constexpr uint32_t num_iters_k = HEAD_DIM / 16; using DTypeQKAccum = std::conditional_t, half, @@ -1324,35 +1451,35 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q, max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeQ) * 16) ? 2 : 1; const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; - const uint32_t max_num_frags_z_reg = - (HEAD_DIM >= 128 && num_frags_x == 2 && + const uint32_t max_num_iters_n_reg = + (HEAD_DIM >= 128 && num_iters_m == 2 && pos_encoding_mode == PosEncodingMode::kRoPELlama && !ALLOW_FP16_QK_REDUCTION) ? 2 - : (8 / num_frags_x); + : (8 / num_iters_m); // TODO(Zihao): fix the following computation - const uint32_t max_num_frags_z_smem = + const uint32_t max_num_iters_n_smem = (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - - num_frags_x * num_warps_x) / - (2 * num_warps_z); + num_iters_m * num_warps_m) / + (2 * num_warps_n); DISPATCH_NUM_FRAGS_Z( - min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { + min(max_num_iters_n_smem, max_num_iters_n_reg), num_iters_n, { if constexpr (is_invalid_configuration(num_frags_x, - num_frags_y, - num_frags_z, - num_warps_x, - num_warps_z)) { + DTypeQKAccum>(num_iters_m, + num_iters_k, + num_iters_n, + num_warps_m, + num_warps_n)) { // Invalid configuration, skip std::ostringstream err_msg; err_msg << "FlashInfer Internal Error: Invalid configuration : " - "num_frags_x=" - << num_frags_x << " num_frags_y=" << num_frags_y - << " num_frags_z=" << num_frags_z - << " num_warps_x=" << num_warps_x - << " num_warps_z=" << num_warps_z + "num_iters_m=" + << num_iters_m << " num_iters_k=" << num_iters_k + << " num_iters_n=" << num_iters_n + << " num_warps_m=" << num_warps_m + << " num_warps_n=" << num_warps_n << " please create an issue " "(https://github.com/flashinfer-ai/flashinfer/issues)" " and report the issue to the developers."; @@ -1360,17 +1487,17 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q, } else { // TODO(Zihao): fix the following computation uint32_t smem_size = - (num_frags_x * num_warps_x * sizeof(DTypeQ) + - num_frags_z * num_warps_z * 2 * sizeof(DTypeQ)) * + (num_iters_m * num_warps_m * sizeof(DTypeQ) + + num_iters_n * num_warps_n * 2 * sizeof(DTypeQ)) * 16 * HEAD_DIM; auto kernel = attention_kernelGetTempV(); - tmp_s = handler->GetTempS(); - request_indices = handler->GetRequestIndices(); - qo_tile_indices = handler->GetQOTileIndices(); - kv_tile_indices = handler->GetKVTileIndices(); - block_valid_mask = handler->GetBlockValidMask(); - o_indptr = handler->GetOIndptr(); - merge_indptr = handler->GetMergeIndptr(); - kv_chunk_size_ptr = handler->GetKVChunkSizePtr(); - warp_layout = handler->GetWarpLayout(); - padded_batch_size = handler->GetPaddedBatchSize(); - total_num_rows = handler->GetTotalNumRows(); + DTypeOut* tmp_v = handler->GetTempV(); + float* tmp_s = handler->GetTempS(); + IdType* request_indices = handler->GetRequestIndices(); + IdType* qo_tile_indices = handler->GetQOTileIndices(); + IdType* kv_tile_indices = handler->GetKVTileIndices(); + bool* block_valid_mask = handler->GetBlockValidMask(); + IdType* o_indptr = handler->GetOIndptr(); + IdType* merge_indptr = handler->GetMergeIndptr(); + IdType* kv_chunk_size_ptr = handler->GetKVChunkSizePtr(); + WarpLayout warp_layout = handler->GetWarpLayout(); + uint32_t padded_batch_size = handler->GetPaddedBatchSize(); + uint32_t total_num_rows = handler->GetTotalNumRows(); DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { return mha_varlen_dispatch -#include - #include "handler.h" namespace flashinfer { @@ -12,7 +10,7 @@ namespace flashinfer { class BatchPrefillWrapper { public: BatchPrefillWrapper(bool enable_cuda_graph) - : handler_(std::make_shared( + : handler_(std::make_unique( enable_cuda_graph)) {} void Plan(torch::Tensor float_workspace_buffer, @@ -24,7 +22,8 @@ class BatchPrefillWrapper { unsigned int num_kv_heads, unsigned int head_dim, unsigned page_size, - torch::Tensor empty_q_data); + torch::Tensor empty_q_data, + int32_t num_sm); bool IsCUDAGraphEnabled() const { return handler_->IsCUDAGraphEnabled(); } @@ -43,7 +42,7 @@ class BatchPrefillWrapper { std::optional alibi_slopes); private: - std::shared_ptr handler_; + std::unique_ptr handler_; }; } // namespace flashinfer \ No newline at end of file diff --git a/src/kernels/attention/flash_infer/handler.h b/src/kernels/attention/flash_infer/handler.h index dcb602ac..934f38a7 100644 --- a/src/kernels/attention/flash_infer/handler.h +++ b/src/kernels/attention/flash_infer/handler.h @@ -9,7 +9,7 @@ #include #include #include -#include +#include #include @@ -24,153 +24,166 @@ namespace flashinfer { -inline std::tuple PrefillBinarySearchKVChunkSize( +// binary search to find the smallest kv_chunk_size that can fit into the grid +// returns kv_chunk_size +inline uint32_t search_kv_chunk_size( const uint32_t max_grid_size, const uint32_t num_kv_heads, - const std::vector& packed_qo_len_arr, - const std::vector& kv_len_arr, - const uint32_t qo_chunk_size, - const uint32_t min_kv_chunk_size = 1) { - int64_t low = min_kv_chunk_size, high = 0; - int64_t batch_size = packed_qo_len_arr.size(); - int64_t max_kv_len = 0; - for (const int64_t& kv_len : kv_len_arr) { - max_kv_len = std::max(max_kv_len, kv_len); - } - high = max_kv_len; - int64_t new_batch_size; + const uint32_t min_kv_chunk_size, + const uint32_t max_kv_chunk_size, + const std::function& cal_batch_size) { + int64_t low = min_kv_chunk_size; + int64_t high = max_kv_chunk_size; while (low < high) { - int64_t mid = (low + high) / 2; - new_batch_size = 0; - for (uint32_t i = 0; i < batch_size; ++i) { - new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * - ceil_div(kv_len_arr[i], mid); - } - if (new_batch_size * num_kv_heads > max_grid_size) { + const int64_t mid = (low + high) / 2; + const int64_t batch_size = cal_batch_size(mid); + if (batch_size * num_kv_heads > max_grid_size) { low = mid + 1; } else { high = mid; } } - new_batch_size = 0; - for (uint32_t i = 0; i < batch_size; ++i) { - new_batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * - ceil_div(std::max(int(kv_len_arr[i]), 1), low); - } - return {low < max_kv_len, low, new_batch_size}; + // low holds the smallest kv_chunk_size that can fit into the grid + return low; } template -cudaError_t PrefillSplitQOKVIndptr(bool& split_kv, - uint32_t& split_max_batch_size, - uint32_t& total_num_tiles_q, - uint32_t& new_batch_size, - WarpLayout& warp_layout, - uint32_t& kv_chunk_size, - uint32_t& total_num_rows, - std::vector& request_indices, - std::vector& qo_tile_indices, - std::vector& kv_tile_indices, - std::vector& merge_indptr, - std::vector& o_indptr, - IdType* qo_indptr_h, - IdType* paged_kv_indptr_h, - uint32_t batch_size, - uint32_t num_qo_heads, - uint32_t num_kv_heads, - uint32_t head_dim, - uint32_t page_size) { - request_indices.clear(); - qo_tile_indices.clear(); - kv_tile_indices.clear(); - merge_indptr.clear(); - o_indptr.clear(); - merge_indptr.push_back(0); - o_indptr.push_back(0); +struct SplitParams { + // whether to split kv + bool split_kv; + // the max batch size? + uint32_t split_max_batch_size; + // total number of tiles in qo + uint32_t total_num_tiles_q; + // total number of partitions + uint32_t new_batch_size; + // warp layout + WarpLayout warp_layout; + // kv_chunk_size that can fit into the grid + uint32_t kv_chunk_size; + // total number of rows in qo + uint32_t total_num_rows; + // request idx for each cta + std::vector request_indices; + // qo_tile_idx for each cta + std::vector qo_tile_indices; + // kv_tile_idx for each cta + std::vector kv_tile_indices; + // kv_tile indptr for each row in qo? + std::vector merge_indptr; + // kv_tile indptr for each request + std::vector o_indptr; +}; + +template +SplitParams split_input(IdType* qo_indptr_h, + IdType* paged_kv_indptr_h, + uint32_t batch_size, + uint32_t num_qo_heads, + uint32_t num_kv_heads, + uint32_t head_dim, + uint32_t page_size, + int32_t num_sm) { + SplitParams split_params; const uint32_t gqa_group_size = num_qo_heads / num_kv_heads; - total_num_rows = qo_indptr_h[batch_size]; - - // step 0: get the number of SMs - int num_sm = 0; - int dev_id = 0; - FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); - FLASHINFER_CUDA_CALL( - cudaDeviceGetAttribute(&num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + split_params.total_num_rows = qo_indptr_h[batch_size]; + int num_blocks_per_sm = 2; int max_grid_size = num_blocks_per_sm * num_sm; - split_max_batch_size = max_grid_size / num_kv_heads; + split_params.split_max_batch_size = max_grid_size / num_kv_heads; // step 1: compute qo_chunk_size - std::vector packed_qo_len_arr(batch_size), kv_len_arr(batch_size); + std::vector packed_qo_len_arr(batch_size); + std::vector kv_chunk_len_arr(batch_size); + int64_t max_kv_chunk_len = 0; int64_t sum_packed_qo_len = 0; for (uint32_t i = 0; i < batch_size; ++i) { packed_qo_len_arr[i] = int64_t(qo_indptr_h[i + 1] - qo_indptr_h[i]) * int64_t(gqa_group_size); - kv_len_arr[i] = int64_t(paged_kv_indptr_h[i + 1] - paged_kv_indptr_h[i]); + auto kv_chunk_len = + int64_t(paged_kv_indptr_h[i + 1] - paged_kv_indptr_h[i]); + kv_chunk_len_arr[i] = std::max(kv_chunk_len, 1); + max_kv_chunk_len = std::max(max_kv_chunk_len, kv_chunk_len); sum_packed_qo_len += packed_qo_len_arr[i]; } int64_t avg_packed_qo_len = sum_packed_qo_len / batch_size; + // WarpLayout: (num_warps_x, num_warps_z, num_frags_x) if (avg_packed_qo_len > 64 && head_dim < 256) { - warp_layout = WarpLayout::k4x1x2; // (num_warps_x = 4, num_warps_z = 1, - // num_frags_x = 2) + split_params.warp_layout = WarpLayout::k4x1x2; } else { auto compute_capacity = GetCudaComputeCapability(); if (compute_capacity.first >= 8) { // Ampere or newer if (avg_packed_qo_len > 16) { - warp_layout = WarpLayout::k4x1x1; // (num_warps_x = 4, num_warps_z = 1, - // num_frags_x = 1) + split_params.warp_layout = WarpLayout::k4x1x1; } else { // avg_packed_qo_len <= 16 - warp_layout = WarpLayout::k1x4x1; // (num_warps_x = 1, num_warps_z = 4, - // num_frags_x = 1) + split_params.warp_layout = WarpLayout::k1x4x1; } } else { // NOTE(Zihao): not enough shared memory on Turing for 1x4x1 layout - warp_layout = WarpLayout::k4x1x1; + split_params.warp_layout = WarpLayout::k4x1x1; } } - const uint32_t qo_chunk_size = get_num_rows_per_cta(warp_layout); + const uint32_t qo_chunk_size = get_num_rows_per_cta(split_params.warp_layout); + + // lambda to calculate batch_size given kv_chunk_size + auto cal_batch_size = [&](int64_t kv_chunk_size) -> int64_t { + int64_t batch_size = 0; + for (size_t i = 0; i < packed_qo_len_arr.size(); ++i) { + batch_size += ceil_div(packed_qo_len_arr[i], qo_chunk_size) * + ceil_div(kv_chunk_len_arr[i], kv_chunk_size); + } + return batch_size; + }; + const uint32_t min_kv_chunk_size = std::max((128 / page_size), 1U); // step 2: determine kv_chunk_size - std::tie(split_kv, kv_chunk_size, new_batch_size) = - PrefillBinarySearchKVChunkSize( - max_grid_size, - num_kv_heads, - packed_qo_len_arr, - kv_len_arr, - qo_chunk_size, - /*min_kv_chunk_size=*/std::max((128 / page_size), 1U)); + auto kv_chunk_size = search_kv_chunk_size(max_grid_size, + num_kv_heads, + min_kv_chunk_size, + max_kv_chunk_len, + cal_batch_size); + + split_params.split_kv = kv_chunk_size < max_kv_chunk_len; + split_params.new_batch_size = cal_batch_size(kv_chunk_size); // step 3: split qo_indptr and kv_indptr - total_num_tiles_q = 0; + split_params.merge_indptr.push_back(0); + split_params.o_indptr.push_back(0); + split_params.total_num_tiles_q = 0; for (uint32_t request_idx = 0; request_idx < batch_size; ++request_idx) { - int64_t packed_qo_len = packed_qo_len_arr[request_idx], - kv_len = std::max(int(kv_len_arr[request_idx]), 1); - int64_t num_tiles_q = ceil_div(packed_qo_len, qo_chunk_size), - num_tiles_kv = ceil_div(kv_len, kv_chunk_size); - total_num_tiles_q += num_tiles_q; + int64_t packed_qo_len = packed_qo_len_arr[request_idx]; + int64_t kv_len = std::max(int(kv_chunk_len_arr[request_idx]), 1); + int64_t num_tiles_q = ceil_div(packed_qo_len, qo_chunk_size); + int64_t num_tiles_kv = ceil_div(kv_len, kv_chunk_size); + split_params.total_num_tiles_q += num_tiles_q; + for (uint32_t q_tile_idx = 0; q_tile_idx < num_tiles_q; ++q_tile_idx) { for (uint32_t kv_tile_idx = 0; kv_tile_idx < num_tiles_kv; ++kv_tile_idx) { - request_indices.push_back(request_idx); - qo_tile_indices.push_back(q_tile_idx); - kv_tile_indices.push_back(kv_tile_idx); + split_params.request_indices.push_back(request_idx); + split_params.qo_tile_indices.push_back(q_tile_idx); + split_params.kv_tile_indices.push_back(kv_tile_idx); } } int64_t qo_len = packed_qo_len / gqa_group_size; for (uint32_t row = 0; row < qo_len; ++row) { - merge_indptr.push_back(merge_indptr.back() + num_tiles_kv); + // start index of flattened kv for each token + split_params.merge_indptr.push_back(split_params.merge_indptr.back() + + num_tiles_kv); } - o_indptr.push_back(o_indptr.back() + qo_len * num_tiles_kv); + // start index of flattened kv each sequence + split_params.o_indptr.push_back(split_params.o_indptr.back() + + qo_len * num_tiles_kv); } - // step 4: multiply kv_chunk_size by page_size - kv_chunk_size *= page_size; + // step 4: multiply kv_chunk_size by page_size to get kv length per chunk + split_params.kv_chunk_size = kv_chunk_size * page_size; - return cudaSuccess; + return split_params; } class BatchPrefillHandler { @@ -236,7 +249,8 @@ class BatchPrefillHandler { uint32_t num_qo_heads, uint32_t num_kv_heads, uint32_t head_dim, - uint32_t page_size) { + uint32_t page_size, + int32_t num_sm) { Clear(); if (num_qo_heads % num_kv_heads != 0) { std::ostringstream err_msg; @@ -244,34 +258,30 @@ class BatchPrefillHandler { << " should be divisible by num_kv_heads " << num_kv_heads; throw std::invalid_argument(err_msg.str()); } - bool split_kv; - uint32_t split_max_batch_size, new_batch_size, total_num_tiles_q, - kv_chunk_size; - std::vector request_indices_vec, qo_tile_indices_vec, - kv_tile_indices_vec, merge_indptr_vec, o_indptr_vec; - FLASHINFER_CUDA_CALL(PrefillSplitQOKVIndptr(split_kv, - split_max_batch_size, - total_num_tiles_q, - new_batch_size, - warp_layout_, - kv_chunk_size, - total_num_rows_, - request_indices_vec, - qo_tile_indices_vec, - kv_tile_indices_vec, - merge_indptr_vec, - o_indptr_vec, - qo_indptr_h, - paged_kv_indptr_h, - batch_size, - num_qo_heads, - num_kv_heads, - head_dim, - page_size)); + + if (num_sm <= 0) { + int dev_id = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute( + &num_sm, cudaDevAttrMultiProcessorCount, dev_id)); + } + + const auto split_params = split_input(qo_indptr_h, + paged_kv_indptr_h, + batch_size, + num_qo_heads, + num_kv_heads, + head_dim, + page_size, + num_sm); + warp_layout_ = split_params.warp_layout; + total_num_rows_ = split_params.total_num_rows; + const uint32_t qo_tile_size = get_num_rows_per_cta(warp_layout_); if (IsCUDAGraphEnabled()) { - padded_batch_size_ = std::max(split_max_batch_size, total_num_tiles_q); + padded_batch_size_ = std::max(split_params.split_max_batch_size, + split_params.total_num_tiles_q); AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); request_indices_ = int_allocator.aligned_alloc(sizeof(IdType) * padded_batch_size_, @@ -301,8 +311,8 @@ class BatchPrefillHandler { void* kv_chunk_size_ptr_h_ = (char*)page_locked_buffer_ + ((char*)kv_chunk_size_ptr_ - (char*)request_indices_); - *(IdType*)kv_chunk_size_ptr_h_ = kv_chunk_size; - if (total_num_tiles_q < split_max_batch_size) { + *(IdType*)kv_chunk_size_ptr_h_ = split_params.kv_chunk_size; + if (split_params.total_num_tiles_q < split_params.split_max_batch_size) { // need merge_indptr merge_indptr_ = int_allocator.aligned_alloc( sizeof(IdType) * (total_num_rows_ + 1), @@ -311,8 +321,8 @@ class BatchPrefillHandler { void* merge_indptr_h_ = (char*)page_locked_buffer_ + ((char*)merge_indptr_ - (char*)request_indices_); - std::copy(merge_indptr_vec.begin(), - merge_indptr_vec.end(), + std::copy(split_params.merge_indptr.begin(), + split_params.merge_indptr.end(), (IdType*)merge_indptr_h_); block_valid_mask_ = int_allocator.aligned_alloc(sizeof(bool) * padded_batch_size_, @@ -322,7 +332,7 @@ class BatchPrefillHandler { (bool*)page_locked_buffer_ + ((bool*)block_valid_mask_ - (bool*)request_indices_); for (uint32_t i = 0; i < padded_batch_size_; ++i) { - block_valid_mask_h_[i] = i < new_batch_size; + block_valid_mask_h_[i] = i < split_params.new_batch_size; } } else { // total_num_tiles_q >= split_max_batch_size, we don't need to perform @@ -330,16 +340,18 @@ class BatchPrefillHandler { merge_indptr_ = nullptr; block_valid_mask_ = nullptr; } - std::copy(request_indices_vec.begin(), - request_indices_vec.end(), + std::copy(split_params.request_indices.begin(), + split_params.request_indices.end(), (IdType*)request_indices_h_); - std::copy(qo_tile_indices_vec.begin(), - qo_tile_indices_vec.end(), + std::copy(split_params.qo_tile_indices.begin(), + split_params.qo_tile_indices.end(), (IdType*)qo_tile_indices_h_); - std::copy(kv_tile_indices_vec.begin(), - kv_tile_indices_vec.end(), + std::copy(split_params.kv_tile_indices.begin(), + split_params.kv_tile_indices.end(), (IdType*)kv_tile_indices_h_); - std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), (IdType*)o_indptr_h_); + std::copy(split_params.o_indptr.begin(), + split_params.o_indptr.end(), + (IdType*)o_indptr_h_); size_t num_bytes_to_copy = (char*)int_allocator.ptr - (char*)request_indices_; @@ -349,16 +361,17 @@ class BatchPrefillHandler { cudaMemcpyHostToDevice, stream_)) - if (total_num_tiles_q < split_max_batch_size) { + if (split_params.total_num_tiles_q < split_params.split_max_batch_size) { AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); tmp_v_ = float_allocator.aligned_alloc( - num_qo_heads * split_max_batch_size * qo_tile_size * head_dim * - sizeof(DTypeOut), + num_qo_heads * split_params.split_max_batch_size * qo_tile_size * + head_dim * sizeof(DTypeOut), 16, "batch_prefill_tmp_v"); tmp_s_ = float_allocator.aligned_alloc( - num_qo_heads * split_max_batch_size * qo_tile_size * sizeof(float), + num_qo_heads * split_params.split_max_batch_size * qo_tile_size * + sizeof(float), 16, "batch_prefill_tmp_s"); } else { @@ -366,42 +379,44 @@ class BatchPrefillHandler { tmp_s_ = nullptr; } } else { - padded_batch_size_ = new_batch_size; + padded_batch_size_ = split_params.new_batch_size; AlignedAllocator int_allocator(int_buffer, int_workspace_size_in_bytes); request_indices_ = int_allocator.aligned_alloc( - sizeof(IdType) * request_indices_vec.size(), + sizeof(IdType) * split_params.request_indices.size(), 16, "batch_prefill_request_indices"); void* request_indices_h_ = page_locked_buffer_; qo_tile_indices_ = int_allocator.aligned_alloc( - sizeof(IdType) * qo_tile_indices_vec.size(), + sizeof(IdType) * split_params.qo_tile_indices.size(), 16, "batch_prefill_qo_tile_indices"); void* qo_tile_indices_h_ = (char*)page_locked_buffer_ + ((char*)qo_tile_indices_ - (char*)request_indices_); kv_tile_indices_ = int_allocator.aligned_alloc( - sizeof(IdType) * kv_tile_indices_vec.size(), + sizeof(IdType) * split_params.kv_tile_indices.size(), 16, "batch_prefill_kv_tile_indices"); void* kv_tile_indices_h_ = (char*)page_locked_buffer_ + ((char*)kv_tile_indices_ - (char*)request_indices_); - if (split_kv) { + if (split_params.split_kv) { // need merge_indptr when split_kv is true merge_indptr_ = int_allocator.aligned_alloc( - sizeof(IdType) * merge_indptr_vec.size(), + sizeof(IdType) * split_params.merge_indptr.size(), 16, "batch_prefill_merge_indptr"); void* merge_indptr_h_ = (char*)page_locked_buffer_ + ((char*)merge_indptr_ - (char*)request_indices_); - std::copy(merge_indptr_vec.begin(), - merge_indptr_vec.end(), + std::copy(split_params.merge_indptr.begin(), + split_params.merge_indptr.end(), (IdType*)merge_indptr_h_); } o_indptr_ = int_allocator.aligned_alloc( - sizeof(IdType) * o_indptr_vec.size(), 16, "batch_prefill_o_indptr"); + sizeof(IdType) * split_params.o_indptr.size(), + 16, + "batch_prefill_o_indptr"); void* o_indptr_h_ = (char*)page_locked_buffer_ + ((char*)o_indptr_ - (char*)request_indices_); kv_chunk_size_ptr_ = int_allocator.aligned_alloc( @@ -409,17 +424,19 @@ class BatchPrefillHandler { void* kv_chunk_size_ptr_h_ = (char*)page_locked_buffer_ + ((char*)kv_chunk_size_ptr_ - (char*)request_indices_); - *(IdType*)kv_chunk_size_ptr_h_ = kv_chunk_size; - std::copy(request_indices_vec.begin(), - request_indices_vec.end(), + *(IdType*)kv_chunk_size_ptr_h_ = split_params.kv_chunk_size; + std::copy(split_params.request_indices.begin(), + split_params.request_indices.end(), (IdType*)request_indices_h_); - std::copy(qo_tile_indices_vec.begin(), - qo_tile_indices_vec.end(), + std::copy(split_params.qo_tile_indices.begin(), + split_params.qo_tile_indices.end(), (IdType*)qo_tile_indices_h_); - std::copy(kv_tile_indices_vec.begin(), - kv_tile_indices_vec.end(), + std::copy(split_params.kv_tile_indices.begin(), + split_params.kv_tile_indices.end(), (IdType*)kv_tile_indices_h_); - std::copy(o_indptr_vec.begin(), o_indptr_vec.end(), (IdType*)o_indptr_h_); + std::copy(split_params.o_indptr.begin(), + split_params.o_indptr.end(), + (IdType*)o_indptr_h_); size_t num_bytes_to_copy = (char*)int_allocator.ptr - (char*)request_indices_; @@ -429,16 +446,23 @@ class BatchPrefillHandler { cudaMemcpyHostToDevice, stream_)) - if (split_kv) { + if (split_params.split_kv) { AlignedAllocator float_allocator(float_buffer, float_workspace_size_in_bytes); + // [n_kv_tiles, n_tokens, n_heads, head_dim] + // new_batch_size = n_kv_tiles * n_q_tiles + // n_tokens = n_q_tiles * qo_tile_size + // n_kv_tiles * n_tokens = new_batch_size * qo_tile_size tmp_v_ = float_allocator.aligned_alloc( - num_qo_heads * new_batch_size * qo_tile_size * head_dim * - sizeof(DTypeOut), + split_params.new_batch_size * qo_tile_size * num_qo_heads * + head_dim * sizeof(DTypeOut), 16, "batch_prefill_tmp_v"); + + // [n_kv_tiles, n_tokens, n_heads] tmp_s_ = float_allocator.aligned_alloc( - num_qo_heads * new_batch_size * qo_tile_size * sizeof(float), + split_params.new_batch_size * qo_tile_size * num_qo_heads * + sizeof(float), 16, "batch_prefill_tmp_s"); } else { @@ -472,41 +496,29 @@ class BatchPrefillHandler { bool IsCUDAGraphEnabled() const { return enable_cuda_graph_; } - BatchPrefillHandler(bool enable_cuda_graph = false) - : request_indices_(nullptr), - qo_tile_indices_(nullptr), - kv_tile_indices_(nullptr), - merge_indptr_(nullptr), - o_indptr_(nullptr), - kv_chunk_size_ptr_(nullptr), - tmp_v_(nullptr), - tmp_s_(nullptr), - block_valid_mask_(nullptr), - total_num_rows_(0U), - padded_batch_size_(0U), - warp_layout_(WarpLayout::k4x1x2), - enable_cuda_graph_(enable_cuda_graph), - stream_(nullptr) { + BatchPrefillHandler(bool enable_cuda_graph) + : enable_cuda_graph_(enable_cuda_graph) { cudaMallocHost(&page_locked_buffer_, 8 * 1024 * 1024); } + ~BatchPrefillHandler() { cudaFreeHost(page_locked_buffer_); } protected: - void* page_locked_buffer_; - void* request_indices_; - void* qo_tile_indices_; - void* kv_tile_indices_; - void* merge_indptr_; - void* o_indptr_; - void* kv_chunk_size_ptr_; - void* tmp_v_; - float* tmp_s_; - bool* block_valid_mask_; - uint32_t total_num_rows_; - uint32_t padded_batch_size_; - WarpLayout warp_layout_; - bool enable_cuda_graph_; - cudaStream_t stream_; + void* page_locked_buffer_ = nullptr; + void* request_indices_ = nullptr; + void* qo_tile_indices_ = nullptr; + void* kv_tile_indices_ = nullptr; + void* merge_indptr_ = nullptr; + void* o_indptr_ = nullptr; + void* kv_chunk_size_ptr_ = nullptr; + void* tmp_v_ = nullptr; + float* tmp_s_ = nullptr; + bool* block_valid_mask_ = nullptr; + uint32_t total_num_rows_ = 0; + uint32_t padded_batch_size_ = 0; + WarpLayout warp_layout_ = WarpLayout::k4x1x2; + bool enable_cuda_graph_ = false; + cudaStream_t stream_ = nullptr; }; } // namespace flashinfer diff --git a/src/kernels/attention/flash_infer/state_merge_kernel.h b/src/kernels/attention/flash_infer/state_merge_kernel.h new file mode 100644 index 00000000..e7148583 --- /dev/null +++ b/src/kernels/attention/flash_infer/state_merge_kernel.h @@ -0,0 +1,221 @@ +// Adapted from https://github.com/flashinfer-ai/flashinfer/ +#pragma once + +#include + +#include +#include +#include +#include + +namespace flashinfer { + +template +__device__ __forceinline__ void threadblock_sync_state(state_t& st, + DTypeIn* v_smem, + float* s_smem) { + const uint32_t tx = threadIdx.x, ty = threadIdx.y; + constexpr uint32_t head_dim = vec_size * bdx; + // [bdy, head_dim] + st.o.cast_store(v_smem + ty * head_dim + tx * vec_size); + // [bdy] + s_smem[ty] = st.get_lse(); + st.init(); + __syncthreads(); + +#pragma unroll + for (uint32_t iter = 0; iter < bdy; ++iter) { + float s = s_smem[iter]; + vec_t v; + v.cast_load(v_smem + iter * head_dim + tx * vec_size); + st.merge(v, s, 1); + } +} + +template +__global__ void PersistentVariableLengthMergeStatesKernel( + DTypeIn* __restrict__ V, + float* __restrict__ S, + IdType* indptr, + DTypeOut* __restrict__ v_merged, + float* __restrict__ s_merged, + uint32_t seq_len, + uint32_t num_heads) { + using cp_async::PrefetchMode; + using cp_async::SharedMemFillMode; + + uint32_t tx = threadIdx.x, ty = threadIdx.y; + uint32_t cta_id = blockIdx.x; + uint32_t num_ctas = gridDim.x; + uint32_t num_iters = ceil_div(seq_len * num_heads, num_ctas); + constexpr uint32_t vec_bits = sizeof(DTypeIn) * vec_size * 8; + constexpr uint32_t head_dim = vec_size * bdx; + extern __shared__ uint8_t smem[]; + // [n_stages, bdy, head_dim] + DTypeIn* v_smem = (DTypeIn*)smem; + // [n_stages, bdy] + float* s_smem = + (float*)(smem + num_smem_stages * bdy * head_dim * sizeof(DTypeIn)); + +#pragma unroll 1 + for (uint32_t i = cta_id; i < seq_len * num_heads; i += num_ctas) { + // token position + uint32_t pos = i / num_heads; + uint32_t head_idx = i % num_heads; + state_t st; + // is it possible that num_index_sets == 0? + const uint32_t num_index_sets = indptr[pos + 1] - indptr[pos]; + + if (num_index_sets == 0) { + // fill with zeros + vec_t v; + v.fill(DTypeOut(0.f)); + v.store(v_merged + (pos * num_heads + head_idx) * head_dim + + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = -5e4; + } + continue; + } + + if (num_index_sets == 1) { + // copy over without merging + vec_t v; + v.cast_load(V + (indptr[pos] * num_heads + head_idx) * head_dim + + tx * vec_size); + v.store(v_merged + (pos * num_heads + head_idx) * head_dim + + tx * vec_size); + if (s_merged != nullptr) { + s_merged[pos * num_heads + head_idx] = + S[indptr[pos] * num_heads + head_idx]; + } + continue; + } + +#pragma unroll + for (uint32_t iter = 0; iter < num_smem_stages; ++iter) { + cp_async::pred_load( + v_smem + (iter * bdy + ty) * head_dim + tx * vec_size, + V + + ((indptr[pos] + (iter * bdy + ty)) * num_heads + head_idx) * + head_dim + + tx * vec_size, + (iter * bdy + ty) < num_index_sets); + cp_async::commit_group(); + } +#pragma unroll 4 + for (uint32_t iter = 0; iter < ceil_div(num_index_sets, bdy); ++iter) { + if (iter % bdx == 0) { + s_smem[ty * bdx + tx] = + iter * bdy + (ty * bdx + tx) < num_index_sets + ? S[(indptr[pos] + (iter * bdy + ty * bdx + tx)) * num_heads + + head_idx] + : 0.f; + __syncthreads(); + } + cp_async::wait_group(); + __syncthreads(); + + vec_t v; + v.cast_load(v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + + tx * vec_size); + if (iter * bdy + ty < num_index_sets) { + float s = s_smem[(iter % bdx) * bdy + ty]; + st.merge(v, s, 1); + } + + // wait for all threads to finish before prefetching the next stage + __syncthreads(); + + cp_async::pred_load( + v_smem + ((iter % num_smem_stages) * bdy + ty) * head_dim + + tx * vec_size, + V + + ((indptr[pos] + ((iter + num_smem_stages) * bdy + ty)) * + num_heads + + head_idx) * + head_dim + + tx * vec_size, + (iter + num_smem_stages) * bdy + ty < num_index_sets); + cp_async::commit_group(); + } + cp_async::wait_group<0>(); + __syncthreads(); + + st.normalize(); + // synchronize st within the threadblock by reusing the shared memory + threadblock_sync_state(st, v_smem, s_smem); + st.normalize(); + + // write back the merged state and lse if needed + // v_merged: [n_tokens, n_heads, head_dim] + st.o.cast_store(v_merged + (pos * num_heads + head_idx) * head_dim + + tx * vec_size); + if (s_merged != nullptr) { + // s_merged: [n_tokens, n_heads] + s_merged[pos * num_heads + head_idx] = st.get_lse(); + } + } +} + +template +cudaError_t VariableLengthMergeStates(DTypeIn* v, + float* s, + IdType* indptr, + DTypeOut* v_merged, + float* s_merged, + uint32_t seq_len, + uint32_t num_heads, + uint32_t head_dim, + cudaStream_t stream = nullptr) { + int dev_id = 0; + int num_sms = 0; + int num_blocks_per_sm = 0; + FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); + FLASHINFER_CUDA_CALL( + cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); + + DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { + constexpr uint32_t vec_size = + std::max(16U / sizeof(DTypeIn), HEAD_DIM / 32U); + constexpr uint32_t bdx = HEAD_DIM / vec_size; + constexpr uint32_t num_threads = 128; + constexpr uint32_t bdy = num_threads / bdx; + constexpr uint32_t num_smem_stages = 4; + uint32_t smem_size = num_smem_stages * bdy * head_dim * sizeof(DTypeIn) + + num_threads * sizeof(float); + auto kernel = PersistentVariableLengthMergeStatesKernel; + FLASHINFER_CUDA_CALL(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, kernel, num_threads, smem_size)); + num_blocks_per_sm = + min(num_blocks_per_sm, ceil_div(seq_len * num_heads, num_sms)); + + dim3 nblks(num_sms * num_blocks_per_sm); + dim3 nthrs(bdx, bdy); + void* args[] = { + &v, &s, &indptr, &v_merged, &s_merged, &seq_len, &num_heads}; + FLASHINFER_CUDA_CALL(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + FLASHINFER_CUDA_CALL( + cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + }); + return cudaSuccess; +} + +} // namespace flashinfer diff --git a/tests/kernels/attention/flash_infer_kv_fp8_test.py b/tests/kernels/attention/flash_infer_kv_fp8_test.py index b64e6bb5..25d0a7d0 100644 --- a/tests/kernels/attention/flash_infer_kv_fp8_test.py +++ b/tests/kernels/attention/flash_infer_kv_fp8_test.py @@ -84,6 +84,7 @@ def test_flashinfer_varlen_masked_self_attention_fp8_kv( empty_q_data = torch.empty(0, dtype=dtype) + num_sm = -1 wrapper.plan( float_workspace_buffer, int_workspace_buffer, @@ -95,6 +96,7 @@ def test_flashinfer_varlen_masked_self_attention_fp8_kv( head_size, block_size, empty_q_data, + num_sm, ) alibi_slopes = torch.randn(n_heads, dtype=torch.float32) if alibi else None diff --git a/tests/kernels/attention/flash_infer_test.py b/tests/kernels/attention/flash_infer_test.py index d57495b3..2a27042e 100644 --- a/tests/kernels/attention/flash_infer_test.py +++ b/tests/kernels/attention/flash_infer_test.py @@ -78,6 +78,7 @@ def test_flashinfer_varlen_masked_self_attention( empty_q_data = torch.empty(0, dtype=dtype) + num_sm = -1 wrapper.plan( float_workspace_buffer, int_workspace_buffer, @@ -89,6 +90,7 @@ def test_flashinfer_varlen_masked_self_attention( head_size, block_size, empty_q_data, + num_sm, ) alibi_slopes = torch.randn(n_heads, dtype=torch.float32) if alibi else None @@ -120,7 +122,7 @@ def test_flashinfer_varlen_masked_self_attention( alibi_slopes=alibi_slopes, ) - if alibi and dtype == torch.bfloat16: + if alibi or dtype == torch.bfloat16: torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2) else: torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2) diff --git a/tests/kernels/attention/ref_attention.py b/tests/kernels/attention/ref_attention.py index 2dd1c829..7848d82f 100644 --- a/tests/kernels/attention/ref_attention.py +++ b/tests/kernels/attention/ref_attention.py @@ -27,7 +27,7 @@ def masked_self_attention( scores += alibi_bias # apply mask - scores.masked_fill_(mask == 0, float("-inf")) + scores.masked_fill_(mask == 0, -5e4) # softmax => [n_heads, q_len, kv_len] scores = torch.softmax(scores, dim=-1)