13
13
#define NH 12
14
14
#endif
15
15
16
+ #define D (C / NH) // head size
17
+
16
18
kernel void naive_3p_query_key (global float * preatt , global const float * q_base , global const float * k_base )
17
19
{
18
- const int D = C / NH ; // head size
19
-
20
20
// Note: Global work size is B * NH * T * T
21
21
int idx = get_global_id (0 );
22
22
@@ -86,8 +86,6 @@ kernel void naive_3p_softmax(global float* att, global const float* preatt)
86
86
87
87
kernel void naive_3p_value (global float * out , global const float * att , global const float * v_base )
88
88
{
89
- const int D = C / NH ; // head size
90
-
91
89
// Note: Global work size is B * T * NH
92
90
int idx = get_global_id (0 );
93
91
int nh = idx % NH ;
@@ -117,7 +115,6 @@ kernel void naive_3p_value(global float* out, global const float* att, global co
117
115
// https://github.com/tspeterkim/flash-attention-minimal
118
116
kernel void flash_attention_minimal (
119
117
global const float * Q , global const float * K , global const float * V ,
120
- const int Tc , const int Tr ,
121
118
const int Bc , const int Br , // Note: this implementation requires Bc == Br!!
122
119
const float softmax_scale ,
123
120
global float * l , global float * m , global float * O ,
@@ -126,13 +123,15 @@ kernel void flash_attention_minimal(
126
123
local float * Vc ,
127
124
local float * SP ) // Used for both S and P
128
125
{
129
- const int D = C / NH ; // head size
126
+ // scale the scale, so we can use exp2 instead of exp
127
+ const float adjusted_scale = softmax_scale * M_LOG2E_F ;
130
128
131
129
// Note: Global Work Size is (B * Bc, NH)
132
130
// Note: Local Work Size is (Bc, 1)
133
131
// --> Group ID is (batch index, head index)
134
132
const int b = get_group_id (0 );
135
133
const int nh = get_group_id (1 );
134
+
136
135
// --> Local ID is row or column index
137
136
const int rc = get_local_id (0 );
138
137
@@ -142,20 +141,20 @@ kernel void flash_attention_minimal(
142
141
// Note: l, m are (B, NH, T)
143
142
const int lm_offset = (b * NH * T ) + (nh * T ); // offset for l and m
144
143
145
- for (int j = 0 ; j < Tc ; j ++ ) {
144
+ for (int to = 0 ; to < T ; to += Bc ) {
146
145
// Load K, V to SLM
147
146
// Each work-item loads one row of Kc and Vc.
148
147
for (int d = 0 ; d < D ; d ++ ) {
149
- Kc [(rc * D ) + d ] = K [qkv_offset + ((Bc * j + rc ) * D ) + d ];
150
- Vc [(rc * D ) + d ] = V [qkv_offset + ((Bc * j + rc ) * D ) + d ];
148
+ Kc [(rc * D ) + d ] = K [qkv_offset + ((to + rc ) * D ) + d ];
149
+ Vc [(rc * D ) + d ] = V [qkv_offset + ((to + rc ) * D ) + d ];
151
150
}
152
151
barrier (CLK_LOCAL_MEM_FENCE );
153
152
154
- for (int i = 0 ; i < Tr ; i ++ ) {
153
+ for (int ti = 0 ; ti < T ; ti += Br ) {
155
154
// Load Q to SLM
156
155
// Each work-item loads one row of Qc
157
156
for (int d = 0 ; d < D ; d ++ ) {
158
- Qc [(rc * D ) + d ] = Q [qkv_offset + ((Bc * i + rc ) * D ) + d ];
157
+ Qc [(rc * D ) + d ] = Q [qkv_offset + ((ti + rc ) * D ) + d ];
159
158
}
160
159
161
160
// Compute SP = QK^T, mi_local = rowmax(SP)
@@ -164,14 +163,14 @@ kernel void flash_attention_minimal(
164
163
float mi_local = - INFINITY ;
165
164
for (int y = 0 ; y < Bc ; y ++ ) {
166
165
float xi = 0 ;
167
- if (CAUSAL && i * Br + rc < j * Bc + y ) {
166
+ if (CAUSAL && ti + rc < to + y ) {
168
167
xi = - INFINITY ;
169
168
}
170
169
else {
171
170
for (int d = 0 ; d < D ; d ++ ) {
172
171
xi += Qc [(rc * D ) + d ] * Kc [(y * D ) + d ];
173
172
}
174
- xi *= softmax_scale ;
173
+ xi *= adjusted_scale ;
175
174
}
176
175
SP [(Bc * rc ) + y ] = xi ;
177
176
mi_local = fmax (xi , mi_local );
@@ -181,84 +180,91 @@ kernel void flash_attention_minimal(
181
180
// implement softmax with causal masking
182
181
float vm = 0 ;
183
182
for (int y = 0 ; y < Bc ; y ++ ) {
184
- SP [(Bc * rc ) + y ] = native_exp (SP [(Bc * rc ) + y ] - mi_local );
183
+ SP [(Bc * rc ) + y ] = native_exp2 (SP [(Bc * rc ) + y ] - mi_local );
185
184
vm += SP [(Bc * rc ) + y ];
186
185
}
187
186
188
187
// Compute new m and l
189
- float mim1 = m [lm_offset + ( Br * i ) + rc ];
190
- float dim1 = l [lm_offset + ( Br * i ) + rc ];
188
+ float mim1 = m [lm_offset + ti + rc ];
189
+ float dim1 = l [lm_offset + ti + rc ];
191
190
192
191
float mi = fmax (mim1 , mi_local );
193
- float di = dim1 * native_exp (mim1 - mi ) + vm * native_exp (mi_local - mi );
192
+ float di = dim1 * native_exp2 (mim1 - mi ) + vm * native_exp2 (mi_local - mi );
194
193
195
- float om = dim1 * native_exp (mim1 - mi ) / di ;
196
- vm = native_exp (mi_local - mi ) / di ;
194
+ float om = dim1 * native_exp2 (mim1 - mi ) / di ;
195
+ vm = native_exp2 (mi_local - mi ) / di ;
197
196
198
197
// Write O, l, m to HBM
199
198
for (int d = 0 ; d < D ; d ++ ) {
200
199
float pv = 0 ; // Pij * Vc
201
200
for (int y = 0 ; y < Bc ; y ++ ) {
202
- //if (j * Bc + y >= T) {
201
+ //if (to * Bc + y >= T) {
203
202
// break;
204
203
//}
205
204
pv += SP [(Bc * rc ) + y ] * Vc [(y * D ) + d ];
206
205
}
207
206
// O is (B, NH, T, D)
208
- O [qkv_offset + ((Bc * i + rc ) * D ) + d ] =
209
- om * O [qkv_offset + ((Bc * i + rc ) * D ) + d ] +
210
- vm * pv ;
207
+ O [qkv_offset + ((ti + rc ) * D ) + d ] =
208
+ om * O [qkv_offset + ((ti + rc ) * D ) + d ] +
209
+ vm * pv ;
211
210
}
212
211
213
- m [lm_offset + ( Br * i ) + rc ] = mi ;
214
- l [lm_offset + ( Br * i ) + rc ] = di ;
212
+ m [lm_offset + ti + rc ] = mi ;
213
+ l [lm_offset + ti + rc ] = di ;
215
214
}
216
215
barrier (CLK_LOCAL_MEM_FENCE ); // otherwise, thread can use the wrong Kc, Vc in inner loop
217
216
}
218
217
}
219
218
219
+ // This is a very basic flash attention kernel.
220
+ // For this kernel, each work-item computes one row of D elements of the output.
221
+ // There is no caching of the Q or O data.
222
+ // There is also no sharing of the K or V data.
220
223
kernel void flash_attention (
221
224
global const float * Q , global const float * K , global const float * V ,
222
225
global float * O ,
223
226
const float scale )
224
227
{
225
- const int D = C / NH ; // head size
228
+ // Note: all data is arranged: B x NH x T x D
229
+
230
+ // scale the scale, so we can use exp2 instead of exp
231
+ const float adjusted_scale = scale * M_LOG2E_F ;
226
232
227
- int b = get_global_id (0 );
233
+ int to = get_global_id (0 );
228
234
int nh = get_global_id (1 );
229
- int to = get_global_id (2 );
235
+ int b = get_global_id (2 );
230
236
231
- float * o = O + b * NH * T * D + nh * T * D + to * D ;
237
+ global float * o = O + b * NH * T * D + nh * T * D + to * D ;
232
238
233
- const float * q = Q + b * NH * T * D + nh * T * D + to * D ;
239
+ global const float * q = Q + b * NH * T * D + nh * T * D + to * D ;
234
240
float mi = - INFINITY ;
235
241
float di = 0.0f ;
236
242
237
243
for (int ti = 0 ; ti < T ; ti ++ ) {
238
- const float * k = K + b * NH * T * D + nh * T * D + ti * D ;
239
- const float * v = V + b * NH * T * D + nh * T * D + ti * D ;
244
+ global const float * k = K + b * NH * T * D + nh * T * D + ti * D ;
245
+ global const float * v = V + b * NH * T * D + nh * T * D + ti * D ;
240
246
241
247
// Compute xi = QK^T
242
- float xi = 0.0 ;
248
+ float xi = 0.0f ;
243
249
if (CAUSAL && to < ti ) {
244
250
xi = - INFINITY ;
245
251
}
246
252
else {
247
253
for (int d = 0 ; d < D ; d ++ ) {
248
254
xi += q [d ] * k [d ];
249
255
}
250
- xi *= scale ;
256
+ xi *= adjusted_scale ;
251
257
}
252
258
253
259
// Update the running maximum
254
260
float mim1 = mi ;
255
261
mi = fmax (mim1 , xi );
256
262
257
263
// softmax(xi)
258
- float smxi = native_exp (xi - mi );
264
+ float smxi = native_exp2 (xi - mi );
259
265
260
266
// Update di
261
- float alpha = native_exp (mim1 - mi );
267
+ float alpha = native_exp2 (mim1 - mi );
262
268
di = di * alpha + smxi ;
263
269
264
270
// Update the un-scaled output from softmax(xi) and V
@@ -273,58 +279,79 @@ kernel void flash_attention(
273
279
}
274
280
}
275
281
276
- kernel void flash_attention_colblock (
282
+ // This is a slightly more complicated flash attention kernel.
283
+ // For this kernel, each work-item still computes one row of D elements of the output.
284
+ // There is caching for the Q, O, K, and V data.
285
+ __attribute__((reqd_work_group_size (32 , 1 , 1 )))
286
+ kernel void flash_attention_blocked (
277
287
global const float * Q , global const float * K , global const float * V ,
278
288
global float * O ,
279
289
const float scale )
280
290
{
281
- const int D = C / NH ; // head size
291
+ // scale the scale, so we can use exp2 instead of exp
292
+ const float adjusted_scale = scale * M_LOG2E_F ;
282
293
283
- int b = get_global_id (0 );
284
- int nh = get_global_id (1 );
285
- int to = get_global_id (2 );
294
+ int to = get_global_id (0 );
295
+ int nh = get_group_id (1 );
296
+ int b = get_group_id (2 );
286
297
287
- float * o = O + b * NH * T * D + nh * T * D + to * D ;
298
+ global float * o = O + b * NH * T * D + nh * T * D + to * D ;
288
299
289
- const float * q = Q + b * NH * T * D + nh * T * D + to * D ;
300
+ global const float * q = Q + b * NH * T * D + nh * T * D + to * D ;
290
301
float mi = - INFINITY ;
291
302
float di = 0.0f ;
292
303
304
+ float oc [D ];
305
+ float qc [D ];
306
+ for (int d = 0 ; d < D ; d ++ ) {
307
+ qc [d ] = q [d ];
308
+ }
309
+
310
+ local float kc [D ];
311
+ local float vc [D ];
312
+
293
313
for (int ti = 0 ; ti < T ; ti ++ ) {
294
- const float * k = K + b * NH * T * D + nh * T * D + ti * D ;
295
- const float * v = V + b * NH * T * D + nh * T * D + ti * D ;
314
+ global const float * k = K + b * NH * T * D + nh * T * D + ti * D ;
315
+ global const float * v = V + b * NH * T * D + nh * T * D + ti * D ;
316
+
317
+ barrier (CLK_LOCAL_MEM_FENCE );
318
+ for (int d = get_local_id (0 ); d < D ; d += get_local_size (0 )) {
319
+ kc [d ] = k [d ];
320
+ vc [d ] = v [d ];
321
+ }
322
+ barrier (CLK_LOCAL_MEM_FENCE );
296
323
297
324
// Compute xi = QK^T
298
- float xi = 0.0 ;
325
+ float xi = 0.0f ;
299
326
if (CAUSAL && to < ti ) {
300
327
xi = - INFINITY ;
301
328
}
302
329
else {
303
330
for (int d = 0 ; d < D ; d ++ ) {
304
- xi += q [d ] * k [d ];
331
+ xi += qc [d ] * kc [d ];
305
332
}
306
- xi *= scale ;
333
+ xi *= adjusted_scale ;
307
334
}
308
335
309
336
// Update the running maximum
310
337
float mim1 = mi ;
311
338
mi = fmax (mim1 , xi );
312
339
313
340
// softmax(xi)
314
- float smxi = native_exp (xi - mi );
341
+ float smxi = native_exp2 (xi - mi );
315
342
316
343
// Update di
317
- float alpha = native_exp (mim1 - mi );
344
+ float alpha = native_exp2 (mim1 - mi );
318
345
di = di * alpha + smxi ;
319
346
320
347
// Update the un-scaled output from softmax(xi) and V
321
348
for (int d = 0 ; d < D ; d ++ ) {
322
- o [d ] = o [d ] * alpha + smxi * v [d ];
349
+ oc [d ] = oc [d ] * alpha + smxi * vc [d ];
323
350
}
324
351
}
325
352
326
353
// Epilog scaling (flash attention 2)
327
354
for (int d = 0 ; d < D ; d ++ ) {
328
- o [d ] = o [d ] * native_recip (di );
355
+ o [d ] = oc [d ] * native_recip (di );
329
356
}
330
357
}
0 commit comments