Skip to content

Commit

Permalink
added a blocked version of flash attention with caching
Browse files Browse the repository at this point in the history
  • Loading branch information
bashbaug committed Dec 9, 2024
1 parent a130f00 commit 332e630
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 74 deletions.
131 changes: 79 additions & 52 deletions samples/99_attentionexperiments/attention_kernels.cl
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
#define NH 12
#endif

#define D (C / NH) // head size

kernel void naive_3p_query_key(global float* preatt, global const float* q_base, global const float* k_base)
{
const int D = C / NH; // head size

// Note: Global work size is B * NH * T * T
int idx = get_global_id(0);

Expand Down Expand Up @@ -86,8 +86,6 @@ kernel void naive_3p_softmax(global float* att, global const float* preatt)

kernel void naive_3p_value(global float* out, global const float* att, global const float* v_base)
{
const int D = C / NH; // head size

// Note: Global work size is B * T * NH
int idx = get_global_id(0);
int nh = idx % NH;
Expand Down Expand Up @@ -117,7 +115,6 @@ kernel void naive_3p_value(global float* out, global const float* att, global co
// https://github.com/tspeterkim/flash-attention-minimal
kernel void flash_attention_minimal(
global const float* Q, global const float* K, global const float* V,
const int Tc, const int Tr,
const int Bc, const int Br, // Note: this implementation requires Bc == Br!!
const float softmax_scale,
global float* l, global float* m, global float* O,
Expand All @@ -126,13 +123,15 @@ kernel void flash_attention_minimal(
local float* Vc,
local float* SP) // Used for both S and P
{
const int D = C / NH; // head size
// scale the scale, so we can use exp2 instead of exp
const float adjusted_scale = softmax_scale * M_LOG2E_F;

// Note: Global Work Size is (B * Bc, NH)
// Note: Local Work Size is (Bc, 1)
// --> Group ID is (batch index, head index)
const int b = get_group_id(0);
const int nh = get_group_id(1);

// --> Local ID is row or column index
const int rc = get_local_id(0);

Expand All @@ -142,20 +141,20 @@ kernel void flash_attention_minimal(
// Note: l, m are (B, NH, T)
const int lm_offset = (b * NH * T) + (nh * T); // offset for l and m

for (int j = 0; j < Tc; j++) {
for (int to = 0; to < T; to += Bc) {
// Load K, V to SLM
// Each work-item loads one row of Kc and Vc.
for (int d = 0; d < D; d++) {
Kc[(rc * D) + d] = K[qkv_offset + ((Bc * j + rc) * D) + d];
Vc[(rc * D) + d] = V[qkv_offset + ((Bc * j + rc) * D) + d];
Kc[(rc * D) + d] = K[qkv_offset + ((to + rc) * D) + d];
Vc[(rc * D) + d] = V[qkv_offset + ((to + rc) * D) + d];
}
barrier(CLK_LOCAL_MEM_FENCE);

for (int i = 0; i < Tr; i++) {
for (int ti = 0; ti < T; ti += Br) {
// Load Q to SLM
// Each work-item loads one row of Qc
for (int d = 0; d < D; d++) {
Qc[(rc * D) + d] = Q[qkv_offset + ((Bc * i + rc) * D) + d];
Qc[(rc * D) + d] = Q[qkv_offset + ((ti + rc) * D) + d];
}

// Compute SP = QK^T, mi_local = rowmax(SP)
Expand All @@ -164,14 +163,14 @@ kernel void flash_attention_minimal(
float mi_local = -INFINITY;
for (int y = 0; y < Bc; y++) {
float xi = 0;
if (CAUSAL && i * Br + rc < j * Bc + y) {
if (CAUSAL && ti + rc < to + y) {
xi = -INFINITY;
}
else {
for (int d = 0; d < D; d++) {
xi += Qc[(rc * D) + d] * Kc[(y * D) + d];
}
xi *= softmax_scale;
xi *= adjusted_scale;
}
SP[(Bc * rc) + y] = xi;
mi_local = fmax(xi, mi_local);
Expand All @@ -181,84 +180,91 @@ kernel void flash_attention_minimal(
// implement softmax with causal masking
float vm = 0;
for (int y = 0; y < Bc; y++) {
SP[(Bc * rc) + y] = native_exp(SP[(Bc * rc) + y] - mi_local);
SP[(Bc * rc) + y] = native_exp2(SP[(Bc * rc) + y] - mi_local);
vm += SP[(Bc * rc) + y];
}

// Compute new m and l
float mim1 = m[lm_offset + (Br * i) + rc];
float dim1 = l[lm_offset + (Br * i) + rc];
float mim1 = m[lm_offset + ti + rc];
float dim1 = l[lm_offset + ti + rc];

float mi = fmax(mim1, mi_local);
float di = dim1 * native_exp(mim1 - mi) + vm * native_exp(mi_local - mi);
float di = dim1 * native_exp2(mim1 - mi) + vm * native_exp2(mi_local - mi);

float om = dim1 * native_exp(mim1 - mi) / di;
vm = native_exp(mi_local - mi) / di;
float om = dim1 * native_exp2(mim1 - mi) / di;
vm = native_exp2(mi_local - mi) / di;

// Write O, l, m to HBM
for (int d = 0; d < D; d++) {
float pv = 0; // Pij * Vc
for (int y = 0; y < Bc; y++) {
//if (j * Bc + y >= T) {
//if (to * Bc + y >= T) {
// break;
//}
pv += SP[(Bc * rc) + y] * Vc[(y * D) + d];
}
// O is (B, NH, T, D)
O[qkv_offset + ((Bc * i + rc) * D) + d] =
om * O[qkv_offset + ((Bc * i + rc) * D) + d] +
vm * pv;
O[qkv_offset + ((ti + rc) * D) + d] =
om * O[qkv_offset + ((ti + rc) * D) + d] +
vm * pv;
}

m[lm_offset + (Br * i) + rc] = mi;
l[lm_offset + (Br * i) + rc] = di;
m[lm_offset + ti + rc] = mi;
l[lm_offset + ti + rc] = di;
}
barrier(CLK_LOCAL_MEM_FENCE); // otherwise, thread can use the wrong Kc, Vc in inner loop
}
}

// This is a very basic flash attention kernel.
// For this kernel, each work-item computes one row of D elements of the output.
// There is no caching of the Q or O data.
// There is also no sharing of the K or V data.
kernel void flash_attention(
global const float* Q, global const float* K, global const float* V,
global float* O,
const float scale)
{
const int D = C / NH; // head size
// Note: all data is arranged: B x NH x T x D

// scale the scale, so we can use exp2 instead of exp
const float adjusted_scale = scale * M_LOG2E_F;

int b = get_global_id(0);
int to = get_global_id(0);
int nh = get_global_id(1);
int to = get_global_id(2);
int b = get_global_id(2);

float* o = O + b * NH * T * D + nh * T * D + to * D;
global float* o = O + b * NH * T * D + nh * T * D + to * D;

const float* q = Q + b * NH * T * D + nh * T * D + to * D;
global const float* q = Q + b * NH * T * D + nh * T * D + to * D;
float mi = -INFINITY;
float di = 0.0f;

for (int ti = 0; ti < T; ti++) {
const float* k = K + b * NH * T * D + nh * T * D + ti * D;
const float* v = V + b * NH * T * D + nh * T * D + ti * D;
global const float* k = K + b * NH * T * D + nh * T * D + ti * D;
global const float* v = V + b * NH * T * D + nh * T * D + ti * D;

// Compute xi = QK^T
float xi = 0.0;
float xi = 0.0f;
if (CAUSAL && to < ti) {
xi = -INFINITY;
}
else {
for (int d = 0; d < D; d++) {
xi += q[d] * k[d];
}
xi *= scale;
xi *= adjusted_scale;
}

// Update the running maximum
float mim1 = mi;
mi = fmax(mim1, xi);

// softmax(xi)
float smxi = native_exp(xi - mi);
float smxi = native_exp2(xi - mi);

// Update di
float alpha = native_exp(mim1 - mi);
float alpha = native_exp2(mim1 - mi);
di = di * alpha + smxi;

// Update the un-scaled output from softmax(xi) and V
Expand All @@ -273,58 +279,79 @@ kernel void flash_attention(
}
}

kernel void flash_attention_colblock(
// This is a slightly more complicated flash attention kernel.
// For this kernel, each work-item still computes one row of D elements of the output.
// There is caching for the Q, O, K, and V data.
__attribute__((reqd_work_group_size(32, 1, 1)))
kernel void flash_attention_blocked(
global const float* Q, global const float* K, global const float* V,
global float* O,
const float scale)
{
const int D = C / NH; // head size
// scale the scale, so we can use exp2 instead of exp
const float adjusted_scale = scale * M_LOG2E_F;

int b = get_global_id(0);
int nh = get_global_id(1);
int to = get_global_id(2);
int to = get_global_id(0);
int nh = get_group_id(1);
int b = get_group_id(2);

float* o = O + b * NH * T * D + nh * T * D + to * D;
global float* o = O + b * NH * T * D + nh * T * D + to * D;

const float* q = Q + b * NH * T * D + nh * T * D + to * D;
global const float* q = Q + b * NH * T * D + nh * T * D + to * D;
float mi = -INFINITY;
float di = 0.0f;

float oc[D];
float qc[D];
for (int d = 0; d < D; d++) {
qc[d] = q[d];
}

local float kc[D];
local float vc[D];

for (int ti = 0; ti < T; ti++) {
const float* k = K + b * NH * T * D + nh * T * D + ti * D;
const float* v = V + b * NH * T * D + nh * T * D + ti * D;
global const float* k = K + b * NH * T * D + nh * T * D + ti * D;
global const float* v = V + b * NH * T * D + nh * T * D + ti * D;

barrier(CLK_LOCAL_MEM_FENCE);
for (int d = get_local_id(0); d < D; d += get_local_size(0)) {
kc[d] = k[d];
vc[d] = v[d];
}
barrier(CLK_LOCAL_MEM_FENCE);

// Compute xi = QK^T
float xi = 0.0;
float xi = 0.0f;
if (CAUSAL && to < ti) {
xi = -INFINITY;
}
else {
for (int d = 0; d < D; d++) {
xi += q[d] * k[d];
xi += qc[d] * kc[d];
}
xi *= scale;
xi *= adjusted_scale;
}

// Update the running maximum
float mim1 = mi;
mi = fmax(mim1, xi);

// softmax(xi)
float smxi = native_exp(xi - mi);
float smxi = native_exp2(xi - mi);

// Update di
float alpha = native_exp(mim1 - mi);
float alpha = native_exp2(mim1 - mi);
di = di * alpha + smxi;

// Update the un-scaled output from softmax(xi) and V
for (int d = 0; d < D; d++) {
o[d] = o[d] * alpha + smxi * v[d];
oc[d] = oc[d] * alpha + smxi * vc[d];
}
}

// Epilog scaling (flash attention 2)
for (int d = 0; d < D; d++) {
o[d] = o[d] * native_recip(di);
o[d] = oc[d] * native_recip(di);
}
}
Loading

0 comments on commit 332e630

Please sign in to comment.