Skip to content

Commit

Permalink
added a blocked version of flash attention with caching
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Dec 9, 2024
1 parent a130f00 commit 53aa9fa
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 64 deletions.
111 changes: 69 additions & 42 deletions samples/99_attentionexperiments/attention_kernels.cl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
#define NH 12
#endif

#define D (C / NH) // head size

kernel void naive_3p_query_key(global float* preatt, global const float* q_base, global const float* k_base)
{
const int D = C / NH; // head size

// Note: Global work size is B * NH * T * T
int idx = get_global_id(0);

Expand Down Expand Up @@ -86,8 +86,6 @@ kernel void naive_3p_softmax(global float* att, global const float* preatt)

kernel void naive_3p_value(global float* out, global const float* att, global const float* v_base)
{
const int D = C / NH; // head size

// Note: Global work size is B * T * NH
int idx = get_global_id(0);
int nh = idx % NH;
Expand Down Expand Up @@ -117,7 +115,6 @@ kernel void naive_3p_value(global float* out, global const float* att, global co
// https://github.com/tspeterkim/flash-attention-minimal
kernel void flash_attention_minimal(
global const float* Q, global const float* K, global const float* V,
const int Tc, const int Tr,
const int Bc, const int Br, // Note: this implementation requires Bc == Br!!
const float softmax_scale,
global float* l, global float* m, global float* O,
Expand All @@ -126,13 +123,15 @@ kernel void flash_attention_minimal(
local float* Vc,
local float* SP) // Used for both S and P
{
const int D = C / NH; // head size
// scale the scale, so we can use exp2 instead of exp
const float adjusted_scale = softmax_scale * M_LOG2E_F;

// Note: Global Work Size is (B * Bc, NH)
// Note: Local Work Size is (Bc, 1)
// --> Group ID is (batch index, head index)
const int b = get_group_id(0);
const int nh = get_group_id(1);

// --> Local ID is row or column index
const int rc = get_local_id(0);

Expand All @@ -142,20 +141,20 @@ kernel void flash_attention_minimal(
// Note: l, m are (B, NH, T)
const int lm_offset = (b * NH * T) + (nh * T); // offset for l and m

for (int j = 0; j < Tc; j++) {
for (int to = 0; to < T; to += Bc) {
// Load K, V to SLM
// Each work-item loads one row of Kc and Vc.
for (int d = 0; d < D; d++) {
Kc[(rc * D) + d] = K[qkv_offset + ((Bc * j + rc) * D) + d];
Vc[(rc * D) + d] = V[qkv_offset + ((Bc * j + rc) * D) + d];
Kc[(rc * D) + d] = K[qkv_offset + ((to + rc) * D) + d];
Vc[(rc * D) + d] = V[qkv_offset + ((to + rc) * D) + d];
}
barrier(CLK_LOCAL_MEM_FENCE);

for (int i = 0; i < Tr; i++) {
for (int ti = 0; ti < T; ti += Br) {
// Load Q to SLM
// Each work-item loads one row of Qc
for (int d = 0; d < D; d++) {
Qc[(rc * D) + d] = Q[qkv_offset + ((Bc * i + rc) * D) + d];
Qc[(rc * D) + d] = Q[qkv_offset + ((ti + rc) * D) + d];
}

// Compute SP = QK^T, mi_local = rowmax(SP)
Expand All @@ -164,14 +163,14 @@ kernel void flash_attention_minimal(
float mi_local = -INFINITY;
for (int y = 0; y < Bc; y++) {
float xi = 0;
if (CAUSAL && i * Br + rc < j * Bc + y) {
if (CAUSAL && ti + rc < to + y) {
xi = -INFINITY;
}
else {
for (int d = 0; d < D; d++) {
xi += Qc[(rc * D) + d] * Kc[(y * D) + d];
}
xi *= softmax_scale;
xi *= adjusted_scale;
}
SP[(Bc * rc) + y] = xi;
mi_local = fmax(xi, mi_local);
Expand All @@ -181,52 +180,59 @@ kernel void flash_attention_minimal(
// implement softmax with causal masking
float vm = 0;
for (int y = 0; y < Bc; y++) {
SP[(Bc * rc) + y] = native_exp(SP[(Bc * rc) + y] - mi_local);
SP[(Bc * rc) + y] = native_exp2(SP[(Bc * rc) + y] - mi_local);
vm += SP[(Bc * rc) + y];
}

// Compute new m and l
float mim1 = m[lm_offset + (Br * i) + rc];
float dim1 = l[lm_offset + (Br * i) + rc];
float mim1 = m[lm_offset + ti + rc];
float dim1 = l[lm_offset + ti + rc];

float mi = fmax(mim1, mi_local);
float di = dim1 * native_exp(mim1 - mi) + vm * native_exp(mi_local - mi);
float di = dim1 * native_exp2(mim1 - mi) + vm * native_exp2(mi_local - mi);

float om = dim1 * native_exp(mim1 - mi) / di;
vm = native_exp(mi_local - mi) / di;
float om = dim1 * native_exp2(mim1 - mi) / di;
vm = native_exp2(mi_local - mi) / di;

// Write O, l, m to HBM
for (int d = 0; d < D; d++) {
float pv = 0; // Pij * Vc
for (int y = 0; y < Bc; y++) {
//if (j * Bc + y >= T) {
//if (to * Bc + y >= T) {
// break;
//}
pv += SP[(Bc * rc) + y] * Vc[(y * D) + d];
}
// O is (B, NH, T, D)
O[qkv_offset + ((Bc * i + rc) * D) + d] =
om * O[qkv_offset + ((Bc * i + rc) * D) + d] +
vm * pv;
O[qkv_offset + ((ti + rc) * D) + d] =
om * O[qkv_offset + ((ti + rc) * D) + d] +
vm * pv;
}

m[lm_offset + (Br * i) + rc] = mi;
l[lm_offset + (Br * i) + rc] = di;
m[lm_offset + ti + rc] = mi;
l[lm_offset + ti + rc] = di;
}
barrier(CLK_LOCAL_MEM_FENCE); // otherwise, thread can use the wrong Kc, Vc in inner loop
}
}

// This is a very basic flash attention kernel.
// For this kernel, each work-item computes one row of D elements of the output.
// There is no caching of the Q or O data.
// There is also no sharing of the K or V data.
kernel void flash_attention(
global const float* Q, global const float* K, global const float* V,
global float* O,
const float scale)
{
const int D = C / NH; // head size
// Note: all data is arranged: B x NH x T x D

// scale the scale, so we can use exp2 instead of exp
const float adjusted_scale = scale * M_LOG2E_F;

int b = get_global_id(0);
int to = get_global_id(0);
int nh = get_global_id(1);
int to = get_global_id(2);
int b = get_global_id(2);

float* o = O + b * NH * T * D + nh * T * D + to * D;

Expand All @@ -247,18 +253,18 @@ kernel void flash_attention(
for (int d = 0; d < D; d++) {
xi += q[d] * k[d];
}
xi *= scale;
xi *= adjusted_scale;
}

// Update the running maximum
float mim1 = mi;
mi = fmax(mim1, xi);

// softmax(xi)
float smxi = native_exp(xi - mi);
float smxi = native_exp2(xi - mi);

// Update di
float alpha = native_exp(mim1 - mi);
float alpha = native_exp2(mim1 - mi);
di = di * alpha + smxi;

// Update the un-scaled output from softmax(xi) and V
Expand All @@ -273,58 +279,79 @@ kernel void flash_attention(
}
}

kernel void flash_attention_colblock(
// This is a slightly more complicated flash attention kernel.
// For this kernel, each work-item still computes one row of D elements of the output.
// There is caching for the Q, O, K, and V data.
__attribute__((reqd_work_group_size(32, 1, 1)))
kernel void flash_attention_blocked(
global const float* Q, global const float* K, global const float* V,
global float* O,
const float scale)
{
const int D = C / NH; // head size
// scale the scale, so we can use exp2 instead of exp
const float adjusted_scale = scale * M_LOG2E_F;

int b = get_global_id(0);
int nh = get_global_id(1);
int to = get_global_id(2);
int to = get_global_id(0);
int nh = get_group_id(1);
int b = get_group_id(2);

float* o = O + b * NH * T * D + nh * T * D + to * D;

const float* q = Q + b * NH * T * D + nh * T * D + to * D;
float mi = -INFINITY;
float di = 0.0f;

float oc[D];
float qc[D];
for (int d = 0; d < D; d++) {
qc[d] = q[d];
}

local float kc[D];
local float vc[D];

for (int ti = 0; ti < T; ti++) {
const float* k = K + b * NH * T * D + nh * T * D + ti * D;
const float* v = V + b * NH * T * D + nh * T * D + ti * D;

barrier(CLK_LOCAL_MEM_FENCE);
for (int d = get_local_id(0); d < D; d += get_local_size(0)) {
kc[d] = k[d];
vc[d] = v[d];
}
barrier(CLK_LOCAL_MEM_FENCE);

// Compute xi = QK^T
float xi = 0.0;
if (CAUSAL && to < ti) {
xi = -INFINITY;
}
else {
for (int d = 0; d < D; d++) {
xi += q[d] * k[d];
xi += qc[d] * kc[d];
}
xi *= scale;
xi *= adjusted_scale;
}

// Update the running maximum
float mim1 = mi;
mi = fmax(mim1, xi);

// softmax(xi)
float smxi = native_exp(xi - mi);
float smxi = native_exp2(xi - mi);

// Update di
float alpha = native_exp(mim1 - mi);
float alpha = native_exp2(mim1 - mi);
di = di * alpha + smxi;

// Update the un-scaled output from softmax(xi) and V
for (int d = 0; d < D; d++) {
o[d] = o[d] * alpha + smxi * v[d];
oc[d] = oc[d] * alpha + smxi * vc[d];
}
}

// Epilog scaling (flash attention 2)
for (int d = 0; d < D; d++) {
o[d] = o[d] * native_recip(di);
o[d] = oc[d] * native_recip(di);
}
}
43 changes: 21 additions & 22 deletions samples/99_attentionexperiments/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -701,18 +701,16 @@ static void flash_attention_minimal_forward(
flash_attention.setArg(0, q);
flash_attention.setArg(1, k);
flash_attention.setArg(2, v);
flash_attention.setArg(3, Tc);
flash_attention.setArg(4, Tr);
flash_attention.setArg(5, Bc);
flash_attention.setArg(6, Br);
flash_attention.setArg(7, softmax_scale);
flash_attention.setArg(8, l);
flash_attention.setArg(9, m);
flash_attention.setArg(10, out);
flash_attention.setArg(11, cl::Local(Q_LMSize));
flash_attention.setArg(12, cl::Local(KV_LMSize));
flash_attention.setArg(13, cl::Local(KV_LMSize));
flash_attention.setArg(14, cl::Local(S_LMSize));
flash_attention.setArg(3, Bc);
flash_attention.setArg(4, Br);
flash_attention.setArg(5, softmax_scale);
flash_attention.setArg(6, l);
flash_attention.setArg(7, m);
flash_attention.setArg(8, out);
flash_attention.setArg(9, cl::Local(Q_LMSize));
flash_attention.setArg(10, cl::Local(KV_LMSize));
flash_attention.setArg(11, cl::Local(KV_LMSize));
flash_attention.setArg(12, cl::Local(S_LMSize));

if (!skipinit) {
queue.enqueueFillBuffer(out, 0, 0, out_ref.size() * sizeof(out_ref[0]));
Expand Down Expand Up @@ -756,13 +754,14 @@ static void flash_attention_forward(

printf("%80s: ", label.c_str()); fflush(stdout);

cl::Kernel flash_attention{program, "flash_attention"};
cl::Kernel flash_attention{program, kernelName.c_str()};
if (flash_attention() == nullptr) {
printf("unsupported.\n");
} else {
const float softmax_scale = 1.0f / (float)std::sqrt(D);

cl::NDRange flash_attention_gws(B, NH, T);
cl::NDRange flash_attention_gws(T, NH, B);
cl::NDRange flash_attention_lws(32, 1, 1);
flash_attention.setArg(0, q);
flash_attention.setArg(1, k);
flash_attention.setArg(2, v);
Expand All @@ -776,7 +775,7 @@ static void flash_attention_forward(
float best = 999.0f;
for (int test = 0; test < testIterations; test++) {
auto start = test_clock::now();
queue.enqueueNDRangeKernel(flash_attention, cl::NullRange, flash_attention_gws);
queue.enqueueNDRangeKernel(flash_attention, cl::NullRange, flash_attention_gws, flash_attention_lws);
queue.finish();
auto end = test_clock::now();
std::chrono::duration<float> sw_time = end - start;
Expand Down Expand Up @@ -974,7 +973,7 @@ int main(int argc, char** argv)
}
#endif

#if 1
#if 0
{
std::vector<float> tout_vec (B * T * C);

Expand All @@ -991,7 +990,7 @@ int main(int argc, char** argv)
}
#endif

#if 1
#if 0
{
std::vector<float> tout_vec (B * T * C);

Expand All @@ -1008,7 +1007,7 @@ int main(int argc, char** argv)
}
#endif

#if 1
#if 0
{
std::vector<float> tout_vec (B * T * C);

Expand Down Expand Up @@ -1037,20 +1036,20 @@ int main(int argc, char** argv)

printf("Running tests...\n");

if (mask & 0x2) {
if (mask & 0x1) {
naive_3p_attention_forward(context, program, queue, out, preatt, att, q, k, v, wgSize, out_vec, preatt_vec, att_vec);
}

if (mask & 0x20) {
if (mask & 0x10) {
flash_attention_minimal_forward(context, program, queue, out, q, k, v, wgSize, out_vec);
}

if (mask & 0x20) {
flash_attention_forward("flash_attention", context, program, queue, out, q, k, v, wgSize, out_vec);
}

if (mask & 0x20) {
flash_attention_forward("flash_attention_colblock", context, program, queue, out, q, k, v, wgSize, out_vec);
if (mask & 0x40) {
flash_attention_forward("flash_attention_blocked", context, program, queue, out, q, k, v, wgSize, out_vec);
}

printf("Done.\n");
Expand Down

0 comments on commit 53aa9fa

Please sign in to comment.