Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds wvSpltK optimization for skinny gemm. #54

Merged
merged 12 commits into from
Jun 18, 2024
13 changes: 13 additions & 0 deletions csrc/custom/custom.cu
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,18 @@ void LLMM1(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c,
at::cuda::getCurrentCUDAStream(), rows_per_block);
}

void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K,
const int N, cudaStream_t stream, const int CuCount);

void wvSpltK(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, const int N_in,
const int CuCount) {
int M = in_a.size(0);
int K = in_a.size(1);
int N = N_in;
wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N,
at::cuda::getCurrentCUDAStream(), CuCount);
}

void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K,
cudaStream_t stream, const int solidx);

Expand Down Expand Up @@ -90,5 +102,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("LLZZ", &LLZZ);
m.def("paged_attention_custom", &paged_attention_custom,
"PagedAttention LL4Mi Custom.");
m.def("wvSpltK", &wvSpltK);
// m.def("MMCustomGPU", &MMCustomGPU);
}
Loading
Loading