Skip to content

Commit

Permalink
adds wvSpltK optimization for skinny gemm. (#54)
Browse files Browse the repository at this point in the history
* adds wvSpltK optimization for skinny gemm.


---------

Co-authored-by: Hashem Hashemi <[email protected]>
  • Loading branch information
amd-hhashemi and Hashem Hashemi authored Jun 18, 2024
1 parent 2e23c13 commit 131b217
Show file tree
Hide file tree
Showing 3 changed files with 1,605 additions and 2 deletions.
13 changes: 13 additions & 0 deletions csrc/custom/custom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
at::cuda::getCurrentCUDAStream(), rows_per_block);
}

void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K,
const int N, cudaStream_t stream, const int CuCount);

void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int N_in,
const int CuCount) {
int M = in_a.size(0);
int K = in_a.size(1);
int N = N_in;
wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N,
at::cuda::getCurrentCUDAStream(), CuCount);
}

void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int solidx);

Expand Down Expand Up @@ -90,5 +102,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("LLZZ", &LLZZ);
m.def("paged_attention_custom", &paged_attention_custom,
"PagedAttention LL4Mi Custom.");
m.def("wvSpltK", &wvSpltK);
// m.def("MMCustomGPU", &MMCustomGPU);
}
Loading

0 comments on commit 131b217

Please sign in to comment.