@@ -282,7 +282,7 @@ kernel void flash_attention(
282
282
// This is a slightly more complicated flash attention kernel.
283
283
// For this kernel, each work-item still computes one row of D elements of the output.
284
284
// There is caching for the Q, O, K, and V data.
285
- __attribute__(( reqd_work_group_size ( 32 , 1 , 1 )))
285
+ #define TT 16
286
286
kernel void flash_attention_blocked (
287
287
global const float * Q , global const float * K , global const float * V ,
288
288
global float * O ,
@@ -301,57 +301,61 @@ kernel void flash_attention_blocked(
301
301
float mi = - INFINITY ;
302
302
float di = 0.0f ;
303
303
304
- float oc [D ];
305
- float qc [D ];
306
- for (int d = 0 ; d < D ; d ++ ) {
307
- qc [d ] = q [d ];
308
- }
309
-
310
- local float kc [D ];
311
- local float vc [D ];
312
-
313
- for (int ti = 0 ; ti < T ; ti ++ ) {
304
+ for (int ti = 0 ; ti < T ; ti += TT ) {
314
305
global const float * k = K + b * NH * T * D + nh * T * D + ti * D ;
315
306
global const float * v = V + b * NH * T * D + nh * T * D + ti * D ;
316
307
317
- barrier (CLK_LOCAL_MEM_FENCE );
318
- for (int d = get_local_id (0 ); d < D ; d += get_local_size (0 )) {
319
- kc [d ] = k [d ];
320
- vc [d ] = v [d ];
321
- }
322
- barrier (CLK_LOCAL_MEM_FENCE );
323
-
324
308
// Compute xi = QK^T
325
- float xi = 0.0f ;
326
- if ( CAUSAL && to < ti ) {
327
- xi = - INFINITY ;
309
+ float xi [ TT ] ;
310
+ for ( int tt = 0 ; tt < TT ; tt ++ ) {
311
+ xi [ tt ] = 0.0f ;
328
312
}
329
- else {
330
- for (int d = 0 ; d < D ; d ++ ) {
331
- xi += qc [d ] * kc [d ];
313
+ for (int d = 0 ; d < D ; d ++ ) {
314
+ for (int tt = 0 ; tt < TT ; tt ++ ) {
315
+ xi [tt ] += q [d ] * k [(tt * D ) + d ];
316
+ }
317
+ }
318
+ for (int tt = 0 ; tt < TT ; tt ++ ) {
319
+ xi [tt ] *= adjusted_scale ;
320
+ }
321
+
322
+ // TODO: find a better way to do this
323
+ if (CAUSAL ) {
324
+ for (int tt = 0 ; tt < TT ; tt ++ ) {
325
+ xi [tt ] = (to < ti + tt ) ? - INFINITY : xi [tt ];
332
326
}
333
- xi *= adjusted_scale ;
334
327
}
335
328
336
329
// Update the running maximum
337
330
float mim1 = mi ;
338
- mi = fmax (mim1 , xi );
331
+ for (int tt = 0 ; tt < TT ; tt ++ ) {
332
+ mi = fmax (mi , xi [tt ]);
333
+ }
339
334
340
335
// softmax(xi)
341
- float smxi = native_exp2 (xi - mi );
336
+ float smxi [TT ];
337
+ for (int tt = 0 ; tt < TT ; tt ++ ) {
338
+ smxi [tt ] = native_exp2 (xi [tt ] - mi );
339
+ }
342
340
343
341
// Update di
344
342
float alpha = native_exp2 (mim1 - mi );
345
- di = di * alpha + smxi ;
343
+ di *= alpha ;
344
+ for (int tt = 0 ; tt < TT ; tt ++ ) {
345
+ di += smxi [tt ];
346
+ }
346
347
347
348
// Update the un-scaled output from softmax(xi) and V
348
349
for (int d = 0 ; d < D ; d ++ ) {
349
- oc [d ] = oc [d ] * alpha + smxi * vc [d ];
350
+ o [d ] = o [d ] * alpha ;
351
+ for (int tt = 0 ; tt < TT ; tt ++ ) {
352
+ o [d ] += smxi [tt ] * v [(tt * D ) + d ];
353
+ }
350
354
}
351
355
}
352
356
353
357
// Epilog scaling (flash attention 2)
354
358
for (int d = 0 ; d < D ; d ++ ) {
355
- o [d ] = oc [d ] * native_recip (di );
359
+ o [d ] = o [d ] * native_recip (di );
356
360
}
357
361
}
0 commit comments