Skip to content

Extend FlashAttention Prefill with KV cache #318

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 21 commits into
base: sycl-develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
b8c4928
Add seq_len_kv_cache
min-jean-cho Apr 18, 2025
fcde843
Add KV cache and update verify kernel
min-jean-cho Apr 18, 2025
359dfa5
Update mmaQK, mmaPV to handle KV cache, new
min-jean-cho Apr 18, 2025
180574e
Correct verify kernel with KV cache
min-jean-cho Apr 19, 2025
cac41c2
Fix causal mask when KV cache
min-jean-cho Apr 19, 2025
9403e6d
Update applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp
min-jean-cho Apr 21, 2025
f413f9c
Update applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp
min-jean-cho Apr 21, 2025
9c02255
Update applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp
min-jean-cho Apr 21, 2025
cfe8ffd
Merge branch 'sycl-develop' into minjean/extend_attention_prefill
min-jean-cho Apr 21, 2025
748815c
Minor update
min-jean-cho Apr 21, 2025
72edfc0
Update flops, gbps calculation
min-jean-cho Apr 22, 2025
6735cfd
Fix verify when num_heads_kv != num_heads_q
min-jean-cho Apr 22, 2025
c9d11ea
Fixing the index for launch
mehdi-goli Apr 22, 2025
e7505c8
Adding prefetch to the extend version. Fixing the stride for the vari…
mehdi-goli Apr 22, 2025
e666342
Update applications/flash_attention_v2/kernel/tile_scheduler.hpp
min-jean-cho Apr 24, 2025
31e5580
Update applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp
min-jean-cho Apr 24, 2025
21056aa
Update applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp
min-jean-cho Apr 24, 2025
31b9aa9
Update applications/flash_attention_v2/kernel/xe_flash_attn_gemm.hpp
min-jean-cho Apr 24, 2025
83bf043
Update applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp
min-jean-cho Apr 24, 2025
5620c38
Applyig the comments
mehdi-goli Apr 25, 2025
e5fbdbb
Update xe_flash_attn_gemm.hpp
mehdi-goli Apr 25, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ class CollectiveEpilogueAttention<epilogue::IntelPVCEpilogue, CtaTileMNK_, Eleme
template <class ProblemShape>
static constexpr Params to_underlying_arguments(ProblemShape const &problem_shape, Arguments const &args,
[[maybe_unused]] void *workspace) {
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = problem_shape;
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape;

auto tensorO = make_tensor(make_gmem_ptr(static_cast<ElementO const*>(args.ptr_O)),
make_layout(make_shape(seq_len_qo, head_size_vo, batch * num_heads_q),
Expand Down Expand Up @@ -179,7 +179,7 @@ class CollectiveEpilogueAttention<epilogue::IntelPVCEpilogue, CtaTileMNK_, Eleme
}

// Indexing variables
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = problem_shape;
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape;
// Represent the full output tensor
Tensor mO_mnl = cute::get_pvc_tensor(make_shape(seq_len_qo, head_size_vo, (is_var_len ? batch : 1) * num_heads_q));

Expand All @@ -204,8 +204,8 @@ class CollectiveEpilogueAttention<epilogue::IntelPVCEpilogue, CtaTileMNK_, Eleme
if constexpr (!VarLen) {
return params;
} else {
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = problem_shape;

auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape;
auto qo_cumulative_length = get<3>(problem_shape).cumulative_length;
int offset_o = num_heads_q * head_size_vo * qo_cumulative_length[l_coord];
auto store_traits = static_cast<traits_store_O const&>(params.xe_store_o);
Expand Down
68 changes: 51 additions & 17 deletions applications/flash_attention_v2/collective/xe_flash_attn_mma.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,8 @@ struct CollectiveMmaAttention<gemm::MainloopIntelPVC<Stages>, ProblemShapeType_,
using atom_load_V = Copy_Atom<traits_load_V, ElementV>;
using val_layout_load_V = decltype(make_layout(shape_div(typename traits_load_V::BlockShape{}, CopyThreadShape{})));
using XE_Copy_V = decltype(make_tiled_copy(atom_load_V{}, Layout<CopyThreadShape>{}, val_layout_load_V{}));

using TensorK = decltype(make_tensor(make_gmem_ptr(static_cast<ElementK const*>(nullptr)), make_shape(0,0,0), StrideK{})); //(m, k)
using TensorV = decltype(make_tensor(make_gmem_ptr(static_cast<ElementV const*>(nullptr)), make_shape(0,0,0), StrideV{})); //(n, k)
// Host side kernel arguments
struct Arguments {
ElementQ const *ptr_Q;
Expand All @@ -163,12 +164,18 @@ struct CollectiveMmaAttention<gemm::MainloopIntelPVC<Stages>, ProblemShapeType_,
StrideK dK;
ElementV const *ptr_V;
StrideV dV;
ElementK const* ptr_K_cache;
StrideK dK_cache;
ElementV const* ptr_V_cache;
StrideV dV_cache;
};

struct Params {
XE_Copy_Q gmem_tiled_copy_q;
XE_Copy_K gmem_tiled_copy_k;
XE_Copy_V gmem_tiled_copy_v;
XE_Copy_K gmem_tiled_copy_k_cache;
XE_Copy_V gmem_tiled_copy_v_cache;
};

//
Expand All @@ -181,25 +188,32 @@ struct CollectiveMmaAttention<gemm::MainloopIntelPVC<Stages>, ProblemShapeType_,
void *workspace) {
(void)workspace;

auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = problem_shape;
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape;

auto tensorQ = make_tensor(make_gmem_ptr(args.ptr_Q), make_layout(make_shape(seq_len_qo, head_size_qk, batch * num_heads_q), args.dQ));
auto tensorK = make_tensor(make_gmem_ptr(args.ptr_K), make_layout(make_shape(seq_len_kv, head_size_qk, batch * num_heads_kv), args.dK));
auto tensorV = make_tensor(make_gmem_ptr(args.ptr_V), make_layout(make_shape(head_size_vo, seq_len_kv, batch * num_heads_kv), args.dV));
auto tensorK_cache = make_tensor(make_gmem_ptr(args.ptr_K_cache), make_layout(make_shape(seq_len_kv_cache, head_size_qk, batch * num_heads_kv), args.dK_cache));
auto tensorV_cache = make_tensor(make_gmem_ptr(args.ptr_V_cache), make_layout(make_shape(head_size_vo, seq_len_kv_cache, batch * num_heads_kv), args.dV_cache));

XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)};
XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)};
XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)};

return Params{copyQ, copyK, copyV};
XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)};
XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)};

return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache};
}

template <class FragQccum, class TensorQ, class TensorK, class FragSrc>
CUTLASS_DEVICE void mmaQK(FragQccum &accum, TensorQ gQ, TensorK gK, FragSrc const &frag_src,
int const &k_tile_count, Params const &params) {
int const &k_tile_count, Params const &params, bool is_KV_cache) {

auto& gmem_tiled_copy_k = is_KV_cache ? params.gmem_tiled_copy_k_cache : params.gmem_tiled_copy_k;

int thread_idx = static_cast<int>(ThreadIdxX());
auto thr_copy_Q = params.gmem_tiled_copy_q.get_slice(thread_idx);
auto thr_copy_K = params.gmem_tiled_copy_k.get_slice(thread_idx);
auto thr_copy_K = gmem_tiled_copy_k.get_slice(thread_idx);
// Instantiate the MMA object
TiledMmaK tiled_mma_k;
TiledMmaQVO tiled_mma_q;
Expand All @@ -216,7 +230,7 @@ struct CollectiveMmaAttention<gemm::MainloopIntelPVC<Stages>, ProblemShapeType_,
// Create fragments
// TODO(Codeplay): fix this, this is probably not general
Tensor tCrQ = make_tensor<ElementQ>(make_fragment_layout(params.gmem_tiled_copy_q, take<0,3>(tCgQ.shape())));
Tensor tCrK = make_tensor<ElementK>(make_fragment_layout(params.gmem_tiled_copy_k, take<0,3>(tCgK.shape())));
Tensor tCrK = make_tensor<ElementK>(make_fragment_layout(gmem_tiled_copy_k, take<0,3>(tCgK.shape())));

// Retile registers for copies
Tensor tQrQ = thr_copy_Q.retile_D(tCrQ);
Expand Down Expand Up @@ -254,14 +268,16 @@ struct CollectiveMmaAttention<gemm::MainloopIntelPVC<Stages>, ProblemShapeType_,

for (int k_tile = 0; k_tile < k_tile_count; ++k_tile) {
copy(params.gmem_tiled_copy_q, tQgQ(_,_,_,k_tile), tQrQ);
copy(params.gmem_tiled_copy_k, tKgK(_,_,_,k_tile), tKrK);
copy(gmem_tiled_copy_k, tKgK(_,_,_,k_tile), tKrK);
cute::gemm(tiled_mma_q, accum, tCrQ, tCrK, frag_src);
}
}

template <class FragQccum, class FragS, class TensorV, class FragSrc>
CUTLASS_DEVICE void mmaPV(FragQccum &accum, FragS const &tSr, TensorV gV,
FragSrc const &frag_src, Params const &params) {
FragSrc const &frag_src, Params const &params, bool is_KV_cache) {

auto& gmem_tiled_copy_v = is_KV_cache ? params.gmem_tiled_copy_v_cache : params.gmem_tiled_copy_v;

int thread_idx = static_cast<int>(ThreadIdxX());
// Instantiate the MMA object
Expand All @@ -270,10 +286,10 @@ struct CollectiveMmaAttention<gemm::MainloopIntelPVC<Stages>, ProblemShapeType_,
auto first_thread_in_sg_idx = sg.get_group_id()[0] * DispatchPolicy::SubgroupSize;
auto thread_mma = tiled_mma.get_slice(first_thread_in_sg_idx);
Tensor tCgV = thread_mma.partition_B(gV);
Tensor tCrV = make_tensor<ElementV>(make_fragment_layout(params.gmem_tiled_copy_v, tCgV.shape()));
Tensor tCrV = make_tensor<ElementV>(make_fragment_layout(gmem_tiled_copy_v, tCgV.shape()));

// Partition the copying of A and B tiles across the threads
auto gmem_thr_copy_V = params.gmem_tiled_copy_v.get_slice(thread_idx);
auto gmem_thr_copy_V = gmem_tiled_copy_v.get_slice(thread_idx);
Tensor tVrV = gmem_thr_copy_V.retile_D(tCrV);
Tensor tVgV = gmem_thr_copy_V.retile_S(tCgV);

Expand All @@ -299,7 +315,7 @@ struct CollectiveMmaAttention<gemm::MainloopIntelPVC<Stages>, ProblemShapeType_,
//
// Mainloop
//
copy(params.gmem_tiled_copy_v, tVgV, tVrV);
copy(gmem_tiled_copy_v, tVgV, tVrV);
cute::gemm(tiled_mma, accum, tPr, tCrV, frag_src);
}

Expand All @@ -308,23 +324,31 @@ struct CollectiveMmaAttention<gemm::MainloopIntelPVC<Stages>, ProblemShapeType_,
if constexpr (!is_var_len) {
return params;
} else {
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo] = problem_shape;
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape;

auto qo_cumulative_length = get<3>(problem_shape).cumulative_length;
auto kv_cumulative_length = get<4>(problem_shape).cumulative_length;

int offset_q = num_heads_q * head_size_qk * qo_cumulative_length[l_coord];
int offset_k = num_heads_kv * head_size_qk * kv_cumulative_length[l_coord];
int offset_v = num_heads_kv * head_size_vo * kv_cumulative_length[l_coord];
int offset_k_cache = num_heads_kv * head_size_qk * seq_len_kv_cache;
int offset_v_cache = num_heads_kv * head_size_vo * seq_len_kv_cache;
Comment on lines +335 to +336
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we consider the cached key-value pairs to be the same across all batches? My understanding is that each batch would have it's seq_len for the cached keys and values, which would mean that seq_len_kv_cache would also be of Variable Length type (same as seq_len_qo and seq_len_kv). This code would potentially give out of bound access because it is missing the multiplication with l_coord (if we want to keep seq_len_kv_cache fixed length), or a multiplication with kv_cache_cumulative_length[l_coord] (if we want to change the type to Variable Length)


auto q_traits = static_cast<traits_load_Q const&>(params.gmem_tiled_copy_q);
ElementQ* q_ptr = (ElementQ*)q_traits.base_ptr;
const ElementQ* q_ptr = (const ElementQ*)q_traits.base_ptr;

auto k_traits = static_cast<traits_load_K const&>(params.gmem_tiled_copy_k);
ElementK* k_ptr = (ElementK*)k_traits.base_ptr;
const ElementK* k_ptr = (const ElementK*)k_traits.base_ptr;

auto v_traits = static_cast<traits_load_V const&>(params.gmem_tiled_copy_v);
ElementV* v_ptr = (ElementV*)v_traits.base_ptr;
const ElementV* v_ptr = (const ElementV*)v_traits.base_ptr;

auto k_traits_cache = static_cast<traits_load_K const&>(params.gmem_tiled_copy_k_cache);
const ElementK* k_cache_ptr = (const ElementK*)k_traits_cache.base_ptr;

auto v_traits_cache = static_cast<traits_load_V const&>(params.gmem_tiled_copy_v_cache);
const ElementV* v_cache_ptr = (const ElementV*)v_traits_cache.base_ptr;

auto shape_q = make_shape(static_cast<int>(seq_len_qo), head_size_qk, num_heads_q);
StrideQ stride_q = cutlass::make_cute_packed_stride(StrideQ{}, shape_q);
Expand All @@ -335,15 +359,25 @@ struct CollectiveMmaAttention<gemm::MainloopIntelPVC<Stages>, ProblemShapeType_,
auto shape_v = make_shape(head_size_vo, static_cast<int>(seq_len_kv), num_heads_kv);
StrideV stride_v = cutlass::make_cute_packed_stride(StrideV{}, shape_v);

auto shape_k_cache = make_shape(static_cast<int>(seq_len_kv_cache), head_size_qk, num_heads_kv);
StrideK stride_k_cache = cutlass::make_cute_packed_stride(StrideK{}, shape_k_cache);

auto shape_v_cache = make_shape(head_size_vo, static_cast<int>(seq_len_kv_cache), num_heads_kv);
StrideV stride_v_cache = cutlass::make_cute_packed_stride(StrideV{}, shape_v_cache);

auto tensorQ = make_tensor(make_gmem_ptr(q_ptr + offset_q), make_layout(shape_q, stride_q));
auto tensorK = make_tensor(make_gmem_ptr(k_ptr + offset_k), make_layout(shape_k, stride_k));
auto tensorV = make_tensor(make_gmem_ptr(v_ptr + offset_v), make_layout(shape_v, stride_v));
auto tensorK_cache = make_tensor(make_gmem_ptr(k_cache_ptr + offset_k_cache), make_layout(shape_k_cache, stride_k_cache));
auto tensorV_cache = make_tensor(make_gmem_ptr(v_cache_ptr + offset_v_cache), make_layout(shape_v_cache, stride_v_cache));

XE_Copy_Q copyQ{XE_Copy_Q{}.with(tensorQ)};
XE_Copy_K copyK{XE_Copy_K{}.with(tensorK)};
XE_Copy_V copyV{XE_Copy_V{}.with(tensorV)};
XE_Copy_K copyK_cache{XE_Copy_K{}.with(tensorK_cache)};
XE_Copy_V copyV_cache{XE_Copy_V{}.with(tensorV_cache)};

return Params{copyQ, copyK, copyV};
return Params{copyQ, copyK, copyV, copyK_cache, copyV_cache};
}
}
};
Expand Down
8 changes: 4 additions & 4 deletions applications/flash_attention_v2/kernel/tile_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ struct XeFlashIndividualTileScheduler {
ProblemSize const& problem_size, KernelHardwareInfo hw_info,
TileShape const& tile_shape) {
using namespace cute;
// problem_size = [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo]
dim3 grid(size(ceil_div(shape<6>(problem_size), shape<1>(tile_shape))),
// problem_size = [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo]
dim3 grid(size(ceil_div(shape<7>(problem_size), shape<1>(tile_shape))),
size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape))),
size(shape<0>(problem_size) * shape<1>(problem_size)));
return Params{ grid, {shape<1>(problem_size)} };
Expand Down Expand Up @@ -127,8 +127,8 @@ struct XeFlashPersistentTileScheduler {
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
hw_info.sm_count = sm_count;

// problem_size = [batch, num_heads_q, numhead_kv, seq_len_qo, seq_len_kv, head_size_qk, head_size_vo]
int num_head_size_blocks = size(ceil_div(shape<6>(problem_size), shape<1>(tile_shape)));
// problem_size = [batch, num_heads_q, numhead_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo]
int num_head_size_blocks = size(ceil_div(shape<7>(problem_size), shape<1>(tile_shape)));
int num_seq_len_blocks = size(ceil_div(shape<3>(problem_size), shape<0>(tile_shape)));
int num_blocks = num_seq_len_blocks * num_head_size_blocks * size(shape<0>(problem_size) * shape<1>(problem_size));

Expand Down
Loading