diff --git a/samples/99_attentionexperiments/attention_kernels.cl b/samples/99_attentionexperiments/attention_kernels.cl index 0d09490..52e2548 100644 --- a/samples/99_attentionexperiments/attention_kernels.cl +++ b/samples/99_attentionexperiments/attention_kernels.cl @@ -282,7 +282,7 @@ kernel void flash_attention( // This is a slightly more complicated flash attention kernel. // For this kernel, each work-item still computes one row of D elements of the output. // There is caching for the Q, O, K, and V data. -__attribute__((reqd_work_group_size(32, 1, 1))) +#define TT 16 kernel void flash_attention_blocked( global const float* Q, global const float* K, global const float* V, global float* O, @@ -301,57 +301,61 @@ kernel void flash_attention_blocked( float mi = -INFINITY; float di = 0.0f; - float oc[D]; - float qc[D]; - for (int d = 0; d < D; d++) { - qc[d] = q[d]; - } - - local float kc[D]; - local float vc[D]; - - for (int ti = 0; ti < T; ti++) { + for (int ti = 0; ti < T; ti+=TT) { global const float* k = K + b * NH * T * D + nh * T * D + ti * D; global const float* v = V + b * NH * T * D + nh * T * D + ti * D; - barrier(CLK_LOCAL_MEM_FENCE); - for (int d = get_local_id(0); d < D; d += get_local_size(0)) { - kc[d] = k[d]; - vc[d] = v[d]; - } - barrier(CLK_LOCAL_MEM_FENCE); - // Compute xi = QK^T - float xi = 0.0f; - if (CAUSAL && to < ti) { - xi = -INFINITY; + float xi[TT]; + for (int tt = 0; tt < TT; tt++) { + xi[tt] = 0.0f; } - else { - for (int d = 0; d < D; d++) { - xi += qc[d] * kc[d]; + for (int d = 0; d < D; d++) { + for (int tt = 0; tt < TT; tt++) { + xi[tt] += q[d] * k[(tt * D) + d]; + } + } + for (int tt = 0; tt < TT; tt++) { + xi[tt] *= adjusted_scale; + } + + // TODO: find a better way to do this + if (CAUSAL) { + for (int tt = 0; tt < TT; tt++) { + xi[tt] = (to < ti + tt) ? -INFINITY : xi[tt]; } - xi *= adjusted_scale; } // Update the running maximum float mim1 = mi; - mi = fmax(mim1, xi); + for (int tt = 0; tt < TT; tt++) { + mi = fmax(mi, xi[tt]); + } // softmax(xi) - float smxi = native_exp2(xi - mi); + float smxi[TT]; + for (int tt = 0; tt < TT; tt++) { + smxi[tt] = native_exp2(xi[tt] - mi); + } // Update di float alpha = native_exp2(mim1 - mi); - di = di * alpha + smxi; + di *= alpha; + for (int tt = 0; tt < TT; tt++) { + di += smxi[tt]; + } // Update the un-scaled output from softmax(xi) and V for (int d = 0; d < D; d++) { - oc[d] = oc[d] * alpha + smxi * vc[d]; + o[d] = o[d] * alpha; + for (int tt = 0; tt < TT; tt++) { + o[d] += smxi[tt] * v[(tt * D) + d]; + } } } // Epilog scaling (flash attention 2) for (int d = 0; d < D; d++) { - o[d] = oc[d] * native_recip(di); + o[d] = o[d] * native_recip(di); } }