Skip to content

Commit

Permalink
work-group kernel is now working
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Dec 27, 2024
1 parent f68deae commit a21a942
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions samples/99_attentionexperiments/attention_kernels.cl
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,8 @@ kernel void flash_attention_blocked(
}

// softmax(qk)
float smqk[BLOCK_N];
float l = 0.0f;
float smqk[BLOCK_N];
for (int tt = 0; tt < BLOCK_N; tt++) {
smqk[tt] = native_exp2(qk[tt] - mi);
l = l + smqk[tt]; // add reduction
Expand Down Expand Up @@ -369,8 +369,7 @@ kernel void flash_attention_blocked(

// This is a slightly more complicated flash attention kernel.
// For this kernel, each work-group computes one row of D elements of the output.
#define BLOCK_N D
__attribute__((reqd_work_group_size(BLOCK_N, 1, 1)))
__attribute__((reqd_work_group_size(D, 1, 1)))
kernel void flash_attention_wg(
global const float* restrict Q, global const float* restrict K, global const float* restrict V,
global float* restrict O,
Expand All @@ -379,32 +378,31 @@ kernel void flash_attention_wg(
// scale the scale, so we can use exp2 instead of exp
const float adjusted_scale = scale * M_LOG2E_F;

int to = get_group_id(0);
int nh = get_group_id(1);
int b = get_group_id(2);
const int tt = get_local_id(0);

const int to = get_group_id(0);
const int nh = get_group_id(1);
const int b = get_group_id(2);

global float* o = O + b * NH * T * D + nh * T * D + to * D;
float acc = o[tt];

global const float* q = Q + b * NH * T * D + nh * T * D + to * D;
float mi = -INFINITY;
float di = 0.0f;

local float scratch[D];

for (int ti = 0; ti < T; ti+=BLOCK_N) {
for (int ti = 0; ti < T; ti+=D) {
global const float* k = K + b * NH * T * D + nh * T * D + ti * D;
global const float* v = V + b * NH * T * D + nh * T * D + ti * D;

int tt = get_local_id(0);

// Compute qk = QK^T
float qk = 0.0f;
for (int d = 0; d < D; d++) {
qk += q[d] * k[(tt * D) + d];
}
qk *= adjusted_scale;

// TODO: find a better way to do this
if (CAUSAL) {
qk = (to < ti + tt) ? -INFINITY : qk;
}
Expand All @@ -419,8 +417,8 @@ kernel void flash_attention_wg(
}

// softmax(qk)
float smqk = native_exp2(qk - mi);
float l = 0.0f;
float smqk = native_exp2(qk - mi);
barrier(CLK_LOCAL_MEM_FENCE);
scratch[tt] = smqk;
barrier(CLK_LOCAL_MEM_FENCE);
Expand All @@ -433,9 +431,12 @@ kernel void flash_attention_wg(
di = di * alpha + l;

// Update the un-scaled output from softmax(qk) and V
o[tt] = o[tt] * alpha + smqk * v[tt];
acc = acc * alpha;
for (int d = 0; d < D; d++) {
acc += scratch[d] * v[(d * D) + tt];
}
}

// Epilog scaling (flash attention 2)
o[tt] = o[tt] * native_recip(di);
o[tt] = acc * native_recip(di);
}

0 comments on commit a21a942

Please sign in to comment.