Skip to content

Commit

Permalink
changed post-quant methods (PaddlePaddle#713)
Browse files Browse the repository at this point in the history
  • Loading branch information
XGZhang11 committed Apr 16, 2021
1 parent 8fad8d4 commit e7a02b5
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 10 deletions.
6 changes: 3 additions & 3 deletions demo/quant/quant_post/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ python quant_post_static.py --model_path ./inference_model/MobileNet --save_path

运行以上命令后,可在``${save_path}``下看到量化后的模型文件和参数文件。

> 使用的量化算法为``'KL'``, 使用训练集中的160张图片进行量化参数的校正
> 使用的量化算法为``'hist'``, 使用训练集中的32张图片进行量化参数的校正

### 测试精度
Expand All @@ -67,6 +67,6 @@ python eval.py --model_path ./quant_model_train/MobileNet --model_name __model__

精度输出为
```
top1_acc/top5_acc= [0.70141864 0.89086477]
top1_acc/top5_acc= [0.70328485 0.89183184]
```
从以上精度对比可以看出,对``mobilenet````imagenet``上的分类模型进行离线量化后 ``top1``精度损失为``0.77%````top5``精度损失为``0.46%``.
从以上精度对比可以看出,对``mobilenet````imagenet``上的分类模型进行离线量化后 ``top1``精度损失为``0.59%````top5``精度损失为``0.36%``.
10 changes: 7 additions & 3 deletions demo/quant/quant_post/quant_post.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 16, "Minibatch size.")
add_arg('batch_num', int, 10, "Batch number")
add_arg('batch_size', int, 32, "Minibatch size.")
add_arg('batch_num', int, 1, "Batch number")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model_path', str, "./inference_model/MobileNet/", "model dir")
add_arg('save_path', str, "./quant_model/MobileNet/", "model dir to save quanted model")
add_arg('model_filename', str, None, "model file name")
add_arg('params_filename', str, None, "params file name")
add_arg('algo', str, 'hist', "calibration algorithm")
add_arg('hist_percent', float, 0.9999, "The percentile of algo:hist")
# yapf: enable


Expand All @@ -46,7 +48,9 @@ def quantize(args):
model_filename=args.model_filename,
params_filename=args.params_filename,
batch_size=args.batch_size,
batch_nums=args.batch_num)
batch_nums=args.batch_num,
algo=args.algo,
hist_percent=args.hist_percent)


def main():
Expand Down
18 changes: 14 additions & 4 deletions paddleslim/quant/quanter.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,9 @@ def quant_post_static(
batch_size=16,
batch_nums=None,
scope=None,
algo='KL',
algo='hist',
hist_percent=0.9999,
bias_correction=False,
quantizable_op_type=["conv2d", "depthwise_conv2d", "mul"],
is_full_quantize=False,
weight_bits=8,
Expand Down Expand Up @@ -358,9 +360,15 @@ def quant_post_static(
generated by sample_generator as calibrate data.
scope(paddle.static.Scope, optional): The scope to run program, use it to load
and save variables. If scope is None, will use paddle.static.global_scope().
algo(str, optional): If algo=KL, use KL-divergenc method to
get the more precise scale factor. If algo='direct', use
abs_max method to get the scale factor. Default: 'KL'.
algo(str, optional): If algo='KL', use KL-divergenc method to
get the scale factor. If algo='hist', use the hist_percent of histogram
to get the scale factor. If algo='mse', search for the best scale factor which
makes the mse loss minimal. Use one batch of data for mse is enough. If
algo='avg', use the average of abs_max values to get the scale factor. If
algo='abs_max', use abs_max method to get the scale factor. Default: 'hist'.
hist_percent(float, optional): The percentile of histogram for algo hist.Default:0.9999.
bias_correction(bool, optional): Bias correction method of https://arxiv.org/abs/1810.05723.
Default: False.
quantizable_op_type(list[str], optional): The list of op types
that will be quantized. Default: ["conv2d", "depthwise_conv2d",
"mul"].
Expand Down Expand Up @@ -397,6 +405,8 @@ def quant_post_static(
batch_nums=batch_nums,
scope=scope,
algo=algo,
hist_percent=hist_percent,
bias_correction=bias_correction,
quantizable_op_type=quantizable_op_type,
is_full_quantize=is_full_quantize,
weight_bits=weight_bits,
Expand Down

0 comments on commit e7a02b5

Please sign in to comment.