diff --git a/.gitmodules b/.gitmodules deleted file mode 100644 index 3a4d360..0000000 --- a/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "third-party/quip"] - path = third-party/quip - url = https://github.com/oseyosey/quip-sharp.git diff --git a/llmtools/engine/inference/cuda/quant_cuda.cpp b/kernels/quant_cuda/quant_cuda.cpp similarity index 100% rename from llmtools/engine/inference/cuda/quant_cuda.cpp rename to kernels/quant_cuda/quant_cuda.cpp diff --git a/llmtools/engine/inference/cuda/quant_cuda_kernel.cu b/kernels/quant_cuda/quant_cuda_kernel.cu similarity index 100% rename from llmtools/engine/inference/cuda/quant_cuda_kernel.cu rename to kernels/quant_cuda/quant_cuda_kernel.cu diff --git a/kernels/quiptools_cuda/quiptools.cu b/kernels/quiptools_cuda/quiptools.cu new file mode 100644 index 0000000..f186087 --- /dev/null +++ b/kernels/quiptools_cuda/quiptools.cu @@ -0,0 +1,675 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace torch::indexing; +using namespace nvcuda; + +#define FULL_MASK 0xffffffff +#define HALF_MASK 0x0000ffff + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) do { CHECK_CUDA(x); CHECK_CONTIGUOUS(x); } while(false) +#define gpuErrchk(ans) do { gpuAssert((ans), __FILE__, __LINE__); } while (false) + + +__host__ static inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) +{ + if (code != cudaSuccess) + { + fprintf(stderr, "GPUassert[%s:%d]: %s\n", file, line, cudaGetErrorString(code)); + if (abort) exit(code); + } +} + + + +__global__ void cuda_lookupmatmul_d4_k8_kernel( + const c10::Half* __restrict__ X, // k x n + const uint8_t* __restrict__ YIs, // m x (n/4) + const c10::Half* __restrict__ CB, // 256 x 4 + c10::Half* __restrict__ Z, // k x m + size_t K, + size_t M, + size_t N) { + + long m1 = blockIdx.x; + long k1 = blockIdx.y; + + __shared__ c10::Half Y_cache[32*16]; + + wmma::fragment a; // 8 x 16 + wmma::fragment b; // 32 x 16 + wmma::fragment c; // 8 x 32 + fill_fragment(c, __float2half(0.0)); + + for (long jn = 0; jn < N / 16; jn++) { +# pragma unroll 4 + for (long r = 0; r < 4; r++) { + uint8_t yidxs = *(uint8_t*)(YIs + jn*(4*M) + m1*4*32 + threadIdx.x*4 + r); + ((uint64_t*)Y_cache)[threadIdx.x*4 + r] = ((uint64_t*)CB)[(yidxs & 255)]; + } + load_matrix_sync(a, (const __half*)(X + 8*N*k1 + 16*jn), N); + load_matrix_sync(b, (const __half*)Y_cache, 16); + mma_sync(c, a, b, c); + } + + store_matrix_sync((__half*)(&Z[8*M*k1 + 32*m1]), c, M, wmma::mem_row_major); +} + + +void lookupmatmul_d4_k8( + torch::Tensor X, // k x n + torch::Tensor YIs, // m x (n/4) + torch::Tensor CB, // 256 x 4 + torch::Tensor Z // k x m +) { + auto k = X.sizes()[0]; + auto m = YIs.sizes()[0]; + auto n = X.sizes()[1]; + + assert(X.dtype() == torch::kFloat16); + assert(YIs.dtype() == torch::kUInt8); + assert(CB.dtype() == torch::kFloat16); + assert(Z.dtype() == torch::kFloat16); + + assert(Z.sizes()[0] == k); + assert(YIs.sizes()[1] * 4 == n); + assert(Z.sizes()[1] == m); + + assert(k % 8 == 0); // if you want larger k, use k = 16 + assert(m % 32 == 0); + assert(n % 16 == 0); + + const dim3 threads(32); + const dim3 blocks(m/32,k/8); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + cuda_lookupmatmul_d4_k8_kernel<<>>( + X.data_ptr(), + YIs.data_ptr(), + CB.data_ptr(), + Z.data_ptr(), + k,m,n + ); +} + + + +__global__ void cuda_lookupmatmul_d4_k16_kernel( + const c10::Half* __restrict__ X, // k x n + const uint8_t* __restrict__ YIs, // m x (n/4) + const c10::Half* __restrict__ CB, // 256 x 4 + c10::Half* __restrict__ Z, // k x m + size_t K, + size_t M, + size_t N) { + + long m1 = blockIdx.x; + long k1 = blockIdx.y; + + __shared__ c10::Half Y_cache[32*16]; + + wmma::fragment a; + wmma::fragment b; + wmma::fragment c0; + fill_fragment(c0, __float2half(0.0)); + + wmma::fragment c1; + fill_fragment(c1, __float2half(0.0)); + + for (long jn = 0; jn < N / 16; jn++) { + for (long r = 0; r < 4; r++) { + uint8_t yidxs = *(uint8_t*)(YIs + jn*(4*M) + m1*4*32 + threadIdx.x*4 + r); + ((uint64_t*)Y_cache)[threadIdx.x*4 + r] = ((uint64_t*)CB)[(yidxs & 255)]; + } + + load_matrix_sync(a, (const __half*)(X + 16*N*k1 + 16*jn), N); + + load_matrix_sync(b, (const __half*)Y_cache, 16); + mma_sync(c0, a, b, c0); + + load_matrix_sync(b, (const __half*)Y_cache + 16*16, 16); + mma_sync(c1, a, b, c1); + } + + store_matrix_sync((__half*)(&Z[16*M*k1 + 32*m1 + 0]), c0, M, wmma::mem_row_major); + store_matrix_sync((__half*)(&Z[16*M*k1 + 32*m1 + 16]), c1, M, wmma::mem_row_major); +} + + +void lookupmatmul_d4_k16( + torch::Tensor X, // k x n + torch::Tensor YIs, // m x (n/4) + torch::Tensor CB, // 256 x 4 + torch::Tensor Z // k x m +) { + auto k = X.sizes()[0]; + auto m = YIs.sizes()[0]; + auto n = X.sizes()[1]; + + assert(X.dtype() == torch::kFloat16); + assert(YIs.dtype() == torch::kUInt8); + assert(CB.dtype() == torch::kFloat16); + assert(Z.dtype() == torch::kFloat16); + + assert(Z.sizes()[0] == k); + assert(YIs.sizes()[1] * 4 == n); + assert(Z.sizes()[1] == m); + + assert(k % 16 == 0); + assert(m % 32 == 0); + assert(n % 16 == 0); + + const dim3 threads(32); + const dim3 blocks(m/32,k/16); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + cuda_lookupmatmul_d4_k16_kernel<<>>( + X.data_ptr(), + YIs.data_ptr(), + CB.data_ptr(), + Z.data_ptr(), + k,m,n + ); +} + + +__global__ void cuda_lookupmatmul_d4_k32_kernel( + const c10::Half* __restrict__ X, // k x n + const uint8_t* __restrict__ YIs, // m x (n/4) + const c10::Half* __restrict__ CB, // 256 x 4 + c10::Half* __restrict__ Z, // k x m + size_t K, + size_t M, + size_t N) { + + long m1 = blockIdx.x; + long k1 = blockIdx.y; + + __shared__ c10::Half Y_cache[32*16]; + + wmma::fragment a; + wmma::fragment b; + wmma::fragment c0; + fill_fragment(c0, __float2half(0.0)); + + wmma::fragment c1; + fill_fragment(c1, __float2half(0.0)); + + wmma::fragment c2; + fill_fragment(c2, __float2half(0.0)); + + wmma::fragment c3; + fill_fragment(c3, __float2half(0.0)); + + for (long jn = 0; jn < N / 16; jn++) { + for (long r = 0; r < 4; r++) { + uint8_t yidxs = *(uint8_t*)(YIs + jn*(4*M) + m1*4*32 + threadIdx.x*4 + r); + ((uint64_t*)Y_cache)[threadIdx.x*4 + r] = ((uint64_t*)CB)[(yidxs & 255)]; + } + + load_matrix_sync(a, (const __half*)(X + 16*N*(2*k1+0) + 16*jn), N); + + load_matrix_sync(b, (const __half*)Y_cache, 16); + mma_sync(c0, a, b, c0); + + load_matrix_sync(b, (const __half*)Y_cache + 16*16, 16); + mma_sync(c1, a, b, c1); + + load_matrix_sync(a, (const __half*)(X + 16*N*(2*k1+1) + 16*jn), N); + mma_sync(c3, a, b, c3); + + load_matrix_sync(b, (const __half*)Y_cache, 16); + mma_sync(c2, a, b, c2); + } + + store_matrix_sync((__half*)(&Z[16*M*(2*k1+0) + 32*m1 + 0]), c0, M, wmma::mem_row_major); + store_matrix_sync((__half*)(&Z[16*M*(2*k1+0) + 32*m1 + 16]), c1, M, wmma::mem_row_major); + store_matrix_sync((__half*)(&Z[16*M*(2*k1+1) + 32*m1 + 0]), c2, M, wmma::mem_row_major); + store_matrix_sync((__half*)(&Z[16*M*(2*k1+1) + 32*m1 + 16]), c3, M, wmma::mem_row_major); +} + + +void lookupmatmul_d4_k32( + torch::Tensor X, // k x n + torch::Tensor YIs, // m x (n/4) + torch::Tensor CB, // 256 x 4 + torch::Tensor Z // k x m +) { + auto k = X.sizes()[0]; + auto m = YIs.sizes()[0]; + auto n = X.sizes()[1]; + + assert(X.dtype() == torch::kFloat16); + assert(YIs.dtype() == torch::kUInt8); + assert(CB.dtype() == torch::kFloat16); + assert(Z.dtype() == torch::kFloat16); + + assert(Z.sizes()[0] == k); + assert(YIs.sizes()[1] * 4 == n); + assert(Z.sizes()[1] == m); + + assert(k % 16 == 0); + assert(m % 32 == 0); + assert(n % 16 == 0); + + const dim3 threads(32); + const dim3 blocks(m/32,k/32); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + cuda_lookupmatmul_d4_k32_kernel<<>>( + X.data_ptr(), + YIs.data_ptr(), + CB.data_ptr(), + Z.data_ptr(), + k,m,n + ); +} + +#define DECOMPRESS_D4_BLOCK_SIZE 256 + +__global__ void cuda_decompress_d4_origorder_kernel( + const uint8_t* __restrict__ YIs, // m x (n/4) + const c10::Half* __restrict__ CB, // 256 x 4 + c10::Half* __restrict__ Y // m x n +) { + const long i = threadIdx.x + DECOMPRESS_D4_BLOCK_SIZE * blockIdx.x; + + for(long r = 0; r < 4; r++) { + uint8_t yidx = ((uint8_t*)YIs)[i*4 + r]; + ((uint64_t*)Y)[i*4 + r] = ((uint64_t*)CB)[yidx & 255]; + } +} + + +void decompress_d4_origorder( + torch::Tensor YIs, // m x (n/4) + torch::Tensor CB, // 256 x 4 + torch::Tensor Y // m x n +) { + size_t m = Y.sizes()[0]; + size_t n = Y.sizes()[1]; + + assert(YIs.is_contiguous()); + assert(CB.is_contiguous()); + assert(Y.is_contiguous()); + + assert(YIs.sizes()[0] == m); + assert(YIs.sizes()[1] * 4 == n); + assert(CB.sizes()[0] == 256); + assert(CB.sizes()[1] == 4); + + const dim3 threads(DECOMPRESS_D4_BLOCK_SIZE); + const dim3 blocks(m*n/(16*DECOMPRESS_D4_BLOCK_SIZE)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + cuda_decompress_d4_origorder_kernel<<>>( + YIs.data_ptr(), + CB.data_ptr(), + Y.data_ptr() + ); +} + + +__global__ void cuda_decompress_d4_kernel( + const uint8_t* __restrict__ YIs, // m x (n/4) + const c10::Half* __restrict__ CB, // 256 x 4 + c10::Half* __restrict__ Y, // m x n + size_t M, + size_t N +) { + const long i = threadIdx.x + DECOMPRESS_D4_BLOCK_SIZE * blockIdx.x; + + const long j = (i % (N/16))*M + (i / (N/16)); + + for(long r = 0; r < 4; r++) { + uint8_t yidx = ((uint8_t*)YIs)[j*4 + r]; + ((uint64_t*)Y)[i*4 + r] = ((uint64_t*)CB)[yidx & 255]; + } +} + + +void decompress_d4( + torch::Tensor YIs, // m x (n/4) + torch::Tensor CB, // 256 x 4 + torch::Tensor Y // m x n +) { + size_t m = Y.sizes()[0]; + size_t n = Y.sizes()[1]; + + assert(YIs.is_contiguous()); + assert(CB.is_contiguous()); + assert(Y.is_contiguous()); + + assert(YIs.sizes()[0] == m); + assert(YIs.sizes()[1] * 4 == n); + assert(CB.sizes()[0] == 256); + assert(CB.sizes()[1] == 4); + + const dim3 threads(DECOMPRESS_D4_BLOCK_SIZE); + const dim3 blocks(m*n/(16*DECOMPRESS_D4_BLOCK_SIZE)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + cuda_decompress_d4_kernel<<>>( + YIs.data_ptr(), + CB.data_ptr(), + Y.data_ptr(), + m,n + ); +} + + + + +// This is a terrible kernel, only use this to not call the pytorch version + +#define DECOMPRESS_HI4B1C_BLOCK_SIZE 128 + +__global__ void cuda_decompress_hi4b1c_packed_kernel( + const int32_t* __restrict__ YIs, // m x (n/8) + const c10::Half* __restrict__ CB, // 16 x 1 + c10::Half* __restrict__ Y // m x n +) { + const long i = threadIdx.x + DECOMPRESS_HI4B1C_BLOCK_SIZE * blockIdx.x; + + // 0 2 4 6 1 3 5 7 + uint32_t packed = YIs[i]; + Y[i*8 + 7] = CB[packed & 15]; + Y[i*8 + 5] = CB[(packed >> 4) & 15]; + Y[i*8 + 3] = CB[(packed >> 8) & 15]; + Y[i*8 + 1] = CB[(packed >> 12) & 15]; + Y[i*8 + 6] = CB[(packed >> 16) & 15]; + Y[i*8 + 4] = CB[(packed >> 20) & 15]; + Y[i*8 + 2] = CB[(packed >> 24) & 15]; + Y[i*8 + 0] = CB[(packed >> 28) & 15]; +} + + +void decompress_hi4b1c_packed( + torch::Tensor YIs, // m x (n/8) + torch::Tensor CB, + torch::Tensor &Y // m x n +) { + size_t m = Y.sizes()[0]; + size_t n = Y.sizes()[1]; + + assert(YIs.is_contiguous()); + assert(Y.is_contiguous()); + + assert(YIs.sizes()[0] == m); + assert(YIs.sizes()[1] * 8 == n); + + assert(CB.sizes()[0] == 16); + assert(CB.sizes()[1] == 1); + + + const dim3 threads(DECOMPRESS_HI4B1C_BLOCK_SIZE); + const dim3 blocks(m*n/(8*DECOMPRESS_HI4B1C_BLOCK_SIZE)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + cuda_decompress_hi4b1c_packed_kernel<<>>( + YIs.data_ptr(), + CB.data_ptr(), + Y.data_ptr() + ); +} + + +// This is a terrible kernel, only use this to not call the pytorch version + +#define DECOMPRESS_HI3B1C_BLOCK_SIZE 128 + +__global__ void cuda_decompress_hi3b1c_packed_kernel( + const int32_t* __restrict__ YIs, // m x (n/8) + const c10::Half* __restrict__ CB, // 16 x 1 + c10::Half* __restrict__ Y // m x n +) { + const long i = threadIdx.x + DECOMPRESS_HI3B1C_BLOCK_SIZE * blockIdx.x; + + // 0 2 4 6 1 3 5 7 + uint32_t packed = YIs[i]; + Y[i*8 + 7] = CB[packed & 15]; + Y[i*8 + 5] = CB[(packed >> 4) & 15]; + Y[i*8 + 3] = CB[(packed >> 8) & 15]; + Y[i*8 + 1] = CB[(packed >> 12) & 15]; + Y[i*8 + 6] = CB[(packed >> 16) & 15]; + Y[i*8 + 4] = CB[(packed >> 20) & 15]; + Y[i*8 + 2] = CB[(packed >> 24) & 15]; + Y[i*8 + 0] = CB[(packed >> 28) & 15]; +} + + +void decompress_hi3b1c_packed( + torch::Tensor YIs, // m x (n/8) + torch::Tensor CB, + torch::Tensor &Y // m x n +) { + size_t m = Y.sizes()[0]; + size_t n = Y.sizes()[1]; + + assert(YIs.is_contiguous()); + assert(Y.is_contiguous()); + + assert(YIs.sizes()[0] == m); + assert(YIs.sizes()[1] * 8 == n); + + assert(CB.sizes()[0] == 8); + assert(CB.sizes()[1] == 1); + + + const dim3 threads(DECOMPRESS_HI3B1C_BLOCK_SIZE); + const dim3 blocks(m*n/(8*DECOMPRESS_HI3B1C_BLOCK_SIZE)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + cuda_decompress_hi3b1c_packed_kernel<<>>( + YIs.data_ptr(), + CB.data_ptr(), + Y.data_ptr() + ); +} + +// This is a terrible kernel, only use this to not call the pytorch version + +#define DECOMPRESS_HI2B1C_BLOCK_SIZE 128 + +__global__ void cuda_decompress_hi2b1c_packed_kernel( + const int32_t* __restrict__ YIs, // m x (n/8) + const c10::Half* __restrict__ CB, // 16 x 1 + c10::Half* __restrict__ Y // m x n +) { + const long i = threadIdx.x + DECOMPRESS_HI2B1C_BLOCK_SIZE * blockIdx.x; + + // 0 2 4 6 1 3 5 7 + uint32_t packed = YIs[i]; + Y[i*8 + 7] = CB[packed & 15]; + Y[i*8 + 5] = CB[(packed >> 4) & 15]; + Y[i*8 + 3] = CB[(packed >> 8) & 15]; + Y[i*8 + 1] = CB[(packed >> 12) & 15]; + Y[i*8 + 6] = CB[(packed >> 16) & 15]; + Y[i*8 + 4] = CB[(packed >> 20) & 15]; + Y[i*8 + 2] = CB[(packed >> 24) & 15]; + Y[i*8 + 0] = CB[(packed >> 28) & 15]; +} + + +void decompress_hi2b1c_packed( + torch::Tensor YIs, // m x (n/8) + torch::Tensor CB, + torch::Tensor &Y // m x n +) { + size_t m = Y.sizes()[0]; + size_t n = Y.sizes()[1]; + + assert(YIs.is_contiguous()); + assert(Y.is_contiguous()); + + assert(YIs.sizes()[0] == m); + assert(YIs.sizes()[1] * 8 == n); + + assert(CB.sizes()[0] == 4); + assert(CB.sizes()[1] == 1); + + + const dim3 threads(DECOMPRESS_HI2B1C_BLOCK_SIZE); + const dim3 blocks(m*n/(8*DECOMPRESS_HI2B1C_BLOCK_SIZE)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + cuda_decompress_hi2b1c_packed_kernel<<>>( + YIs.data_ptr(), + CB.data_ptr(), + Y.data_ptr() + ); +} + + + +// This is a terrible kernel, only use this to not call the pytorch version + +#define DECOMPRESS_E81B_BLOCK_SIZE 4 + +__global__ void cuda_decompress_e81b_packed_kernel( + const int64_t* __restrict__ YIs, // m x (n/8) + const c10::Half* __restrict__ CB, // 256 x 8 + c10::Half* __restrict__ Y // m x n +) { + const long i = threadIdx.x + DECOMPRESS_E81B_BLOCK_SIZE * blockIdx.x; + + uint64_t packed = YIs[i]; + +#pragma unroll + for (long j = 0; j < 8; j++) { + uint64_t yidx = packed & 255; + ((uint64_t*)Y)[(i*8 + j)*2] = ((uint64_t*)CB)[yidx*2]; + ((uint64_t*)Y)[(i*8 + j)*2 + 1] = ((uint64_t*)CB)[yidx*2 + 1]; + packed = packed >> 8; + } + +} + +void decompress_e81b_packed( + torch::Tensor YIs, // m x (n/8) + torch::Tensor CB, + torch::Tensor &Y // m x n +) { + size_t m = Y.sizes()[0]; + size_t n = Y.sizes()[1]; + + assert(YIs.is_contiguous()); + assert(Y.is_contiguous()); + + assert(YIs.sizes()[0] == m); + assert(YIs.sizes()[1] * 64 == n); + + assert(CB.sizes()[0] == 256); + assert(CB.sizes()[1] == 8); + + at::DeviceGuard guard(CB.device()); + const dim3 threads(DECOMPRESS_E81B_BLOCK_SIZE); + const dim3 blocks(m*n/(64*DECOMPRESS_E81B_BLOCK_SIZE)); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + cuda_decompress_e81b_packed_kernel<<>>( + YIs.data_ptr(), + CB.data_ptr(), + Y.data_ptr() + ); +} + + + +__global__ void cuda_lookupmatmul_e81b_k8_kernel( + const c10::Half* __restrict__ X, // k x n + const int64_t* __restrict__ YIs, // m x (n/64) + const c10::Half* __restrict__ CB, // 256 x 8 + float* __restrict__ Z, + size_t K, + size_t M, + size_t N) { + + long m1 = blockIdx.x; + long k1 = blockIdx.y; + + __shared__ c10::Half Y_cache0[32*16]; + wmma::fragment a0; // 8 x 16 + wmma::fragment b0; // 32 x 16 + + __shared__ c10::Half Y_cache1[32*16]; + wmma::fragment a1; // 8 x 16 + wmma::fragment b1; // 32 x 16 + + wmma::fragment c; // 8 x 32 + fill_fragment(c, 0.0); + + +#pragma unroll + for (long jn = 0; jn < N / 32; jn++) { + uint32_t packed = ((uint32_t*)YIs)[(m1*32 + threadIdx.x)*(N/32) + jn]; +#pragma unroll + for (long r = 0; r < 2; r++) { + uint32_t yidx = packed & 255; + ((uint64_t*)Y_cache0)[(threadIdx.x*2 + r)*2] = ((uint64_t*)CB)[yidx*2]; + ((uint64_t*)Y_cache0)[(threadIdx.x*2 + r)*2 + 1] = ((uint64_t*)CB)[yidx*2 + 1]; + packed = packed >> 8; + } +#pragma unroll + for (long r = 0; r < 2; r++) { + uint32_t yidx = packed & 255; + ((uint64_t*)Y_cache1)[(threadIdx.x*2 + r)*2] = ((uint64_t*)CB)[yidx*2]; + ((uint64_t*)Y_cache1)[(threadIdx.x*2 + r)*2 + 1] = ((uint64_t*)CB)[yidx*2 + 1]; + packed = packed >> 8; + } + + load_matrix_sync(a0, (const __half*)(X + 8*N*k1 + 32*jn), N); + load_matrix_sync(b0, (const __half*)Y_cache0, 16); + mma_sync(c, a0, b0, c); + + load_matrix_sync(a1, (const __half*)(X + 8*N*k1 + 32*jn + 16), N); + load_matrix_sync(b1, (const __half*)Y_cache1, 16); + mma_sync(c, a1, b1, c); + + } + + store_matrix_sync(&Z[8*M*k1 + 32*m1], c, M, wmma::mem_row_major); +} + + +void lookupmatmul_e81b_k8( + torch::Tensor X, // k x n + torch::Tensor YIs, // m x (n/64) + torch::Tensor CB, // 256 x 8 + torch::Tensor Z // k x m +) { + auto k = X.sizes()[0]; + auto m = YIs.sizes()[0]; + auto n = X.sizes()[1]; + + assert(Z.sizes()[0] == k); + assert(YIs.sizes()[1] * 64 == n); + assert(Z.sizes()[1] == m); + + assert(k <= 8); + assert(m % 32 == 0); + assert(n % 32 == 0); + + at::DeviceGuard guard(CB.device()); + const dim3 threads(32); + const dim3 blocks(m/32, k/8); + cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); + cuda_lookupmatmul_e81b_k8_kernel<<>>( + X.data_ptr(), + YIs.data_ptr(), + CB.data_ptr(), + Z.data_ptr(), + k,m,n + ); +} diff --git a/kernels/quiptools_cuda/quiptools_e8p_gemv.cu b/kernels/quiptools_cuda/quiptools_e8p_gemv.cu new file mode 100644 index 0000000..de199a6 --- /dev/null +++ b/kernels/quiptools_cuda/quiptools_e8p_gemv.cu @@ -0,0 +1,585 @@ +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace torch::indexing; +using namespace nvcuda; + +#define FULL_MASK 0xffffffff +#define HALF_MASK 0x0000ffff + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) do { CHECK_CUDA(x); CHECK_CONTIGUOUS(x); } while(false) +#define gpuErrchk(ans) do { gpuAssert((ans), __FILE__, __LINE__); } while (false) + + +__host__ static inline void gpuAssert(cudaError_t code, const char *file, int line, bool abort=true) +{ + if (code != cudaSuccess) + { + fprintf(stderr, "GPUassert[%s:%d]: %s\n", file, line, cudaGetErrorString(code)); + if (abort) exit(code); + } +} + +__device__ static inline uint32_t add_as_half2(uint32_t x, uint32_t y) { + uint32_t z; + asm("add.f16x2 %0,%1,%2;" : "=r"(z) : "r"(x), "r"(y)); + return z; +} + + +__device__ static inline uint32_t mask_lop3(uint32_t x, uint32_t m0, uint32_t m1) { + uint32_t y; + asm("lop3.b32 %0, %1, %2, %3, 0xEA;" : "=r"(y) : "r"(x), "r"(m0), "r"(m1)); + return y; + // return (x & m0) | m1; +} + +#define BASE_OFFSET 0xd080d080 +#define XMASK 0x00f000f0 +#define WMASK 0x50085008 + + +__global__ static void +// __launch_bounds__(1024, 1024) +decode_matvec_e8p_kernel( + float *__restrict__ output, + const uint2 *__restrict__ input, + const uint2 *__restrict__ weights_compressed, + const uint32_t *__restrict__ codebook_abs, + int N, + int K +) { + int warpId = threadIdx.y; + int laneId = threadIdx.x; + + // __shared__ float sum_scratch[16*32]; + + // __shared__ uint32_t codebook_local[256*32]; + // for (int icb = warpId; icb < 256; icb += 32) { + // codebook_local[icb*32 + laneId] = codebook_abs[icb]; + // } + // __syncthreads(); + + __shared__ uint2 shared_weights[1024*2]; + + for (int iin = blockIdx.x; iin < (N >> 4); iin += gridDim.x) { + + float z0 = 0.0; + float z1 = 0.0; + float z2 = 0.0; + float z3 = 0.0; + + // int shwo = laneId + 32*warpId; + + // __pipeline_memcpy_async(shared_weights + shwo, weights_compressed + laneId + 32*warpId + 1024*0 + (K >> 1)*iin, 8); + // __pipeline_commit(); + + for (int iik = warpId; iik < (K >> 6); iik += 32) { + // if (iik + 1 < (K >> 11)) { + // __pipeline_memcpy_async(shared_weights + (shwo ^ 1024), weights_compressed + laneId + 32*iik + 1024 + (K >> 1)*iin, 8); + // __pipeline_commit(); + // __pipeline_wait_prior(1); + // shwo = shwo ^ 1024; + // } + // else { + // __pipeline_wait_prior(0); + // } + + // uint2 w_compr = shared_weights[shwo]; // weights_compressed[laneId + 32*warpId + 1024*iik + (K >> 1)*iin]; + uint2 w_compr = weights_compressed[laneId + 32*iik + (K >> 1)*iin]; + uint32_t a = w_compr.x; + uint32_t b = w_compr.y; + + uint32_t s = b; + s = s ^ (s >> 4); + s = s ^ (s >> 8); + s = s ^ (s >> 16); + uint32_t sb = (s & 15); + s = b ^ sb; + sb = sb | (sb << 16); + + uint32_t input_to_warp = ((const uint32_t*)(&input[16*iik]))[laneId]; + uint32_t shifted_laneId = (laneId & 3) << 3; + + /// BLOCK 01 + { + uint32_t x = codebook_abs[(a >> 0) & 255]; + x = x ^ ((s & 0x11111111) * 14); + + uint32_t o = BASE_OFFSET | ((sb & 0x00010001) << 4); + + uint32_t w00 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); + uint32_t w01 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); + uint32_t w02 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); + uint32_t w03 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); + + x = codebook_abs[(a >> 8) & 255]; + x = x ^ ((s & 0x22222222) * 7); + + o = BASE_OFFSET | ((sb & 0x00020002) << 3); + + uint32_t w10 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); + uint32_t w11 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); + uint32_t w12 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); + uint32_t w13 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); + + // uint2 x_in = input[0 + (laneId & 3)*4 + 16*warpId + 16*32*iik]; + // uint32_t x_in0 = x_in.x; + // uint32_t x_in1 = x_in.y; + + uint32_t x_in0 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 0); + uint32_t x_in1 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 1); + + asm( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + " { %0, %1, %2, %3 }," + " { %4, %5, %6, %7 }," + " { %8, %9 }," + " { %0, %1, %2, %3 };" + : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3) + : "r"(w00), "r"(w10), "r"(w01), "r"(w11), + "r"(x_in0), "r"(x_in1) + ); + + + // x_in = input[1 + (laneId & 3)*4 + 16*warpId + 16*32*iik]; + // x_in0 = x_in.x; + // x_in1 = x_in.y; + + x_in0 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 2); + x_in1 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 3); + + asm( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + " { %0, %1, %2, %3 }," + " { %4, %5, %6, %7 }," + " { %8, %9 }," + " { %0, %1, %2, %3 };" + : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3) + : "r"(w02), "r"(w12), "r"(w03), "r"(w13), + "r"(x_in0), "r"(x_in1) + ); + } + /// BLOCK 23 + { + uint32_t x = codebook_abs[(a >> 16) & 255]; + s = s >> 2; + x = x ^ ((s & 0x11111111) * 14); + + uint32_t o = BASE_OFFSET | ((sb & 0x00040004) << 2); + + uint32_t w00 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); + uint32_t w01 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); + uint32_t w02 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); + uint32_t w03 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); + + x = codebook_abs[(a >> 24) & 255]; + x = x ^ ((s & 0x22222222) * 7); + + o = BASE_OFFSET | ((sb & 0x00080008) << 1); + + uint32_t w10 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); + uint32_t w11 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); + uint32_t w12 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); + uint32_t w13 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); + + + // uint2 x_in = input[2 + (laneId & 3)*4 + 16*warpId + 16*32*iik]; + // uint32_t x_in0 = x_in.x; + // uint32_t x_in1 = x_in.y; + + uint32_t x_in0 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 4); + uint32_t x_in1 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 5); + + asm( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + " { %0, %1, %2, %3 }," + " { %4, %5, %6, %7 }," + " { %8, %9 }," + " { %0, %1, %2, %3 };" + : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3) + : "r"(w00), "r"(w10), "r"(w01), "r"(w11), + "r"(x_in0), "r"(x_in1) + ); + + + // x_in = input[3 + (laneId & 3)*4 + 16*warpId + 16*32*iik]; + // x_in0 = x_in.x; + // x_in1 = x_in.y; + + x_in0 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 6); + x_in1 = __shfl_sync(FULL_MASK, input_to_warp, shifted_laneId | 7); + + asm( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + " { %0, %1, %2, %3 }," + " { %4, %5, %6, %7 }," + " { %8, %9 }," + " { %0, %1, %2, %3 };" + : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3) + : "r"(w02), "r"(w12), "r"(w03), "r"(w13), + "r"(x_in0), "r"(x_in1) + ); + } + } + + // we produced 16 outputs, so only 16 threads + if ((laneId & 1) == 0) { + atomicAdd(output + (iin << 4) + (laneId >> 1), (laneId & 2) ? z2 : z0); + } + + // if ((laneId & 3) == 0) { + // sum_scratch[warpId + ((laneId >> 1) + 0) * 32] = z0; + // sum_scratch[warpId + ((laneId >> 1) + 1) * 32] = z2; + // } + // __syncthreads(); + + // // load and sum + // if (warpId < 16) { + // float acc = sum_scratch[laneId + warpId*32]; + // for (int offset = 16; offset > 0; offset /= 2) { + // acc += __shfl_down_sync(FULL_MASK, acc, offset); + // } + // if (laneId == 0) { + // output[(iin << 4) + warpId] = acc; + // } + // } + } +} + + +__host__ extern torch::Tensor decode_matvec_e8p( + torch::Tensor x, + torch::Tensor weights_compressed, + torch::Tensor codebook_abs +) { + + CHECK_INPUT(x); + CHECK_INPUT(weights_compressed); + CHECK_INPUT(codebook_abs); + + TORCH_CHECK(x.dim() == 1); + TORCH_CHECK(weights_compressed.dim() == 4); + TORCH_CHECK(weights_compressed.size(3) == 4); + TORCH_CHECK(weights_compressed.size(2) == 8); + TORCH_CHECK(codebook_abs.dim() == 1); + TORCH_CHECK(x.scalar_type() == torch::kFloat16); + TORCH_CHECK(weights_compressed.scalar_type() == torch::kInt64); + TORCH_CHECK(codebook_abs.scalar_type() == torch::kInt32); + TORCH_CHECK(x.size(-1) == weights_compressed.size(1) << 6); + TORCH_CHECK(codebook_abs.size(-1) == 256); + + int64_t N = weights_compressed.size(0) * 16; + int64_t K = x.size(-1); + + TORCH_CHECK(K % 64 == 0, "K is not divisible by 64"); + TORCH_CHECK(N % 16 == 0, "N is not divisible by 16"); + + TORCH_CHECK(K < 65536, "K is not too large"); + TORCH_CHECK(N < 65536, "N is not too large"); + + at::DeviceGuard guard(x.device()); + torch::TensorOptions options = torch::TensorOptions() + .dtype(torch::kFloat32) + .layout(torch::kStrided) + .device(torch::kCUDA) + .requires_grad(false); + torch::Tensor output = torch::zeros(std::vector{N}, options); + + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, x.get_device()); + int64_t grid_size = static_cast(deviceProp.multiProcessorCount); + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + const dim3 block_size(32,32); + + decode_matvec_e8p_kernel<<>>( + output.data_ptr(), + (const uint2*)x.data_ptr(), + (const uint2*)weights_compressed.data_ptr(), + (const uint32_t*)codebook_abs.data_ptr(), + N, + K); + + gpuErrchk(cudaPeekAtLastError()); + + return output; +} + + + +__global__ static void +test_tc_kernel(float *__restrict__ output) { + int laneId = threadIdx.x; + + uint32_t w0 = (laneId == 0) ? 0x3C003C00 : 0x00000000; + uint32_t w1 = 0x00000000; + uint32_t w2 = 0x00000000; + uint32_t w3 = 0x00000000; + + uint32_t x0 = (laneId == 0) ? 0x3C003C00 : 0x00000000; + uint32_t x1 = 0x00000000; + + float z0 = 0.0; + float z1 = 0.0; + float z2 = 0.0; + float z3 = 0.0; + + asm( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32" + " { %0, %1, %2, %3 }," + " { %4, %5, %6, %7 }," + " { %8, %9 }," + " { %0, %1, %2, %3 };" + : "+f"(z0), "+f"(z1), "+f"(z2), "+f"(z3) + : "r"(w0), "r"(w1), "r"(w2), "r"(w3), + "r"(x0), "r"(x1) + ); + + output[laneId*4 + 0] = z0; + output[laneId*4 + 1] = z1; + output[laneId*4 + 2] = z2; + output[laneId*4 + 3] = z3; +} + +__host__ extern torch::Tensor test_tc() { + + torch::TensorOptions options = torch::TensorOptions() + .dtype(torch::kFloat32) + .layout(torch::kStrided) + .device(torch::kCUDA) + .requires_grad(false); + torch::Tensor output = torch::zeros(std::vector{32*4}, options); + + test_tc_kernel<<<1, 32>>>(output.data_ptr()); + + gpuErrchk(cudaPeekAtLastError()); + + return output; +} + + + + +__global__ static void +test_codebook_expand_kernel(uint32_t *__restrict__ output, const uint32_t *__restrict__ codebook_abs) { + uint32_t a = threadIdx.x; + uint32_t b = 0; + + for (int i = 0; i < 8; i++) { + b |= (((blockIdx.x >> i) & 1) << (4*i)); + } + + uint32_t s = b; + s = s ^ (s >> 4); + s = s ^ (s >> 8); + s = s ^ (s >> 16); + uint32_t sb = (s & 15); + s = b ^ sb; + sb = sb | (sb << 16); + + uint32_t x = codebook_abs[(a >> 0) & 255]; + x = x ^ ((s & 0x11111111) * 14); + + uint32_t o = BASE_OFFSET | ((sb & 0x00010001) << 4); + + uint32_t w0 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); + uint32_t w1 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); + uint32_t w2 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); + uint32_t w3 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); + + output[blockIdx.x*256*4 + threadIdx.x*4 + 0] = w0; + output[blockIdx.x*256*4 + threadIdx.x*4 + 1] = w1; + output[blockIdx.x*256*4 + threadIdx.x*4 + 2] = w2; + output[blockIdx.x*256*4 + threadIdx.x*4 + 3] = w3; +} + +__host__ extern torch::Tensor test_codebook_expand(torch::Tensor codebook_abs) { + + torch::TensorOptions options = torch::TensorOptions() + .dtype(torch::kFloat16) + .layout(torch::kStrided) + .device(torch::kCUDA) + .requires_grad(false); + torch::Tensor output = torch::zeros(std::vector{256*256,8}, options); + + test_codebook_expand_kernel<<<256, 256>>>((uint32_t*)output.data_ptr(), (const uint32_t*)codebook_abs.data_ptr()); + + gpuErrchk(cudaPeekAtLastError()); + + return output; +} + + + + +__global__ static void +// __launch_bounds__(1024, 1024) +decompress_packed_e8p_kernel( + uint32_t *__restrict__ output, + const uint2 *__restrict__ weights_compressed, + const uint32_t *__restrict__ codebook_abs, + int N, + int K +) { + int warpId = threadIdx.y; + int laneId = threadIdx.x; + + for (int iin = blockIdx.x; iin < (N >> 4); iin += gridDim.x) { + + for (int iik = warpId; iik < (K >> 6); iik += 32) { + uint2 w_compr = weights_compressed[laneId + 32*iik + (K >> 1)*iin]; + uint32_t a = w_compr.x; + uint32_t b = w_compr.y; + + uint32_t s = b; + s = s ^ (s >> 4); + s = s ^ (s >> 8); + s = s ^ (s >> 16); + uint32_t sb = (s & 15); + s = b ^ sb; + sb = sb | (sb << 16); + + /// BLOCK 01 + { + uint32_t x = codebook_abs[(a >> 0) & 255]; + x = x ^ ((s & 0x11111111) * 14); + + uint32_t o = BASE_OFFSET | ((sb & 0x00010001) << 4); + + uint32_t w00 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); + uint32_t w01 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); + uint32_t w02 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); + uint32_t w03 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); + + x = codebook_abs[(a >> 8) & 255]; + x = x ^ ((s & 0x22222222) * 7); + + o = BASE_OFFSET | ((sb & 0x00020002) << 3); + + uint32_t w10 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); + uint32_t w11 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); + uint32_t w12 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); + uint32_t w13 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); + + output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 0] = w00; + output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 1] = w01; + output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 0] = w10; + output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 1] = w11; + + output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 2] = w02; + output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 3] = w03; + output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 2] = w12; + output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 0*4 + ((laneId & 3) << 3) + 3] = w13; + + } + /// BLOCK 23 + { + uint32_t x = codebook_abs[(a >> 16) & 255]; + s = s >> 2; + x = x ^ ((s & 0x11111111) * 14); + + uint32_t o = BASE_OFFSET | ((sb & 0x00040004) << 2); + + uint32_t w00 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); + uint32_t w01 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); + uint32_t w02 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); + uint32_t w03 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); + + x = codebook_abs[(a >> 24) & 255]; + x = x ^ ((s & 0x22222222) * 7); + + o = BASE_OFFSET | ((sb & 0x00080008) << 1); + + uint32_t w10 = add_as_half2(mask_lop3(x << 4, XMASK, WMASK), o); + uint32_t w11 = add_as_half2(mask_lop3(x << 0, XMASK, WMASK), o); + uint32_t w12 = add_as_half2(mask_lop3(x >> 4, XMASK, WMASK), o); + uint32_t w13 = add_as_half2(mask_lop3(x >> 8, XMASK, WMASK), o); + + output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 0] = w00; + output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 1] = w01; + output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 0] = w10; + output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 1] = w11; + + output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 2] = w02; + output[iin*8*K + (laneId >> 2)*K + 0 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 3] = w03; + output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 2] = w12; + output[iin*8*K + (laneId >> 2)*K + 1 * (K >> 1) + iik*32 + 1*4 + ((laneId & 3) << 3) + 3] = w13; + } + } + } +} + + +__host__ extern torch::Tensor decompress_packed_e8p( + torch::Tensor weights_compressed, + torch::Tensor codebook_abs +) { + CHECK_INPUT(weights_compressed); + CHECK_INPUT(codebook_abs); + + TORCH_CHECK(weights_compressed.dim() == 4); + TORCH_CHECK(weights_compressed.size(3) == 4); + TORCH_CHECK(weights_compressed.size(2) == 8); + TORCH_CHECK(codebook_abs.dim() == 1); + TORCH_CHECK(weights_compressed.scalar_type() == torch::kInt64); + TORCH_CHECK(codebook_abs.scalar_type() == torch::kInt32); + TORCH_CHECK(codebook_abs.size(-1) == 256); + + int64_t N = weights_compressed.size(0) * 16; + int64_t K = weights_compressed.size(1) << 6; + + TORCH_CHECK(K % 64 == 0, "K is not divisible by 64"); + TORCH_CHECK(N % 16 == 0, "N is not divisible by 16"); + + TORCH_CHECK(K < 65536, "K is not too large"); + TORCH_CHECK(N < 65536, "N is not too large"); + + at::DeviceGuard guard(codebook_abs.device()); + torch::TensorOptions options = torch::TensorOptions() + .dtype(torch::kFloat16) + .layout(torch::kStrided) + .device(torch::kCUDA) + .requires_grad(false); + torch::Tensor output = torch::zeros(std::vector{N,K}, options); + + cudaDeviceProp deviceProp; + cudaGetDeviceProperties(&deviceProp, weights_compressed.get_device()); + int64_t grid_size = static_cast(deviceProp.multiProcessorCount); + at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream(); + + const dim3 block_size(32,32); + + decompress_packed_e8p_kernel<<>>( + (uint32_t*)output.data_ptr(), + (const uint2*)weights_compressed.data_ptr(), + (const uint32_t*)codebook_abs.data_ptr(), + N, + K); + + gpuErrchk(cudaPeekAtLastError()); + + return output; +} \ No newline at end of file diff --git a/kernels/quiptools_cuda/quiptools_wrapper.cpp b/kernels/quiptools_cuda/quiptools_wrapper.cpp new file mode 100644 index 0000000..97a6b9f --- /dev/null +++ b/kernels/quiptools_cuda/quiptools_wrapper.cpp @@ -0,0 +1,95 @@ +#include + +#include +#include + +void lookupmatmul_d4_k8( + torch::Tensor X, // k x n + torch::Tensor YIs, // m x (n/4) + torch::Tensor CB, // 256 x 4 + torch::Tensor Z // k x m +); + +void lookupmatmul_d4_k16( + torch::Tensor X, // k x n + torch::Tensor YIs, // m x (n/4) + torch::Tensor CB, // 256 x 4 + torch::Tensor Z // k x m +); + +void lookupmatmul_d4_k32( + torch::Tensor X, // k x n + torch::Tensor YIs, // m x (n/4) + torch::Tensor CB, // 256 x 4 + torch::Tensor Z // k x m +); + +void decompress_d4( + torch::Tensor YIs, // m x (n/4) + torch::Tensor CB, // 256 x 4 + torch::Tensor Y // m x n +); + +void decompress_d4_origorder( + torch::Tensor YIs, // m x (n/4) + torch::Tensor CB, // 256 x 4 + torch::Tensor Y // m x n +); + +torch::Tensor decompress_packed_e8p( + torch::Tensor weights_compressed, // m x (n/8) + torch::Tensor codebook_abs // 256 x 8 +); + +torch::Tensor decode_matvec_e8p( + torch::Tensor x, + torch::Tensor weights_compressed, + torch::Tensor codebook_abs +); + +void decompress_hi4b1c_packed( + torch::Tensor YIs, // m x (n/8) + torch::Tensor CB, // 16 x 1 + torch::Tensor &Y // m x n +); + +void decompress_hi3b1c_packed( + torch::Tensor YIs, // m x (n/8) + torch::Tensor CB, // 16 x 1 + torch::Tensor &Y // m x n +); + +void decompress_hi2b1c_packed( + torch::Tensor YIs, // m x (n/8) + torch::Tensor CB, // 16 x 1 + torch::Tensor &Y // m x n +); + +void decompress_e81b_packed( + torch::Tensor YIs, // m x (n/8) + torch::Tensor CB, // 256 x 8 + torch::Tensor &Y // m x n +); + +void lookupmatmul_e81b_k8( + torch::Tensor X, // k x n + torch::Tensor YIs, // m x (n/4) + torch::Tensor CB, // 256 x 4 + torch::Tensor Z // k x m +); + + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("lookupmatmul_d4_k8", &lookupmatmul_d4_k8, "lookupmatmul_d4_k8"); + m.def("lookupmatmul_d4_k16", &lookupmatmul_d4_k16, "lookupmatmul_d4_k16"); + m.def("lookupmatmul_d4_k32", &lookupmatmul_d4_k32, "lookupmatmul_d4_k32"); + m.def("decompress_d4", &decompress_d4, "decompress_d4"); + m.def("decompress_d4_origorder", &decompress_d4_origorder, "decompress_d4_origorder"); + m.def("decompress_packed_e8p", &decompress_packed_e8p, "decompress_packed_e8p"); + m.def("decode_matvec_e8p", &decode_matvec_e8p, "decode_matvec_e8p"); + m.def("decompress_hi4b1c_packed", &decompress_hi4b1c_packed, "decompress_hi4b1c_packed"); + m.def("decompress_hi3b1c_packed", &decompress_hi3b1c_packed, "decompress_hi3b1c_packed"); + m.def("decompress_hi2b1c_packed", &decompress_hi2b1c_packed, "decompress_hi2b1c_packed"); + m.def("decompress_e81b_packed", &decompress_e81b_packed, "decompress_e81b_packed"); + m.def("lookupmatmul_e81b_k8", &lookupmatmul_e81b_k8, "lookupmatmul_e81b_k8"); +} diff --git a/llmtools/engine/inference/modules.py b/llmtools/engine/inference/modules.py index 81ff200..47d412c 100644 --- a/llmtools/engine/inference/modules.py +++ b/llmtools/engine/inference/modules.py @@ -8,14 +8,14 @@ ## QUIP Implementation -import quiptools_cuda -from quip.lib.utils import matmul_hadU_cuda, matmul_hadUt_cuda, dtype_from_str -from quip.lib import codebook +# import quiptools_cuda +# from quip_sharp.lib.utils import matmul_hadU_cuda, matmul_hadUt_cuda, dtype_from_str +# from quip_sharp.lib import codebook -try: - import quant_cuda -except: - print('CUDA extension not installed. Inference will not work for (OPTQ)') +# try: +# import quant_cuda +# except: +# print('CUDA extension not installed. Inference will not work for (OPTQ)') # Assumes layer is perfectly divisible into 256 * 256 blocks diff --git a/llmtools/engine/lora/lora_old.py b/llmtools/engine/lora/lora_old.py index ff1846d..12d477d 100644 --- a/llmtools/engine/lora/lora_old.py +++ b/llmtools/engine/lora/lora_old.py @@ -22,8 +22,8 @@ from llmtools.engine.inference.modules import QuantLinear from llmtools.engine.lora.peft import quant_peft #* QUIP Quant Linear Layer path *# -from quip.lib.linear.quantized_linear import QuantizedLinear -from quip.lib.linear.fused_quantized_linear import FusedQuantizedLinear +from quip_sharp.lib.linear.quantized_linear import QuantizedLinear +from quip_sharp.lib.linear.fused_quantized_linear import FusedQuantizedLinear # hacky way to do imports for now LoraLayer = quant_peft.tuners.lora.LoraLayer diff --git a/llmtools/llms/llamaquip/model.py b/llmtools/llms/llamaquip/model.py index cca7ed6..4c21344 100644 --- a/llmtools/llms/llamaquip/model.py +++ b/llmtools/llms/llamaquip/model.py @@ -1,9 +1,8 @@ -import glog import os import json from transformers import LlamaTokenizer, LlamaConfig -from quip.model.llama import LlamaForCausalLM #* This is the latest llama with the fused quantized linear layer *# -#from quip.model.llama_nofuse import LlamaForCausalLM #* This is the latest llama with the unfused quantized linear layer *# +from quip_sharp.model.llama import LlamaForCausalLM #* This is the latest llama with the fused quantized linear layer *# +#from quip_sharp.model.llama_nofuse import LlamaForCausalLM #* This is the latest llama with the unfused quantized linear layer *# def load_llama_quip(hf_path, device_map='auto'): model = LlamaForCausalLM.from_pretrained(hf_path, @@ -22,7 +21,6 @@ def load_llama_quip(hf_path, device_map='auto'): _name_or_path = _config['_name_or_path'] #tokenizer = LlamaTokenizer.from_pretrained(_name_or_path, add_eos_token=True) # Append EOS token to end of the sentence - glog.info('loaded model!') #tokenizer.pad_token = tokenizer.eos_token return model, _name_or_path, _config \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..6d76f11 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools", "torch"] + +[tool.setuptools.packages.find] +include = ["llmtools*"] + +[project] +name = "llmtools" +version = "0.1.0" +authors = [ + { name="Volodymyr Kuleshov", email="kuleshov@cornell.edu" }, + { name="Junjie Oscar Yin", email="jyin27@jhu.edu" } +] + +description = "llmtools: Finetuning Large Language Models on One Consumer GPU in Under 4 Bits" +readme = "README.md" +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: MIT License", +] +dependencies = [ + "torch >= 2", + "transformers>=4.36.0", + "peft==0.5.0", + "datasets", + "bitsandbytes", + "quip_sharp@git+https://github.com/osbm/quip-sharp@main" +] + +[project.scripts] +console_scripts = "llmtools.run:main" + +[project.urls] +Homepage = "https://github.com/kuleshov-group/llmtools" +Issues = "https://github.com/kuleshov-group/llmtools/issues" diff --git a/setup.py b/setup.py index 1f03e36..03a6271 100644 --- a/setup.py +++ b/setup.py @@ -1,31 +1,28 @@ from setuptools import setup, find_packages, Extension from torch.utils import cpp_extension -setup( - name='llmtools', - version='0.1.0', - packages=find_packages(include=['llmtools', 'llmtools.*']), - entry_points={ - 'console_scripts': ['llmtools=llmtools.run:main'] - } -) - -setup( - name='quip', - version='0.1.0', - packages=find_packages(where='third-party'), - package_dir={'': 'third-party'}, -) setup( - name='quant_cuda', - ext_modules=[cpp_extension.CUDAExtension( - 'quant_cuda', - [ - 'llmtools/engine/inference/cuda/quant_cuda.cpp', - 'llmtools/engine/inference/cuda/quant_cuda_kernel.cu' - ] - )], + ext_modules=[ + cpp_extension.CUDAExtension( + 'quant_cuda', + [ + 'kernels/quant_cuda/quant_cuda.cpp', + 'kernels/quant_cuda/quant_cuda_kernel.cu' + ] + ), + cpp_extension.CUDAExtension( + 'quiptools_cuda', + [ + 'kernels/quiptools_cuda/quiptools_wrapper.cpp', + 'kernels/quiptools_cuda/quiptools.cu', + 'kernels/quiptools_cuda/quiptools_e8p_gemv.cu' + ], + extra_compile_args={ + 'cxx': ['-g', '-lineinfo'], + 'nvcc': ['-O2', '-g', '-Xcompiler', '-rdynamic', '-lineinfo'] + } + ) + ], cmdclass={'build_ext': cpp_extension.BuildExtension} ) - diff --git a/third-party/quip b/third-party/quip deleted file mode 160000 index 186a4bb..0000000 --- a/third-party/quip +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 186a4bb7f28bd72e64111020b2119fac7f3c6bef