Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Oct 15, 2024
1 parent 5deab1a commit 0f0f310
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/developer-guide/operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,7 @@ y = affine(out)
| 4 | vdim | int | embed_dim | |
| 5 | attn_mask | int | 0 | |
| 6 | scale | float | 1.f / sqrt(embed_dim / num_heads) | |
| 18 | int8_scale_term | int | 0 | |

| weight | type | shape |
| ------------- | ----- | --------------------- |
Expand All @@ -1288,6 +1289,10 @@ y = affine(out)
| v_bias_data | float | [embed_dim] |
| out_weight_data| float/fp16/int8 | [qdim * embed_dim] |
| out_bias_data | float | [qdim] |
| q_weight_data_int8_scales| float | [embed_dim] |
| k_weight_data_int8_scales| float | [embed_dim] |
| v_weight_data_int8_scales| float | [embed_dim] |
| out_weight_data_int8_scales| float | [1] |

# MVN
```
Expand Down
14 changes: 14 additions & 0 deletions tools/modelwriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -2038,6 +2038,7 @@ int ModelWriter::save(const char* parampath, const char* binpath)
fprintf_param_value(" 4=%d", vdim)
fprintf_param_value(" 5=%d", attn_mask)
fprintf_param_value(" 6=%e", scale)
fprintf_param_value(" 18=%d", int8_scale_term)

fwrite_weight_tag_data(op->q_weight_data, bp);
fwrite_weight_data(op->q_bias_data, bp);
Expand All @@ -2047,6 +2048,19 @@ int ModelWriter::save(const char* parampath, const char* binpath)
fwrite_weight_data(op->v_bias_data, bp);
fwrite_weight_tag_data(op->out_weight_data, bp);
fwrite_weight_data(op->out_bias_data, bp);

#if NCNN_INT8
// write int8_scale data
if (op->int8_scale_term)
{
fwrite_weight_data(op->q_weight_data_int8_scales, bp, 90, 100);
fwrite_weight_data(op->k_weight_data_int8_scales, bp, 90, 100);
fwrite_weight_data(op->v_weight_data_int8_scales, bp, 90, 100);
ncnn::Mat out_weight_data_int8_scales(1);
out_weight_data_int8_scales[0] = op->out_weight_data_int8_scale;
fwrite_weight_data(out_weight_data_int8_scales, bp, 90, 100);
}
#endif // NCNN_INT8
}
else if (layer->type == "MVN")
{
Expand Down
133 changes: 133 additions & 0 deletions tools/quantize/ncnn2int8.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ class NetQuantize : public ModelWriter

int quantize_embed();
int quantize_gemm();
int quantize_multiheadattention();

int fuse_requantize();
};
Expand Down Expand Up @@ -721,6 +722,137 @@ int NetQuantize::quantize_gemm()
return 0;
}

int NetQuantize::quantize_multiheadattention()
{
for (size_t i = 0; i < layers.size(); i++)
{
if (layers[i]->type != "MultiHeadAttention")
continue;

// MultiHeadAttention - quantize weight from fp32 to int8
ncnn::MultiHeadAttention* mha = (ncnn::MultiHeadAttention*)layers[i];

fprintf(stderr, "quantize_multiheadattention %s\n", mha->name.c_str());

// TODO move to ncnn2table

const int qdim = mha->weight_data_size / mha->embed_dim;

{
mha->q_weight_data_int8_scales.create(mha->embed_dim);
for (int i = 0; i < mha->embed_dim; i++)
{
float absmax = 0.f;

const float* ptr = (const float*)mha->q_weight_data + i * qdim;
for (int j = 0; j < qdim; j++)
{
absmax = std::max(absmax, (float)fabs(ptr[j]));
}

mha->q_weight_data_int8_scales[i] = absmax == 0.f ? 1.f : 127 / absmax;
}

ncnn::Mat q_weight_data = mha->q_weight_data.reshape(qdim, mha->embed_dim);
ncnn::Mat q_weight_data_int8;

ncnn::Option opt_q = opt;
opt_q.blob_allocator = q_weight_data.allocator;
opt_q.use_packing_layout = false;
ncnn::quantize_to_int8(q_weight_data, q_weight_data_int8, mha->q_weight_data_int8_scales, opt_q);
if (q_weight_data_int8.empty())
return -100;

mha->q_weight_data = q_weight_data_int8.reshape(qdim * mha->embed_dim);
}

{
mha->k_weight_data_int8_scales.create(mha->embed_dim);
for (int i = 0; i < mha->embed_dim; i++)
{
float absmax = 0.f;

const float* ptr = (const float*)mha->k_weight_data + i * mha->kdim;
for (int j = 0; j < mha->kdim; j++)
{
absmax = std::max(absmax, (float)fabs(ptr[j]));
}

mha->k_weight_data_int8_scales[i] = absmax == 0.f ? 1.f : 127 / absmax;
}

ncnn::Mat k_weight_data = mha->k_weight_data.reshape(mha->kdim, mha->embed_dim);
ncnn::Mat k_weight_data_int8;

ncnn::Option opt_q = opt;
opt_q.blob_allocator = k_weight_data.allocator;
opt_q.use_packing_layout = false;
ncnn::quantize_to_int8(k_weight_data, k_weight_data_int8, mha->k_weight_data_int8_scales, opt_q);
if (k_weight_data_int8.empty())
return -100;

mha->k_weight_data = k_weight_data_int8.reshape(mha->kdim * mha->embed_dim);
}

{
mha->v_weight_data_int8_scales.create(mha->embed_dim);
for (int i = 0; i < mha->embed_dim; i++)
{
float absmax = 0.f;

const float* ptr = (const float*)mha->v_weight_data + i * mha->vdim;
for (int j = 0; j < mha->vdim; j++)
{
absmax = std::max(absmax, (float)fabs(ptr[j]));
}

mha->v_weight_data_int8_scales[i] = absmax == 0.f ? 1.f : 127 / absmax;
}

ncnn::Mat v_weight_data = mha->v_weight_data.reshape(mha->vdim, mha->embed_dim);
ncnn::Mat v_weight_data_int8;

ncnn::Option opt_q = opt;
opt_q.blob_allocator = v_weight_data.allocator;
opt_q.use_packing_layout = false;
ncnn::quantize_to_int8(v_weight_data, v_weight_data_int8, mha->v_weight_data_int8_scales, opt_q);
if (v_weight_data_int8.empty())
return -100;

mha->v_weight_data = v_weight_data_int8.reshape(mha->vdim * mha->embed_dim);
}

{
const float* ptr = mha->out_weight_data;
float absmax = 0.f;
for (int j = 0; j < mha->out_weight_data.w; j++)
{
absmax = std::max(absmax, (float)fabs(ptr[j]));
}

mha->out_weight_data_int8_scale = absmax == 0.f ? 1.f : 127 / absmax;

ncnn::Mat out_weight_data_int8_scales(1);
out_weight_data_int8_scales[0] = mha->out_weight_data_int8_scale;

ncnn::Mat out_weight_data_int8;

ncnn::Option opt_q = opt;
opt_q.blob_allocator = mha->out_weight_data.allocator;
opt_q.use_packing_layout = false;
ncnn::quantize_to_int8(mha->out_weight_data, out_weight_data_int8, out_weight_data_int8_scales, opt_q);
if (out_weight_data_int8.empty())
return -100;

mha->out_weight_data = out_weight_data_int8;
}

mha->int8_scale_term = 2;
}

return 0;
}

int NetQuantize::fuse_requantize()
{
const size_t layer_count = layers.size();
Expand Down Expand Up @@ -970,6 +1102,7 @@ int main(int argc, char** argv)
quantizer.quantize_gru();
quantizer.quantize_embed();
quantizer.quantize_gemm();
quantizer.quantize_multiheadattention();

quantizer.fuse_requantize();

Expand Down

0 comments on commit 0f0f310

Please sign in to comment.