Skip to content

Commit

Permalink
start to unroll the sequence length loop
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Dec 13, 2024
1 parent 332e630 commit 0ad30b5
Showing 1 changed file with 34 additions and 30 deletions.
64 changes: 34 additions & 30 deletions samples/99_attentionexperiments/attention_kernels.cl
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ kernel void flash_attention(
// This is a slightly more complicated flash attention kernel.
// For this kernel, each work-item still computes one row of D elements of the output.
// There is caching for the Q, O, K, and V data.
__attribute__((reqd_work_group_size(32, 1, 1)))
#define TT 16
kernel void flash_attention_blocked(
global const float* Q, global const float* K, global const float* V,
global float* O,
Expand All @@ -301,57 +301,61 @@ kernel void flash_attention_blocked(
float mi = -INFINITY;
float di = 0.0f;

float oc[D];
float qc[D];
for (int d = 0; d < D; d++) {
qc[d] = q[d];
}

local float kc[D];
local float vc[D];

for (int ti = 0; ti < T; ti++) {
for (int ti = 0; ti < T; ti+=TT) {
global const float* k = K + b * NH * T * D + nh * T * D + ti * D;
global const float* v = V + b * NH * T * D + nh * T * D + ti * D;

barrier(CLK_LOCAL_MEM_FENCE);
for (int d = get_local_id(0); d < D; d += get_local_size(0)) {
kc[d] = k[d];
vc[d] = v[d];
}
barrier(CLK_LOCAL_MEM_FENCE);

// Compute xi = QK^T
float xi = 0.0f;
if (CAUSAL && to < ti) {
xi = -INFINITY;
float xi[TT];
for (int tt = 0; tt < TT; tt++) {
xi[tt] = 0.0f;
}
else {
for (int d = 0; d < D; d++) {
xi += qc[d] * kc[d];
for (int d = 0; d < D; d++) {
for (int tt = 0; tt < TT; tt++) {
xi[tt] += q[d] * k[(tt * D) + d];
}
}
for (int tt = 0; tt < TT; tt++) {
xi[tt] *= adjusted_scale;
}

// TODO: find a better way to do this
if (CAUSAL) {
for (int tt = 0; tt < TT; tt++) {
xi[tt] = (to < ti + tt) ? -INFINITY : xi[tt];
}
xi *= adjusted_scale;
}

// Update the running maximum
float mim1 = mi;
mi = fmax(mim1, xi);
for (int tt = 0; tt < TT; tt++) {
mi = fmax(mi, xi[tt]);
}

// softmax(xi)
float smxi = native_exp2(xi - mi);
float smxi[TT];
for (int tt = 0; tt < TT; tt++) {
smxi[tt] = native_exp2(xi[tt] - mi);
}

// Update di
float alpha = native_exp2(mim1 - mi);
di = di * alpha + smxi;
di *= alpha;
for (int tt = 0; tt < TT; tt++) {
di += smxi[tt];
}

// Update the un-scaled output from softmax(xi) and V
for (int d = 0; d < D; d++) {
oc[d] = oc[d] * alpha + smxi * vc[d];
o[d] = o[d] * alpha;
for (int tt = 0; tt < TT; tt++) {
o[d] += smxi[tt] * v[(tt * D) + d];
}
}
}

// Epilog scaling (flash attention 2)
for (int d = 0; d < D; d++) {
o[d] = oc[d] * native_recip(di);
o[d] = o[d] * native_recip(di);
}
}

0 comments on commit 0ad30b5

Please sign in to comment.