Skip to content

Commit

Permalink
Unstructured pruning (PaddlePaddle#710)
Browse files Browse the repository at this point in the history
  • Loading branch information
minghaoBD authored Apr 14, 2021
1 parent 065f644 commit 8fad8d4
Show file tree
Hide file tree
Showing 20 changed files with 1,491 additions and 0 deletions.
106 changes: 106 additions & 0 deletions demo/dygraph/unstructured_pruning/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# 非结构化稀疏 -- 动态图剪裁(包括按照阈值和比例剪裁两种模式)

## 简介

在模型压缩中,常见的稀疏方式为结构化和非结构化稀疏,前者在某个特定维度(特征通道、卷积核等等)上进行稀疏化操作;后者以每一个参数为单元进行稀疏化,所以更加依赖于硬件对稀疏后矩阵运算的加速能力。本目录即在PaddlePaddle和PaddleSlim框架下开发的非结构化稀疏算法,MobileNetV1在ImageNet上的稀疏化实验中,剪裁率55.19%,达到无损的表现。

## 版本要求
```bash
python3.5+
paddlepaddle>=2.0.0
paddleslim>=2.1.0
```

请参照github安装[paddlepaddle](https://github.com/PaddlePaddle/Paddle)[paddleslim](https://github.com/PaddlePaddle/PaddleSlim)

## 使用

训练前:
- 训练数据下载后,可以通过重写../imagenet_reader.py文件,并在train.py/evaluate.py文件中调用实现。
- 开发者可以通过重写paddleslim.dygraph.prune.unstructured_pruner.py中的UnstructuredPruner.mask_parameters()和UnstructuredPruner.update_threshold()来定义自己的非结构化稀疏策略(目前为剪裁掉绝对值小的parameters)。
- 开发可以在初始化UnstructuredPruner时,传入自定义的skip_params_func,来定义哪些参数不参与剪裁。skip_params_func示例代码如下(路径:paddleslim.dygraph.prune.unstructured_pruner._get_skip_params())。默认为所有的归一化层的参数不参与剪裁。

```python
def _get_skip_params(model):
"""
This function is used to check whether the given model's layers are valid to be pruned.
Usually, the convolutions are to be pruned while we skip the normalization-related parameters.
Deverlopers could replace this function by passing their own when initializing the UnstructuredPuner instance.
Args:
- model(Paddle.nn.Layer): the current model waiting to be checked.
Return:
- skip_params(set<String>): a set of parameters' names
"""
skip_params = set()
for _, sub_layer in model.named_sublayers():
if type(sub_layer).__name__.split('.')[-1] in paddle.nn.norm.__all__:
skip_params.add(sub_layer.full_name())
return skip_params
```

训练:
```bash
python3 train.py --data cifar10 --lr 0.1 --pruning_mode ratio --ratio=0.5
```

推理:
```bash
python3 eval --pruned_model models/ --data cifar10
```

剪裁训练代码示例:
```python
model = mobilenet_v1(num_classes=class_dim, pretrained=True)
#STEP1: initialize the pruner
pruner = UnstructuredPruner(model, mode='ratio', ratio=0.5)

for epoch in range(epochs):
for batch_id, data in enumerate(train_loader):
loss = calculate_loss()
loss.backward()
opt.step()
opt.clear_grad()
#STEP2: update the pruner's threshold given the updated parameters
pruner.step()

if epoch % args.test_period == 0:
#STEP3: before evaluation during training, eliminate the non-zeros generated by opt.step(), which, however, the cached masks setting to be zeros.
pruner.update_params()
eval(epoch)

if epoch % args.model_period == 0:
# STEP4: same purpose as STEP3
pruner.update_params()
paddle.save(model.state_dict(), "model-pruned.pdparams")
paddle.save(opt.state_dict(), "opt-pruned.pdopt")
```

剪裁后测试代码示例:
```python
model = mobilenet_v1(num_classes=class_dim, pretrained=True)
model.set_state_dict(paddle.load("model-pruned.pdparams"))
print(UnstructuredPruner.total_sparse(model)) #注意,total_sparse为静态方法(static method),可以不创建实例(instance)直接调用,方便只做测试的写法。
test()
```

更多使用参数请参照shell文件或者运行如下命令查看:
```bash
python train --h
python evaluate --h
```

## 实验结果 (刚开始在动态图代码验证,以下为静态图代码上的结果)

| 模型 | 数据集 | 压缩方法 | 压缩率| Top-1/Top-5 Acc | lr | threshold | epoch |
|:--:|:---:|:--:|:--:|:--:|:--:|:--:|:--:|
| MobileNetV1 | ImageNet | Baseline | - | 70.99%/89.68% | - | - | - |
| MobileNetV1 | ImageNet | ratio | -55.19% | 70.87%/89.80% (-0.12%/+0.12%) | 0.005 | - | 68 |
| YOLO v3 | VOC | - | - |76.24% | - | - | - |
| YOLO v3 | VOC |threshold | -41.35% | 75.29%(-0.95%) | 0.005 | 0.05 | 10w |
| YOLO v3 | VOC |threshold | -53.00% | 75.00%(-1.24%) | 0.005 | 0.075 | 10w |

## TODO

- [ ] 完成实验,验证动态图下的效果,并得到压缩模型。
- [ ] 扩充衡量parameter重要性的方法(目前仅为绝对值)。
111 changes: 111 additions & 0 deletions demo/dygraph/unstructured_pruning/evaluate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import paddle
import os
import sys
import argparse
import numpy as np
sys.path.append(
os.path.join(os.path.dirname("__file__"), os.path.pardir, os.path.pardir))
from paddleslim.dygraph.prune.unstructured_pruner import UnstructuredPruner
from utility import add_arguments, print_arguments
import paddle.vision.transforms as T
import paddle.nn.functional as F
import functools
from paddle.vision.models import mobilenet_v1
import time
import logging
from paddleslim.common import get_logger

_logger = get_logger(__name__, level=logging.INFO)

parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
# yapf: disable
add_arg('batch_size', int, 64, "Minibatch size.")
add_arg('use_gpu', bool, True, "Whether to use GPU or not.")
add_arg('model', str, "MobileNet", "The target model.")
add_arg('pruned_model', str, "dymodels/model-pruned.pdparams", "Whether to use pretrained model.")
add_arg('data', str, "cifar10", "Which data to use. 'cifar10' or 'imagenet'.")
add_arg('log_period', int, 100, "Log period in batches.")
# yapf: enable


def compress(args):
test_reader = None
if args.data == "imagenet":
import imagenet_reader as reader
val_dataset = reader.ImageNetDataset(data_dir='/data', mode='val')
class_dim = 1000
elif args.data == "cifar10":
normalize = T.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], data_format='CHW')
transform = T.Compose([T.Transpose(), normalize])
val_dataset = paddle.vision.datasets.Cifar10(
mode='test', backend='cv2', transform=transform)
class_dim = 10
else:
raise ValueError("{} is not supported.".format(args.data))

places = paddle.static.cuda_places(
) if args.use_gpu else paddle.static.cpu_places()
batch_size_per_card = int(args.batch_size / len(places))
valid_loader = paddle.io.DataLoader(
val_dataset,
places=places,
drop_last=False,
return_list=True,
batch_size=batch_size_per_card,
shuffle=False,
use_shared_memory=True)

# model definition
model = mobilenet_v1(num_classes=class_dim, pretrained=True)

def test(epoch):
model.eval()
acc_top1_ns = []
acc_top5_ns = []
for batch_id, data in enumerate(valid_loader):
start_time = time.time()
x_data = data[0]
y_data = paddle.to_tensor(data[1])
if args.data == 'cifar10':
y_data = paddle.unsqueeze(y_data, 1)
end_time = time.time()

logits = model(x_data)
loss = F.cross_entropy(logits, y_data)
acc_top1 = paddle.metric.accuracy(logits, y_data, k=1)
acc_top5 = paddle.metric.accuracy(logits, y_data, k=5)

acc_top1_ns.append(acc_top1.numpy())
acc_top5_ns.append(acc_top5.numpy())
if batch_id % args.log_period == 0:
_logger.info(
"Eval epoch[{}] batch[{}] - acc_top1: {}; acc_top5: {}; time: {}".
format(epoch, batch_id,
np.mean(acc_top1.numpy()),
np.mean(acc_top5.numpy()), end_time - start_time))
acc_top1_ns.append(np.mean(acc_top1.numpy()))
acc_top5_ns.append(np.mean(acc_top5.numpy()))

_logger.info("Final eval epoch[{}] - acc_top1: {}; acc_top5: {}".format(
epoch,
np.mean(np.array(
acc_top1_ns, dtype="object")),
np.mean(np.array(
acc_top5_ns, dtype="object"))))

model.set_state_dict(paddle.load(args.pruned_model))
_logger.info("The current density of the pruned model is: {}%".format(
round(100 * UnstructuredPruner.total_sparse(model), 2)))
test(0)


def main():
args = parser.parse_args()
print_arguments(args)
compress(args)


if __name__ == '__main__':
main()
5 changes: 5 additions & 0 deletions demo/dygraph/unstructured_pruning/evaluate_cifar10.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
python3.7 evaluate.py \
--pruned_model="models/model-pruned.pdparams" \
--data="cifar10"
5 changes: 5 additions & 0 deletions demo/dygraph/unstructured_pruning/evaluate_imagenet.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#!/bin/bash
export CUDA_VISIBLE_DEVICES=3
python3.7 evaluate.py \
--pruned_model="models/model-pruned.pdparams" \
--data="imagenet"
Loading

0 comments on commit 8fad8d4

Please sign in to comment.