@@ -162,18 +162,18 @@ kernel void flash_attention_minimal(
162
162
// from one row of Qc and Kc.
163
163
float mi_local = - INFINITY ;
164
164
for (int y = 0 ; y < Bc ; y ++ ) {
165
- float xi = 0 ;
165
+ float qk = 0 ;
166
166
if (CAUSAL && ti + rc < to + y ) {
167
- xi = - INFINITY ;
167
+ qk = - INFINITY ;
168
168
}
169
169
else {
170
170
for (int d = 0 ; d < D ; d ++ ) {
171
- xi += Qc [(rc * D ) + d ] * Kc [(y * D ) + d ];
171
+ qk += Qc [(rc * D ) + d ] * Kc [(y * D ) + d ];
172
172
}
173
- xi *= adjusted_scale ;
173
+ qk *= adjusted_scale ;
174
174
}
175
- SP [(Bc * rc ) + y ] = xi ;
176
- mi_local = fmax (xi , mi_local );
175
+ SP [(Bc * rc ) + y ] = qk ;
176
+ mi_local = fmax (qk , mi_local );
177
177
}
178
178
179
179
// SP = exp(SP - mi_local), vm = rowsum(SP)
@@ -221,8 +221,8 @@ kernel void flash_attention_minimal(
221
221
// There is no caching of the Q or O data.
222
222
// There is also no sharing of the K or V data.
223
223
kernel void flash_attention (
224
- global const float * Q , global const float * K , global const float * V ,
225
- global float * O ,
224
+ global const float * restrict Q , global const float * restrict K , global const float * restrict V ,
225
+ global float * restrict O ,
226
226
const float scale )
227
227
{
228
228
// Note: all data is arranged: B x NH x T x D
@@ -244,30 +244,30 @@ kernel void flash_attention(
244
244
global const float * k = K + b * NH * T * D + nh * T * D + ti * D ;
245
245
global const float * v = V + b * NH * T * D + nh * T * D + ti * D ;
246
246
247
- // Compute xi = QK^T
248
- float xi = 0.0f ;
247
+ // Compute qk = QK^T
248
+ float qk = 0.0f ;
249
249
if (CAUSAL && to < ti ) {
250
- xi = - INFINITY ;
250
+ qk = - INFINITY ;
251
251
}
252
252
else {
253
253
for (int d = 0 ; d < D ; d ++ ) {
254
- xi += q [d ] * k [d ];
254
+ qk += q [d ] * k [d ];
255
255
}
256
- xi *= adjusted_scale ;
256
+ qk *= adjusted_scale ;
257
257
}
258
258
259
259
// Update the running maximum
260
260
float mim1 = mi ;
261
- mi = fmax (mim1 , xi );
261
+ mi = fmax (mim1 , qk );
262
262
263
- // softmax(xi )
264
- float smxi = native_exp2 (xi - mi );
263
+ // softmax(qk )
264
+ float smxi = native_exp2 (qk - mi );
265
265
266
266
// Update di
267
267
float alpha = native_exp2 (mim1 - mi );
268
268
di = di * alpha + smxi ;
269
269
270
- // Update the un-scaled output from softmax(xi ) and V
270
+ // Update the un-scaled output from softmax(qk ) and V
271
271
for (int d = 0 ; d < D ; d ++ ) {
272
272
o [d ] = o [d ] * alpha + smxi * v [d ];
273
273
}
@@ -281,81 +281,161 @@ kernel void flash_attention(
281
281
282
282
// This is a slightly more complicated flash attention kernel.
283
283
// For this kernel, each work-item still computes one row of D elements of the output.
284
- // There is caching for the Q, O, K, and V data.
285
- #define TT 16
284
+ #define BLOCK_N D
286
285
kernel void flash_attention_blocked (
287
- global const float * Q , global const float * K , global const float * V ,
288
- global float * O ,
286
+ global const float * restrict Q , global const float * restrict K , global const float * restrict V ,
287
+ global float * restrict O ,
289
288
const float scale )
290
289
{
291
290
// scale the scale, so we can use exp2 instead of exp
292
291
const float adjusted_scale = scale * M_LOG2E_F ;
293
292
294
293
int to = get_global_id (0 );
295
- int nh = get_group_id (1 );
296
- int b = get_group_id (2 );
294
+ int nh = get_global_id (1 );
295
+ int b = get_global_id (2 );
297
296
298
297
global float * o = O + b * NH * T * D + nh * T * D + to * D ;
299
298
300
299
global const float * q = Q + b * NH * T * D + nh * T * D + to * D ;
301
300
float mi = - INFINITY ;
302
301
float di = 0.0f ;
303
302
304
- for (int ti = 0 ; ti < T ; ti += TT ) {
303
+ for (int ti = 0 ; ti < T ; ti += BLOCK_N ) {
305
304
global const float * k = K + b * NH * T * D + nh * T * D + ti * D ;
306
305
global const float * v = V + b * NH * T * D + nh * T * D + ti * D ;
307
306
308
- // Compute xi = QK^T
309
- float xi [TT ];
310
- for (int tt = 0 ; tt < TT ; tt ++ ) {
311
- xi [tt ] = 0.0f ;
312
- }
313
- for (int d = 0 ; d < D ; d ++ ) {
314
- for (int tt = 0 ; tt < TT ; tt ++ ) {
315
- xi [tt ] += q [d ] * k [(tt * D ) + d ];
307
+ // Compute qk = QK^T
308
+ float qk [BLOCK_N ];
309
+ for (int tt = 0 ; tt < BLOCK_N ; tt ++ ) {
310
+ qk [tt ] = 0.0f ;
311
+ for (int d = 0 ; d < D ; d ++ ) {
312
+ qk [tt ] += q [d ] * k [(tt * D ) + d ];
316
313
}
317
- }
318
- for (int tt = 0 ; tt < TT ; tt ++ ) {
319
- xi [tt ] *= adjusted_scale ;
320
- }
321
-
322
- // TODO: find a better way to do this
323
- if (CAUSAL ) {
324
- for (int tt = 0 ; tt < TT ; tt ++ ) {
325
- xi [tt ] = (to < ti + tt ) ? - INFINITY : xi [tt ];
314
+ qk [tt ] *= adjusted_scale ;
315
+ if (CAUSAL ) {
316
+ qk [tt ] = (to < ti + tt ) ? - INFINITY : qk [tt ];
326
317
}
327
318
}
328
319
329
- // Update the running maximum
320
+ // Update the running maqkmum
330
321
float mim1 = mi ;
331
- for (int tt = 0 ; tt < TT ; tt ++ ) {
332
- mi = fmax (mi , xi [tt ]);
322
+ for (int tt = 0 ; tt < BLOCK_N ; tt ++ ) {
323
+ mi = fmax (mi , qk [tt ]); // max reduction
333
324
}
334
325
335
- // softmax(xi)
336
- float smxi [TT ];
337
- for (int tt = 0 ; tt < TT ; tt ++ ) {
338
- smxi [tt ] = native_exp2 (xi [tt ] - mi );
326
+ // softmax(qk)
327
+ float smqk [BLOCK_N ];
328
+ float l = 0.0f ;
329
+ for (int tt = 0 ; tt < BLOCK_N ; tt ++ ) {
330
+ smqk [tt ] = native_exp2 (qk [tt ] - mi );
331
+ l = l + smqk [tt ]; // add reduction
339
332
}
340
333
341
334
// Update di
342
335
float alpha = native_exp2 (mim1 - mi );
343
- di *= alpha ;
344
- for (int tt = 0 ; tt < TT ; tt ++ ) {
345
- di += smxi [tt ];
346
- }
336
+ di = di * alpha + l ;
347
337
348
- // Update the un-scaled output from softmax(xi) and V
338
+ #if 1
339
+ // Update the un-scaled output from softmax(qk) and V
349
340
for (int d = 0 ; d < D ; d ++ ) {
350
341
o [d ] = o [d ] * alpha ;
351
- for (int tt = 0 ; tt < TT ; tt ++ ) {
352
- o [d ] += smxi [tt ] * v [(tt * D ) + d ];
342
+ }
343
+ for (int d = 0 ; d < D ; d ++ ) {
344
+ float update = 0.0f ;
345
+ for (int tt = 0 ; tt < BLOCK_N ; tt ++ ) {
346
+ update += smqk [tt ] * v [(tt * D ) + d ];
353
347
}
348
+ o [d ] += update ;
354
349
}
350
+ #else
351
+ // Update the un-scaled output from softmax(qk) and V
352
+ for (int d = 0 ; d < D ; d ++ ) {
353
+ float update = o [d ] * alpha ;
354
+ for (int tt = 0 ; tt < BLOCK_N ; tt ++ ) {
355
+ update += smqk [tt ] * v [(tt * D ) + d ];
356
+ }
357
+ o [d ] = update ;
358
+ }
359
+ #endif
355
360
}
356
361
357
362
// Epilog scaling (flash attention 2)
358
363
for (int d = 0 ; d < D ; d ++ ) {
359
364
o [d ] = o [d ] * native_recip (di );
360
365
}
361
366
}
367
+ #undef BLOCK_N
368
+
369
+
370
+ // This is a slightly more complicated flash attention kernel.
371
+ // For this kernel, each work-group computes one row of D elements of the output.
372
+ #define BLOCK_N D
373
+ __attribute__((reqd_work_group_size (BLOCK_N , 1 , 1 )))
374
+ kernel void flash_attention_wg (
375
+ global const float * restrict Q , global const float * restrict K , global const float * restrict V ,
376
+ global float * restrict O ,
377
+ const float scale )
378
+ {
379
+ // scale the scale, so we can use exp2 instead of exp
380
+ const float adjusted_scale = scale * M_LOG2E_F ;
381
+
382
+ int to = get_group_id (0 );
383
+ int nh = get_group_id (1 );
384
+ int b = get_group_id (2 );
385
+
386
+ global float * o = O + b * NH * T * D + nh * T * D + to * D ;
387
+
388
+ global const float * q = Q + b * NH * T * D + nh * T * D + to * D ;
389
+ float mi = - INFINITY ;
390
+ float di = 0.0f ;
391
+
392
+ local float scratch [D ];
393
+
394
+ for (int ti = 0 ; ti < T ; ti += BLOCK_N ) {
395
+ global const float * k = K + b * NH * T * D + nh * T * D + ti * D ;
396
+ global const float * v = V + b * NH * T * D + nh * T * D + ti * D ;
397
+
398
+ int tt = get_local_id (0 );
399
+
400
+ // Compute qk = QK^T
401
+ float qk = 0.0f ;
402
+ for (int d = 0 ; d < D ; d ++ ) {
403
+ qk += q [d ] * k [(tt * D ) + d ];
404
+ }
405
+ qk *= adjusted_scale ;
406
+
407
+ // TODO: find a better way to do this
408
+ if (CAUSAL ) {
409
+ qk = (to < ti + tt ) ? - INFINITY : qk ;
410
+ }
411
+
412
+ // Update the running maximum
413
+ float mim1 = mi ;
414
+ barrier (CLK_LOCAL_MEM_FENCE );
415
+ scratch [tt ] = qk ;
416
+ barrier (CLK_LOCAL_MEM_FENCE );
417
+ for (int d = 0 ; d < D ; d ++ ) {
418
+ mi = fmax (mi , scratch [d ]); // max reduction of qk
419
+ }
420
+
421
+ // softmax(qk)
422
+ float smqk = native_exp2 (qk - mi );
423
+ float l = 0.0f ;
424
+ barrier (CLK_LOCAL_MEM_FENCE );
425
+ scratch [tt ] = smqk ;
426
+ barrier (CLK_LOCAL_MEM_FENCE );
427
+ for (int d = 0 ; d < D ; d ++ ) {
428
+ l = l + scratch [d ]; // add reduction of smqk
429
+ }
430
+
431
+ // Update di
432
+ float alpha = native_exp2 (mim1 - mi );
433
+ di = di * alpha + l ;
434
+
435
+ // Update the un-scaled output from softmax(qk) and V
436
+ o [tt ] = o [tt ] * alpha + smqk * v [tt ];
437
+ }
438
+
439
+ // Epilog scaling (flash attention 2)
440
+ o [tt ] = o [tt ] * native_recip (di );
441
+ }
0 commit comments