From 9e3785b2557db892eda7a72051ff7f405cffb35a Mon Sep 17 00:00:00 2001 From: Hashem Hashemi Date: Mon, 17 Jun 2024 15:31:02 -0500 Subject: [PATCH 01/12] adds wvSpltK optimization for skinny gemm. --- csrc/custom/custom.cu | 15 + csrc/custom/custom_kernels.cu | 1436 ++++++++++++++++++++++ vllm/model_executor/layers/tuned_gemm.py | 14 +- 3 files changed, 1461 insertions(+), 4 deletions(-) diff --git a/csrc/custom/custom.cu b/csrc/custom/custom.cu index 3da25ece3e87c..5c8beed37b304 100644 --- a/csrc/custom/custom.cu +++ b/csrc/custom/custom.cu @@ -39,6 +39,20 @@ void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, at::cuda::getCurrentCUDAStream(), rows_per_block); } +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K, const int N, + cudaStream_t stream, const int CuCount); + +void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, + const int N_in, const int CuCount) { + int M = in_a.size(0); + int K = in_a.size(1); + int N = N_in; + wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, + at::cuda::getCurrentCUDAStream(), CuCount); +} + + + void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, cudaStream_t stream, const int solidx); @@ -90,5 +104,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("LLZZ", &LLZZ); m.def("paged_attention_custom", &paged_attention_custom, "PagedAttention LL4Mi Custom."); + m.def("wvSpltK", &wvSpltK); // m.def("MMCustomGPU", &MMCustomGPU); } diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index afecf82eb3d77..09c28d1a46eca 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -2,6 +2,9 @@ #include #include #include +#include +#include "hsa/hsa.h" +#include "hsa/hsa_ext_amd.h" constexpr int WARP_SIZE = 64; @@ -309,3 +312,1436 @@ void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, if (cudaSuccess != err) throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); } + +///////////////////////////////////////////// + +using half8 = __attribute__((__vector_size__(4 * sizeof(float)))) float; + +/*template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); + //return *((T*)addr); +}*/ + +#define THRDS 64 +#define YTILE 2 +#define WvPrGrp 16 +#define A_CHUNK 8 +#define UNRL 2 +#define M 1 +#define DTYPE half + +__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) + { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32*1024); k += THRDS * WvPrGrp * A_CHUNK) + { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + //uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for bank-conflict-free readback + + if (k_in >= min(K * M, 32*1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of availble columns + //---------------------------------------------------- + while (n < N) + { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // splitted across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m=0; m= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) + { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + //if (k_ >= K) break; + //bool skip = (k_ >= K); + //bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- +#if (YTILE >= 2) + //if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif +#if (YTILE >= 3) + //if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); +#endif +#if (YTILE >= 4) + //if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); +#endif +#if (YTILE >= 5) + //if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); +#endif +#if (YTILE >= 6) + //if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); +#endif +#if (YTILE >= 7) + //if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); +#endif +#if (YTILE >= 8) + //if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif +/* +#if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = (loadnt((half8*)(&B_[8 * K]))); +#endif +#if (YTILE >= 10) + if (n+9>=N) continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); +#endif +#if (YTILE >= 11) + if (n+10>=N) continue; bigB10[k2].h8 = (loadnt((half8*)(&B_[10 * K]))); +#endif +*/ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m=0; m < M; m++) + { + if (k_+K*m < 32*1024) + bigA[m][k2] = *((const bigType*)(&(s[k_+K*m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_+K*m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t m = 0; m < M; m++) + { +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) + { + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][0]) : "0"(sum[m][0]), "v" (bigA[m][k2].f[b]), "v" (bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][1]) : "0"(sum[m][1]), "v" (bigA[m][k2].f[b]), "v" (bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][2]) : "0"(sum[m][2]), "v" (bigA[m][k2].f[b]), "v" (bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][3]) : "0"(sum[m][3]), "v" (bigA[m][k2].f[b]), "v" (bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][4]) : "0"(sum[m][4]), "v" (bigA[m][k2].f[b]), "v" (bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][5]) : "0"(sum[m][5]), "v" (bigA[m][k2].f[b]), "v" (bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][6]) : "0"(sum[m][6]), "v" (bigA[m][k2].f[b]), "v" (bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][7]) : "0"(sum[m][7]), "v" (bigA[m][k2].f[b]), "v" (bigB7[k2].f[b])); +#endif +#if (YTILE >= 9) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][8]) : "0"(sum[m][8]), "v" (bigA[m][k2].f[b]), "v" (bigB8[k2].f[b])); +#endif +#if (YTILE >= 10) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][9]) : "0"(sum[m][9]), "v" (bigA[m][k2].f[b]), "v" (bigB9[k2].f[b])); +#endif +#if (YTILE >= 11) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][10]) : "0"(sum[m][10]), "v" (bigA[m][k2].f[b]), "v" (bigB10[k2].f[b])); +#endif + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) + { + for (int y = 0; y < YTILE; y++) + { + //for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + //} + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + + } + } + + if (threadIdx.x == 0) + { + for (int m = 0; m < M; m++) + { + for (int i = 0; i < YTILE; i++) + { + if (commitColumn[i]) + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * WvPrGrp * YTILE; + + //if (threadIdx.x == 0) + //n = atomicAdd(((unsigned int*)(C)), YTILE); + //n = __shfl(n, 0, 64); + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + } + + +} + +#undef YTILE +#undef UNRL +#undef M + + +#define YTILE 2 +#define UNRL 2 +#define M 2 + +__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) + { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32*1024); k += THRDS * WvPrGrp * A_CHUNK) + { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + //uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for bank-conflict-free readback + + if (k_in >= min(K * M, 32*1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of availble columns + //---------------------------------------------------- + while (n < N) + { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // splitted across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m=0; m= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) + { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + //if (k_ >= K) break; + //bool skip = (k_ >= K); + //bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- +#if (YTILE >= 2) + //if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif +#if (YTILE >= 3) + //if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); +#endif +#if (YTILE >= 4) + //if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); +#endif +#if (YTILE >= 5) + //if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); +#endif +#if (YTILE >= 6) + //if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); +#endif +#if (YTILE >= 7) + //if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); +#endif +#if (YTILE >= 8) + //if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif +/* +#if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = (loadnt((half8*)(&B_[8 * K]))); +#endif +#if (YTILE >= 10) + if (n+9>=N) continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); +#endif +#if (YTILE >= 11) + if (n+10>=N) continue; bigB10[k2].h8 = (loadnt((half8*)(&B_[10 * K]))); +#endif +*/ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m=0; m < M; m++) + { + if (k_+K*m < 32*1024) + bigA[m][k2] = *((const bigType*)(&(s[k_+K*m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_+K*m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t m = 0; m < M; m++) + { +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) + { + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][0]) : "0"(sum[m][0]), "v" (bigA[m][k2].f[b]), "v" (bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][1]) : "0"(sum[m][1]), "v" (bigA[m][k2].f[b]), "v" (bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][2]) : "0"(sum[m][2]), "v" (bigA[m][k2].f[b]), "v" (bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][3]) : "0"(sum[m][3]), "v" (bigA[m][k2].f[b]), "v" (bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][4]) : "0"(sum[m][4]), "v" (bigA[m][k2].f[b]), "v" (bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][5]) : "0"(sum[m][5]), "v" (bigA[m][k2].f[b]), "v" (bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][6]) : "0"(sum[m][6]), "v" (bigA[m][k2].f[b]), "v" (bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][7]) : "0"(sum[m][7]), "v" (bigA[m][k2].f[b]), "v" (bigB7[k2].f[b])); +#endif +#if (YTILE >= 9) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][8]) : "0"(sum[m][8]), "v" (bigA[m][k2].f[b]), "v" (bigB8[k2].f[b])); +#endif +#if (YTILE >= 10) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][9]) : "0"(sum[m][9]), "v" (bigA[m][k2].f[b]), "v" (bigB9[k2].f[b])); +#endif +#if (YTILE >= 11) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][10]) : "0"(sum[m][10]), "v" (bigA[m][k2].f[b]), "v" (bigB10[k2].f[b])); +#endif + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) + { + for (int y = 0; y < YTILE; y++) + { + //for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + //} + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + + } + } + + if (threadIdx.x == 0) + { + for (int m = 0; m < M; m++) + { + for (int i = 0; i < YTILE; i++) + { + if (commitColumn[i]) + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * WvPrGrp * YTILE; + + //if (threadIdx.x == 0) + //n = atomicAdd(((unsigned int*)(C)), YTILE); + //n = __shfl(n, 0, 64); + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + } + + +} + +#undef YTILE +#undef UNRL +#undef M + +#define YTILE 5 +#define UNRL 2 +#define M 3 + +__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) + { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32*1024); k += THRDS * WvPrGrp * A_CHUNK) + { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + //uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for bank-conflict-free readback + + if (k_in >= min(K * M, 32*1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of availble columns + //---------------------------------------------------- + while (n < N) + { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // splitted across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m=0; m= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) + { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + //if (k_ >= K) break; + //bool skip = (k_ >= K); + //bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- +#if (YTILE >= 2) + //if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif +#if (YTILE >= 3) + //if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); +#endif +#if (YTILE >= 4) + //if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); +#endif +#if (YTILE >= 5) + //if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); +#endif +#if (YTILE >= 6) + //if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); +#endif +#if (YTILE >= 7) + //if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); +#endif +#if (YTILE >= 8) + //if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif +/* +#if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = (loadnt((half8*)(&B_[8 * K]))); +#endif +#if (YTILE >= 10) + if (n+9>=N) continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); +#endif +#if (YTILE >= 11) + if (n+10>=N) continue; bigB10[k2].h8 = (loadnt((half8*)(&B_[10 * K]))); +#endif +*/ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m=0; m < M; m++) + { + if (k_+K*m < 32*1024) + bigA[m][k2] = *((const bigType*)(&(s[k_+K*m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_+K*m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t m = 0; m < M; m++) + { +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) + { + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][0]) : "0"(sum[m][0]), "v" (bigA[m][k2].f[b]), "v" (bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][1]) : "0"(sum[m][1]), "v" (bigA[m][k2].f[b]), "v" (bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][2]) : "0"(sum[m][2]), "v" (bigA[m][k2].f[b]), "v" (bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][3]) : "0"(sum[m][3]), "v" (bigA[m][k2].f[b]), "v" (bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][4]) : "0"(sum[m][4]), "v" (bigA[m][k2].f[b]), "v" (bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][5]) : "0"(sum[m][5]), "v" (bigA[m][k2].f[b]), "v" (bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][6]) : "0"(sum[m][6]), "v" (bigA[m][k2].f[b]), "v" (bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][7]) : "0"(sum[m][7]), "v" (bigA[m][k2].f[b]), "v" (bigB7[k2].f[b])); +#endif +#if (YTILE >= 9) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][8]) : "0"(sum[m][8]), "v" (bigA[m][k2].f[b]), "v" (bigB8[k2].f[b])); +#endif +#if (YTILE >= 10) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][9]) : "0"(sum[m][9]), "v" (bigA[m][k2].f[b]), "v" (bigB9[k2].f[b])); +#endif +#if (YTILE >= 11) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][10]) : "0"(sum[m][10]), "v" (bigA[m][k2].f[b]), "v" (bigB10[k2].f[b])); +#endif + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) + { + for (int y = 0; y < YTILE; y++) + { + //for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + //} + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + + } + } + + if (threadIdx.x == 0) + { + for (int m = 0; m < M; m++) + { + for (int i = 0; i < YTILE; i++) + { + if (commitColumn[i]) + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * WvPrGrp * YTILE; + + //if (threadIdx.x == 0) + //n = atomicAdd(((unsigned int*)(C)), YTILE); + //n = __shfl(n, 0, 64); + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + } + + +} + +#undef YTILE +#undef UNRL +#undef M + +#define YTILE 7 +#define UNRL 1 +#define M 4 + +__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) + { + commitColumn[i] = 1; + } + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32*1024); k += THRDS * WvPrGrp * A_CHUNK) + { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + //uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for bank-conflict-free readback + + if (k_in >= min(K * M, 32*1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of availble columns + //---------------------------------------------------- + while (n < N) + { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // splitted across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m=0; m= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) + { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + //if (k_ >= K) break; + //bool skip = (k_ >= K); + //bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- +#if (YTILE >= 2) + //if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif +#if (YTILE >= 3) + //if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); +#endif +#if (YTILE >= 4) + //if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); +#endif +#if (YTILE >= 5) + //if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); +#endif +#if (YTILE >= 6) + //if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); +#endif +#if (YTILE >= 7) + //if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); +#endif +#if (YTILE >= 8) + //if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif +/* +#if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = (loadnt((half8*)(&B_[8 * K]))); +#endif +#if (YTILE >= 10) + if (n+9>=N) continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); +#endif +#if (YTILE >= 11) + if (n+10>=N) continue; bigB10[k2].h8 = (loadnt((half8*)(&B_[10 * K]))); +#endif +*/ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m=0; m < M; m++) + { + if (k_+K*m < 32*1024) + bigA[m][k2] = *((const bigType*)(&(s[k_+K*m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_+K*m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t m = 0; m < M; m++) + { +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) + { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) + { + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][0]) : "0"(sum[m][0]), "v" (bigA[m][k2].f[b]), "v" (bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][1]) : "0"(sum[m][1]), "v" (bigA[m][k2].f[b]), "v" (bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][2]) : "0"(sum[m][2]), "v" (bigA[m][k2].f[b]), "v" (bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][3]) : "0"(sum[m][3]), "v" (bigA[m][k2].f[b]), "v" (bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][4]) : "0"(sum[m][4]), "v" (bigA[m][k2].f[b]), "v" (bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][5]) : "0"(sum[m][5]), "v" (bigA[m][k2].f[b]), "v" (bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][6]) : "0"(sum[m][6]), "v" (bigA[m][k2].f[b]), "v" (bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][7]) : "0"(sum[m][7]), "v" (bigA[m][k2].f[b]), "v" (bigB7[k2].f[b])); +#endif +#if (YTILE >= 9) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][8]) : "0"(sum[m][8]), "v" (bigA[m][k2].f[b]), "v" (bigB8[k2].f[b])); +#endif +#if (YTILE >= 10) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][9]) : "0"(sum[m][9]), "v" (bigA[m][k2].f[b]), "v" (bigB9[k2].f[b])); +#endif +#if (YTILE >= 11) + asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][10]) : "0"(sum[m][10]), "v" (bigA[m][k2].f[b]), "v" (bigB10[k2].f[b])); +#endif + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) + { + for (int y = 0; y < YTILE; y++) + { + //for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + //} + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); + + } + } + + if (threadIdx.x == 0) + { + for (int m = 0; m < M; m++) + { + for (int i = 0; i < YTILE; i++) + { + if (commitColumn[i]) + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * WvPrGrp * YTILE; + + //if (threadIdx.x == 0) + //n = atomicAdd(((unsigned int*)(C)), YTILE); + //n = __shfl(n, 0, 64); + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if(n < N && (n + YTILE) >= N) + { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) + { + commitColumn[i] = 0; + } + n = startColumn; + } + } + + +} + + + +void wvSpltK_(void *in_a, void *in_b, void *out_c, const int M_in, const int K_in,const int N_in, cudaStream_t stream, const int CuCount = 0) { + dim3 grid(CuCount); + dim3 block(THRDS, WvPrGrp); + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto *c = reinterpret_cast(out_c); + switch(N_in) { + case 1: + wvSpltK_hf_m1_<<>>(K_in, M_in, af4, bf4, c, CuCount); + break; + case 2: + wvSpltK_hf_m2_<<>>(K_in, M_in, af4, bf4, c, CuCount); + break; + case 3: + wvSpltK_hf_m3_<<>>(K_in, M_in, af4, bf4, c, CuCount); + break; + case 4: + wvSpltK_hf_m4_<<>>(K_in, M_in, af4, bf4, c, CuCount); + break; + default: + throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + "," + std::to_string(K_in) + "," + std::to_string(N_in)); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); + } +} + + + diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index a3d299c05caef..6c78430adf4d0 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -23,6 +23,7 @@ def __init__(self): self.bestsols = {} self.load_best_sols() self.create_ds() + self.CuCount = torch.cuda.get_device_properties(device='cuda').multi_processor_count if (self.save_gemm == 1): self.tuned_df = pd.DataFrame(columns=['M', 'N', 'K']) @@ -69,13 +70,12 @@ def mm(self, inp, weights): k = inp_view.shape[1] soltype, solidx = self.query_sol(m=m, n=n, k=k) if soltype == 1: - #print(">>> found hipblas") + print(">>> found hipblas") out = hipb_mm(inp_view, weights.t(), solidx) elif soltype == 2: - #print(">>> found rocblas") + print(">>> found rocblas") out = rocb_mm(inp_view, weights.t(), solidx) else: - if (self.save_gemm == 1): #print('>>>Tgemm Default',inp_view.shape, # inp.shape,weights.shape,soltype,solidx) @@ -89,7 +89,13 @@ def mm(self, inp, weights): ]).drop_duplicates() self.tuned_df.to_csv(self.untune_path, index=False) - if n == 1 and inp_view.dtype == torch.float16: + if ((n == 4 or n == 3 or n== 2 or n == 1 ) and inp_view.dtype == torch.float16) : + out = torch.empty(inp_view.shape[0], + weights.shape[0], + dtype=inp_view.dtype, + device='cuda') + _custom_C.wvSpltK(weights, inp_view, out, n, self.CuCount) + elif n == 1 and inp_view.dtype == torch.float16: out = torch.empty(inp_view.shape[0], weights.shape[0], dtype=inp_view.dtype, From 0649fcdb427f4a9a8f7e087606b2a2f005249615 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 17 Jun 2024 21:06:27 +0000 Subject: [PATCH 02/12] style adjustments, remove hipbls/rocblas comments. --- csrc/custom/custom_kernels.cu | 2492 +++++++++++----------- vllm/model_executor/layers/tuned_gemm.py | 4 +- 2 files changed, 1261 insertions(+), 1235 deletions(-) diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index 09c28d1a46eca..5c37701189098 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -331,691 +331,702 @@ __device__ __forceinline__ T loadnt(T* addr) { #define M 1 #define DTYPE half -__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { - - union bigType { - DTYPE h[A_CHUNK]; - float f[A_CHUNK / 2]; - float2 f2[A_CHUNK / 4]; - double d[A_CHUNK / 4]; - __int128_t b128; - half8 h8; - }; +__global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } - //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU - // Goal is to bring the activation matrix A to the LDS - // and use it across the lifetime of the work group - // TODO: When activation matrix is larger than 64 KB - // then this is not goint to work! - //---------------------------------------------------- - __shared__ half s[1024 * 32]; - - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- - uint32_t commitColumn[YTILE]; - for (uint32_t i = 0; i < YTILE; i++) - { - commitColumn[i] = 1; + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of availble columns + //---------------------------------------------------- + while (n < N) { //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. + // 'sum' accumulates the matrix A x B computation + // splitted across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. //---------------------------------------------------- - uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if(n < N && (n + YTILE) >= N) - { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) - { - commitColumn[i] = 0; - } - n = startColumn; - } + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; +#if (YTILE >= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif //---------------------------------------------------- - // Fetch the activation matrix to LDS - // Loop iteration: + // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements - // - Each WG will fetch 512 * 16 => 8K elements - // - Then the WG will move to another 8 K elements - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * M, 32*1024); k += THRDS * WvPrGrp * A_CHUNK) - { - uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - // Transpose of A implementation - //uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for bank-conflict-free readback - - if (k_in >= min(K * M, 32*1024)) break; - - ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; - } - __syncthreads(); - - float sum[M][YTILE]; - - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of availble columns + // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- - while (n < N) - { + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // if (k_ >= K) break; + // bool skip = (k_ >= K); + // bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // splitted across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. + // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int m=0; m= 2) - bigType bigB1[UNRL]; + // if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); #endif #if (YTILE >= 3) - bigType bigB2[UNRL]; + // if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); #endif #if (YTILE >= 4) - bigType bigB3[UNRL]; + // if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); #endif #if (YTILE >= 5) - bigType bigB4[UNRL]; + // if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); #endif #if (YTILE >= 6) - bigType bigB5[UNRL]; + // if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); #endif #if (YTILE >= 7) - bigType bigB6[UNRL]; + // if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); #endif #if (YTILE >= 8) - bigType bigB7[UNRL]; + // if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif + /* + #if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = + (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) + continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if + (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = + (loadnt((half8*)(&B_[10 * K]))); #endif + */ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t m = 0; m < M; m++) { +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); #endif #if (YTILE >= 9) - bigType bigB8[UNRL]; + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][8]) + : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); #endif #if (YTILE >= 10) - bigType bigB9[UNRL]; + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][9]) + : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); #endif #if (YTILE >= 11) - bigType bigB10[UNRL]; + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][10]) + : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); #endif - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) - { - // Fetch the weight matrix from memory! -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) - { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - //if (k_ >= K) break; - //bool skip = (k_ >= K); - //bool dummy = (k_ >= K); - - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- -#if (YTILE >= 2) - //if (n+1>=N) continue; - bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); -#endif -#if (YTILE >= 3) - //if (n+2>=N) continue; - bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); -#endif -#if (YTILE >= 4) - //if (n+3>=N) continue; - bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); -#endif -#if (YTILE >= 5) - //if (n+4>=N) continue; - bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); -#endif -#if (YTILE >= 6) - //if (n+5>=N) continue; - bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); -#endif -#if (YTILE >= 7) - //if (n+6>=N) continue; - bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); -#endif -#if (YTILE >= 8) - //if (n+7>=N) continue; - bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); -#endif -/* -#if (YTILE >= 9) - if (n+8>=N) continue; bigB8[k2].h8 = (loadnt((half8*)(&B_[8 * K]))); -#endif -#if (YTILE >= 10) - if (n+9>=N) continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); -#endif -#if (YTILE >= 11) - if (n+10>=N) continue; bigB10[k2].h8 = (loadnt((half8*)(&B_[10 * K]))); -#endif -*/ - } - - // Fetch activation matrix from either just LDS or from both LDS / memory -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) - { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Fetch A activation matrix in interleaved fashion from LDS or memory - - for (int m=0; m < M; m++) - { - if (k_+K*m < 32*1024) - bigA[m][k2] = *((const bigType*)(&(s[k_+K*m]))); - else - bigA[m][k2] = *((const bigType*)(&(A[k_+K*m]))); - } - } - - // Do the matrix multiplication in interleaved manner -#pragma unroll - for (uint32_t m = 0; m < M; m++) - { -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) - { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Do the matrix multiplication of activation and weight matrix - // - Rememeber the accumulation is happening for K-split of 64! -#pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) - { - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][0]) : "0"(sum[m][0]), "v" (bigA[m][k2].f[b]), "v" (bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- -#if (YTILE >= 2) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][1]) : "0"(sum[m][1]), "v" (bigA[m][k2].f[b]), "v" (bigB1[k2].f[b])); -#endif -#if (YTILE >= 3) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][2]) : "0"(sum[m][2]), "v" (bigA[m][k2].f[b]), "v" (bigB2[k2].f[b])); -#endif -#if (YTILE >= 4) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][3]) : "0"(sum[m][3]), "v" (bigA[m][k2].f[b]), "v" (bigB3[k2].f[b])); -#endif -#if (YTILE >= 5) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][4]) : "0"(sum[m][4]), "v" (bigA[m][k2].f[b]), "v" (bigB4[k2].f[b])); -#endif -#if (YTILE >= 6) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][5]) : "0"(sum[m][5]), "v" (bigA[m][k2].f[b]), "v" (bigB5[k2].f[b])); -#endif -#if (YTILE >= 7) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][6]) : "0"(sum[m][6]), "v" (bigA[m][k2].f[b]), "v" (bigB6[k2].f[b])); -#endif -#if (YTILE >= 8) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][7]) : "0"(sum[m][7]), "v" (bigA[m][k2].f[b]), "v" (bigB7[k2].f[b])); -#endif -#if (YTILE >= 9) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][8]) : "0"(sum[m][8]), "v" (bigA[m][k2].f[b]), "v" (bigB8[k2].f[b])); -#endif -#if (YTILE >= 10) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][9]) : "0"(sum[m][9]), "v" (bigA[m][k2].f[b]), "v" (bigB9[k2].f[b])); -#endif -#if (YTILE >= 11) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][10]) : "0"(sum[m][10]), "v" (bigA[m][k2].f[b]), "v" (bigB10[k2].f[b])); -#endif - } - } - } - } - - //---------------------------------------------------- - // Final reduction step using shuffle - //---------------------------------------------------- - for (int m = 0; m < M; m++) - { - for (int y = 0; y < YTILE; y++) - { - //for (int offset = 64 / 2; offset > 4 ; offset /= 2) { - // sum[y] += __shfl_down(sum[y], offset); - //} - sum[m][y] += __shfl_down(sum[m][y], 32); - sum[m][y] += __shfl_down(sum[m][y], 16); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - - } - } - - if (threadIdx.x == 0) - { - for (int m = 0; m < M; m++) - { - for (int i = 0; i < YTILE; i++) - { - if (commitColumn[i]) - C[n + i + m * N] = __float2half(sum[m][i]); - } - } + } } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + // for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + // } + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } - n += CuCount * WvPrGrp * YTILE; - - //if (threadIdx.x == 0) - //n = atomicAdd(((unsigned int*)(C)), YTILE); - //n = __shfl(n, 0, 64); - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if(n < N && (n + YTILE) >= N) - { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) - { - commitColumn[i] = 0; - } - n = startColumn; + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); } + } } + n += CuCount * WvPrGrp * YTILE; + + // if (threadIdx.x == 0) + // n = atomicAdd(((unsigned int*)(C)), YTILE); + // n = __shfl(n, 0, 64); + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } } #undef YTILE #undef UNRL #undef M - #define YTILE 2 #define UNRL 2 -#define M 2 - -__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { - - union bigType { - DTYPE h[A_CHUNK]; - float f[A_CHUNK / 2]; - float2 f2[A_CHUNK / 4]; - double d[A_CHUNK / 4]; - __int128_t b128; - half8 h8; - }; +#define M 2 + +__global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } - //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU - // Goal is to bring the activation matrix A to the LDS - // and use it across the lifetime of the work group - // TODO: When activation matrix is larger than 64 KB - // then this is not goint to work! - //---------------------------------------------------- - __shared__ half s[1024 * 32]; - - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- - uint32_t commitColumn[YTILE]; - for (uint32_t i = 0; i < YTILE; i++) - { - commitColumn[i] = 1; + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; } + n = startColumn; + } + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of availble columns + //---------------------------------------------------- + while (n < N) { //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. + // 'sum' accumulates the matrix A x B computation + // splitted across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. //---------------------------------------------------- - uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if(n < N && (n + YTILE) >= N) - { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) - { - commitColumn[i] = 0; - } - n = startColumn; - } + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; +#if (YTILE >= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif //---------------------------------------------------- - // Fetch the activation matrix to LDS - // Loop iteration: + // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements - // - Each WG will fetch 512 * 16 => 8K elements - // - Then the WG will move to another 8 K elements - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * M, 32*1024); k += THRDS * WvPrGrp * A_CHUNK) - { - uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - // Transpose of A implementation - //uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for bank-conflict-free readback - - if (k_in >= min(K * M, 32*1024)) break; - - ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; - } - __syncthreads(); - - float sum[M][YTILE]; - - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of availble columns + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- - while (n < N) - { + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // if (k_ >= K) break; + // bool skip = (k_ >= K); + // bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // splitted across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. + // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int m=0; m= 2) - bigType bigB1[UNRL]; + // if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); #endif #if (YTILE >= 3) - bigType bigB2[UNRL]; + // if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); #endif #if (YTILE >= 4) - bigType bigB3[UNRL]; + // if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); #endif #if (YTILE >= 5) - bigType bigB4[UNRL]; + // if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); #endif #if (YTILE >= 6) - bigType bigB5[UNRL]; + // if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); #endif #if (YTILE >= 7) - bigType bigB6[UNRL]; + // if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); #endif #if (YTILE >= 8) - bigType bigB7[UNRL]; + // if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif + /* + #if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = + (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) + continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if + (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = + (loadnt((half8*)(&B_[10 * K]))); #endif + */ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t m = 0; m < M; m++) { +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); #endif #if (YTILE >= 9) - bigType bigB8[UNRL]; + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][8]) + : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); #endif #if (YTILE >= 10) - bigType bigB9[UNRL]; + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][9]) + : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); #endif #if (YTILE >= 11) - bigType bigB10[UNRL]; + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][10]) + : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); #endif - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) - { - // Fetch the weight matrix from memory! -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) - { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - //if (k_ >= K) break; - //bool skip = (k_ >= K); - //bool dummy = (k_ >= K); - - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- -#if (YTILE >= 2) - //if (n+1>=N) continue; - bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); -#endif -#if (YTILE >= 3) - //if (n+2>=N) continue; - bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); -#endif -#if (YTILE >= 4) - //if (n+3>=N) continue; - bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); -#endif -#if (YTILE >= 5) - //if (n+4>=N) continue; - bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); -#endif -#if (YTILE >= 6) - //if (n+5>=N) continue; - bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); -#endif -#if (YTILE >= 7) - //if (n+6>=N) continue; - bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); -#endif -#if (YTILE >= 8) - //if (n+7>=N) continue; - bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); -#endif -/* -#if (YTILE >= 9) - if (n+8>=N) continue; bigB8[k2].h8 = (loadnt((half8*)(&B_[8 * K]))); -#endif -#if (YTILE >= 10) - if (n+9>=N) continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); -#endif -#if (YTILE >= 11) - if (n+10>=N) continue; bigB10[k2].h8 = (loadnt((half8*)(&B_[10 * K]))); -#endif -*/ - } - - // Fetch activation matrix from either just LDS or from both LDS / memory -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) - { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Fetch A activation matrix in interleaved fashion from LDS or memory - - for (int m=0; m < M; m++) - { - if (k_+K*m < 32*1024) - bigA[m][k2] = *((const bigType*)(&(s[k_+K*m]))); - else - bigA[m][k2] = *((const bigType*)(&(A[k_+K*m]))); - } - } - - // Do the matrix multiplication in interleaved manner -#pragma unroll - for (uint32_t m = 0; m < M; m++) - { -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) - { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Do the matrix multiplication of activation and weight matrix - // - Rememeber the accumulation is happening for K-split of 64! -#pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) - { - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][0]) : "0"(sum[m][0]), "v" (bigA[m][k2].f[b]), "v" (bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- -#if (YTILE >= 2) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][1]) : "0"(sum[m][1]), "v" (bigA[m][k2].f[b]), "v" (bigB1[k2].f[b])); -#endif -#if (YTILE >= 3) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][2]) : "0"(sum[m][2]), "v" (bigA[m][k2].f[b]), "v" (bigB2[k2].f[b])); -#endif -#if (YTILE >= 4) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][3]) : "0"(sum[m][3]), "v" (bigA[m][k2].f[b]), "v" (bigB3[k2].f[b])); -#endif -#if (YTILE >= 5) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][4]) : "0"(sum[m][4]), "v" (bigA[m][k2].f[b]), "v" (bigB4[k2].f[b])); -#endif -#if (YTILE >= 6) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][5]) : "0"(sum[m][5]), "v" (bigA[m][k2].f[b]), "v" (bigB5[k2].f[b])); -#endif -#if (YTILE >= 7) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][6]) : "0"(sum[m][6]), "v" (bigA[m][k2].f[b]), "v" (bigB6[k2].f[b])); -#endif -#if (YTILE >= 8) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][7]) : "0"(sum[m][7]), "v" (bigA[m][k2].f[b]), "v" (bigB7[k2].f[b])); -#endif -#if (YTILE >= 9) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][8]) : "0"(sum[m][8]), "v" (bigA[m][k2].f[b]), "v" (bigB8[k2].f[b])); -#endif -#if (YTILE >= 10) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][9]) : "0"(sum[m][9]), "v" (bigA[m][k2].f[b]), "v" (bigB9[k2].f[b])); -#endif -#if (YTILE >= 11) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][10]) : "0"(sum[m][10]), "v" (bigA[m][k2].f[b]), "v" (bigB10[k2].f[b])); -#endif - } - } - } - } - - //---------------------------------------------------- - // Final reduction step using shuffle - //---------------------------------------------------- - for (int m = 0; m < M; m++) - { - for (int y = 0; y < YTILE; y++) - { - //for (int offset = 64 / 2; offset > 4 ; offset /= 2) { - // sum[y] += __shfl_down(sum[y], offset); - //} - sum[m][y] += __shfl_down(sum[m][y], 32); - sum[m][y] += __shfl_down(sum[m][y], 16); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - - } - } - - if (threadIdx.x == 0) - { - for (int m = 0; m < M; m++) - { - for (int i = 0; i < YTILE; i++) - { - if (commitColumn[i]) - C[n + i + m * N] = __float2half(sum[m][i]); - } - } + } } + } + } - n += CuCount * WvPrGrp * YTILE; - - //if (threadIdx.x == 0) - //n = atomicAdd(((unsigned int*)(C)), YTILE); - //n = __shfl(n, 0, 64); - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if(n < N && (n + YTILE) >= N) - { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) - { - commitColumn[i] = 0; - } - n = startColumn; + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + // for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + // } + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); } + } } + n += CuCount * WvPrGrp * YTILE; + + // if (threadIdx.x == 0) + // n = atomicAdd(((unsigned int*)(C)), YTILE); + // n = __shfl(n, 0, 64); + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } } #undef YTILE @@ -1024,345 +1035,351 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, const D #define YTILE 5 #define UNRL 2 -#define M 3 - -__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { - - union bigType { - DTYPE h[A_CHUNK]; - float f[A_CHUNK / 2]; - float2 f2[A_CHUNK / 4]; - double d[A_CHUNK / 4]; - __int128_t b128; - half8 h8; - }; +#define M 3 + +__global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } - //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU - // Goal is to bring the activation matrix A to the LDS - // and use it across the lifetime of the work group - // TODO: When activation matrix is larger than 64 KB - // then this is not goint to work! - //---------------------------------------------------- - __shared__ half s[1024 * 32]; - - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- - uint32_t commitColumn[YTILE]; - for (uint32_t i = 0; i < YTILE; i++) - { - commitColumn[i] = 1; + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of availble columns + //---------------------------------------------------- + while (n < N) { //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. + // 'sum' accumulates the matrix A x B computation + // splitted across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. //---------------------------------------------------- - uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if(n < N && (n + YTILE) >= N) - { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) - { - commitColumn[i] = 0; - } - n = startColumn; - } + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; +#if (YTILE >= 2) + bigType bigB1[UNRL]; +#endif +#if (YTILE >= 3) + bigType bigB2[UNRL]; +#endif +#if (YTILE >= 4) + bigType bigB3[UNRL]; +#endif +#if (YTILE >= 5) + bigType bigB4[UNRL]; +#endif +#if (YTILE >= 6) + bigType bigB5[UNRL]; +#endif +#if (YTILE >= 7) + bigType bigB6[UNRL]; +#endif +#if (YTILE >= 8) + bigType bigB7[UNRL]; +#endif +#if (YTILE >= 9) + bigType bigB8[UNRL]; +#endif +#if (YTILE >= 10) + bigType bigB9[UNRL]; +#endif +#if (YTILE >= 11) + bigType bigB10[UNRL]; +#endif //---------------------------------------------------- - // Fetch the activation matrix to LDS - // Loop iteration: + // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements - // - Each WG will fetch 512 * 16 => 8K elements - // - Then the WG will move to another 8 K elements - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * M, 32*1024); k += THRDS * WvPrGrp * A_CHUNK) - { - uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - // Transpose of A implementation - //uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for bank-conflict-free readback - - if (k_in >= min(K * M, 32*1024)) break; - - ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; - } - __syncthreads(); - - float sum[M][YTILE]; - - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of availble columns + // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- - while (n < N) - { + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // if (k_ >= K) break; + // bool skip = (k_ >= K); + // bool dummy = (k_ >= K); + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // splitted across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. + // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int m=0; m= 2) - bigType bigB1[UNRL]; + // if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); #endif #if (YTILE >= 3) - bigType bigB2[UNRL]; + // if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); #endif #if (YTILE >= 4) - bigType bigB3[UNRL]; + // if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); #endif #if (YTILE >= 5) - bigType bigB4[UNRL]; + // if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); #endif #if (YTILE >= 6) - bigType bigB5[UNRL]; + // if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); #endif #if (YTILE >= 7) - bigType bigB6[UNRL]; + // if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); #endif #if (YTILE >= 8) - bigType bigB7[UNRL]; + // if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif + /* + #if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = + (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) + continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if + (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = + (loadnt((half8*)(&B_[10 * K]))); #endif + */ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t m = 0; m < M; m++) { +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); #endif #if (YTILE >= 9) - bigType bigB8[UNRL]; + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][8]) + : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); #endif #if (YTILE >= 10) - bigType bigB9[UNRL]; + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][9]) + : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); #endif #if (YTILE >= 11) - bigType bigB10[UNRL]; + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][10]) + : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); #endif - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) - { - // Fetch the weight matrix from memory! -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) - { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - //if (k_ >= K) break; - //bool skip = (k_ >= K); - //bool dummy = (k_ >= K); - - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- -#if (YTILE >= 2) - //if (n+1>=N) continue; - bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); -#endif -#if (YTILE >= 3) - //if (n+2>=N) continue; - bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); -#endif -#if (YTILE >= 4) - //if (n+3>=N) continue; - bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); -#endif -#if (YTILE >= 5) - //if (n+4>=N) continue; - bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); -#endif -#if (YTILE >= 6) - //if (n+5>=N) continue; - bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); -#endif -#if (YTILE >= 7) - //if (n+6>=N) continue; - bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); -#endif -#if (YTILE >= 8) - //if (n+7>=N) continue; - bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); -#endif -/* -#if (YTILE >= 9) - if (n+8>=N) continue; bigB8[k2].h8 = (loadnt((half8*)(&B_[8 * K]))); -#endif -#if (YTILE >= 10) - if (n+9>=N) continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); -#endif -#if (YTILE >= 11) - if (n+10>=N) continue; bigB10[k2].h8 = (loadnt((half8*)(&B_[10 * K]))); -#endif -*/ - } - - // Fetch activation matrix from either just LDS or from both LDS / memory -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) - { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Fetch A activation matrix in interleaved fashion from LDS or memory - - for (int m=0; m < M; m++) - { - if (k_+K*m < 32*1024) - bigA[m][k2] = *((const bigType*)(&(s[k_+K*m]))); - else - bigA[m][k2] = *((const bigType*)(&(A[k_+K*m]))); - } - } - - // Do the matrix multiplication in interleaved manner -#pragma unroll - for (uint32_t m = 0; m < M; m++) - { -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) - { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Do the matrix multiplication of activation and weight matrix - // - Rememeber the accumulation is happening for K-split of 64! -#pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) - { - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][0]) : "0"(sum[m][0]), "v" (bigA[m][k2].f[b]), "v" (bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- -#if (YTILE >= 2) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][1]) : "0"(sum[m][1]), "v" (bigA[m][k2].f[b]), "v" (bigB1[k2].f[b])); -#endif -#if (YTILE >= 3) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][2]) : "0"(sum[m][2]), "v" (bigA[m][k2].f[b]), "v" (bigB2[k2].f[b])); -#endif -#if (YTILE >= 4) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][3]) : "0"(sum[m][3]), "v" (bigA[m][k2].f[b]), "v" (bigB3[k2].f[b])); -#endif -#if (YTILE >= 5) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][4]) : "0"(sum[m][4]), "v" (bigA[m][k2].f[b]), "v" (bigB4[k2].f[b])); -#endif -#if (YTILE >= 6) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][5]) : "0"(sum[m][5]), "v" (bigA[m][k2].f[b]), "v" (bigB5[k2].f[b])); -#endif -#if (YTILE >= 7) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][6]) : "0"(sum[m][6]), "v" (bigA[m][k2].f[b]), "v" (bigB6[k2].f[b])); -#endif -#if (YTILE >= 8) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][7]) : "0"(sum[m][7]), "v" (bigA[m][k2].f[b]), "v" (bigB7[k2].f[b])); -#endif -#if (YTILE >= 9) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][8]) : "0"(sum[m][8]), "v" (bigA[m][k2].f[b]), "v" (bigB8[k2].f[b])); -#endif -#if (YTILE >= 10) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][9]) : "0"(sum[m][9]), "v" (bigA[m][k2].f[b]), "v" (bigB9[k2].f[b])); -#endif -#if (YTILE >= 11) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][10]) : "0"(sum[m][10]), "v" (bigA[m][k2].f[b]), "v" (bigB10[k2].f[b])); -#endif - } - } - } - } - - //---------------------------------------------------- - // Final reduction step using shuffle - //---------------------------------------------------- - for (int m = 0; m < M; m++) - { - for (int y = 0; y < YTILE; y++) - { - //for (int offset = 64 / 2; offset > 4 ; offset /= 2) { - // sum[y] += __shfl_down(sum[y], offset); - //} - sum[m][y] += __shfl_down(sum[m][y], 32); - sum[m][y] += __shfl_down(sum[m][y], 16); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - - } - } - - if (threadIdx.x == 0) - { - for (int m = 0; m < M; m++) - { - for (int i = 0; i < YTILE; i++) - { - if (commitColumn[i]) - C[n + i + m * N] = __float2half(sum[m][i]); - } - } + } } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + // for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + // } + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } - n += CuCount * WvPrGrp * YTILE; - - //if (threadIdx.x == 0) - //n = atomicAdd(((unsigned int*)(C)), YTILE); - //n = __shfl(n, 0, 64); - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if(n < N && (n + YTILE) >= N) - { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) - { - commitColumn[i] = 0; - } - n = startColumn; + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); } + } } + n += CuCount * WvPrGrp * YTILE; + + // if (threadIdx.x == 0) + // n = atomicAdd(((unsigned int*)(C)), YTILE); + // n = __shfl(n, 0, 64); + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } } #undef YTILE @@ -1371,377 +1388,386 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, const D #define YTILE 7 #define UNRL 1 -#define M 4 - -__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { - - union bigType { - DTYPE h[A_CHUNK]; - float f[A_CHUNK / 2]; - float2 f2[A_CHUNK / 4]; - double d[A_CHUNK / 4]; - __int128_t b128; - half8 h8; - }; +#define M 4 + +__global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } - //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU - // Goal is to bring the activation matrix A to the LDS - // and use it across the lifetime of the work group - // TODO: When activation matrix is larger than 64 KB - // then this is not goint to work! - //---------------------------------------------------- - __shared__ half s[1024 * 32]; - - //---------------------------------------------------- - // Computation of columns that need to be committed to memory! - //---------------------------------------------------- - uint32_t commitColumn[YTILE]; - for (uint32_t i = 0; i < YTILE; i++) - { - commitColumn[i] = 1; + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; } + n = startColumn; + } - //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. - //---------------------------------------------------- - uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if(n < N && (n + YTILE) >= N) - { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) - { - commitColumn[i] = 0; - } - n = startColumn; - } + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of availble columns + //---------------------------------------------------- + while (n < N) { //---------------------------------------------------- - // Fetch the activation matrix to LDS - // Loop iteration: - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements - // - Each WG will fetch 512 * 16 => 8K elements - // - Then the WG will move to another 8 K elements - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k = 0; k < min(K * M, 32*1024); k += THRDS * WvPrGrp * A_CHUNK) - { - uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - // Transpose of A implementation - //uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for bank-conflict-free readback - - if (k_in >= min(K * M, 32*1024)) break; - - ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; - //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; - } - __syncthreads(); - - float sum[M][YTILE]; - - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. + // 'sum' accumulates the matrix A x B computation + // splitted across 64 lanes. // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of availble columns + // YTILE represents how many column of weight matrix + // are being worked on by each wave. //---------------------------------------------------- - while (n < N) - { - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // splitted across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int m=0; m= 2) - bigType bigB1[UNRL]; + bigType bigB1[UNRL]; #endif #if (YTILE >= 3) - bigType bigB2[UNRL]; + bigType bigB2[UNRL]; #endif #if (YTILE >= 4) - bigType bigB3[UNRL]; + bigType bigB3[UNRL]; #endif #if (YTILE >= 5) - bigType bigB4[UNRL]; + bigType bigB4[UNRL]; #endif #if (YTILE >= 6) - bigType bigB5[UNRL]; + bigType bigB5[UNRL]; #endif #if (YTILE >= 7) - bigType bigB6[UNRL]; + bigType bigB6[UNRL]; #endif #if (YTILE >= 8) - bigType bigB7[UNRL]; + bigType bigB7[UNRL]; #endif #if (YTILE >= 9) - bigType bigB8[UNRL]; + bigType bigB8[UNRL]; #endif #if (YTILE >= 10) - bigType bigB9[UNRL]; + bigType bigB9[UNRL]; #endif #if (YTILE >= 11) - bigType bigB10[UNRL]; + bigType bigB10[UNRL]; #endif - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) - { - // Fetch the weight matrix from memory! -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) - { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - //if (k_ >= K) break; - //bool skip = (k_ >= K); - //bool dummy = (k_ >= K); - - const half* B_ = &B[(n + 0) * K + k_]; - bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); - //---------------------------------------------------- - // The following code with YTILE > 1 has to be deleted - //---------------------------------------------------- -#if (YTILE >= 2) - //if (n+1>=N) continue; - bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); -#endif -#if (YTILE >= 3) - //if (n+2>=N) continue; - bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); -#endif -#if (YTILE >= 4) - //if (n+3>=N) continue; - bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); -#endif -#if (YTILE >= 5) - //if (n+4>=N) continue; - bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); -#endif -#if (YTILE >= 6) - //if (n+5>=N) continue; - bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); -#endif -#if (YTILE >= 7) - //if (n+6>=N) continue; - bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); -#endif -#if (YTILE >= 8) - //if (n+7>=N) continue; - bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); -#endif -/* -#if (YTILE >= 9) - if (n+8>=N) continue; bigB8[k2].h8 = (loadnt((half8*)(&B_[8 * K]))); -#endif -#if (YTILE >= 10) - if (n+9>=N) continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); -#endif -#if (YTILE >= 11) - if (n+10>=N) continue; bigB10[k2].h8 = (loadnt((half8*)(&B_[10 * K]))); -#endif -*/ - } - - // Fetch activation matrix from either just LDS or from both LDS / memory -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) - { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Fetch A activation matrix in interleaved fashion from LDS or memory - - for (int m=0; m < M; m++) - { - if (k_+K*m < 32*1024) - bigA[m][k2] = *((const bigType*)(&(s[k_+K*m]))); - else - bigA[m][k2] = *((const bigType*)(&(A[k_+K*m]))); - } - } - - // Do the matrix multiplication in interleaved manner -#pragma unroll - for (uint32_t m = 0; m < M; m++) - { -#pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) - { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - // Do the matrix multiplication of activation and weight matrix - // - Rememeber the accumulation is happening for K-split of 64! -#pragma unroll - for (uint32_t b = 0; b < A_CHUNK / 2; b++) - { - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][0]) : "0"(sum[m][0]), "v" (bigA[m][k2].f[b]), "v" (bigB0[k2].f[b])); - - //---------------------------------------------------- - // The following code with YTILE > 1 - //---------------------------------------------------- -#if (YTILE >= 2) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][1]) : "0"(sum[m][1]), "v" (bigA[m][k2].f[b]), "v" (bigB1[k2].f[b])); -#endif -#if (YTILE >= 3) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][2]) : "0"(sum[m][2]), "v" (bigA[m][k2].f[b]), "v" (bigB2[k2].f[b])); -#endif -#if (YTILE >= 4) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][3]) : "0"(sum[m][3]), "v" (bigA[m][k2].f[b]), "v" (bigB3[k2].f[b])); -#endif -#if (YTILE >= 5) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][4]) : "0"(sum[m][4]), "v" (bigA[m][k2].f[b]), "v" (bigB4[k2].f[b])); -#endif -#if (YTILE >= 6) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][5]) : "0"(sum[m][5]), "v" (bigA[m][k2].f[b]), "v" (bigB5[k2].f[b])); -#endif -#if (YTILE >= 7) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][6]) : "0"(sum[m][6]), "v" (bigA[m][k2].f[b]), "v" (bigB6[k2].f[b])); -#endif -#if (YTILE >= 8) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][7]) : "0"(sum[m][7]), "v" (bigA[m][k2].f[b]), "v" (bigB7[k2].f[b])); -#endif -#if (YTILE >= 9) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][8]) : "0"(sum[m][8]), "v" (bigA[m][k2].f[b]), "v" (bigB8[k2].f[b])); -#endif -#if (YTILE >= 10) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][9]) : "0"(sum[m][9]), "v" (bigA[m][k2].f[b]), "v" (bigB9[k2].f[b])); -#endif -#if (YTILE >= 11) - asm ("v_dot2c_f32_f16 %0, %2, %3" : "=v" (sum[m][10]) : "0"(sum[m][10]), "v" (bigA[m][k2].f[b]), "v" (bigB10[k2].f[b])); -#endif - } - } - } - } + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // if (k_ >= K) break; + // bool skip = (k_ >= K); + // bool dummy = (k_ >= K); + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); //---------------------------------------------------- - // Final reduction step using shuffle + // The following code with YTILE > 1 has to be deleted //---------------------------------------------------- - for (int m = 0; m < M; m++) - { - for (int y = 0; y < YTILE; y++) - { - //for (int offset = 64 / 2; offset > 4 ; offset /= 2) { - // sum[y] += __shfl_down(sum[y], offset); - //} - sum[m][y] += __shfl_down(sum[m][y], 32); - sum[m][y] += __shfl_down(sum[m][y], 16); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" : "=v" (sum[m][y]) : "0"(sum[m][y]), "v" (sum[m][y]), "v" (sum[m][y])); - - } - } - - if (threadIdx.x == 0) - { - for (int m = 0; m < M; m++) - { - for (int i = 0; i < YTILE; i++) - { - if (commitColumn[i]) - C[n + i + m * N] = __float2half(sum[m][i]); - } - } +#if (YTILE >= 2) + // if (n+1>=N) continue; + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif +#if (YTILE >= 3) + // if (n+2>=N) continue; + bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); +#endif +#if (YTILE >= 4) + // if (n+3>=N) continue; + bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); +#endif +#if (YTILE >= 5) + // if (n+4>=N) continue; + bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); +#endif +#if (YTILE >= 6) + // if (n+5>=N) continue; + bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); +#endif +#if (YTILE >= 7) + // if (n+6>=N) continue; + bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); +#endif +#if (YTILE >= 8) + // if (n+7>=N) continue; + bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); +#endif + /* + #if (YTILE >= 9) + if (n+8>=N) continue; bigB8[k2].h8 = + (loadnt((half8*)(&B_[8 * K]))); #endif #if (YTILE >= 10) if (n+9>=N) + continue; bigB9[k2].h8 = (loadnt((half8*)(&B_[9 * K]))); #endif #if + (YTILE >= 11) if (n+10>=N) continue; bigB10[k2].h8 = + (loadnt((half8*)(&B_[10 * K]))); #endif + */ + } + + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); } + } - n += CuCount * WvPrGrp * YTILE; - - //if (threadIdx.x == 0) - //n = atomicAdd(((unsigned int*)(C)), YTILE); - //n = __shfl(n, 0, 64); - - // Check whether there will be fragmenation! - // This will happen only for the last wave! - if(n < N && (n + YTILE) >= N) - { - uint32_t startColumn = N - YTILE; - for (uint32_t i = 0; i < (n - startColumn); i++) - { - commitColumn[i] = 0; - } - n = startColumn; + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t m = 0; m < M; m++) { +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); +#endif +#if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); +#endif +#if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); +#endif +#if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); +#endif +#if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); +#endif +#if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); +#endif +#if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); +#endif +#if (YTILE >= 9) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][8]) + : "0"(sum[m][8]), "v"(bigA[m][k2].f[b]), "v"(bigB8[k2].f[b])); +#endif +#if (YTILE >= 10) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][9]) + : "0"(sum[m][9]), "v"(bigA[m][k2].f[b]), "v"(bigB9[k2].f[b])); +#endif +#if (YTILE >= 11) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][10]) + : "0"(sum[m][10]), "v"(bigA[m][k2].f[b]), "v"(bigB10[k2].f[b])); +#endif + } } + } } + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + // for (int offset = 64 / 2; offset > 4 ; offset /= 2) { + // sum[y] += __shfl_down(sum[y], offset); + // } + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } -} + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + n += CuCount * WvPrGrp * YTILE; + // if (threadIdx.x == 0) + // n = atomicAdd(((unsigned int*)(C)), YTILE); + // n = __shfl(n, 0, 64); -void wvSpltK_(void *in_a, void *in_b, void *out_c, const int M_in, const int K_in,const int N_in, cudaStream_t stream, const int CuCount = 0) { - dim3 grid(CuCount); - dim3 block(THRDS, WvPrGrp); - half* af4 = reinterpret_cast(in_a); - const half* bf4 = reinterpret_cast(in_b); - auto *c = reinterpret_cast(out_c); - switch(N_in) { - case 1: - wvSpltK_hf_m1_<<>>(K_in, M_in, af4, bf4, c, CuCount); - break; - case 2: - wvSpltK_hf_m2_<<>>(K_in, M_in, af4, bf4, c, CuCount); - break; - case 3: - wvSpltK_hf_m3_<<>>(K_in, M_in, af4, bf4, c, CuCount); - break; - case 4: - wvSpltK_hf_m4_<<>>(K_in, M_in, af4, bf4, c, CuCount); - break; - default: - throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + "," + std::to_string(K_in) + "," + std::to_string(N_in)); - } - - cudaError_t err = cudaGetLastError(); - if (cudaSuccess != err) { - throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; } + } } +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, + const int K_in, const int N_in, cudaStream_t stream, + const int CuCount = 0) { + dim3 grid(CuCount); + dim3 block(THRDS, WvPrGrp); + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); + switch (N_in) { + case 1: + wvSpltK_hf_m1_<<>>(K_in, M_in, af4, bf4, c, + CuCount); + break; + case 2: + wvSpltK_hf_m2_<<>>(K_in, M_in, af4, bf4, c, + CuCount); + break; + case 3: + wvSpltK_hf_m3_<<>>(K_in, M_in, af4, bf4, c, + CuCount); + break; + case 4: + wvSpltK_hf_m4_<<>>(K_in, M_in, af4, bf4, c, + CuCount); + break; + default: + throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + + "," + std::to_string(K_in) + "," + + std::to_string(N_in)); + } - + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); + } +} diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 6c78430adf4d0..5e1f4f35d3d87 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -70,10 +70,10 @@ def mm(self, inp, weights): k = inp_view.shape[1] soltype, solidx = self.query_sol(m=m, n=n, k=k) if soltype == 1: - print(">>> found hipblas") + #print(">>> found hipblas") out = hipb_mm(inp_view, weights.t(), solidx) elif soltype == 2: - print(">>> found rocblas") + #print(">>> found rocblas") out = rocb_mm(inp_view, weights.t(), solidx) else: if (self.save_gemm == 1): From 05c96a1c669d0bc2126c2a50c08b0c3b38f68d21 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 17 Jun 2024 21:40:18 +0000 Subject: [PATCH 03/12] more style adjustments --- csrc/custom/custom.cu | 12 +++++------- vllm/model_executor/layers/tuned_gemm.py | 6 ++++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/csrc/custom/custom.cu b/csrc/custom/custom.cu index 5c8beed37b304..9e92187967d47 100644 --- a/csrc/custom/custom.cu +++ b/csrc/custom/custom.cu @@ -39,20 +39,18 @@ void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, at::cuda::getCurrentCUDAStream(), rows_per_block); } -void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K, const int N, - cudaStream_t stream, const int CuCount); +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K, + const int N, cudaStream_t stream, const int CuCount); -void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, - const int N_in, const int CuCount) { +void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int N_in, + const int CuCount) { int M = in_a.size(0); int K = in_a.size(1); int N = N_in; wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, - at::cuda::getCurrentCUDAStream(), CuCount); + at::cuda::getCurrentCUDAStream(), CuCount); } - - void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, cudaStream_t stream, const int solidx); diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 5e1f4f35d3d87..8b9b20722fb9e 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -23,7 +23,8 @@ def __init__(self): self.bestsols = {} self.load_best_sols() self.create_ds() - self.CuCount = torch.cuda.get_device_properties(device='cuda').multi_processor_count + self.CuCount = torch.cuda.get_device_properties( + device='cuda').multi_processor_count if (self.save_gemm == 1): self.tuned_df = pd.DataFrame(columns=['M', 'N', 'K']) @@ -89,7 +90,8 @@ def mm(self, inp, weights): ]).drop_duplicates() self.tuned_df.to_csv(self.untune_path, index=False) - if ((n == 4 or n == 3 or n== 2 or n == 1 ) and inp_view.dtype == torch.float16) : + if ((n == 4 or n == 3 or n== 2 or n == 1) + and inp_view.dtype == torch.float16): out = torch.empty(inp_view.shape[0], weights.shape[0], dtype=inp_view.dtype, From b54e8f91656ef53361cc95b103bbd53a226318f7 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 17 Jun 2024 21:52:55 +0000 Subject: [PATCH 04/12] more style adjustments (2) --- vllm/model_executor/layers/tuned_gemm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 8b9b20722fb9e..eb3811ad5dbfc 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -91,11 +91,11 @@ def mm(self, inp, weights): self.tuned_df.to_csv(self.untune_path, index=False) if ((n == 4 or n == 3 or n== 2 or n == 1) - and inp_view.dtype == torch.float16): + and inp_view.dtype == torch.float16): out = torch.empty(inp_view.shape[0], - weights.shape[0], - dtype=inp_view.dtype, - device='cuda') + weights.shape[0], + dtype=inp_view.dtype, + device='cuda') _custom_C.wvSpltK(weights, inp_view, out, n, self.CuCount) elif n == 1 and inp_view.dtype == torch.float16: out = torch.empty(inp_view.shape[0], From 610a87b3a10e9bba20bd5945ad52a6d0ae32580d Mon Sep 17 00:00:00 2001 From: root Date: Mon, 17 Jun 2024 22:00:47 +0000 Subject: [PATCH 05/12] adds missing condition --- vllm/model_executor/layers/tuned_gemm.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index eb3811ad5dbfc..427a0336f4645 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -90,7 +90,8 @@ def mm(self, inp, weights): ]).drop_duplicates() self.tuned_df.to_csv(self.untune_path, index=False) - if ((n == 4 or n == 3 or n== 2 or n == 1) + if ((n == 4 or n == 3 or n == 2 or n == 1) + and k % 8 == 0 and inp_view.dtype == torch.float16): out = torch.empty(inp_view.shape[0], weights.shape[0], From 6979b3f2ead399ebf42b8abc4e526bc5e372df82 Mon Sep 17 00:00:00 2001 From: root Date: Mon, 17 Jun 2024 22:26:04 +0000 Subject: [PATCH 06/12] another sytle adjustment (3) --- vllm/model_executor/layers/tuned_gemm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/layers/tuned_gemm.py b/vllm/model_executor/layers/tuned_gemm.py index 427a0336f4645..3ecacafd31977 100644 --- a/vllm/model_executor/layers/tuned_gemm.py +++ b/vllm/model_executor/layers/tuned_gemm.py @@ -90,8 +90,7 @@ def mm(self, inp, weights): ]).drop_duplicates() self.tuned_df.to_csv(self.untune_path, index=False) - if ((n == 4 or n == 3 or n == 2 or n == 1) - and k % 8 == 0 + if ((n == 4 or n == 3 or n == 2 or n == 1) and k % 8 == 0 and inp_view.dtype == torch.float16): out = torch.empty(inp_view.shape[0], weights.shape[0], From 8ba4e751e65d5768b038324ac46b61b5a17b691d Mon Sep 17 00:00:00 2001 From: root Date: Tue, 18 Jun 2024 03:37:41 +0000 Subject: [PATCH 07/12] Perf Optimization for smaller Gemms --- csrc/custom/custom_kernels.cu | 189 ++++++++++++++++++++++++++++------ 1 file changed, 160 insertions(+), 29 deletions(-) diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index 5c37701189098..8771908537279 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -331,6 +331,131 @@ __device__ __forceinline__ T loadnt(T* addr) { #define M 1 #define DTYPE half +__global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + __int128_t b128; + half8 h8; + }; + + __shared__ half s[1024 * 32]; + + uint32_t n = (blockIdx.x * WvPrGrp + threadIdx.y) * YTILE; + + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + ((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + float sum[M][YTILE]; + + while (n < N) { + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; +#if (YTILE >= 2) + bigType bigB1[UNRL]; +#endif + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); +#if (YTILE >= 2) + bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); +#endif + } + // Fetch activation matrix from either just LDS or from both LDS / memory +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + for (int m = 0; m < M; m++) { + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner +#pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; +#pragma unroll + for (uint32_t m = 0; m < M; m++) { + + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! +#pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- +#if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); +#endif + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + sum[m][y] += __shfl_down(sum[m][y], 32); + sum[m][y] += __shfl_down(sum[m][y], 16); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shl:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 0) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * WvPrGrp * YTILE; + } +} + + + __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { @@ -552,15 +677,15 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, // Do the matrix multiplication in interleaved manner #pragma unroll - for (uint32_t m = 0; m < M; m++) { + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; + for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix - // - Rememeber the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" @@ -905,15 +1030,15 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, // Do the matrix multiplication in interleaved manner #pragma unroll - for (uint32_t m = 0; m < M; m++) { + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; + for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix - // - Rememeber the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" @@ -1258,15 +1383,15 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, // Do the matrix multiplication in interleaved manner #pragma unroll - for (uint32_t m = 0; m < M; m++) { + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; + for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix - // - Rememeber the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" @@ -1611,15 +1736,15 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, // Do the matrix multiplication in interleaved manner #pragma unroll - for (uint32_t m = 0; m < M; m++) { + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; #pragma unroll - for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; + for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix - // - Rememeber the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Rememeber the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" @@ -1745,8 +1870,14 @@ void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, auto* c = reinterpret_cast(out_c); switch (N_in) { case 1: - wvSpltK_hf_m1_<<>>(K_in, M_in, af4, bf4, c, + if ((K_in <= 32*1024) && (M_in % 2 == 0)) { + wvSpltK_hf_m1_sml_<<>>(K_in, M_in, af4, bf4, + c, CuCount); + } + else { + wvSpltK_hf_m1_<<>>(K_in, M_in, af4, bf4, c, CuCount); + } break; case 2: wvSpltK_hf_m2_<<>>(K_in, M_in, af4, bf4, c, From 37fd8aeb27b40a268bc81182a7149f6b4b615b2d Mon Sep 17 00:00:00 2001 From: root Date: Tue, 18 Jun 2024 20:12:07 +0000 Subject: [PATCH 08/12] Spelling errors corrected --- csrc/custom/custom_kernels.cu | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index 8771908537279..c6329e929694a 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -3,8 +3,6 @@ #include #include #include -#include "hsa/hsa.h" -#include "hsa/hsa_ext_amd.h" constexpr int WARP_SIZE = 64; @@ -400,7 +398,7 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, for (uint32_t m = 0; m < M; m++) { // Do the matrix multiplication of activation and weight matrix - // - Rememeber the accumulation is happening for K-split of 64! + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" @@ -541,12 +539,12 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, // - WGs iterates across columns of weight matrix // - Each wave within WG works on a given column(s) // - After completing first set of columns, WGs start - // working on the next set of availble columns + // working on the next set of available columns //---------------------------------------------------- while (n < N) { //---------------------------------------------------- // 'sum' accumulates the matrix A x B computation - // splitted across 64 lanes. + // split across 64 lanes. // // YTILE represents how many column of weight matrix // are being worked on by each wave. @@ -685,7 +683,7 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, for (uint32_t m = 0; m < M; m++) { // Do the matrix multiplication of activation and weight matrix - // - Rememeber the accumulation is happening for K-split of 64! + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" @@ -894,12 +892,12 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, // - WGs iterates across columns of weight matrix // - Each wave within WG works on a given column(s) // - After completing first set of columns, WGs start - // working on the next set of availble columns + // working on the next set of available columns //---------------------------------------------------- while (n < N) { //---------------------------------------------------- // 'sum' accumulates the matrix A x B computation - // splitted across 64 lanes. + // split across 64 lanes. // // YTILE represents how many column of weight matrix // are being worked on by each wave. @@ -1038,7 +1036,7 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, for (uint32_t m = 0; m < M; m++) { // Do the matrix multiplication of activation and weight matrix - // - Rememeber the accumulation is happening for K-split of 64! + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" @@ -1247,12 +1245,12 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, // - WGs iterates across columns of weight matrix // - Each wave within WG works on a given column(s) // - After completing first set of columns, WGs start - // working on the next set of availble columns + // working on the next set of available columns //---------------------------------------------------- while (n < N) { //---------------------------------------------------- // 'sum' accumulates the matrix A x B computation - // splitted across 64 lanes. + // split across 64 lanes. // // YTILE represents how many column of weight matrix // are being worked on by each wave. @@ -1391,7 +1389,7 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, for (uint32_t m = 0; m < M; m++) { // Do the matrix multiplication of activation and weight matrix - // - Rememeber the accumulation is happening for K-split of 64! + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" @@ -1600,12 +1598,12 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, // - WGs iterates across columns of weight matrix // - Each wave within WG works on a given column(s) // - After completing first set of columns, WGs start - // working on the next set of availble columns + // working on the next set of available columns //---------------------------------------------------- while (n < N) { //---------------------------------------------------- // 'sum' accumulates the matrix A x B computation - // splitted across 64 lanes. + // split across 64 lanes. // // YTILE represents how many column of weight matrix // are being worked on by each wave. @@ -1744,7 +1742,7 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, for (uint32_t m = 0; m < M; m++) { // Do the matrix multiplication of activation and weight matrix - // - Rememeber the accumulation is happening for K-split of 64! + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" From 63c7d7ef7c3a97cd4c78799072a3f6c9abb17561 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 18 Jun 2024 21:42:55 +0000 Subject: [PATCH 09/12] remove bf16 include, not used here --- csrc/custom/custom_kernels.cu | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index c6329e929694a..8154431234378 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -2,7 +2,6 @@ #include #include #include -#include constexpr int WARP_SIZE = 64; From 3de25e936b8bbcd935a3fba76af4f59fc23060f0 Mon Sep 17 00:00:00 2001 From: root Date: Tue, 18 Jun 2024 22:02:19 +0000 Subject: [PATCH 10/12] line adjustments for lint error --- csrc/custom/custom_kernels.cu | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index 8154431234378..c51ae332af8e0 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -329,8 +329,8 @@ __device__ __forceinline__ T loadnt(T* addr) { #define DTYPE half __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, - const DTYPE* __restrict__ A, DTYPE* C, - const int CuCount) { + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { union bigType { DTYPE h[A_CHUNK]; float f[A_CHUNK / 2]; @@ -395,9 +395,8 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, if (k_ >= K) break; #pragma unroll for (uint32_t m = 0; m < M; m++) { - - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { asm("v_dot2c_f32_f16 %0, %2, %3" @@ -680,7 +679,6 @@ __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, if (k_ >= K) break; #pragma unroll for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! #pragma unroll @@ -1033,7 +1031,6 @@ __global__ void wvSpltK_hf_m2_(const int K, const int N, const DTYPE* B, if (k_ >= K) break; #pragma unroll for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! #pragma unroll @@ -1386,7 +1383,6 @@ __global__ void wvSpltK_hf_m3_(const int K, const int N, const DTYPE* B, if (k_ >= K) break; #pragma unroll for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! #pragma unroll @@ -1739,7 +1735,6 @@ __global__ void wvSpltK_hf_m4_(const int K, const int N, const DTYPE* B, if (k_ >= K) break; #pragma unroll for (uint32_t m = 0; m < M; m++) { - // Do the matrix multiplication of activation and weight matrix // - Remember the accumulation is happening for K-split of 64! #pragma unroll From 674807c3080bc2b9eea44dac128b86c3f70f644d Mon Sep 17 00:00:00 2001 From: root Date: Tue, 18 Jun 2024 22:11:20 +0000 Subject: [PATCH 11/12] more lint line adjustments --- csrc/custom/custom_kernels.cu | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index c51ae332af8e0..5155a6150a19e 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -1862,13 +1862,12 @@ void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, auto* c = reinterpret_cast(out_c); switch (N_in) { case 1: - if ((K_in <= 32*1024) && (M_in % 2 == 0)) { - wvSpltK_hf_m1_sml_<<>>(K_in, M_in, af4, bf4, - c, CuCount); - } - else { - wvSpltK_hf_m1_<<>>(K_in, M_in, af4, bf4, c, - CuCount); + if ((K_in <= 32 * 1024) && (M_in % 2 == 0)) { + wvSpltK_hf_m1_sml_<<>>(K_in, M_in, af4, bf4, + c, CuCount); + } else { + wvSpltK_hf_m1_<<>>(K_in, M_in, af4, bf4, c, + CuCount); } break; case 2: From 7af9c414749ad1c0af562550cfbac2d0840580da Mon Sep 17 00:00:00 2001 From: root Date: Tue, 18 Jun 2024 22:45:22 +0000 Subject: [PATCH 12/12] ialignment fixes --- csrc/custom/custom_kernels.cu | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/csrc/custom/custom_kernels.cu b/csrc/custom/custom_kernels.cu index 5155a6150a19e..2c4698533332e 100644 --- a/csrc/custom/custom_kernels.cu +++ b/csrc/custom/custom_kernels.cu @@ -383,7 +383,7 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, // Fetch A activation matrix in interleaved fashion from LDS or memory for (int m = 0; m < M; m++) { - bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); } } @@ -450,8 +450,6 @@ __global__ void wvSpltK_hf_m1_sml_(const int K, const int N, const DTYPE* B, } } - - __global__ void wvSpltK_hf_m1_(const int K, const int N, const DTYPE* B, const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { @@ -1863,8 +1861,8 @@ void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, switch (N_in) { case 1: if ((K_in <= 32 * 1024) && (M_in % 2 == 0)) { - wvSpltK_hf_m1_sml_<<>>(K_in, M_in, af4, bf4, - c, CuCount); + wvSpltK_hf_m1_sml_<<>>(K_in, M_in, af4, bf4, c, + CuCount); } else { wvSpltK_hf_m1_<<>>(K_in, M_in, af4, bf4, c, CuCount);