Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix LLMM1 kernel #28

Merged
merged 5 commits into from
Jun 14, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 40 additions & 90 deletions csrc/custom/custom_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,55 +26,48 @@ __device__ __forceinline__ float4 load_ntmprl(const float4* addr) {
// TBlock fetches entire rows of A, and entire col of B (K dimension); assume
// N=1 for time being grid is M/A_NUM_ROWS blocks
template <int NUM_A_ROWS_PER_BLOCK>
__global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c) {
__global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c,
const int K) {
__shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE];
const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * blockDim.x;
// int row_addr_1 = row_addr + CUDA_NUM_THREADS;
// int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS;
// int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS;
const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8;
const int threadid = threadIdx.x;
const int warp = threadIdx.x / WARP_SIZE;
const int lane = threadIdx.x % WARP_SIZE;
const int num_warps = blockDim.x / WARP_SIZE;
const int qwarpid = threadid / 16;
const int qthreadid = threadid % 16;
float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK];
// float4 colB_elem4;
__half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w;
float4 sum4; //[NUM_A_ROWS_PER_BLOCK];
float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0;
float4 sum4; //[NUM_A_ROWS_PER_BLOCK];
float acc[NUM_A_ROWS_PER_BLOCK] = {0.0};
__half2 acch2;
__half2 oval;

// rowA_elem4 = af4[row_addr + threadid];
//__syncthreads();
// rowA_elem4_1 = af4[row_addr_1 + threadid];
// rowA_elem4_2 = af4[row_addr_2 + threadid];
// rowA_elem4_3 = af4[row_addr_3 + threadid];
// As we later use warp shuffle operations, we may have more threads in the
// block than the actual available data, hence the if guard here.
if (threadid * 8 < K) {
#pragma unroll
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
rowA_elem4[i] = load_ntmprl(&af4[row_addr + i * blockDim.x + threadid]);
// rowA_elem4[i] = af4[row_addr + i*blockDim.x + threadid];
//__syncthreads();
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
// rowA_elem4[i] holds 8 * half numbers seen as a single float4.
rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]);
}
}

colB_elem4x = bf4[threadid * 4 + 0];
colB_elem4y = bf4[threadid * 4 + 1];
colB_elem4z = bf4[threadid * 4 + 2];
colB_elem4w = bf4[threadid * 4 + 3];

// __syncthreads();
__half2 Af2;
__half2 Bf2;
float2 S;
// auto Bh2ptr = reinterpret_cast<__half2 *>(&colB_elem4);
// auto Bf2x = *Bh2ptr;
// auto Bf2y = *(Bh2ptr+1);
// auto Bf2z = *(Bh2ptr+2);
// auto Bf2w = *(Bh2ptr+3);

auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4);
__half2* ah2lptr;

#pragma unroll
for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) {
// Multiply-add on 8 half.
ah2lptr = Ah2ptr + i * 4;
Af2 = *(ah2lptr);
acch2 = __hmul2(Af2, colB_elem4x);
Expand All @@ -85,9 +78,14 @@ __global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c) {
Af2 = *(ah2lptr + 3);
acch2 = __hfma2(Af2, colB_elem4w, acch2);
S = __half22float2(acch2);
acc[i] = S.x + S.y;

// See comment above concerning the if guard.
if (threadid * 8 < K) {
acc[i] = S.x + S.y; // accumulation on float
}
}

// all reduce across warp.
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
#pragma unroll
Expand All @@ -97,104 +95,56 @@ __global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c) {
}

// Warp leaders store the data to shared memory.
// if (lane == 0) {
// #pragma unroll
// for (int i=0; i<NUM_A_ROWS_PER_BLOCK; i++) {
// red_smem[i][warp] = acc[i];
// }
// }

if (lane < NUM_A_ROWS_PER_BLOCK) {
red_smem[lane][warp] = acc[lane];
}

// Make sure the data is in shared memory.
__syncthreads();

if (qwarpid < NUM_A_ROWS_PER_BLOCK) {
// if (threadid<64) {
// #pragma unroll
// for (int i=0; i<NUM_A_ROWS_PER_BLOCK/2; i++) {
// acc[i+2*qwarpid] = 0.0;
// }
////acc[qwarpid] = 0.0;

////if (qthreadid<num_warps) {
// #pragma unroll
// for (int i=0; i<NUM_A_ROWS_PER_BLOCK/2; i++) {
// acc[i+2*qwarpid] = red_smem[i+2*qwarpid][qthreadid];
// }
////acc[qwarpid] = red_smem[qwarpid][qthreadid];

////}
acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f;
// if (threadid<32) {
#pragma unroll
for (int mask = 16 / 2; mask >= 1; mask /= 2) {
// #pragma unroll
// for (int i=0; i<NUM_A_ROWS_PER_BLOCK/2; i++) {
// acc[i+2*qwarpid] += __shfl_xor(acc[i+2*qwarpid], mask);
// }
acc[qwarpid] += __shfl_xor(acc[qwarpid], mask);
}
float oval2 = __shfl_xor(acc[qwarpid], 16);
// acc[1] = __shfl_xor(acc[1],16);
// acc[3] = __shfl_xor(acc[3],16);
//}
// __syncthreads();
// if (threadid < NUM_A_ROWS_PER_BLOCK/2) {

if (threadid % WARP_SIZE == 0 or threadid % WARP_SIZE == 32) {
// oval =
// __float22half2_rn(make_float2(acc[2*threadid],acc[2*threadid+1])); oval
// = __float22half2_rn(make_float2(acc[2*qwarpid],acc[2*qwarpid+1])); oval
// = __float22half2_rn(make_float2(acc[qwarpid],acc[qwarpid+1]));
oval = __float22half2_rn(make_float2(acc[qwarpid], oval2));
c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval;
}
} // threadid<WARP_SIZE

// if (threadid < NUM_A_ROWS_PER_BLOCK/2) {
// acc[2*threadid] = 0.0;
// acc[2*threadid+1] = 0.0;
//
// if (num_warps>8) {
// #pragma unroll
// for (int j=0; j<8; j++) {
// acc[2*threadid] += red_smem[2*threadid][j];
// acc[2*threadid+1] += red_smem[2*threadid+1][j];
// }
// }
// #pragma unroll
// for (int j=0; j<num_warps-8; j++) {
// acc[2*threadid] += red_smem[2*threadid][j+8];
// acc[2*threadid+1] += red_smem[2*threadid+1][j+8];
// }

// oval =
// __float22half2_rn(make_float2(acc[2*threadid],acc[2*threadid+1]));
// c[blockIdx.x*NUM_A_ROWS_PER_BLOCK/2+threadid] = oval;
//}
}
}

// define the kernel calling code:
// template <typename T>
void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int rows_per_block = 4) {
float4* af4 = reinterpret_cast<float4*>(in_a);
auto* bf4 = reinterpret_cast<__half2*>(in_b);
auto* c = reinterpret_cast<__half2*>(out_c);
// constexpr int A_ROWS_PER_BLOCK = 8;
const int NUM_THREADS = K * 2 / 16;

// NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle
// operations.
const int NUM_THREADS =
K * 2 / 16 % WARP_SIZE == 0
? K * 2 / 16
: K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE);

int NUM_BLOCKS = M / rows_per_block;

if (rows_per_block == 2) {
LLGemm1_kernel<2><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c);
LLGemm1_kernel<2><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
} else if (rows_per_block == 4) {
LLGemm1_kernel<4><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c);
LLGemm1_kernel<4><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
} else if (rows_per_block == 8) {
LLGemm1_kernel<8><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c);
LLGemm1_kernel<8><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
} else if (rows_per_block == 16) {
LLGemm1_kernel<16><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c);
LLGemm1_kernel<16><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
} else {
NUM_BLOCKS = M / 4;
LLGemm1_kernel<4><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c);
LLGemm1_kernel<4><<<NUM_BLOCKS, NUM_THREADS, 0, stream>>>(af4, bf4, c, K);
}

cudaError_t err = cudaGetLastError();
Expand Down