Skip to content

Commit

Permalink
move the cuda kernel into this repo from https://github.com/qwopqwop2…
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Apr 17, 2023
1 parent fe6d135 commit d553bf3
Show file tree
Hide file tree
Showing 3 changed files with 599 additions and 1 deletion.
21 changes: 20 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from setuptools import setup, find_packages
import sys
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

install_requires = []
with open("./requirements.txt", "r") as requirements_file:
Expand All @@ -7,6 +9,21 @@
for r in reqs:
install_requires.append(r)

quant_cuda_module = CUDAExtension(
'alpaca_lora_4bit.quant_cuda',
sources=[
'src/alpaca_lora_4bit/quant_cuda/quant_cuda.cpp',
'src/alpaca_lora_4bit/quant_cuda/quant_cuda_kernel.cu'
])

# conditionally only install the cuda extension explicitly
ext_modules = []
cmdclass = {}
if '--cuda' in sys.argv or any(["cuda" in arg for arg in sys.argv]):
ext_modules.append(quant_cuda_module)
cmdclass = {'build_ext': BuildExtension}
sys.argv.remove('--cuda')


setup(
name='alpaca_lora_4bit',
Expand All @@ -19,4 +36,6 @@
'cuda': 'gptq_llama @ git+https://github.com/sterlind/GPTQ-for-LLaMa.git@lora_4bit',
'triton': 'triton',
},
ext_modules=ext_modules,
cmdclass=cmdclass,
)
70 changes: 70 additions & 0 deletions src/alpaca_lora_4bit/quant_cuda/quant_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#include <torch/all.h>
#include <torch/python.h>
#include <c10/cuda/CUDAGuard.h>

void vecquant2matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);

void vecquant2matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant2matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}

void vecquant3matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);

void vecquant3matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant3matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}

void vecquant4matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);

void vecquant4matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant4matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}

void vecquant8matmul_cuda(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
);

void vecquant8matmul(
torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor scales, torch::Tensor zeros,
torch::Tensor g_idx
) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
vecquant8matmul_cuda(vec, mat, mul, scales, zeros, g_idx);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("vecquant2matmul", &vecquant2matmul, "Vector 2-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant3matmul", &vecquant3matmul, "Vector 3-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant4matmul", &vecquant4matmul, "Vector 4-bit Quantized Matrix Multiplication (CUDA)");
m.def("vecquant8matmul", &vecquant8matmul, "Vector 8-bit Quantized Matrix Multiplication (CUDA)");
}
Loading

0 comments on commit d553bf3

Please sign in to comment.