Skip to content

Commit e1dc96e

Browse files
min-jean-chomehdi-goliaacostadiaz
authored
(Take2) Extend FlashAttention Prefill with KV cache (#331)
Original: #318 This extends FlashAttention prefill with cached KV in addition to current KV (blue box in the below figure). Both causal and non-causal are supported. <img width="611" alt="extend prefill" src="https://github.com/user-attachments/assets/d27e5dc6-5700-447e-b8c8-33e8073d7891" /> --------- Co-authored-by: Mehdi Goli <[email protected]> Co-authored-by: Alejandro Acosta <[email protected]>
1 parent da9c63f commit e1dc96e

File tree

8 files changed

+2114
-0
lines changed

8 files changed

+2114
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
/***************************************************************************************************
2+
* Copyright (c) 2024 - 2025 Codeplay Software Ltd. All rights reserved.
3+
* SPDX-License-Identifier: BSD-3-Clause
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are met:
7+
*
8+
* 1. Redistributions of source code must retain the above copyright notice, this
9+
* list of conditions and the following disclaimer.
10+
*
11+
* 2. Redistributions in binary form must reproduce the above copyright notice,
12+
* this list of conditions and the following disclaimer in the documentation
13+
* and/or other materials provided with the distribution.
14+
*
15+
* 3. Neither the name of the copyright holder nor the names of its
16+
* contributors may be used to endorse or promote products derived from
17+
* this software without specific prior written permission.
18+
*
19+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20+
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21+
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22+
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23+
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24+
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25+
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26+
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27+
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29+
*
30+
**************************************************************************************************/
31+
/*! \file
32+
\brief Functor performing elementwise operations used by epilogues.
33+
*/
34+
35+
#pragma once
36+
37+
#include <sycl/sycl.hpp>
38+
#include "cutlass/cutlass.h"
39+
#include "cutlass/epilogue/dispatch_policy.hpp"
40+
#include "cutlass/epilogue/collective/collective_epilogue.hpp"
41+
#include "cutlass/epilogue/collective/detail.hpp"
42+
#include "cutlass/detail/layout.hpp"
43+
44+
/////////////////////////////////////////////////////////////////////////////////////////////////
45+
46+
namespace cutlass {
47+
namespace flash_attention {
48+
namespace collective {
49+
50+
/////////////////////////////////////////////////////////////////////////////////////////////////
51+
52+
template <class DispatchPolicy, class... Args> class CollectiveEpilogueAttention {
53+
static_assert(cutlass::detail::dependent_false<DispatchPolicy>, "Could not find an epilogue specialization.");
54+
};
55+
56+
template <class CtaTileMNK_, class ElementO_, class StrideO_, class ElementLSE_, class CopyOpO_>
57+
class CollectiveEpilogueAttention<epilogue::IntelPVCEpilogue, CtaTileMNK_, ElementO_, StrideO_, ElementLSE_, CopyOpO_> {
58+
public:
59+
//
60+
// Type Aliases
61+
//
62+
using DispatchPolicy = epilogue::IntelPVCEpilogue;
63+
using CtaTileMNK = CtaTileMNK_;
64+
using ElementO = ElementO_;
65+
using ElementAccumulator = ElementO_;
66+
using StrideO = StrideO_;
67+
using ElementLSE = ElementLSE_;
68+
using CopyOpO = CopyOpO_;
69+
70+
using GmemTiledCopyO = CopyOpO;
71+
using ElementOutput = ElementO_;
72+
using ElementCompute = ElementO_;
73+
74+
static constexpr int SubgroupSize = DispatchPolicy::SubgroupSize;
75+
76+
static_assert(cute::rank(CtaTileMNK{}) == 4, "CtaTileMNK must be rank-4: [CTA_M_Q, CTA_N_V, CTA_N_QK, CTA_K_QK]");
77+
static_assert(cute::rank(StrideO{}) == 3, "StrideO must be rank-3: [seq_len_qo, head_size_vo, batch * num_heads]");
78+
79+
using CopyThreadShape = Shape<_1, Int<SubgroupSize>>;
80+
81+
using traits_store_O = Copy_Traits<GmemTiledCopyO, StrideO>;
82+
using atom_load_O = Copy_Atom<traits_store_O, ElementO>;
83+
using val_layout_load_O = decltype(make_layout(shape_div(typename traits_store_O::BlockShape{}, CopyThreadShape{})));
84+
using XE_Copy_O = decltype(make_tiled_copy(atom_load_O{}, Layout<CopyThreadShape>{}, val_layout_load_O{}));
85+
86+
private:
87+
constexpr static bool is_destination_supported = not cute::is_void_v<ElementO>;
88+
89+
public:
90+
using EmptyType = cute::tuple<>;
91+
92+
struct TensorStorageImpl : cute::tuple<EmptyType, EmptyType> {};
93+
94+
struct SharedStorage {
95+
using TensorStorage = TensorStorageImpl;
96+
97+
TensorStorage tensors;
98+
};
99+
using TensorStorage = typename SharedStorage::TensorStorage;
100+
101+
// Host side epilogue arguments
102+
struct Arguments {
103+
ElementO const *ptr_O;
104+
StrideO dO;
105+
};
106+
107+
// Device side epilogue params
108+
struct Params {
109+
XE_Copy_O xe_store_o;
110+
};
111+
112+
//
113+
// Methods
114+
//
115+
116+
template <class ProblemShape>
117+
static constexpr Params to_underlying_arguments(ProblemShape const &problem_shape, Arguments const &args,
118+
[[maybe_unused]] void *workspace) {
119+
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape;
120+
121+
auto tensorO = make_tensor(make_gmem_ptr(static_cast<ElementO const*>(args.ptr_O)),
122+
make_layout(make_shape(seq_len_qo, head_size_vo, batch * num_heads_q),
123+
args.dO));
124+
XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)};
125+
return {
126+
xe_store_o,
127+
};
128+
}
129+
130+
template <class ProblemShape>
131+
static size_t get_workspace_size(ProblemShape const &problem_shape, Arguments const &args) {
132+
return 0;
133+
}
134+
135+
template <class ProblemShape>
136+
static cutlass::Status initialize_workspace(ProblemShape const &problem_shape, Arguments const &args, void *workspace,
137+
cudaStream_t stream, CudaHostAdapter *cuda_adapter = nullptr) {
138+
return Status::kSuccess;
139+
}
140+
141+
template <class ProblemShape>
142+
CUTLASS_HOST_DEVICE static bool can_implement(ProblemShape const &problem_shape,
143+
[[maybe_unused]] Arguments const &args) {
144+
return true;
145+
}
146+
147+
CUTLASS_HOST_DEVICE
148+
CollectiveEpilogueAttention(Params const &params_, TensorStorage const &) : params(params_) {}
149+
150+
template <class ProblemShape, class TileCoord, class FragOut, class FragMax, class FragSum, class TiledMma>
151+
CUTLASS_DEVICE void operator()(ProblemShape problem_shape, TileCoord tile_coord, FragOut &out, FragMax const &max,
152+
FragSum &sum, TiledMma tiled_mma, ElementCompute const &softmax_scale) {
153+
154+
using namespace cute;
155+
156+
using MmaAtomShape = typename TiledMma::AtomShape_MNK;
157+
using SubgroupTileShape = decltype(cute::shape_div(take<0, 3>(CtaTileMNK{}), take<1, 4>(typename TiledMma::ThrLayoutVMNK{}.shape())));
158+
using FragsShape = decltype(cute::shape_div(take<0, 2>(SubgroupTileShape{}), take<0, 2>(MmaAtomShape())));
159+
static constexpr bool is_var_len = cutlass::fmha::collective::is_variable_length_v<tuple_element_t<2, ProblemShape>>;
160+
161+
static constexpr int FragsM = get<0>(FragsShape{}); // A frags per sub_group
162+
static constexpr int FragsN = get<1>(FragsShape{}); // B frags per sub_group
163+
static constexpr int Vec = (get<0>(MmaAtomShape()) * get<1>(MmaAtomShape())) / SubgroupSize;
164+
165+
auto g = syclcompat::get_nd_item<1>().get_sub_group();
166+
167+
CUTLASS_PRAGMA_UNROLL
168+
for (int y = 0; y < FragsM; y++) {
169+
CUTLASS_PRAGMA_UNROLL
170+
for (int x = 0; x < Vec; x++) {
171+
int indx = y * Vec + x;
172+
auto cur_sum = reduce_over_group(g, sum(indx), sycl::plus<>());
173+
auto cur_scale = (cur_sum == 0.f || cur_sum != cur_sum) ? 1.f : sycl::native::recip(cur_sum);
174+
CUTLASS_PRAGMA_UNROLL
175+
for (int z = 0; z < FragsN; z++) {
176+
out(x, y, z) *= cur_scale;
177+
}
178+
}
179+
}
180+
181+
// Indexing variables
182+
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape;
183+
// Represent the full output tensor
184+
Tensor mO_mnl = cute::get_pvc_tensor(make_shape(seq_len_qo, head_size_vo, (is_var_len ? batch : 1) * num_heads_q));
185+
186+
auto [m_coord, n_coord, k_coord, l_coord] = tile_coord;
187+
// Tile the output tensor per WG
188+
Tensor g_wg_O = local_tile(mO_mnl, select<0,1>(CtaTileMNK{}), make_coord(m_coord,n_coord,l_coord)); // (BLK_M,BLK_N,m,n,l)
189+
static constexpr auto ATOM_N = get<2>(typename TiledMma::ThrLayoutVMNK{}.shape());
190+
static constexpr auto ATOM_K = get<3>(typename TiledMma::ThrLayoutVMNK{}.shape());
191+
auto m_sg = get_sub_group_id() / ATOM_N;
192+
auto n_sg = get_sub_group_id() % ATOM_N;
193+
// Tile the output tensor per SG
194+
Tensor gO = local_tile(g_wg_O, SubgroupTileShape{}, make_coord(m_sg,n_sg,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
195+
196+
auto thread_xe_store_o = params.xe_store_o.get_thread_slice(ThreadIdxX());
197+
Tensor tOgO = thread_xe_store_o.partition_D(gO);
198+
199+
copy(params.xe_store_o, out, tOgO);
200+
}
201+
202+
template <bool VarLen, class ProblemShapeType>
203+
CUTLASS_DEVICE static constexpr Params get_updated_copies(Params const& params, ProblemShapeType const& problem_shape, int const& l_coord) {
204+
if constexpr (!VarLen) {
205+
return params;
206+
} else {
207+
auto [batch, num_heads_q, num_heads_kv, seq_len_qo, seq_len_kv, seq_len_kv_cache, head_size_qk, head_size_vo] = problem_shape;
208+
209+
auto qo_cumulative_length = get<3>(problem_shape).cumulative_length;
210+
int offset_o = num_heads_q * head_size_vo * qo_cumulative_length[l_coord];
211+
auto store_traits = static_cast<traits_store_O const&>(params.xe_store_o);
212+
213+
ElementO* base_ptr = (ElementO*)store_traits.base_ptr;
214+
auto shape_o = make_shape(static_cast<int>(seq_len_qo), head_size_vo, num_heads_q);
215+
StrideO stride_o = cutlass::make_cute_packed_stride(StrideO{}, shape_o);
216+
217+
auto tensorO = make_tensor(make_gmem_ptr(base_ptr + offset_o), make_layout(shape_o, stride_o));
218+
XE_Copy_O xe_store_o{XE_Copy_O{}.with(tensorO)};
219+
220+
return Params{xe_store_o};
221+
}
222+
}
223+
224+
private:
225+
Params const &params;
226+
};
227+
228+
/////////////////////////////////////////////////////////////////////////////////////////////////
229+
230+
} // namespace collective
231+
} // namespace flash_attention
232+
} // namespace cutlass
233+
234+
/////////////////////////////////////////////////////////////////////////////////////////////////

0 commit comments

Comments
 (0)