Skip to content

Commit

Permalink
use all warps to load q
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Sep 20, 2024
1 parent 998fbab commit 2cdb3cb
Showing 1 changed file with 57 additions and 56 deletions.
113 changes: 57 additions & 56 deletions src/kernels/attention/flash_infer/attention_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ __device__ __forceinline__ void init_states(float (*o_frag)[num_iters_k][8],
for (uint32_t fy = 0; fy < num_iters_k; ++fy) {
#pragma unroll
for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) {
// o_frag: [num_iters_m, num_iters_k, 8]
o_frag[fx][fy][reg_id] = 0.f;
}
}
Expand All @@ -163,6 +164,7 @@ __device__ __forceinline__ void init_states(float (*o_frag)[num_iters_k][8],
for (uint32_t fx = 0; fx < num_iters_m; ++fx) {
#pragma unroll
for (uint32_t j = 0; j < 2; ++j) {
// m/d: [num_iters_m, 2]
m[fx][j] = DTypeQKAccum(-5e4);
d[fx][j] = 1.f;
}
Expand All @@ -188,62 +190,59 @@ __device__ __forceinline__ void load_q_global_smem(
head_dim / num_elems_per_128b<DTypeQ>();
const uint32_t lane_idx = threadIdx.x;
const uint32_t warp_idx_x = get_warp_idx_x<num_warps_m, num_warps_n>();
// only let first column warps load q
// TODO: let all warps load q
if (get_warp_idx_z<num_warps_m, num_warps_n>() == 0) {
// rows to load: num_iters_m * 16
// threads layout in warp: 4 x 8
// | t0 | t1 | t2 | t3 | t4 | t5 | t6 | t7 |
// | t8 | t9 | t10 | t11 | t12 | t13 | t14 | t15 |
// | t16 | t17 | t18 | t19 | t20 | t21 | t22 | t23 |
// | t24 | t25 | t26 | t27 | t28 | t29 | t30 | t31 |
//
const uint32_t warp_idx_z = get_warp_idx_z<num_warps_m, num_warps_n>();
// let all warps load q
// rows to load: num_iters_m * 16
// threads layout in warp: 4 x 8
// | t0 | t1 | t2 | t3 | t4 | t5 | t6 | t7 |
// | t8 | t9 | t10 | t11 | t12 | t13 | t14 | t15 |
// | t16 | t17 | t18 | t19 | t20 | t21 | t22 | t23 |
// | t24 | t25 | t26 | t27 | t28 | t29 | t30 | t31 |
//

// q_smem: [num_iters_m, num_warps_m, 16, head_dim]
uint32_t q_smem_x = warp_idx_x * num_iters_m * 16 + lane_idx / 8;
uint32_t q_smem_y = lane_idx % 8;
// q_smem: [num_iters_m, num_warps_m, 16, head_dim]
uint32_t q_smem_x = warp_idx_x * num_iters_m * 16 + lane_idx / 8;

#pragma unroll
for (uint32_t fx = 0; fx < num_iters_m; ++fx) {
// each wrap loads 4 rows, loading 16 rows needs 4(16/4) iters
#pragma unroll
for (uint32_t j = 0; j < 4; ++j) {
const uint32_t packed_q_idx =
packed_offset + fx * 16 + lane_idx / 8 + j * 4;
uint32_t q, r;
group_size.divmod(packed_q_idx, q, r);
// q_idx = packed_q_idx / group_size
// h_idx = packed_q_idx % group_size
const uint32_t q_idx = q;
// q_ptr_base: [n_tokens, n_heads, head_dim]
// q_ptr for given header: [head_dim]
DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h;

// load head_dim from global memory to shared memory using 8 threads
// 8 threads load 8 * 16 bytes columns once
// iters: head_dim * 2 / (8 * 16) = head_dim / 16 / 4 = num_iters_k / 4
#pragma unroll
for (uint32_t fyo = 0; fyo < num_iters_k / 4; ++fyo) {
const uint32_t q_smem_offset_w =
q_smem->get_permuted_offset<channel_size_128b_q>(q_smem_x,
q_smem_y);

// load q fragment from gmem to smem
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(
q_smem_offset_w, q_ptr, q_idx < qo_upper_bound);
// move ahead by 8 int128_t
q_smem_y += 8;

// move ahead by 8 * 8 items
q_ptr += 8 * num_elems_per_128b<DTypeQ>();
}

// adjust offset for next iteration
// move row by 4
// move columns by -num_iters_k / 4 * 8 = -num_iters_k * 2
q_smem_x += 4;
q_smem_y -= num_iters_k * 2;
for (uint32_t fx = 0; fx < num_iters_m; ++fx) {
// each wrap loads 4 rows, loading 16 rows needs 4 iters
#pragma unroll
for (uint32_t j = 0; j < 4; ++j) {
const uint32_t packed_q_idx =
packed_offset + fx * 16 + lane_idx / 8 + j * 4;
uint32_t q, r;
group_size.divmod(packed_q_idx, q, r);
// q_idx = packed_q_idx / group_size
// h_idx = packed_q_idx % group_size
const uint32_t q_idx = q;
// q_ptr_base: [n_tokens, n_heads, head_dim]
// q_ptr for given header: [head_dim]
DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h +
warp_idx_z * 8 * num_elems_per_128b<DTypeQ>();

// load head_dim from global memory to shared memory using num_warps_n
// warps echa warp loads 8 columns once
uint32_t q_smem_y = warp_idx_z * 8 + lane_idx % 8;
#pragma unroll
// for (uint32_t fyo = 0; fyo < num_iters_k / 4 / num_warps_n ; ++fyo) {
while(q_smem_y * num_elems_per_128b<DTypeQ>() < head_dim) {
const uint32_t q_smem_offset_w =
q_smem->template get_permuted_offset<channel_size_128b_q>(q_smem_x,
q_smem_y);

// load q fragment from gmem to smem
q_smem->load_128b_async<SharedMemFillMode::kNoFill>(
q_smem_offset_w, q_ptr, q_idx < qo_upper_bound);
// move ahead by 8 int128_t for each warp
q_smem_y += (8 * num_warps_n);

// move ahead by 8 * 16 bytes
q_ptr += (8 * num_elems_per_128b<DTypeQ>() * num_warps_n);
}

// adjust offset for next iteration
// move ahead by 4 rows
q_smem_x += 4;
}
}
}
Expand Down Expand Up @@ -876,7 +875,7 @@ __device__ __forceinline__ void write_o_reg_gmem(

} // namespace

// dim3 nblks(n_splits, 1, num_kv_heads);
// dim3 nblks(n_splits, num_kv_heads);
// dim3 nthrs(32, num_warps_m, num_warps_n);
template <LogitsPostHook logits_post_hook,
MaskMode mask_mode,
Expand Down Expand Up @@ -930,11 +929,12 @@ __launch_bounds__(num_warps_m* num_warps_n* warp_size) void attention_kernel(
return;
}

const uint32_t kv_head_idx = blockIdx.z;
const uint32_t kv_head_idx = blockIdx.y;
const uint32_t q_head_idx_base = kv_head_idx * group_size;
const uint32_t lane_idx = threadIdx.x;
const uint32_t warp_idx = get_warp_idx<num_warps_m, num_warps_n>();
const uint32_t num_kv_heads = gridDim.z;

const uint32_t num_kv_heads = gridDim.y;
const uint32_t num_qo_heads = num_kv_heads * group_size;

const uint32_t request_idx = request_indices[bx];
Expand Down Expand Up @@ -992,6 +992,7 @@ __launch_bounds__(num_warps_m* num_warps_n* warp_size) void attention_kernel(
const uint32_t q_stride_n = num_qo_heads * head_dim;
const uint32_t q_stride_h = head_dim;

// TODO: move this into load_q_global_smem
// [n_tokens, n_heads, head_dim]
DTypeQ* q_ptr_base =
q + get_elem_offset_impl(
Expand Down Expand Up @@ -1399,7 +1400,7 @@ cudaError_t mha_varlen_dispatch(DTypeQ* q,
return cudaSuccess;
}

dim3 nblks(padded_batch_size, 1, num_kv_heads);
dim3 nblks(padded_batch_size, num_kv_heads);
dim3 nthrs(32, num_warps_m, num_warps_n);

constexpr uint32_t num_iters_k = HEAD_DIM / 16;
Expand Down

0 comments on commit 2cdb3cb

Please sign in to comment.