Skip to content

Commit 332e630

Browse files
committed
added a blocked version of flash attention with caching
1 parent a130f00 commit 332e630

File tree

2 files changed

+100
-74
lines changed

2 files changed

+100
-74
lines changed

samples/99_attentionexperiments/attention_kernels.cl

Lines changed: 79 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@
1313
#define NH 12
1414
#endif
1515

16+
#define D (C / NH) // head size
17+
1618
kernel void naive_3p_query_key(global float* preatt, global const float* q_base, global const float* k_base)
1719
{
18-
const int D = C / NH; // head size
19-
2020
// Note: Global work size is B * NH * T * T
2121
int idx = get_global_id(0);
2222

@@ -86,8 +86,6 @@ kernel void naive_3p_softmax(global float* att, global const float* preatt)
8686

8787
kernel void naive_3p_value(global float* out, global const float* att, global const float* v_base)
8888
{
89-
const int D = C / NH; // head size
90-
9189
// Note: Global work size is B * T * NH
9290
int idx = get_global_id(0);
9391
int nh = idx % NH;
@@ -117,7 +115,6 @@ kernel void naive_3p_value(global float* out, global const float* att, global co
117115
// https://github.com/tspeterkim/flash-attention-minimal
118116
kernel void flash_attention_minimal(
119117
global const float* Q, global const float* K, global const float* V,
120-
const int Tc, const int Tr,
121118
const int Bc, const int Br, // Note: this implementation requires Bc == Br!!
122119
const float softmax_scale,
123120
global float* l, global float* m, global float* O,
@@ -126,13 +123,15 @@ kernel void flash_attention_minimal(
126123
local float* Vc,
127124
local float* SP) // Used for both S and P
128125
{
129-
const int D = C / NH; // head size
126+
// scale the scale, so we can use exp2 instead of exp
127+
const float adjusted_scale = softmax_scale * M_LOG2E_F;
130128

131129
// Note: Global Work Size is (B * Bc, NH)
132130
// Note: Local Work Size is (Bc, 1)
133131
// --> Group ID is (batch index, head index)
134132
const int b = get_group_id(0);
135133
const int nh = get_group_id(1);
134+
136135
// --> Local ID is row or column index
137136
const int rc = get_local_id(0);
138137

@@ -142,20 +141,20 @@ kernel void flash_attention_minimal(
142141
// Note: l, m are (B, NH, T)
143142
const int lm_offset = (b * NH * T) + (nh * T); // offset for l and m
144143

145-
for (int j = 0; j < Tc; j++) {
144+
for (int to = 0; to < T; to += Bc) {
146145
// Load K, V to SLM
147146
// Each work-item loads one row of Kc and Vc.
148147
for (int d = 0; d < D; d++) {
149-
Kc[(rc * D) + d] = K[qkv_offset + ((Bc * j + rc) * D) + d];
150-
Vc[(rc * D) + d] = V[qkv_offset + ((Bc * j + rc) * D) + d];
148+
Kc[(rc * D) + d] = K[qkv_offset + ((to + rc) * D) + d];
149+
Vc[(rc * D) + d] = V[qkv_offset + ((to + rc) * D) + d];
151150
}
152151
barrier(CLK_LOCAL_MEM_FENCE);
153152

154-
for (int i = 0; i < Tr; i++) {
153+
for (int ti = 0; ti < T; ti += Br) {
155154
// Load Q to SLM
156155
// Each work-item loads one row of Qc
157156
for (int d = 0; d < D; d++) {
158-
Qc[(rc * D) + d] = Q[qkv_offset + ((Bc * i + rc) * D) + d];
157+
Qc[(rc * D) + d] = Q[qkv_offset + ((ti + rc) * D) + d];
159158
}
160159

161160
// Compute SP = QK^T, mi_local = rowmax(SP)
@@ -164,14 +163,14 @@ kernel void flash_attention_minimal(
164163
float mi_local = -INFINITY;
165164
for (int y = 0; y < Bc; y++) {
166165
float xi = 0;
167-
if (CAUSAL && i * Br + rc < j * Bc + y) {
166+
if (CAUSAL && ti + rc < to + y) {
168167
xi = -INFINITY;
169168
}
170169
else {
171170
for (int d = 0; d < D; d++) {
172171
xi += Qc[(rc * D) + d] * Kc[(y * D) + d];
173172
}
174-
xi *= softmax_scale;
173+
xi *= adjusted_scale;
175174
}
176175
SP[(Bc * rc) + y] = xi;
177176
mi_local = fmax(xi, mi_local);
@@ -181,84 +180,91 @@ kernel void flash_attention_minimal(
181180
// implement softmax with causal masking
182181
float vm = 0;
183182
for (int y = 0; y < Bc; y++) {
184-
SP[(Bc * rc) + y] = native_exp(SP[(Bc * rc) + y] - mi_local);
183+
SP[(Bc * rc) + y] = native_exp2(SP[(Bc * rc) + y] - mi_local);
185184
vm += SP[(Bc * rc) + y];
186185
}
187186

188187
// Compute new m and l
189-
float mim1 = m[lm_offset + (Br * i) + rc];
190-
float dim1 = l[lm_offset + (Br * i) + rc];
188+
float mim1 = m[lm_offset + ti + rc];
189+
float dim1 = l[lm_offset + ti + rc];
191190

192191
float mi = fmax(mim1, mi_local);
193-
float di = dim1 * native_exp(mim1 - mi) + vm * native_exp(mi_local - mi);
192+
float di = dim1 * native_exp2(mim1 - mi) + vm * native_exp2(mi_local - mi);
194193

195-
float om = dim1 * native_exp(mim1 - mi) / di;
196-
vm = native_exp(mi_local - mi) / di;
194+
float om = dim1 * native_exp2(mim1 - mi) / di;
195+
vm = native_exp2(mi_local - mi) / di;
197196

198197
// Write O, l, m to HBM
199198
for (int d = 0; d < D; d++) {
200199
float pv = 0; // Pij * Vc
201200
for (int y = 0; y < Bc; y++) {
202-
//if (j * Bc + y >= T) {
201+
//if (to * Bc + y >= T) {
203202
// break;
204203
//}
205204
pv += SP[(Bc * rc) + y] * Vc[(y * D) + d];
206205
}
207206
// O is (B, NH, T, D)
208-
O[qkv_offset + ((Bc * i + rc) * D) + d] =
209-
om * O[qkv_offset + ((Bc * i + rc) * D) + d] +
210-
vm * pv;
207+
O[qkv_offset + ((ti + rc) * D) + d] =
208+
om * O[qkv_offset + ((ti + rc) * D) + d] +
209+
vm * pv;
211210
}
212211

213-
m[lm_offset + (Br * i) + rc] = mi;
214-
l[lm_offset + (Br * i) + rc] = di;
212+
m[lm_offset + ti + rc] = mi;
213+
l[lm_offset + ti + rc] = di;
215214
}
216215
barrier(CLK_LOCAL_MEM_FENCE); // otherwise, thread can use the wrong Kc, Vc in inner loop
217216
}
218217
}
219218

219+
// This is a very basic flash attention kernel.
220+
// For this kernel, each work-item computes one row of D elements of the output.
221+
// There is no caching of the Q or O data.
222+
// There is also no sharing of the K or V data.
220223
kernel void flash_attention(
221224
global const float* Q, global const float* K, global const float* V,
222225
global float* O,
223226
const float scale)
224227
{
225-
const int D = C / NH; // head size
228+
// Note: all data is arranged: B x NH x T x D
229+
230+
// scale the scale, so we can use exp2 instead of exp
231+
const float adjusted_scale = scale * M_LOG2E_F;
226232

227-
int b = get_global_id(0);
233+
int to = get_global_id(0);
228234
int nh = get_global_id(1);
229-
int to = get_global_id(2);
235+
int b = get_global_id(2);
230236

231-
float* o = O + b * NH * T * D + nh * T * D + to * D;
237+
global float* o = O + b * NH * T * D + nh * T * D + to * D;
232238

233-
const float* q = Q + b * NH * T * D + nh * T * D + to * D;
239+
global const float* q = Q + b * NH * T * D + nh * T * D + to * D;
234240
float mi = -INFINITY;
235241
float di = 0.0f;
236242

237243
for (int ti = 0; ti < T; ti++) {
238-
const float* k = K + b * NH * T * D + nh * T * D + ti * D;
239-
const float* v = V + b * NH * T * D + nh * T * D + ti * D;
244+
global const float* k = K + b * NH * T * D + nh * T * D + ti * D;
245+
global const float* v = V + b * NH * T * D + nh * T * D + ti * D;
240246

241247
// Compute xi = QK^T
242-
float xi = 0.0;
248+
float xi = 0.0f;
243249
if (CAUSAL && to < ti) {
244250
xi = -INFINITY;
245251
}
246252
else {
247253
for (int d = 0; d < D; d++) {
248254
xi += q[d] * k[d];
249255
}
250-
xi *= scale;
256+
xi *= adjusted_scale;
251257
}
252258

253259
// Update the running maximum
254260
float mim1 = mi;
255261
mi = fmax(mim1, xi);
256262

257263
// softmax(xi)
258-
float smxi = native_exp(xi - mi);
264+
float smxi = native_exp2(xi - mi);
259265

260266
// Update di
261-
float alpha = native_exp(mim1 - mi);
267+
float alpha = native_exp2(mim1 - mi);
262268
di = di * alpha + smxi;
263269

264270
// Update the un-scaled output from softmax(xi) and V
@@ -273,58 +279,79 @@ kernel void flash_attention(
273279
}
274280
}
275281

276-
kernel void flash_attention_colblock(
282+
// This is a slightly more complicated flash attention kernel.
283+
// For this kernel, each work-item still computes one row of D elements of the output.
284+
// There is caching for the Q, O, K, and V data.
285+
__attribute__((reqd_work_group_size(32, 1, 1)))
286+
kernel void flash_attention_blocked(
277287
global const float* Q, global const float* K, global const float* V,
278288
global float* O,
279289
const float scale)
280290
{
281-
const int D = C / NH; // head size
291+
// scale the scale, so we can use exp2 instead of exp
292+
const float adjusted_scale = scale * M_LOG2E_F;
282293

283-
int b = get_global_id(0);
284-
int nh = get_global_id(1);
285-
int to = get_global_id(2);
294+
int to = get_global_id(0);
295+
int nh = get_group_id(1);
296+
int b = get_group_id(2);
286297

287-
float* o = O + b * NH * T * D + nh * T * D + to * D;
298+
global float* o = O + b * NH * T * D + nh * T * D + to * D;
288299

289-
const float* q = Q + b * NH * T * D + nh * T * D + to * D;
300+
global const float* q = Q + b * NH * T * D + nh * T * D + to * D;
290301
float mi = -INFINITY;
291302
float di = 0.0f;
292303

304+
float oc[D];
305+
float qc[D];
306+
for (int d = 0; d < D; d++) {
307+
qc[d] = q[d];
308+
}
309+
310+
local float kc[D];
311+
local float vc[D];
312+
293313
for (int ti = 0; ti < T; ti++) {
294-
const float* k = K + b * NH * T * D + nh * T * D + ti * D;
295-
const float* v = V + b * NH * T * D + nh * T * D + ti * D;
314+
global const float* k = K + b * NH * T * D + nh * T * D + ti * D;
315+
global const float* v = V + b * NH * T * D + nh * T * D + ti * D;
316+
317+
barrier(CLK_LOCAL_MEM_FENCE);
318+
for (int d = get_local_id(0); d < D; d += get_local_size(0)) {
319+
kc[d] = k[d];
320+
vc[d] = v[d];
321+
}
322+
barrier(CLK_LOCAL_MEM_FENCE);
296323

297324
// Compute xi = QK^T
298-
float xi = 0.0;
325+
float xi = 0.0f;
299326
if (CAUSAL && to < ti) {
300327
xi = -INFINITY;
301328
}
302329
else {
303330
for (int d = 0; d < D; d++) {
304-
xi += q[d] * k[d];
331+
xi += qc[d] * kc[d];
305332
}
306-
xi *= scale;
333+
xi *= adjusted_scale;
307334
}
308335

309336
// Update the running maximum
310337
float mim1 = mi;
311338
mi = fmax(mim1, xi);
312339

313340
// softmax(xi)
314-
float smxi = native_exp(xi - mi);
341+
float smxi = native_exp2(xi - mi);
315342

316343
// Update di
317-
float alpha = native_exp(mim1 - mi);
344+
float alpha = native_exp2(mim1 - mi);
318345
di = di * alpha + smxi;
319346

320347
// Update the un-scaled output from softmax(xi) and V
321348
for (int d = 0; d < D; d++) {
322-
o[d] = o[d] * alpha + smxi * v[d];
349+
oc[d] = oc[d] * alpha + smxi * vc[d];
323350
}
324351
}
325352

326353
// Epilog scaling (flash attention 2)
327354
for (int d = 0; d < D; d++) {
328-
o[d] = o[d] * native_recip(di);
355+
o[d] = oc[d] * native_recip(di);
329356
}
330357
}

0 commit comments

Comments
 (0)