From 91f7c711177ac1e9dea436ff967ef49233e9bac5 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Fri, 20 Sep 2024 19:50:26 -0400 Subject: [PATCH] use all warps to write o --- .../attention/flash_infer/attention_kernel.h | 79 ++++++++++--------- 1 file changed, 43 insertions(+), 36 deletions(-) diff --git a/src/kernels/attention/flash_infer/attention_kernel.h b/src/kernels/attention/flash_infer/attention_kernel.h index 2e152af1..2a90c7da 100644 --- a/src/kernels/attention/flash_infer/attention_kernel.h +++ b/src/kernels/attention/flash_infer/attention_kernel.h @@ -200,7 +200,7 @@ __device__ __forceinline__ void load_q_global_smem( // | t24 | t25 | t26 | t27 | t28 | t29 | t30 | t31 | // - // q_smem: [num_iters_m, num_warps_m, 16, head_dim] + // q_smem: [num_warps_m, num_iters_m, 16, head_dim] uint32_t q_smem_x = warp_idx_x * num_iters_m * 16 + lane_idx / 8; #pragma unroll @@ -212,9 +212,9 @@ __device__ __forceinline__ void load_q_global_smem( packed_offset + fx * 16 + lane_idx / 8 + j * 4; uint32_t q, r; group_size.divmod(packed_q_idx, q, r); - // q_idx = packed_q_idx / group_size - // h_idx = packed_q_idx % group_size - const uint32_t q_idx = q; + if (q >= qo_upper_bound) { + continue; + } // q_ptr_base: [n_tokens, n_heads, head_dim] // q_ptr for given header: [head_dim] DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h + @@ -223,24 +223,18 @@ __device__ __forceinline__ void load_q_global_smem( // load head_dim from global memory to shared memory using num_warps_n // warps echa warp loads 8 columns once uint32_t q_smem_y = warp_idx_z * 8 + lane_idx % 8; -#pragma unroll - // for (uint32_t fyo = 0; fyo < num_iters_k / 4 / num_warps_n ; ++fyo) { - while(q_smem_y * num_elems_per_128b() < head_dim) { + while (q_smem_y * num_elems_per_128b() < head_dim) { const uint32_t q_smem_offset_w = q_smem->template get_permuted_offset(q_smem_x, q_smem_y); // load q fragment from gmem to smem - q_smem->load_128b_async( - q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); - // move ahead by 8 int128_t for each warp + q_smem->load_128b_async(q_smem_offset_w, q_ptr); + // move ahead by 8*int128_t for each warp q_smem_y += (8 * num_warps_n); - - // move ahead by 8 * 16 bytes q_ptr += (8 * num_elems_per_128b() * num_warps_n); } - // adjust offset for next iteration // move ahead by 4 rows q_smem_x += 4; } @@ -810,9 +804,13 @@ __device__ __forceinline__ void write_o_reg_gmem( constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); const uint32_t warp_idx_x = get_warp_idx_x(); + const uint32_t warp_idx_z = get_warp_idx_z(); const uint32_t lane_idx = threadIdx.x; + // o_frag: [num_iters_m, num_iters_k, 8] if (get_warp_idx_z() == 0) { + // write o from register to shared memory + // why not every thread writes to shared memory? #pragma unroll for (uint32_t fx = 0; fx < num_iters_m; ++fx) { #pragma unroll @@ -832,6 +830,7 @@ __device__ __forceinline__ void write_o_reg_gmem( (warp_idx_x * num_iters_m + fx) * 16 + lane_idx / 4, fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; + // TODO: avoid manipulating permuted offset directly ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * channel_size_128b_out))[lane_idx % 4] = o_frag_f16[1]; ((uint32_t*)(o_smem->base + (o_smem_offset_w ^ 0x1)))[lane_idx % 4] = @@ -841,34 +840,42 @@ __device__ __forceinline__ void write_o_reg_gmem( #endif } } + } - uint32_t o_smem_offset_w = - o_smem->get_permuted_offset( - warp_idx_x * num_iters_m * 16 + lane_idx / 8, lane_idx % 8); + // write o from shared memory to global memory + // o_smem: [num_warps_m, num_iters_m, 16, head_dim] + uint32_t o_smem_x = warp_idx_x * num_iters_m * 16 + lane_idx / 8; #pragma unroll - for (uint32_t fx = 0; fx < num_iters_m; ++fx) { -#pragma unroll - for (uint32_t j = 0; j < 4; ++j) { - uint32_t q, r; - group_size.divmod( - o_packed_idx_base + lane_idx / 8 + fx * 16 + j * 4, q, r); - const uint32_t o_idx = q; - DTypeOut* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h; + for (uint32_t fx = 0; fx < num_iters_m; ++fx) { + // each wrap writes 4 rows, 16 rows needs 4(16/4) iters #pragma unroll - for (uint32_t fyo = 0; fyo < num_iters_k / 4; ++fyo) { - if (o_idx < qo_upper_bound) { - o_smem->store_128b(o_smem_offset_w, o_ptr); - } - o_ptr += 8 * num_elems_per_128b(); - o_smem_offset_w = o_smem->template advance_offset_by_column<8>( - o_smem_offset_w, fyo); - } - o_smem_offset_w = - o_smem->template advance_offset_by_row<4, channel_size_128b_out>( - o_smem_offset_w) - - 2 * num_iters_k; + for (uint32_t j = 0; j < 4; ++j) { + const uint32_t packed_o_idx = + o_packed_idx_base + lane_idx / 8 + fx * 16 + j * 4; + uint32_t q, r; + group_size.divmod(packed_o_idx, q, r); + // skip if out of boundary + if (q >= qo_upper_bound) { + continue; + } + + DTypeOut* o_ptr = o_ptr_base + q * o_stride_n + r * o_stride_h + + warp_idx_z * 8 * num_elems_per_128b(); + uint32_t o_smem_y = warp_idx_z * 8 + lane_idx % 8; + // write head_dim from shared memory to global memory + while (o_smem_y * num_elems_per_128b() < head_dim) { + const uint32_t o_smem_offset_w = + o_smem->template get_permuted_offset( + o_smem_x, o_smem_y); + o_smem->store_128b(o_smem_offset_w, o_ptr); + + // move ahead by 8 * int128_t for each warp + o_smem_y += (8 * num_warps_n); + o_ptr += 8 * num_elems_per_128b() * num_warps_n; } + // move row by 4 + o_smem_x += 4; } } }