diff --git a/samples/99_attentionexperiments/attention_kernels.cl b/samples/99_attentionexperiments/attention_kernels.cl index d050107..402e9f2 100644 --- a/samples/99_attentionexperiments/attention_kernels.cl +++ b/samples/99_attentionexperiments/attention_kernels.cl @@ -13,10 +13,10 @@ #define NH 12 #endif +#define D (C / NH) // head size + kernel void naive_3p_query_key(global float* preatt, global const float* q_base, global const float* k_base) { - const int D = C / NH; // head size - // Note: Global work size is B * NH * T * T int idx = get_global_id(0); @@ -86,8 +86,6 @@ kernel void naive_3p_softmax(global float* att, global const float* preatt) kernel void naive_3p_value(global float* out, global const float* att, global const float* v_base) { - const int D = C / NH; // head size - // Note: Global work size is B * T * NH int idx = get_global_id(0); int nh = idx % NH; @@ -117,7 +115,6 @@ kernel void naive_3p_value(global float* out, global const float* att, global co // https://github.com/tspeterkim/flash-attention-minimal kernel void flash_attention_minimal( global const float* Q, global const float* K, global const float* V, - const int Tc, const int Tr, const int Bc, const int Br, // Note: this implementation requires Bc == Br!! const float softmax_scale, global float* l, global float* m, global float* O, @@ -126,13 +123,15 @@ kernel void flash_attention_minimal( local float* Vc, local float* SP) // Used for both S and P { - const int D = C / NH; // head size + // scale the scale, so we can use exp2 instead of exp + const float adjusted_scale = softmax_scale * M_LOG2E_F; // Note: Global Work Size is (B * Bc, NH) // Note: Local Work Size is (Bc, 1) // --> Group ID is (batch index, head index) const int b = get_group_id(0); const int nh = get_group_id(1); + // --> Local ID is row or column index const int rc = get_local_id(0); @@ -142,20 +141,20 @@ kernel void flash_attention_minimal( // Note: l, m are (B, NH, T) const int lm_offset = (b * NH * T) + (nh * T); // offset for l and m - for (int j = 0; j < Tc; j++) { + for (int to = 0; to < T; to += Bc) { // Load K, V to SLM // Each work-item loads one row of Kc and Vc. for (int d = 0; d < D; d++) { - Kc[(rc * D) + d] = K[qkv_offset + ((Bc * j + rc) * D) + d]; - Vc[(rc * D) + d] = V[qkv_offset + ((Bc * j + rc) * D) + d]; + Kc[(rc * D) + d] = K[qkv_offset + ((to + rc) * D) + d]; + Vc[(rc * D) + d] = V[qkv_offset + ((to + rc) * D) + d]; } barrier(CLK_LOCAL_MEM_FENCE); - for (int i = 0; i < Tr; i++) { + for (int ti = 0; ti < T; ti += Br) { // Load Q to SLM // Each work-item loads one row of Qc for (int d = 0; d < D; d++) { - Qc[(rc * D) + d] = Q[qkv_offset + ((Bc * i + rc) * D) + d]; + Qc[(rc * D) + d] = Q[qkv_offset + ((ti + rc) * D) + d]; } // Compute SP = QK^T, mi_local = rowmax(SP) @@ -164,14 +163,14 @@ kernel void flash_attention_minimal( float mi_local = -INFINITY; for (int y = 0; y < Bc; y++) { float xi = 0; - if (CAUSAL && i * Br + rc < j * Bc + y) { + if (CAUSAL && ti + rc < to + y) { xi = -INFINITY; } else { for (int d = 0; d < D; d++) { xi += Qc[(rc * D) + d] * Kc[(y * D) + d]; } - xi *= softmax_scale; + xi *= adjusted_scale; } SP[(Bc * rc) + y] = xi; mi_local = fmax(xi, mi_local); @@ -181,52 +180,59 @@ kernel void flash_attention_minimal( // implement softmax with causal masking float vm = 0; for (int y = 0; y < Bc; y++) { - SP[(Bc * rc) + y] = native_exp(SP[(Bc * rc) + y] - mi_local); + SP[(Bc * rc) + y] = native_exp2(SP[(Bc * rc) + y] - mi_local); vm += SP[(Bc * rc) + y]; } // Compute new m and l - float mim1 = m[lm_offset + (Br * i) + rc]; - float dim1 = l[lm_offset + (Br * i) + rc]; + float mim1 = m[lm_offset + ti + rc]; + float dim1 = l[lm_offset + ti + rc]; float mi = fmax(mim1, mi_local); - float di = dim1 * native_exp(mim1 - mi) + vm * native_exp(mi_local - mi); + float di = dim1 * native_exp2(mim1 - mi) + vm * native_exp2(mi_local - mi); - float om = dim1 * native_exp(mim1 - mi) / di; - vm = native_exp(mi_local - mi) / di; + float om = dim1 * native_exp2(mim1 - mi) / di; + vm = native_exp2(mi_local - mi) / di; // Write O, l, m to HBM for (int d = 0; d < D; d++) { float pv = 0; // Pij * Vc for (int y = 0; y < Bc; y++) { - //if (j * Bc + y >= T) { + //if (to * Bc + y >= T) { // break; //} pv += SP[(Bc * rc) + y] * Vc[(y * D) + d]; } // O is (B, NH, T, D) - O[qkv_offset + ((Bc * i + rc) * D) + d] = - om * O[qkv_offset + ((Bc * i + rc) * D) + d] + - vm * pv; + O[qkv_offset + ((ti + rc) * D) + d] = + om * O[qkv_offset + ((ti + rc) * D) + d] + + vm * pv; } - m[lm_offset + (Br * i) + rc] = mi; - l[lm_offset + (Br * i) + rc] = di; + m[lm_offset + ti + rc] = mi; + l[lm_offset + ti + rc] = di; } barrier(CLK_LOCAL_MEM_FENCE); // otherwise, thread can use the wrong Kc, Vc in inner loop } } +// This is a very basic flash attention kernel. +// For this kernel, each work-item computes one row of D elements of the output. +// There is no caching of the Q or O data. +// There is also no sharing of the K or V data. kernel void flash_attention( global const float* Q, global const float* K, global const float* V, global float* O, const float scale) { - const int D = C / NH; // head size + // Note: all data is arranged: B x NH x T x D + + // scale the scale, so we can use exp2 instead of exp + const float adjusted_scale = scale * M_LOG2E_F; - int b = get_global_id(0); + int to = get_global_id(0); int nh = get_global_id(1); - int to = get_global_id(2); + int b = get_global_id(2); float* o = O + b * NH * T * D + nh * T * D + to * D; @@ -247,7 +253,7 @@ kernel void flash_attention( for (int d = 0; d < D; d++) { xi += q[d] * k[d]; } - xi *= scale; + xi *= adjusted_scale; } // Update the running maximum @@ -255,10 +261,10 @@ kernel void flash_attention( mi = fmax(mim1, xi); // softmax(xi) - float smxi = native_exp(xi - mi); + float smxi = native_exp2(xi - mi); // Update di - float alpha = native_exp(mim1 - mi); + float alpha = native_exp2(mim1 - mi); di = di * alpha + smxi; // Update the un-scaled output from softmax(xi) and V @@ -273,16 +279,21 @@ kernel void flash_attention( } } -kernel void flash_attention_colblock( +// This is a slightly more complicated flash attention kernel. +// For this kernel, each work-item still computes one row of D elements of the output. +// There is caching for the Q, O, K, and V data. +__attribute__((reqd_work_group_size(32, 1, 1))) +kernel void flash_attention_blocked( global const float* Q, global const float* K, global const float* V, global float* O, const float scale) { - const int D = C / NH; // head size + // scale the scale, so we can use exp2 instead of exp + const float adjusted_scale = scale * M_LOG2E_F; - int b = get_global_id(0); - int nh = get_global_id(1); - int to = get_global_id(2); + int to = get_global_id(0); + int nh = get_group_id(1); + int b = get_group_id(2); float* o = O + b * NH * T * D + nh * T * D + to * D; @@ -290,10 +301,26 @@ kernel void flash_attention_colblock( float mi = -INFINITY; float di = 0.0f; + float oc[D]; + float qc[D]; + for (int d = 0; d < D; d++) { + qc[d] = q[d]; + } + + local float kc[D]; + local float vc[D]; + for (int ti = 0; ti < T; ti++) { const float* k = K + b * NH * T * D + nh * T * D + ti * D; const float* v = V + b * NH * T * D + nh * T * D + ti * D; + barrier(CLK_LOCAL_MEM_FENCE); + for (int d = get_local_id(0); d < D; d += get_local_size(0)) { + kc[d] = k[d]; + vc[d] = v[d]; + } + barrier(CLK_LOCAL_MEM_FENCE); + // Compute xi = QK^T float xi = 0.0; if (CAUSAL && to < ti) { @@ -301,9 +328,9 @@ kernel void flash_attention_colblock( } else { for (int d = 0; d < D; d++) { - xi += q[d] * k[d]; + xi += qc[d] * kc[d]; } - xi *= scale; + xi *= adjusted_scale; } // Update the running maximum @@ -311,20 +338,20 @@ kernel void flash_attention_colblock( mi = fmax(mim1, xi); // softmax(xi) - float smxi = native_exp(xi - mi); + float smxi = native_exp2(xi - mi); // Update di - float alpha = native_exp(mim1 - mi); + float alpha = native_exp2(mim1 - mi); di = di * alpha + smxi; // Update the un-scaled output from softmax(xi) and V for (int d = 0; d < D; d++) { - o[d] = o[d] * alpha + smxi * v[d]; + oc[d] = oc[d] * alpha + smxi * vc[d]; } } // Epilog scaling (flash attention 2) for (int d = 0; d < D; d++) { - o[d] = o[d] * native_recip(di); + o[d] = oc[d] * native_recip(di); } } diff --git a/samples/99_attentionexperiments/main.cpp b/samples/99_attentionexperiments/main.cpp index 5884a87..047b5df 100644 --- a/samples/99_attentionexperiments/main.cpp +++ b/samples/99_attentionexperiments/main.cpp @@ -701,18 +701,16 @@ static void flash_attention_minimal_forward( flash_attention.setArg(0, q); flash_attention.setArg(1, k); flash_attention.setArg(2, v); - flash_attention.setArg(3, Tc); - flash_attention.setArg(4, Tr); - flash_attention.setArg(5, Bc); - flash_attention.setArg(6, Br); - flash_attention.setArg(7, softmax_scale); - flash_attention.setArg(8, l); - flash_attention.setArg(9, m); - flash_attention.setArg(10, out); - flash_attention.setArg(11, cl::Local(Q_LMSize)); - flash_attention.setArg(12, cl::Local(KV_LMSize)); - flash_attention.setArg(13, cl::Local(KV_LMSize)); - flash_attention.setArg(14, cl::Local(S_LMSize)); + flash_attention.setArg(3, Bc); + flash_attention.setArg(4, Br); + flash_attention.setArg(5, softmax_scale); + flash_attention.setArg(6, l); + flash_attention.setArg(7, m); + flash_attention.setArg(8, out); + flash_attention.setArg(9, cl::Local(Q_LMSize)); + flash_attention.setArg(10, cl::Local(KV_LMSize)); + flash_attention.setArg(11, cl::Local(KV_LMSize)); + flash_attention.setArg(12, cl::Local(S_LMSize)); if (!skipinit) { queue.enqueueFillBuffer(out, 0, 0, out_ref.size() * sizeof(out_ref[0])); @@ -756,13 +754,14 @@ static void flash_attention_forward( printf("%80s: ", label.c_str()); fflush(stdout); - cl::Kernel flash_attention{program, "flash_attention"}; + cl::Kernel flash_attention{program, kernelName.c_str()}; if (flash_attention() == nullptr) { printf("unsupported.\n"); } else { const float softmax_scale = 1.0f / (float)std::sqrt(D); - cl::NDRange flash_attention_gws(B, NH, T); + cl::NDRange flash_attention_gws(T, NH, B); + cl::NDRange flash_attention_lws(32, 1, 1); flash_attention.setArg(0, q); flash_attention.setArg(1, k); flash_attention.setArg(2, v); @@ -776,7 +775,7 @@ static void flash_attention_forward( float best = 999.0f; for (int test = 0; test < testIterations; test++) { auto start = test_clock::now(); - queue.enqueueNDRangeKernel(flash_attention, cl::NullRange, flash_attention_gws); + queue.enqueueNDRangeKernel(flash_attention, cl::NullRange, flash_attention_gws, flash_attention_lws); queue.finish(); auto end = test_clock::now(); std::chrono::duration sw_time = end - start; @@ -974,7 +973,7 @@ int main(int argc, char** argv) } #endif -#if 1 +#if 0 { std::vector tout_vec (B * T * C); @@ -991,7 +990,7 @@ int main(int argc, char** argv) } #endif -#if 1 +#if 0 { std::vector tout_vec (B * T * C); @@ -1008,7 +1007,7 @@ int main(int argc, char** argv) } #endif -#if 1 +#if 0 { std::vector tout_vec (B * T * C); @@ -1037,11 +1036,11 @@ int main(int argc, char** argv) printf("Running tests...\n"); - if (mask & 0x2) { + if (mask & 0x1) { naive_3p_attention_forward(context, program, queue, out, preatt, att, q, k, v, wgSize, out_vec, preatt_vec, att_vec); } - if (mask & 0x20) { + if (mask & 0x10) { flash_attention_minimal_forward(context, program, queue, out, q, k, v, wgSize, out_vec); } @@ -1049,8 +1048,8 @@ int main(int argc, char** argv) flash_attention_forward("flash_attention", context, program, queue, out, q, k, v, wgSize, out_vec); } - if (mask & 0x20) { - flash_attention_forward("flash_attention_colblock", context, program, queue, out, q, k, v, wgSize, out_vec); + if (mask & 0x40) { + flash_attention_forward("flash_attention_blocked", context, program, queue, out, q, k, v, wgSize, out_vec); } printf("Done.\n");