Skip to content

Commit 0ad30b5

Browse files
committed
start to unroll the sequence length loop
1 parent 332e630 commit 0ad30b5

File tree

1 file changed

+34
-30
lines changed

1 file changed

+34
-30
lines changed

samples/99_attentionexperiments/attention_kernels.cl

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ kernel void flash_attention(
282282
// This is a slightly more complicated flash attention kernel.
283283
// For this kernel, each work-item still computes one row of D elements of the output.
284284
// There is caching for the Q, O, K, and V data.
285-
__attribute__((reqd_work_group_size(32, 1, 1)))
285+
#define TT 16
286286
kernel void flash_attention_blocked(
287287
global const float* Q, global const float* K, global const float* V,
288288
global float* O,
@@ -301,57 +301,61 @@ kernel void flash_attention_blocked(
301301
float mi = -INFINITY;
302302
float di = 0.0f;
303303

304-
float oc[D];
305-
float qc[D];
306-
for (int d = 0; d < D; d++) {
307-
qc[d] = q[d];
308-
}
309-
310-
local float kc[D];
311-
local float vc[D];
312-
313-
for (int ti = 0; ti < T; ti++) {
304+
for (int ti = 0; ti < T; ti+=TT) {
314305
global const float* k = K + b * NH * T * D + nh * T * D + ti * D;
315306
global const float* v = V + b * NH * T * D + nh * T * D + ti * D;
316307

317-
barrier(CLK_LOCAL_MEM_FENCE);
318-
for (int d = get_local_id(0); d < D; d += get_local_size(0)) {
319-
kc[d] = k[d];
320-
vc[d] = v[d];
321-
}
322-
barrier(CLK_LOCAL_MEM_FENCE);
323-
324308
// Compute xi = QK^T
325-
float xi = 0.0f;
326-
if (CAUSAL && to < ti) {
327-
xi = -INFINITY;
309+
float xi[TT];
310+
for (int tt = 0; tt < TT; tt++) {
311+
xi[tt] = 0.0f;
328312
}
329-
else {
330-
for (int d = 0; d < D; d++) {
331-
xi += qc[d] * kc[d];
313+
for (int d = 0; d < D; d++) {
314+
for (int tt = 0; tt < TT; tt++) {
315+
xi[tt] += q[d] * k[(tt * D) + d];
316+
}
317+
}
318+
for (int tt = 0; tt < TT; tt++) {
319+
xi[tt] *= adjusted_scale;
320+
}
321+
322+
// TODO: find a better way to do this
323+
if (CAUSAL) {
324+
for (int tt = 0; tt < TT; tt++) {
325+
xi[tt] = (to < ti + tt) ? -INFINITY : xi[tt];
332326
}
333-
xi *= adjusted_scale;
334327
}
335328

336329
// Update the running maximum
337330
float mim1 = mi;
338-
mi = fmax(mim1, xi);
331+
for (int tt = 0; tt < TT; tt++) {
332+
mi = fmax(mi, xi[tt]);
333+
}
339334

340335
// softmax(xi)
341-
float smxi = native_exp2(xi - mi);
336+
float smxi[TT];
337+
for (int tt = 0; tt < TT; tt++) {
338+
smxi[tt] = native_exp2(xi[tt] - mi);
339+
}
342340

343341
// Update di
344342
float alpha = native_exp2(mim1 - mi);
345-
di = di * alpha + smxi;
343+
di *= alpha;
344+
for (int tt = 0; tt < TT; tt++) {
345+
di += smxi[tt];
346+
}
346347

347348
// Update the un-scaled output from softmax(xi) and V
348349
for (int d = 0; d < D; d++) {
349-
oc[d] = oc[d] * alpha + smxi * vc[d];
350+
o[d] = o[d] * alpha;
351+
for (int tt = 0; tt < TT; tt++) {
352+
o[d] += smxi[tt] * v[(tt * D) + d];
353+
}
350354
}
351355
}
352356

353357
// Epilog scaling (flash attention 2)
354358
for (int d = 0; d < D; d++) {
355-
o[d] = oc[d] * native_recip(di);
359+
o[d] = o[d] * native_recip(di);
356360
}
357361
}

0 commit comments

Comments
 (0)