Skip to content

Commit

Permalink
Merge pull request #28 from fxmarty/fix-llmm1-kernel
Browse files Browse the repository at this point in the history
Fix LLMM1 kernel
  • Loading branch information
shajrawi authored Jun 14, 2024
2 parents 12079c3 + 54fcf12 commit d3da246
Showing 1 changed file with 40 additions and 90 deletions.
130 changes: 40 additions & 90 deletions csrc/custom/custom_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,55 +26,48 @@ __device__ __forceinline__ float4 load_ntmprl(const float4* addr) {
// TBlock fetches entire rows of A, and entire col of B (K dimension); assume
// N=1 for time being grid is M/A_NUM_ROWS blocks
template <int NUM_A_ROWS_PER_BLOCK>
__global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c) {
__global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c,
const int K) {
__shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE];
const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * blockDim.x;
// int row_addr_1 = row_addr + CUDA_NUM_THREADS;
// int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS;
// int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS;
const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8;
const int threadid = threadIdx.x;
const int warp = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
const int num_warps = blockDim.x / WARP_SIZE;
const int qwarpid = threadid / 16;
const int qthreadid = threadid % 16;
float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK];
// float4 colB_elem4;
__half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w;
float4 sum4; //[NUM_A_ROWS_PER_BLOCK];
float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0;
float4 sum4; //[NUM_A_ROWS_PER_BLOCK];
float acc[NUM_A_ROWS_PER_BLOCK] = {0.0};
__half2 acch2;
__half2 oval;

// rowA_elem4 = af4[row_addr + threadid];
//__syncthreads();
// rowA_elem4_1 = af4[row_addr_1 + threadid];
// rowA_elem4_2 = af4[row_addr_2 + threadid];
// rowA_elem4_3 = af4[row_addr_3 + threadid];
// As we later use warp shuffle operations, we may have more threads in the
// block than the actual available data, hence the if guard here.
if (threadid * 8 < K) {
#pragma unroll
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
rowA_elem4[i] = load_ntmprl(&af4[row_addr + i * blockDim.x + threadid]);
// rowA_elem4[i] = af4[row_addr + i*blockDim.x + threadid];
//__syncthreads();
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
// rowA_elem4[i] holds 8 * half numbers seen as a single float4.
rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]);
}
}

colB_elem4x = bf4[threadid * 4 + 0];
colB_elem4y = bf4[threadid * 4 + 1];
colB_elem4z = bf4[threadid * 4 + 2];
colB_elem4w = bf4[threadid * 4 + 3];

// __syncthreads();
__half2 Af2;
__half2 Bf2;
float2 S;
// auto Bh2ptr = reinterpret_cast<__half2 *>(&colB_elem4);
// auto Bf2x = *Bh2ptr;
// auto Bf2y = *(Bh2ptr+1);
// auto Bf2z = *(Bh2ptr+2);
// auto Bf2w = *(Bh2ptr+3);

auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4);
__half2* ah2lptr;

#pragma unroll
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
// Multiply-add on 8 half.
ah2lptr = Ah2ptr + i * 4;
Af2 = *(ah2lptr);
acch2 = __hmul2(Af2, colB_elem4x);
Expand All @@ -85,9 +78,14 @@ __global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c) {
Af2 = *(ah2lptr + 3);
acch2 = __hfma2(Af2, colB_elem4w, acch2);
S = __half22float2(acch2);
acc[i] = S.x + S.y;

// See comment above concerning the if guard.
if (threadid * 8 < K) {
acc[i] = S.x + S.y; // accumulation on float
}
}

// all reduce across warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
#pragma unroll
Expand All @@ -97,104 +95,56 @@ __global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c) {
}

// Warp leaders store the data to shared memory.
// if (lane == 0) {
// #pragma unroll
// for (int i=0; i<NUM_A_ROWS_PER_BLOCK; i++) {
// red_smem[i][warp] = acc[i];
// }
// }

if (lane < NUM_A_ROWS_PER_BLOCK) {
red_smem[lane][warp] = acc[lane];
}

// Make sure the data is in shared memory.
__syncthreads();

if (qwarpid < NUM_A_ROWS_PER_BLOCK) {
// if (threadid<64) {
// #pragma unroll
// for (int i=0; i<NUM_A_ROWS_PER_BLOCK/2; i++) {
// acc[i+2*qwarpid] = 0.0;
// }
////acc[qwarpid] = 0.0;

////if (qthreadid<num_warps) {
// #pragma unroll
// for (int i=0; i<NUM_A_ROWS_PER_BLOCK/2; i++) {
// acc[i+2*qwarpid] = red_smem[i+2*qwarpid][qthreadid];
// }
////acc[qwarpid] = red_smem[qwarpid][qthreadid];

////}
acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f;
// if (threadid<32) {
#pragma unroll
for (int mask = 16 / 2; mask >= 1; mask /= 2) {
// #pragma unroll
// for (int i=0; i<NUM_A_ROWS_PER_BLOCK/2; i++) {
// acc[i+2*qwarpid] += __shfl_xor(acc[i+2*qwarpid], mask);
// }
acc[qwarpid] += __shfl_xor(acc[qwarpid], mask);
}
float oval2 = __shfl_xor(acc[qwarpid], 16);
// acc[1] = __shfl_xor(acc[1],16);
// acc[3] = __shfl_xor(acc[3],16);
//}
// __syncthreads();
// if (threadid < NUM_A_ROWS_PER_BLOCK/2) {

if (threadid % WARP_SIZE == 0 or threadid % WARP_SIZE == 32) {
// oval =
// __float22half2_rn(make_float2(acc[2*threadid],acc[2*threadid+1])); oval
// = __float22half2_rn(make_float2(acc[2*qwarpid],acc[2*qwarpid+1])); oval
// = __float22half2_rn(make_float2(acc[qwarpid],acc[qwarpid+1]));
oval = __float22half2_rn(make_float2(acc[qwarpid], oval2));
c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval;
}
} // threadid<WARP_SIZE

// if (threadid < NUM_A_ROWS_PER_BLOCK/2) {
// acc[2*threadid] = 0.0;
// acc[2*threadid+1] = 0.0;
//
// if (num_warps>8) {
// #pragma unroll
// for (int j=0; j<8; j++) {
// acc[2*threadid] += red_smem[2*threadid][j];
// acc[2*threadid+1] += red_smem[2*threadid+1][j];
// }
// }
// #pragma unroll
// for (int j=0; j<num_warps-8; j++) {
// acc[2*threadid] += red_smem[2*threadid][j+8];
// acc[2*threadid+1] += red_smem[2*threadid+1][j+8];
// }

// oval =
// __float22half2_rn(make_float2(acc[2*threadid],acc[2*threadid+1]));
// c[blockIdx.x*NUM_A_ROWS_PER_BLOCK/2+threadid] = oval;
//}
}
}

// define the kernel calling code:
// template <typename T>
void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int rows_per_block = 4) {
float4* af4 = reinterpret_cast<float4*>(in_a);
auto* bf4 = reinterpret_cast<__half2*>(in_b);
auto* c = reinterpret_cast<__half2*>(out_c);
// constexpr int A_ROWS_PER_BLOCK = 8;
const int NUM_THREADS = K * 2 / 16;

// NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle
// operations.
const int NUM_THREADS =
K * 2 / 16 % WARP_SIZE == 0
? K * 2 / 16
: K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE);

int NUM_BLOCKS = M / rows_per_block;

if (rows_per_block == 2) {
LLGemm1_kernel<2><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c);
LLGemm1_kernel<2><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
} else if (rows_per_block == 4) {
LLGemm1_kernel<4><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c);
LLGemm1_kernel<4><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
} else if (rows_per_block == 8) {
LLGemm1_kernel<8><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c);
LLGemm1_kernel<8><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
} else if (rows_per_block == 16) {
LLGemm1_kernel<16><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c);
LLGemm1_kernel<16><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
} else {
NUM_BLOCKS = M / 4;
LLGemm1_kernel<4><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c);
LLGemm1_kernel<4><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
}

cudaError_t err = cudaGetLastError();
Expand Down

0 comments on commit d3da246

Please sign in to comment.