Implementation with PyTorch. This implementation is based on soft-filter-pruning.
FPGM has been re-implemented in Pytorch and NNI.
Usage in Pytorch
from torch.ao.sparsity.pruning._experimental.pruner import FPGM_pruner
# set network-level sparsity: all layers have a sparsity level of 30%
pruner = FPGMPruner(sparsity_level = 0.3)
# set layer-level sparsity: sparsity_level of conv2d1 = 30%, sparsity_level of conv2d2 = 50%
config = [
{"tensor_fqn": "conv2d1.weight"},
{"tensor_fqn": "conv2d2.weight", "sparsity_level": 0.5}
]
pruner.prepare(model, config)
pruner.enable_mask_update = True
pruner.step()
# Get real pruned models (without zeros)
pruned_model = pruner.prune()
See source code here and official test code here.
Usage in NNI
from nni.algorithms.compression.pytorch.pruning import FPGMPruner
config_list = [{
'sparsity': 0.5,
'op_types': ['Conv2d']
}]
pruner = FPGMPruner(model, config_list)
pruner.compress()
See explanation here.
- Requirements
- Models and log files
- Training ResNet on ImageNet
- Training ResNet on Cifar-10
- Training VGGNet on Cifar-10
- Notes
- Citation
- Python 3.6
- PyTorch 0.3.1
- TorchVision 0.3.0
The trained models with log files can be found in Google Drive. Specifically:
models for pruning ResNet on ImageNet
models for pruning ResNet on CIFAR-10
models for pruning VGGNet on CIFAR-10
The pruned model without zeros, refer to this issue.
We train each model from scratch by default. If you wish to train the model with pre-trained models, please use the options --use_pretrain --lr 0.01
.
Run Pruning Training ResNet (depth 152,101,50,34,18) on Imagenet:
python pruning_imagenet.py -a resnet152 --save_path ./snapshots/resnet152-rate-0.7 --rate_norm 1 --rate_dist 0.4 --layer_begin 0 --layer_end 462 --layer_inter 3 /path/to/Imagenet2012
python pruning_imagenet.py -a resnet101 --save_path ./snapshots/resnet101-rate-0.7 --rate_norm 1 --rate_dist 0.4 --layer_begin 0 --layer_end 309 --layer_inter 3 /path/to/Imagenet2012
python pruning_imagenet.py -a resnet50 --save_path ./snapshots/resnet50-rate-0.7 --rate_norm 1 --rate_dist 0.4 --layer_begin 0 --layer_end 156 --layer_inter 3 /path/to/Imagenet2012
python pruning_imagenet.py -a resnet34 --save_path ./snapshots/resnet34-rate-0.7 --rate_norm 1 --rate_dist 0.4 --layer_begin 0 --layer_end 105 --layer_inter 3 /path/to/Imagenet2012
python pruning_imagenet.py -a resnet18 --save_path ./snapshots/resnet18-rate-0.7 --rate_norm 1 --rate_dist 0.4 --layer_begin 0 --layer_end 57 --layer_inter 3 /path/to/Imagenet2012
Explanation:
Note1: rate_norm = 0.9
means pruning 10% filters by norm-based criterion, rate_dist = 0.2
means pruning 20% filters by distance-based criterion.
Note2: the layer_begin
and layer_end
is the index of the first and last conv layer, layer_inter
choose the conv layer instead of BN layer.
Run resnet(100 epochs):
python original_train.py -a resnet50 --save_dir ./snapshots/resnet50-baseline /path/to/Imagenet2012 --workers 36
sh function/inference_pruned.sh
The pruned model without zeros, refer to this issue.
To train the ImageNet model with / without pruning, see the directory scripts
.
Full script is here.
sh scripts/pruning_cifar10.sh
Please be care of the hyper-parameter layer_end
for different layer of ResNet.
Reproduce ablation study of Cifar-10:
sh scripts/ablation_pruning_cifar10.sh
Refer to the directory VGG_cifar
.
Reproduce previous paper Pruning Filters for Efficient ConvNets
sh VGG_cifar/scripts/PFEC_train_prune.sh
Four function included in the script, including training baseline, pruning from pretrain, pruning from scratch, finetune the pruend
sh VGG_cifar/scripts/pruning_vgg_my_method.sh
Including pruning the pretrained, pruning the scratch.
We use the torchvision of 0.3.0. If the version of your torchvision is 0.2.0, then the transforms.RandomResizedCrop
should be transforms.RandomSizedCrop
and the transforms.Resize
should be transforms.Scale
.
This can improve the accuracy slightly.
We follow the Facebook process of ImageNet. Two subfolders ("train" and "val") are included in the "/path/to/ImageNet2012". The correspding code is here.
Refer to the file.
@inproceedings{he2019filter,
title = {Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration},
author = {He, Yang and Liu, Ping and Wang, Ziwei and Hu, Zhilan and Yang, Yi},
booktitle = {Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
year = {2019}
}