Skip to content

Commit

Permalink
fix: Fix the bug where the Torch library is not correctly linked. (#152)
Browse files Browse the repository at this point in the history
  • Loading branch information
lcy-seso authored Dec 31, 2024
1 parent be44b0d commit f1c0204
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 61 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# --------------------------------------------------------------------------

cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(vptq_cuda_ops LANGUAGES C CXX CUDA)
project(vptq LANGUAGES C CXX CUDA)

# Prohibit in-source builds
if(${CMAKE_SOURCE_DIR} STREQUAL ${CMAKE_BINARY_DIR})
Expand Down
6 changes: 4 additions & 2 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@
# Copyright (c) Microsoft Corporation. All rights reserved. Licensed under the
# MIT License.
# --------------------------------------------------------------------------
set(TARGET "cuda_ops")
file(GLOB_RECURSE SOURCES "*.cu")
set(TARGET "vptq")
file(GLOB_RECURSE SOURCES "*.cu" "*.cc")
message(STATUS "Building ${TARGET} with ${SOURCES}")

cuda_add_library(${TARGET} SHARED ${SOURCES})

Expand All @@ -29,3 +30,4 @@ target_compile_options(
--use_fast_math
--generate-line-info>)
target_compile_features(${TARGET} PUBLIC cxx_std_17 cuda_std_17)
target_link_libraries(${TARGET} "${TORCH_LIBRARIES}")
4 changes: 2 additions & 2 deletions csrc/common.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
Expand Down Expand Up @@ -39,4 +39,4 @@ inline void gpuAssert(cudaError_t code, const char* file, int line) {
line);
TORCH_CHECK(false, cudaGetErrorString(code));
}
}
}
6 changes: 3 additions & 3 deletions csrc/dequant_impl_packed.cu
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ __global__ void DequantizeWithOutliers_PackIndice(
// @param weight_bias
// @return torch::Tensor
torch::Tensor launch_deqantize_outliers_cuda_packkernel(
const int* outf_x_inf, const torch::Tensor& q_indice,
const int64_t* outf_x_inf, const torch::Tensor& q_indice,
const torch::Tensor& centroids,
const c10::optional<torch::Tensor>& q_indice_residual,
const c10::optional<torch::Tensor>& residual_centroids,
Expand Down Expand Up @@ -534,7 +534,7 @@ torch::Tensor launch_deqantize_outliers_cuda_packkernel(
// @param bias
// @return torch::Tensor
torch::Tensor launch_gemv_outliers_cuda_packkernel(
const int out_features, const torch::Tensor& input,
const int64_t out_features, const torch::Tensor& input,
const torch::Tensor& q_indice, const torch::Tensor& centroids,
const c10::optional<torch::Tensor>& q_indice_residual,
const c10::optional<torch::Tensor>& residual_centroids,
Expand All @@ -544,7 +544,7 @@ torch::Tensor launch_gemv_outliers_cuda_packkernel(
const torch::Tensor& weight_bias,
const c10::optional<torch::Tensor>& bias) {
OptionalCUDAGuard cudaguard(input.device().index());
const int base_groupsize = centroids.size(-1);
const int64_t base_groupsize = centroids.size(-1);
int index_bits = log2(centroids.size(1));
int res_index_bits = residual_centroids.has_value()
? log2(residual_centroids.value().size(1))
Expand Down
67 changes: 44 additions & 23 deletions csrc/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
#include "common.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <torch/extension.h>
#include <torch/library.h>

#define CHECK_CUDA(x) \
TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor")
Expand All @@ -15,7 +17,7 @@
CHECK_CONTIGUOUS(x)

torch::Tensor launch_deqantize_outliers_cuda_packkernel(
const int* outf_x_inf, const torch::Tensor& q_indice,
const int64_t* outf_x_inf, const torch::Tensor& q_indice,
const torch::Tensor& centroids,
const c10::optional<torch::Tensor>& q_indice_residual,
const c10::optional<torch::Tensor>& residual_centroids,
Expand All @@ -25,7 +27,7 @@ torch::Tensor launch_deqantize_outliers_cuda_packkernel(
const torch::Tensor& weight_bias);

torch::Tensor launch_gemv_outliers_cuda_packkernel(
const int out_features, const torch::Tensor& input,
const int64_t out_features, const torch::Tensor& input,
const torch::Tensor& q_indice, const torch::Tensor& centroids,
const c10::optional<torch::Tensor>& q_indice_residual,
const c10::optional<torch::Tensor>& residual_centroids,
Expand All @@ -42,8 +44,8 @@ torch::Tensor dequant(const torch::Tensor& q_indice,
const c10::optional<torch::Tensor>& outliers_centroids,
const c10::optional<torch::Tensor>& invperm,
const torch::Tensor& weight_scale,
const torch::Tensor& weight_bias, int groupsize,
int in_features, int out_features) {
const torch::Tensor& weight_bias, int64_t groupsize,
int64_t in_features, int64_t out_features) {
auto dev_index = q_indice.device().index();

CHECK_INPUT(q_indice);
Expand Down Expand Up @@ -85,7 +87,7 @@ torch::Tensor dequant(const torch::Tensor& q_indice,

at::cuda::OptionalCUDAGuard guard(q_indice.device());
torch::Tensor output;
const int out_f_x_in_f[2] = {out_features, in_features};
const int64_t out_f_x_in_f[2] = {out_features, in_features};

output = launch_deqantize_outliers_cuda_packkernel(
out_f_x_in_f, q_indice, centroids, q_indice_residual, residual_centroids,
Expand All @@ -106,8 +108,9 @@ torch::Tensor wqA16Gemm(const torch::Tensor& input,
const c10::optional<torch::Tensor>& invperm,
const torch::Tensor& weight_scale,
const torch::Tensor& weight_bias,
const c10::optional<torch::Tensor>& bias, int groupsize,
int in_features, int out_features) {
const c10::optional<torch::Tensor>& bias,
int64_t groupsize, int64_t in_features,
int64_t out_features) {
CHECK_INPUT(q_indice);
CHECK_INPUT(input);
if (q_indice_residual.has_value()) {
Expand Down Expand Up @@ -155,22 +158,40 @@ torch::Tensor wqA16Gemm(const torch::Tensor& input,
return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dequant", &dequant,
R"DOC(Dequantize matrix weights to fp16.
function type:
const torch::Tensor& qweight,
const torch::Tensor& scales,
const torch::Tensor& qzeros,
Tensor g_idx, int groupsize, int bits, int in_features
)DOC");
TORCH_LIBRARY_IMPL(vptq, CUDA, m) {
m.impl("dequant", dequant);
m.impl("gemm", wqA16Gemm);
}

m.def("gemm", &wqA16Gemm,
R"DOC(Compute the gemm output, usually gemv.
function type:
const torch::Tensor& qweight,
const torch::Tensor& scales,
const torch::Tensor& qzeros,
tensor g_idx, int groupsize, int bits, int in_features
TORCH_LIBRARY(vptq, m) {
m.def(
R"DOC(dequant(Tensor q_indice,
Tensor centroids,
Tensor? q_indice_residual,
Tensor? residual_centroids,
Tensor? q_indice_outliers,
Tensor? outliers_centroids,
Tensor? invperm,
Tensor weight_scale,
Tensor weight_bias,
int groupsize,
int in_features,
int out_features) -> Tensor
)DOC");
m.def(
R"DOC(gemm(Tensor input,
Tensor q_indice,
Tensor centroids,
Tensor? q_indice_residual,
Tensor? residual_centroids,
Tensor? q_indice_outliers,
Tensor? outliers_centroids,
Tensor? invperm,
Tensor weight_scale,
Tensor weight_bias,
Tensor? bias,
int groupsize,
int in_features,
int out_features) -> Tensor
)DOC");
}
15 changes: 11 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,28 @@ def get_requirements():
class CMakeExtension(Extension):
""" specify the root folder of the CMake projects"""

def __init__(self, name="cuda_ops", cmake_lists_dir=".", **kwargs):
def __init__(self, name, cmake_lists_dir=".", **kwargs):
Extension.__init__(self, name, sources=[], **kwargs)
self.cmake_lists_dir = os.path.abspath(cmake_lists_dir)


class CMakeBuildExt(build_ext):
"""launches the CMake build."""

def get_ext_filename(self, name):
return f"lib{name}.so"

def copy_extensions_to_source(self) -> None:
build_py = self.get_finalized_command("build_py")
for ext in self.extensions:
source_path = os.path.join(self.build_lib, "lib" + ext.name + ".so")
source_path = os.path.join(
self.build_lib, self.get_ext_filename(ext.name)
)
inplace_file, _ = self._get_inplace_equivalent(build_py, ext)

target_path = os.path.join(build_py.build_lib, "vptq", inplace_file)
target_path = os.path.join(
build_py.build_lib, "vptq", "ops", inplace_file
)

# Always copy, even if source is older than destination, to ensure
# that the right extensions for the current Python/platform are
Expand Down Expand Up @@ -169,7 +176,7 @@ def run(self):
version=get_version(),
description=description,
author="Wang Yang, Wen JiCheng",
ext_modules=[CMakeExtension()],
ext_modules=[CMakeExtension("vptq")],
cmdclass={
"build_ext": CMakeBuildExt,
"clean": Clean,
Expand Down
5 changes: 4 additions & 1 deletion vptq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@

__version__ = importlib.metadata.version("vptq")

__all__ = ["AutoModelForCausalLM", "VQuantLinear"]
__all__ = [
"AutoModelForCausalLM",
"VQuantLinear",
]
5 changes: 4 additions & 1 deletion vptq/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,7 @@

from vptq.layers.model_base import AutoModelForCausalLM, VQuantLinear

__all__ = ["AutoModelForCausalLM", "VQuantLinear"]
__all__ = [
"AutoModelForCausalLM",
"VQuantLinear",
]
13 changes: 0 additions & 13 deletions vptq/layers/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
# --------------------------------------------------------------------------

import glob
import importlib.util
from pathlib import Path

import accelerate
Expand Down Expand Up @@ -203,18 +202,6 @@ def from_pretrained(
preload_module_classes=["VQuantLinear"]
)

# check cuda kernel exist
if importlib.util.find_spec("vptq.cuda_ops") is not None:
pass
else:
print((
"!!! Warning !!!: CUDA kernels are not found, "
"please check CUDA and VPTQ installation."
))
print((
"!!! Warning !!!: Running on Torch implementations, "
"which is extremely slow."
))
model.eval()

torch.cuda.empty_cache()
Expand Down
38 changes: 27 additions & 11 deletions vptq/ops/quant_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,38 @@
# --------------------------------------------------------------------------

__all__ = [
'dequant',
'quant_gemm',
"dequant",
"quant_gemm",
]

import math
import os

import torch
from torch.nn import functional as F

# isort: off
# we need to import the CUDA kernels after importing torch
__cuda_ops_installed = True
try:
from vptq import cuda_ops
except ImportError:
__cuda_ops_installed = False

def _load_library(filename: str) -> bool:
"""Load a shared library from the given filename."""
try:
libdir = os.path.dirname(os.path.dirname(__file__))
torch.ops.load_library(os.path.join(libdir, filename))
print(f"Successfully loaded: '{filename}'")
return True
except Exception as error:
print((
f"{error}\n"
"!!! Warning !!!: CUDA kernels are not found, "
"please check CUDA and VPTQ installation."
))
print((
"!!! Warning !!!: Running on Torch implementations, "
"which is extremely slow."
))
return False


__cuda_ops_installed: bool = _load_library("libvptq.so")


def unpack_index_tensor(
Expand Down Expand Up @@ -226,7 +242,7 @@ def quant_gemm(
enable_norm = weight_scale is not None and weight_bias is not None

if (x.numel() // x.shape[-1] < 3) and __cuda_ops_installed:
out = cuda_ops.gemm(
out = torch.ops.vptq.gemm(
x,
indices,
centroids_,
Expand All @@ -245,7 +261,7 @@ def quant_gemm(
return out
else:
if __cuda_ops_installed:
weight = cuda_ops.dequant(
weight = torch.ops.vptq.dequant(
indices,
centroids_,
residual_indices,
Expand Down

0 comments on commit f1c0204

Please sign in to comment.