Skip to content

Commit f68deae

Browse files
committed
add work-group versions (not working yet)
1 parent 0ad30b5 commit f68deae

File tree

2 files changed

+192
-57
lines changed

2 files changed

+192
-57
lines changed

samples/99_attentionexperiments/attention_kernels.cl

Lines changed: 135 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -162,18 +162,18 @@ kernel void flash_attention_minimal(
162162
// from one row of Qc and Kc.
163163
float mi_local = -INFINITY;
164164
for (int y = 0; y < Bc; y++) {
165-
float xi = 0;
165+
float qk = 0;
166166
if (CAUSAL && ti + rc < to + y) {
167-
xi = -INFINITY;
167+
qk = -INFINITY;
168168
}
169169
else {
170170
for (int d = 0; d < D; d++) {
171-
xi += Qc[(rc * D) + d] * Kc[(y * D) + d];
171+
qk += Qc[(rc * D) + d] * Kc[(y * D) + d];
172172
}
173-
xi *= adjusted_scale;
173+
qk *= adjusted_scale;
174174
}
175-
SP[(Bc * rc) + y] = xi;
176-
mi_local = fmax(xi, mi_local);
175+
SP[(Bc * rc) + y] = qk;
176+
mi_local = fmax(qk, mi_local);
177177
}
178178

179179
// SP = exp(SP - mi_local), vm = rowsum(SP)
@@ -221,8 +221,8 @@ kernel void flash_attention_minimal(
221221
// There is no caching of the Q or O data.
222222
// There is also no sharing of the K or V data.
223223
kernel void flash_attention(
224-
global const float* Q, global const float* K, global const float* V,
225-
global float* O,
224+
global const float* restrict Q, global const float* restrict K, global const float* restrict V,
225+
global float* restrict O,
226226
const float scale)
227227
{
228228
// Note: all data is arranged: B x NH x T x D
@@ -244,30 +244,30 @@ kernel void flash_attention(
244244
global const float* k = K + b * NH * T * D + nh * T * D + ti * D;
245245
global const float* v = V + b * NH * T * D + nh * T * D + ti * D;
246246

247-
// Compute xi = QK^T
248-
float xi = 0.0f;
247+
// Compute qk = QK^T
248+
float qk = 0.0f;
249249
if (CAUSAL && to < ti) {
250-
xi = -INFINITY;
250+
qk = -INFINITY;
251251
}
252252
else {
253253
for (int d = 0; d < D; d++) {
254-
xi += q[d] * k[d];
254+
qk += q[d] * k[d];
255255
}
256-
xi *= adjusted_scale;
256+
qk *= adjusted_scale;
257257
}
258258

259259
// Update the running maximum
260260
float mim1 = mi;
261-
mi = fmax(mim1, xi);
261+
mi = fmax(mim1, qk);
262262

263-
// softmax(xi)
264-
float smxi = native_exp2(xi - mi);
263+
// softmax(qk)
264+
float smxi = native_exp2(qk - mi);
265265

266266
// Update di
267267
float alpha = native_exp2(mim1 - mi);
268268
di = di * alpha + smxi;
269269

270-
// Update the un-scaled output from softmax(xi) and V
270+
// Update the un-scaled output from softmax(qk) and V
271271
for (int d = 0; d < D; d++) {
272272
o[d] = o[d] * alpha + smxi * v[d];
273273
}
@@ -281,81 +281,161 @@ kernel void flash_attention(
281281

282282
// This is a slightly more complicated flash attention kernel.
283283
// For this kernel, each work-item still computes one row of D elements of the output.
284-
// There is caching for the Q, O, K, and V data.
285-
#define TT 16
284+
#define BLOCK_N D
286285
kernel void flash_attention_blocked(
287-
global const float* Q, global const float* K, global const float* V,
288-
global float* O,
286+
global const float* restrict Q, global const float* restrict K, global const float* restrict V,
287+
global float* restrict O,
289288
const float scale)
290289
{
291290
// scale the scale, so we can use exp2 instead of exp
292291
const float adjusted_scale = scale * M_LOG2E_F;
293292

294293
int to = get_global_id(0);
295-
int nh = get_group_id(1);
296-
int b = get_group_id(2);
294+
int nh = get_global_id(1);
295+
int b = get_global_id(2);
297296

298297
global float* o = O + b * NH * T * D + nh * T * D + to * D;
299298

300299
global const float* q = Q + b * NH * T * D + nh * T * D + to * D;
301300
float mi = -INFINITY;
302301
float di = 0.0f;
303302

304-
for (int ti = 0; ti < T; ti+=TT) {
303+
for (int ti = 0; ti < T; ti+=BLOCK_N) {
305304
global const float* k = K + b * NH * T * D + nh * T * D + ti * D;
306305
global const float* v = V + b * NH * T * D + nh * T * D + ti * D;
307306

308-
// Compute xi = QK^T
309-
float xi[TT];
310-
for (int tt = 0; tt < TT; tt++) {
311-
xi[tt] = 0.0f;
312-
}
313-
for (int d = 0; d < D; d++) {
314-
for (int tt = 0; tt < TT; tt++) {
315-
xi[tt] += q[d] * k[(tt * D) + d];
307+
// Compute qk = QK^T
308+
float qk[BLOCK_N];
309+
for (int tt = 0; tt < BLOCK_N; tt++) {
310+
qk[tt] = 0.0f;
311+
for (int d = 0; d < D; d++) {
312+
qk[tt] += q[d] * k[(tt * D) + d];
316313
}
317-
}
318-
for (int tt = 0; tt < TT; tt++) {
319-
xi[tt] *= adjusted_scale;
320-
}
321-
322-
// TODO: find a better way to do this
323-
if (CAUSAL) {
324-
for (int tt = 0; tt < TT; tt++) {
325-
xi[tt] = (to < ti + tt) ? -INFINITY : xi[tt];
314+
qk[tt] *= adjusted_scale;
315+
if (CAUSAL) {
316+
qk[tt] = (to < ti + tt) ? -INFINITY : qk[tt];
326317
}
327318
}
328319

329-
// Update the running maximum
320+
// Update the running maqkmum
330321
float mim1 = mi;
331-
for (int tt = 0; tt < TT; tt++) {
332-
mi = fmax(mi, xi[tt]);
322+
for (int tt = 0; tt < BLOCK_N; tt++) {
323+
mi = fmax(mi, qk[tt]); // max reduction
333324
}
334325

335-
// softmax(xi)
336-
float smxi[TT];
337-
for (int tt = 0; tt < TT; tt++) {
338-
smxi[tt] = native_exp2(xi[tt] - mi);
326+
// softmax(qk)
327+
float smqk[BLOCK_N];
328+
float l = 0.0f;
329+
for (int tt = 0; tt < BLOCK_N; tt++) {
330+
smqk[tt] = native_exp2(qk[tt] - mi);
331+
l = l + smqk[tt]; // add reduction
339332
}
340333

341334
// Update di
342335
float alpha = native_exp2(mim1 - mi);
343-
di *= alpha;
344-
for (int tt = 0; tt < TT; tt++) {
345-
di += smxi[tt];
346-
}
336+
di = di * alpha + l;
347337

348-
// Update the un-scaled output from softmax(xi) and V
338+
#if 1
339+
// Update the un-scaled output from softmax(qk) and V
349340
for (int d = 0; d < D; d++) {
350341
o[d] = o[d] * alpha;
351-
for (int tt = 0; tt < TT; tt++) {
352-
o[d] += smxi[tt] * v[(tt * D) + d];
342+
}
343+
for (int d = 0; d < D; d++) {
344+
float update = 0.0f;
345+
for (int tt = 0; tt < BLOCK_N; tt++) {
346+
update += smqk[tt] * v[(tt * D) + d];
353347
}
348+
o[d] += update;
354349
}
350+
#else
351+
// Update the un-scaled output from softmax(qk) and V
352+
for (int d = 0; d < D; d++) {
353+
float update = o[d] * alpha;
354+
for (int tt = 0; tt < BLOCK_N; tt++) {
355+
update += smqk[tt] * v[(tt * D) + d];
356+
}
357+
o[d] = update;
358+
}
359+
#endif
355360
}
356361

357362
// Epilog scaling (flash attention 2)
358363
for (int d = 0; d < D; d++) {
359364
o[d] = o[d] * native_recip(di);
360365
}
361366
}
367+
#undef BLOCK_N
368+
369+
370+
// This is a slightly more complicated flash attention kernel.
371+
// For this kernel, each work-group computes one row of D elements of the output.
372+
#define BLOCK_N D
373+
__attribute__((reqd_work_group_size(BLOCK_N, 1, 1)))
374+
kernel void flash_attention_wg(
375+
global const float* restrict Q, global const float* restrict K, global const float* restrict V,
376+
global float* restrict O,
377+
const float scale)
378+
{
379+
// scale the scale, so we can use exp2 instead of exp
380+
const float adjusted_scale = scale * M_LOG2E_F;
381+
382+
int to = get_group_id(0);
383+
int nh = get_group_id(1);
384+
int b = get_group_id(2);
385+
386+
global float* o = O + b * NH * T * D + nh * T * D + to * D;
387+
388+
global const float* q = Q + b * NH * T * D + nh * T * D + to * D;
389+
float mi = -INFINITY;
390+
float di = 0.0f;
391+
392+
local float scratch[D];
393+
394+
for (int ti = 0; ti < T; ti+=BLOCK_N) {
395+
global const float* k = K + b * NH * T * D + nh * T * D + ti * D;
396+
global const float* v = V + b * NH * T * D + nh * T * D + ti * D;
397+
398+
int tt = get_local_id(0);
399+
400+
// Compute qk = QK^T
401+
float qk = 0.0f;
402+
for (int d = 0; d < D; d++) {
403+
qk += q[d] * k[(tt * D) + d];
404+
}
405+
qk *= adjusted_scale;
406+
407+
// TODO: find a better way to do this
408+
if (CAUSAL) {
409+
qk = (to < ti + tt) ? -INFINITY : qk;
410+
}
411+
412+
// Update the running maximum
413+
float mim1 = mi;
414+
barrier(CLK_LOCAL_MEM_FENCE);
415+
scratch[tt] = qk;
416+
barrier(CLK_LOCAL_MEM_FENCE);
417+
for (int d = 0; d < D; d++) {
418+
mi = fmax(mi, scratch[d]); // max reduction of qk
419+
}
420+
421+
// softmax(qk)
422+
float smqk = native_exp2(qk - mi);
423+
float l = 0.0f;
424+
barrier(CLK_LOCAL_MEM_FENCE);
425+
scratch[tt] = smqk;
426+
barrier(CLK_LOCAL_MEM_FENCE);
427+
for (int d = 0; d < D; d++) {
428+
l = l + scratch[d]; // add reduction of smqk
429+
}
430+
431+
// Update di
432+
float alpha = native_exp2(mim1 - mi);
433+
di = di * alpha + l;
434+
435+
// Update the un-scaled output from softmax(qk) and V
436+
o[tt] = o[tt] * alpha + smqk * v[tt];
437+
}
438+
439+
// Epilog scaling (flash attention 2)
440+
o[tt] = o[tt] * native_recip(di);
441+
}

samples/99_attentionexperiments/main.cpp

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -761,7 +761,61 @@ static void flash_attention_forward(
761761
const float softmax_scale = 1.0f / (float)std::sqrt(D);
762762

763763
cl::NDRange flash_attention_gws(T, NH, B);
764-
cl::NDRange flash_attention_lws(32, 1, 1);
764+
flash_attention.setArg(0, q);
765+
flash_attention.setArg(1, k);
766+
flash_attention.setArg(2, v);
767+
flash_attention.setArg(3, out);
768+
flash_attention.setArg(4, softmax_scale);
769+
770+
if (!skipinit) {
771+
queue.enqueueFillBuffer(out, 0, 0, out_ref.size() * sizeof(out_ref[0]));
772+
}
773+
774+
float best = 999.0f;
775+
for (int test = 0; test < testIterations; test++) {
776+
auto start = test_clock::now();
777+
queue.enqueueNDRangeKernel(flash_attention, cl::NullRange, flash_attention_gws);
778+
queue.finish();
779+
auto end = test_clock::now();
780+
std::chrono::duration<float> sw_time = end - start;
781+
auto elapsed = sw_time.count();
782+
best = std::min(best, elapsed);
783+
}
784+
printf("Best in %f seconds\n", best);
785+
786+
if (validate) {
787+
printf("Checking results: out... "); fflush(stdout);
788+
std::vector<float> out_check(out_ref.size());
789+
queue.enqueueReadBuffer(out, CL_TRUE, 0, out_check.size() * sizeof(out_check[0]), out_check.data());
790+
check_results(out_check, out_ref);
791+
printf(" done!\n");
792+
}
793+
}
794+
}
795+
796+
static void flash_attention_forward_wg(
797+
const std::string& kernelName,
798+
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
799+
cl::Buffer& out,
800+
cl::Buffer& q, cl::Buffer& k, cl::Buffer& v,
801+
size_t wgSize,
802+
const std::vector<float>& out_ref)
803+
{
804+
std::string label(__FUNCTION__);
805+
label += "(";
806+
label += kernelName;
807+
label += ")";
808+
809+
printf("%80s: ", label.c_str()); fflush(stdout);
810+
811+
cl::Kernel flash_attention{program, kernelName.c_str()};
812+
if (flash_attention() == nullptr) {
813+
printf("unsupported.\n");
814+
} else {
815+
const float softmax_scale = 1.0f / (float)std::sqrt(D);
816+
817+
cl::NDRange flash_attention_gws(T * D, NH, B);
818+
cl::NDRange flash_attention_lws(D, 1, 1);
765819
flash_attention.setArg(0, q);
766820
flash_attention.setArg(1, k);
767821
flash_attention.setArg(2, v);
@@ -1046,10 +1100,11 @@ int main(int argc, char** argv)
10461100

10471101
if (mask & 0x20) {
10481102
flash_attention_forward("flash_attention", context, program, queue, out, q, k, v, wgSize, out_vec);
1103+
flash_attention_forward("flash_attention_blocked", context, program, queue, out, q, k, v, wgSize, out_vec);
10491104
}
10501105

10511106
if (mask & 0x40) {
1052-
flash_attention_forward("flash_attention_blocked", context, program, queue, out, q, k, v, wgSize, out_vec);
1107+
flash_attention_forward_wg("flash_attention_wg", context, program, queue, out, q, k, v, wgSize, out_vec);
10531108
}
10541109

10551110
printf("Done.\n");

0 commit comments

Comments
 (0)