Skip to content

Commit

Permalink
Fix the bug where the Torch library is not correctly linked.
Browse files Browse the repository at this point in the history
  • Loading branch information
lcy-seso committed Dec 30, 2024
1 parent be44b0d commit b0bb880
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 8 deletions.
1 change: 1 addition & 0 deletions csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,4 @@ target_compile_options(
--use_fast_math
--generate-line-info>)
target_compile_features(${TARGET} PUBLIC cxx_std_17 cuda_std_17)
target_link_libraries(${TARGET} "${TORCH_LIBRARIES}")
3 changes: 2 additions & 1 deletion csrc/ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "common.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>

#include <torch/extension.h>

#define CHECK_CUDA(x) \
Expand Down Expand Up @@ -155,7 +156,7 @@ torch::Tensor wqA16Gemm(const torch::Tensor& input,
return output;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_LIBRARY(cuda_ops, m) {
m.def("dequant", &dequant,
R"DOC(Dequantize matrix weights to fp16.
function type:
Expand Down
11 changes: 9 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,20 @@ def __init__(self, name="cuda_ops", cmake_lists_dir=".", **kwargs):
class CMakeBuildExt(build_ext):
"""launches the CMake build."""

def get_ext_filename(self, name):
return f"lib{name}.so"

def copy_extensions_to_source(self) -> None:
build_py = self.get_finalized_command("build_py")
for ext in self.extensions:
source_path = os.path.join(self.build_lib, "lib" + ext.name + ".so")
source_path = os.path.join(
self.build_lib, self.get_ext_filename(ext.name)
)
inplace_file, _ = self._get_inplace_equivalent(build_py, ext)

target_path = os.path.join(build_py.build_lib, "vptq", inplace_file)
target_path = os.path.join(
build_py.build_lib, "vptq", "ops", inplace_file
)

# Always copy, even if source is older than destination, to ensure
# that the right extensions for the current Python/platform are
Expand Down
5 changes: 4 additions & 1 deletion vptq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,7 @@

__version__ = importlib.metadata.version("vptq")

__all__ = ["AutoModelForCausalLM", "VQuantLinear"]
__all__ = [
"AutoModelForCausalLM",
"VQuantLinear",
]
17 changes: 13 additions & 4 deletions vptq/ops/quant_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
]

import math
import os

import torch
from torch.nn import functional as F
Expand All @@ -17,8 +18,16 @@
# we need to import the CUDA kernels after importing torch
__cuda_ops_installed = True
try:
from vptq import cuda_ops
except ImportError:
# from vptq import cuda_ops

torch.ops.load_library(
os.path.join(os.path.dirname(__file__), "libcuda_ops.so")
)
except Exception:
print((
"Customized CUDA operator is not installed. "
"PyTorch's implementation is used for quantized GEMM."
))
__cuda_ops_installed = False


Expand Down Expand Up @@ -226,7 +235,7 @@ def quant_gemm(
enable_norm = weight_scale is not None and weight_bias is not None

if (x.numel() // x.shape[-1] < 3) and __cuda_ops_installed:
out = cuda_ops.gemm(
out = torch.ops.cuda_ops.gemm(
x,
indices,
centroids_,
Expand All @@ -245,7 +254,7 @@ def quant_gemm(
return out
else:
if __cuda_ops_installed:
weight = cuda_ops.dequant(
weight = torch.ops.cuda_ops.dequant(
indices,
centroids_,
residual_indices,
Expand Down

0 comments on commit b0bb880

Please sign in to comment.