Skip to content

Commit

Permalink
add work-group versions (not working yet)
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Dec 22, 2024
1 parent 0ad30b5 commit f68deae
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 57 deletions.
190 changes: 135 additions & 55 deletions samples/99_attentionexperiments/attention_kernels.cl
Original file line number Diff line number Diff line change
Expand Up @@ -162,18 +162,18 @@ kernel void flash_attention_minimal(
// from one row of Qc and Kc.
float mi_local = -INFINITY;
for (int y = 0; y < Bc; y++) {
float xi = 0;
float qk = 0;
if (CAUSAL && ti + rc < to + y) {
xi = -INFINITY;
qk = -INFINITY;
}
else {
for (int d = 0; d < D; d++) {
xi += Qc[(rc * D) + d] * Kc[(y * D) + d];
qk += Qc[(rc * D) + d] * Kc[(y * D) + d];
}
xi *= adjusted_scale;
qk *= adjusted_scale;
}
SP[(Bc * rc) + y] = xi;
mi_local = fmax(xi, mi_local);
SP[(Bc * rc) + y] = qk;
mi_local = fmax(qk, mi_local);
}

// SP = exp(SP - mi_local), vm = rowsum(SP)
Expand Down Expand Up @@ -221,8 +221,8 @@ kernel void flash_attention_minimal(
// There is no caching of the Q or O data.
// There is also no sharing of the K or V data.
kernel void flash_attention(
global const float* Q, global const float* K, global const float* V,
global float* O,
global const float* restrict Q, global const float* restrict K, global const float* restrict V,
global float* restrict O,
const float scale)
{
// Note: all data is arranged: B x NH x T x D
Expand All @@ -244,30 +244,30 @@ kernel void flash_attention(
global const float* k = K + b * NH * T * D + nh * T * D + ti * D;
global const float* v = V + b * NH * T * D + nh * T * D + ti * D;

// Compute xi = QK^T
float xi = 0.0f;
// Compute qk = QK^T
float qk = 0.0f;
if (CAUSAL && to < ti) {
xi = -INFINITY;
qk = -INFINITY;
}
else {
for (int d = 0; d < D; d++) {
xi += q[d] * k[d];
qk += q[d] * k[d];
}
xi *= adjusted_scale;
qk *= adjusted_scale;
}

// Update the running maximum
float mim1 = mi;
mi = fmax(mim1, xi);
mi = fmax(mim1, qk);

// softmax(xi)
float smxi = native_exp2(xi - mi);
// softmax(qk)
float smxi = native_exp2(qk - mi);

// Update di
float alpha = native_exp2(mim1 - mi);
di = di * alpha + smxi;

// Update the un-scaled output from softmax(xi) and V
// Update the un-scaled output from softmax(qk) and V
for (int d = 0; d < D; d++) {
o[d] = o[d] * alpha + smxi * v[d];
}
Expand All @@ -281,81 +281,161 @@ kernel void flash_attention(

// This is a slightly more complicated flash attention kernel.
// For this kernel, each work-item still computes one row of D elements of the output.
// There is caching for the Q, O, K, and V data.
#define TT 16
#define BLOCK_N D
kernel void flash_attention_blocked(
global const float* Q, global const float* K, global const float* V,
global float* O,
global const float* restrict Q, global const float* restrict K, global const float* restrict V,
global float* restrict O,
const float scale)
{
// scale the scale, so we can use exp2 instead of exp
const float adjusted_scale = scale * M_LOG2E_F;

int to = get_global_id(0);
int nh = get_group_id(1);
int b = get_group_id(2);
int nh = get_global_id(1);
int b = get_global_id(2);

global float* o = O + b * NH * T * D + nh * T * D + to * D;

global const float* q = Q + b * NH * T * D + nh * T * D + to * D;
float mi = -INFINITY;
float di = 0.0f;

for (int ti = 0; ti < T; ti+=TT) {
for (int ti = 0; ti < T; ti+=BLOCK_N) {
global const float* k = K + b * NH * T * D + nh * T * D + ti * D;
global const float* v = V + b * NH * T * D + nh * T * D + ti * D;

// Compute xi = QK^T
float xi[TT];
for (int tt = 0; tt < TT; tt++) {
xi[tt] = 0.0f;
}
for (int d = 0; d < D; d++) {
for (int tt = 0; tt < TT; tt++) {
xi[tt] += q[d] * k[(tt * D) + d];
// Compute qk = QK^T
float qk[BLOCK_N];
for (int tt = 0; tt < BLOCK_N; tt++) {
qk[tt] = 0.0f;
for (int d = 0; d < D; d++) {
qk[tt] += q[d] * k[(tt * D) + d];
}
}
for (int tt = 0; tt < TT; tt++) {
xi[tt] *= adjusted_scale;
}

// TODO: find a better way to do this
if (CAUSAL) {
for (int tt = 0; tt < TT; tt++) {
xi[tt] = (to < ti + tt) ? -INFINITY : xi[tt];
qk[tt] *= adjusted_scale;
if (CAUSAL) {
qk[tt] = (to < ti + tt) ? -INFINITY : qk[tt];
}
}

// Update the running maximum
// Update the running maqkmum
float mim1 = mi;
for (int tt = 0; tt < TT; tt++) {
mi = fmax(mi, xi[tt]);
for (int tt = 0; tt < BLOCK_N; tt++) {
mi = fmax(mi, qk[tt]); // max reduction
}

// softmax(xi)
float smxi[TT];
for (int tt = 0; tt < TT; tt++) {
smxi[tt] = native_exp2(xi[tt] - mi);
// softmax(qk)
float smqk[BLOCK_N];
float l = 0.0f;
for (int tt = 0; tt < BLOCK_N; tt++) {
smqk[tt] = native_exp2(qk[tt] - mi);
l = l + smqk[tt]; // add reduction
}

// Update di
float alpha = native_exp2(mim1 - mi);
di *= alpha;
for (int tt = 0; tt < TT; tt++) {
di += smxi[tt];
}
di = di * alpha + l;

// Update the un-scaled output from softmax(xi) and V
#if 1
// Update the un-scaled output from softmax(qk) and V
for (int d = 0; d < D; d++) {
o[d] = o[d] * alpha;
for (int tt = 0; tt < TT; tt++) {
o[d] += smxi[tt] * v[(tt * D) + d];
}
for (int d = 0; d < D; d++) {
float update = 0.0f;
for (int tt = 0; tt < BLOCK_N; tt++) {
update += smqk[tt] * v[(tt * D) + d];
}
o[d] += update;
}
#else
// Update the un-scaled output from softmax(qk) and V
for (int d = 0; d < D; d++) {
float update = o[d] * alpha;
for (int tt = 0; tt < BLOCK_N; tt++) {
update += smqk[tt] * v[(tt * D) + d];
}
o[d] = update;
}
#endif
}

// Epilog scaling (flash attention 2)
for (int d = 0; d < D; d++) {
o[d] = o[d] * native_recip(di);
}
}
#undef BLOCK_N


// This is a slightly more complicated flash attention kernel.
// For this kernel, each work-group computes one row of D elements of the output.
#define BLOCK_N D
__attribute__((reqd_work_group_size(BLOCK_N, 1, 1)))
kernel void flash_attention_wg(
global const float* restrict Q, global const float* restrict K, global const float* restrict V,
global float* restrict O,
const float scale)
{
// scale the scale, so we can use exp2 instead of exp
const float adjusted_scale = scale * M_LOG2E_F;

int to = get_group_id(0);
int nh = get_group_id(1);
int b = get_group_id(2);

global float* o = O + b * NH * T * D + nh * T * D + to * D;

global const float* q = Q + b * NH * T * D + nh * T * D + to * D;
float mi = -INFINITY;
float di = 0.0f;

local float scratch[D];

for (int ti = 0; ti < T; ti+=BLOCK_N) {
global const float* k = K + b * NH * T * D + nh * T * D + ti * D;
global const float* v = V + b * NH * T * D + nh * T * D + ti * D;

int tt = get_local_id(0);

// Compute qk = QK^T
float qk = 0.0f;
for (int d = 0; d < D; d++) {
qk += q[d] * k[(tt * D) + d];
}
qk *= adjusted_scale;

// TODO: find a better way to do this
if (CAUSAL) {
qk = (to < ti + tt) ? -INFINITY : qk;
}

// Update the running maximum
float mim1 = mi;
barrier(CLK_LOCAL_MEM_FENCE);
scratch[tt] = qk;
barrier(CLK_LOCAL_MEM_FENCE);
for (int d = 0; d < D; d++) {
mi = fmax(mi, scratch[d]); // max reduction of qk
}

// softmax(qk)
float smqk = native_exp2(qk - mi);
float l = 0.0f;
barrier(CLK_LOCAL_MEM_FENCE);
scratch[tt] = smqk;
barrier(CLK_LOCAL_MEM_FENCE);
for (int d = 0; d < D; d++) {
l = l + scratch[d]; // add reduction of smqk
}

// Update di
float alpha = native_exp2(mim1 - mi);
di = di * alpha + l;

// Update the un-scaled output from softmax(qk) and V
o[tt] = o[tt] * alpha + smqk * v[tt];
}

// Epilog scaling (flash attention 2)
o[tt] = o[tt] * native_recip(di);
}
59 changes: 57 additions & 2 deletions samples/99_attentionexperiments/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,61 @@ static void flash_attention_forward(
const float softmax_scale = 1.0f / (float)std::sqrt(D);

cl::NDRange flash_attention_gws(T, NH, B);
cl::NDRange flash_attention_lws(32, 1, 1);
flash_attention.setArg(0, q);
flash_attention.setArg(1, k);
flash_attention.setArg(2, v);
flash_attention.setArg(3, out);
flash_attention.setArg(4, softmax_scale);

if (!skipinit) {
queue.enqueueFillBuffer(out, 0, 0, out_ref.size() * sizeof(out_ref[0]));
}

float best = 999.0f;
for (int test = 0; test < testIterations; test++) {
auto start = test_clock::now();
queue.enqueueNDRangeKernel(flash_attention, cl::NullRange, flash_attention_gws);
queue.finish();
auto end = test_clock::now();
std::chrono::duration<float> sw_time = end - start;
auto elapsed = sw_time.count();
best = std::min(best, elapsed);
}
printf("Best in %f seconds\n", best);

if (validate) {
printf("Checking results: out... "); fflush(stdout);
std::vector<float> out_check(out_ref.size());
queue.enqueueReadBuffer(out, CL_TRUE, 0, out_check.size() * sizeof(out_check[0]), out_check.data());
check_results(out_check, out_ref);
printf(" done!\n");
}
}
}

static void flash_attention_forward_wg(
const std::string& kernelName,
cl::Context& context, cl::Program& program, cl::CommandQueue& queue,
cl::Buffer& out,
cl::Buffer& q, cl::Buffer& k, cl::Buffer& v,
size_t wgSize,
const std::vector<float>& out_ref)
{
std::string label(__FUNCTION__);
label += "(";
label += kernelName;
label += ")";

printf("%80s: ", label.c_str()); fflush(stdout);

cl::Kernel flash_attention{program, kernelName.c_str()};
if (flash_attention() == nullptr) {
printf("unsupported.\n");
} else {
const float softmax_scale = 1.0f / (float)std::sqrt(D);

cl::NDRange flash_attention_gws(T * D, NH, B);
cl::NDRange flash_attention_lws(D, 1, 1);
flash_attention.setArg(0, q);
flash_attention.setArg(1, k);
flash_attention.setArg(2, v);
Expand Down Expand Up @@ -1046,10 +1100,11 @@ int main(int argc, char** argv)

if (mask & 0x20) {
flash_attention_forward("flash_attention", context, program, queue, out, q, k, v, wgSize, out_vec);
flash_attention_forward("flash_attention_blocked", context, program, queue, out, q, k, v, wgSize, out_vec);
}

if (mask & 0x40) {
flash_attention_forward("flash_attention_blocked", context, program, queue, out, q, k, v, wgSize, out_vec);
flash_attention_forward_wg("flash_attention_wg", context, program, queue, out, q, k, v, wgSize, out_vec);
}

printf("Done.\n");
Expand Down

0 comments on commit f68deae

Please sign in to comment.