Skip to content

Commit

Permalink
feat: fix and use marlin kernel for awq by default (#326)
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi authored Sep 4, 2024
1 parent 2a947e1 commit 5c8c82a
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 10 deletions.
7 changes: 5 additions & 2 deletions src/layers/linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include "linear_impl.h"
#include "quantization/qlinear_awq_impl.h"
#include "quantization/qlinear_awq_marlin_impl.h"
#include "quantization/qlinear_exllamav2_impl.h"
#include "quantization/qlinear_gptq_impl.h"
#include "quantization/qlinear_gptq_marlin_impl.h"
Expand Down Expand Up @@ -133,7 +134,8 @@ std::shared_ptr<ParallelLinearImpl> create_column_parallel_qlinear(
if (boost::iequals(quant_args.quant_method(), "awq") ||
boost::iequals(quant_args.quant_method(), "GEMM")) {
// default to use awq implementation for gemm
return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearAWQImpl);
// return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearAWQImpl);
return MAKE_COLUMN_PARALLEL_QLINEAR(ColumnParallelQLinearAWQMarlinImpl);
}
// not supported quant method
LOG(FATAL) << "Unsupported quant method: " << quant_args.quant_method();
Expand Down Expand Up @@ -163,7 +165,8 @@ std::shared_ptr<ParallelLinearImpl> create_row_parallel_qlinear(
if (boost::iequals(quant_args.quant_method(), "awq") ||
boost::iequals(quant_args.quant_method(), "GEMM")) {
// default to use awq implementation for gemm
return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearAWQImpl);
// return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearAWQImpl);
return MAKE_ROW_PARALLEL_QLINEAR(RowParallelQLinearAWQMarlinImpl);
}
// not supported quant method
LOG(FATAL) << "Unsupported quant method: " << quant_args.quant_method();
Expand Down
2 changes: 2 additions & 0 deletions src/model_loader/args_overrider.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ DEFINE_string(bits, "", "number of bits for quantization");
DEFINE_string(group_size, "", "group size for quantization");
DEFINE_string(desc_act, "", "desc_act for quantization");
DEFINE_string(is_sym, "", "is_sym for quantization");
DEFINE_string(zero_point, "", "zero_point for quantization");

// define gflags for all tokenizer args defined in
DEFINE_string(tokenizer_type,
Expand Down Expand Up @@ -189,6 +190,7 @@ void override_args_from_gflag(ModelArgs& args,
OVERRIDE_ARG_FROM_GFLAG(quant_args, group_size);
OVERRIDE_ARG_FROM_GFLAG(quant_args, desc_act);
OVERRIDE_ARG_FROM_GFLAG(quant_args, is_sym);
OVERRIDE_ARG_FROM_GFLAG(quant_args, zero_point);

// override tokenizer args from gflag
OVERRIDE_ARG_FROM_GFLAG(tokenizer_args, tokenizer_type);
Expand Down
1 change: 1 addition & 0 deletions src/model_loader/args_overrider.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ DECLARE_string(bits);
DECLARE_string(group_size);
DECLARE_string(desc_act);
DECLARE_string(is_sym);
DECLARE_string(zero_point);

// tokenizer flags
DECLARE_string(tokenizer_type);
Expand Down
6 changes: 6 additions & 0 deletions src/model_loader/model_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,9 @@ bool HFModelLoader::load_model_args(const std::string& model_weights_path) {
if (auto v = reader.value<bool>("quantization_config.sym")) {
quant_args_.is_sym() = v.value();
}
if (auto v = reader.value<bool>("quantization_config.zero_point")) {
quant_args_.zero_point() = v.value();
}
}

// load quantization args for awq if exists
Expand Down Expand Up @@ -188,6 +191,9 @@ bool HFModelLoader::load_model_args(const std::string& model_weights_path) {
if (auto v = gptq_reader.value<bool>("sym")) {
quant_args_.is_sym() = v.value();
}
if (auto v = gptq_reader.value<bool>("zero_point")) {
quant_args_.zero_point() = v.value();
}
}

// load tokenizer args from tokenizer_config.json if exists
Expand Down
17 changes: 9 additions & 8 deletions src/quantization/qlinear_awq_marlin_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ int64_t round_up(int64_t num, int64_t multiple) {
}

void check_awq_quant_args(const QuantArgs& quant_args) {
CHECK(quant_args.is_sym())
<< "Only symmetric quantization is supported for GPTQ";
CHECK(quant_args.zero_point() && !quant_args.is_sym())
<< "Only zero_point is supported for AWQ";

const auto bits = quant_args.bits();
CHECK(bits == 4 || bits == 8) << "Only 4 and 8 bits are supported for GPTQ";
CHECK(bits == 4 || bits == 8) << "Only 4 and 8 bits are supported for AWQ";

const auto group_size = quant_args.group_size();
CHECK(group_size == -1 || group_size == 32 || group_size == 64 ||
group_size == 128)
<< "Only group_size of -1, 32, 64, 128 are supported for GPTQ";
<< "Only group_size of -1, 32, 64, 128 are supported for AWQ";
}

const std::vector<int64_t> kScalesPerm = {
Expand Down Expand Up @@ -104,10 +104,12 @@ void repack_weight(torch::Tensor& qweight,
auto marlin_qweights = torch::empty(
{qweight.size(0) / 16, qweight.size(1) * 16}, qweight.options());
marlin::awq_repack(qweight, marlin_qweights, num_bits);
// (k, n/pack_factor) -> (k/16, n*16/pack_factor)
qweight.set_data(marlin_qweights);

// permute and repack qzeros
auto marlin_qzeros = repack_qzeros(qzeros, num_bits);
// (n_groups, n/pack_factor) -> (n_groups, n/pack_factor)
qzeros.set_data(marlin_qzeros);

// permute scales
Expand All @@ -118,6 +120,7 @@ void repack_weight(torch::Tensor& qweight,
scales.reshape({-1, perm_len})
.index_select(/*dim=*/1, torch::tensor(scale_perm, scales.device()));
marlin_scales = marlin_scales.reshape(scales.sizes()).contiguous();
// (n_groups, n) -> (n_groups, n)
scales.set_data(marlin_scales);
}

Expand Down Expand Up @@ -231,8 +234,7 @@ torch::Tensor ColumnParallelQLinearAWQMarlinImpl::forward(torch::Tensor input) {
weight_repacked_ = true;
}

auto output =
torch::empty({input.size(0), qweight_.size(1)}, input.options());
auto output = torch::empty({input.size(0), scales_.size(1)}, input.options());
marlin::gptq_gemm(input,
qweight_,
output,
Expand Down Expand Up @@ -338,8 +340,7 @@ torch::Tensor RowParallelQLinearAWQMarlinImpl::forward(torch::Tensor input) {
input = scatter_to_model_parallel_region(input, parallel_args_);
}

auto output =
torch::empty({input.size(0), qweight_.size(1)}, input.options());
auto output = torch::empty({input.size(0), scales_.size(1)}, input.options());
marlin::gptq_gemm(input,
qweight_,
output,
Expand Down
4 changes: 4 additions & 0 deletions src/quantization/quant_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ struct QuantArgs {
// whether the input is symmetric
DEFINE_ARG(bool, is_sym) = false;

// whether has zero point
DEFINE_ARG(bool, zero_point) = false;

// check if weights can be fused
bool can_be_fused() const {
// can't fuse quantized weights if desc_act is true
Expand All @@ -36,6 +39,7 @@ inline std::ostream& operator<<(std::ostream& os, const QuantArgs& args) {
os << ", group_size: " << args.group_size();
os << ", desc_act: " << args.desc_act();
os << ", is_sym: " << args.is_sym();
os << ", zero_point: " << args.zero_point();
os << "]";
return os;
}
Expand Down

0 comments on commit 5c8c82a

Please sign in to comment.