Skip to content

Commit

Permalink
Decode: varlen, paged KV, leftpad
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Nov 11, 2024
1 parent 26a6e0a commit df96486
Show file tree
Hide file tree
Showing 19 changed files with 3,103 additions and 1,368 deletions.
553 changes: 323 additions & 230 deletions hopper/benchmark_attn.py

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions hopper/epilogue_bwd_sm90_tma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ struct CollectiveEpilogueBwd {
cute::array_aligned<Element, cute::cosize_v<SmemLayoutdKV>, SmemAlignmentdKV> smem_dv;
};

using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_k, d, head, batch)
using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;

using TMA_dKV = decltype(make_tma_copy(
Expand Down Expand Up @@ -196,8 +196,7 @@ struct CollectiveEpilogueBwd {
if (warp_idx_sync == NumEpilogueThreads / cutlass::NumThreadsPerWarp - 1) {
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads + cutlass::NumThreadsPerWarp,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
int const lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
if (cute::elect_one_sync()) {
cute::copy(params.tma_store_dV, tdVsdV, tdVgdV);
cute::copy(params.tma_store_dK, tdKsdK, tdKgdK);
tma_store_arrive();
Expand Down Expand Up @@ -319,7 +318,7 @@ struct CollectiveEpilogueBwdGQA {
cute::array_aligned<ElementAccum, cute::cosize_v<SmemLayoutdKVaccumTMA>> smem_dkv;
};

using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_q, d, head, batch)
using ShapedKV = cute::Shape<int32_t, int32_t, int32_t, int32_t>; // (seqlen_k, d, head, batch)
using StridedKV = cute::Stride<int64_t, _1, int64_t, int64_t>;

using TMA_add_dKV = decltype(make_tma_copy(
Expand Down Expand Up @@ -427,6 +426,9 @@ struct CollectiveEpilogueBwdGQA {
auto r2s_thr_copy_dKVaccum = r2s_tiled_copy_dKVaccum.get_thread_slice(thread_idx);
Tensor tdKVsdKVaccum = r2s_thr_copy_dKVaccum.partition_D(sdKV);

// Make sure all WGs have finished reading K and V, otherwise we get racy dQ
// because smem_q could be changed.
cutlass::arch::NamedBarrier::sync(NumEpilogueThreads, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
Tensor taccdKVrdV = r2s_thr_copy_dKVaccum.retile_S(tdVrdV); // ((Atom,AtomNum), MMA_M, MMA_N)
cute::copy(r2s_tiled_copy_dKVaccum, taccdKVrdV, tdKVsdKVaccum);

Expand Down
393 changes: 310 additions & 83 deletions hopper/epilogue_fwd_sm90_tma.hpp

Large diffs are not rendered by default.

48 changes: 26 additions & 22 deletions hopper/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,12 @@ struct Qkv_params {

// The number of heads.
int h, h_k;
// In the case of multi-query and grouped-query attention (MQA/GQA), nheads_k could be
// different from nheads (query).
int h_h_k_ratio; // precompute h / h_k,
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_fwd_params : public Qkv_params {
using index_t = int64_t;

// The O matrix (output).
void * __restrict__ o_ptr;
Expand All @@ -50,37 +48,43 @@ struct Flash_fwd_params : public Qkv_params {
index_t o_row_stride;
index_t o_head_stride;

// The pointer to the P matrix.
void * __restrict__ p_ptr;

// The pointer to the softmax sum.
void * __restrict__ softmax_lse_ptr;
void * __restrict__ softmax_lseaccum_ptr;

// For FP8 scaling
float * __restrict__ q_scale_ptr;
float * __restrict__ k_scale_ptr;
float * __restrict__ v_scale_ptr;
float * __restrict__ q_descale_ptr;
float * __restrict__ k_descale_ptr;
float * __restrict__ v_descale_ptr;

// The dimensions.
int b, seqlen_q, seqlen_k, seqlen_knew, d, seqlen_q_rounded, seqlen_k_rounded, d_rounded, rotary_dim;
int total_q, total_k;
int b_k; // When having KV cache and with cache_batch_idx, K & V might have larger batch size than Q

// The scaling factors for the kernel.
float scale_softmax;
float scale_softmax_log2;
uint32_t scale_softmax_log2_half2;
float softcap;

// array of length b+1 holding starting offset of each sequence.
int * __restrict__ cu_seqlens_q;
int * __restrict__ cu_seqlens_k;
int * __restrict__ leftpad_k;

// If provided, the actual length of each q/k sequence.
int *__restrict__ seqused_q;
int *__restrict__ seqused_k;

int *__restrict__ blockmask;
// The stride between rows of Oaccum.
index_t oaccum_split_stride;
index_t oaccum_batch_stride;
index_t oaccum_row_stride;
index_t oaccum_head_stride;

// The stride between rows of LSEaccum.
index_t lseaccum_split_stride;
index_t lseaccum_batch_stride;
index_t lseaccum_head_stride;

// The K_new and V_new matrices.
void * __restrict__ knew_ptr;
Expand All @@ -99,12 +103,13 @@ struct Flash_fwd_params : public Qkv_params {
void * __restrict__ rotary_sin_ptr;

// The indices to index into the KV cache.
int * __restrict__ cache_batch_idx;
int * __restrict__ kv_batch_idx;

// Paged KV cache
int * __restrict__ block_table;
index_t block_table_batch_stride;
int page_block_size;
int * __restrict__ page_table;
index_t page_table_batch_stride;
int page_size;
int num_pages;

// The dropout probability (probability of keeping an activation).
float p_dropout;
Expand All @@ -114,15 +119,16 @@ struct Flash_fwd_params : public Qkv_params {

// Scale factor of 1 / (1 - p_dropout).
float rp_dropout;
float scale_softmax_rp_dropout;

// Local window size
int window_size_left, window_size_right;
int sink_token_length;

// Pointer to the RNG seed (idx 0) and offset (idx 1).
uint64_t * rng_state;

bool is_bf16;
bool is_fp32;
bool is_e4m3;
bool is_causal;
bool is_local;
Expand All @@ -134,16 +140,15 @@ struct Flash_fwd_params : public Qkv_params {
bool is_rotary_interleaved;

int num_splits; // For split-KV version

void * __restrict__ alibi_slopes_ptr;
index_t alibi_slopes_batch_stride;
int pack_gqa; // 0: no packing, 1: pack GQA, -1: use heuristic to decide

int * __restrict__ tile_count_semaphore;
};

////////////////////////////////////////////////////////////////////////////////////////////////////

struct Flash_bwd_params : public Flash_fwd_params {
using index_t = int64_t;

// The dO and dQKV matrices.
void *__restrict__ do_ptr;
Expand All @@ -161,8 +166,6 @@ struct Flash_bwd_params : public Flash_fwd_params {
// dv_accum_ptr;

// The stride between rows of the dO, dQ, dK and dV matrices.
// TD [2022-04-16]: We're using 32-bit indexing to save registers.
// The code probably won't work for arrays larger than 2GB.
index_t do_batch_stride;
index_t do_row_stride;
index_t do_head_stride;
Expand Down Expand Up @@ -192,3 +195,4 @@ struct Flash_bwd_params : public Flash_fwd_params {

template<typename T, int Headdim> void run_mha_fwd_(Flash_fwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_bwd_(Flash_bwd_params &params, cudaStream_t stream);
template<typename T, int Headdim> void run_mha_fwd_combine_(Flash_fwd_params &params, cudaStream_t stream);
Loading

0 comments on commit df96486

Please sign in to comment.