A pytorch toolkit for structured neural network pruning automatically
完全自动化的模型剪枝工具
用户层:人人都会用的剪枝工具,仅需二行代码即可完成全自动化剪枝
中间层:提供统一接口,让开发者可以自己封装SOTA剪枝算法,不断更新工具
系统底层:自动分析网络结构并构建剪枝关系
模型类型 | 支持 | 已测试 |
---|---|---|
分类模型 | √ | AlexNet,VGG,ResNet系列等 |
检测模型 | √ | CenterNet,YOLO系列等 |
分割模型 | √ | 正在测试 |
函数名 | 剪枝算法 |
---|---|
l1_norm_pruning | Pruning Filters for Efficient ConvNets |
l2_norm_pruning | Pruning Filters for Efficient ConvNets |
fpgm_pruning | Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration |
在原始剪枝算法上,做了部分调整。此外,后续会支持更多的剪枝算法。
pip install -e ./
model可以来源于torchvision,也可以是自己在Pytorch中构建的model
import torch_pruning as pruning
from torchvision.models import resnet18
import torch
# 模型建立
model = resnet18()
flops_raw, params_raw = pruning.get_model_complexity_info(
model, (3, 224, 224), as_strings=True, print_per_layer_stat=False)
print('-[INFO] before pruning flops: ' + flops_raw)
print('-[INFO] before pruning params: ' + params_raw)
# 选择裁剪方式
mod = 'fpgm'
# 剪枝引擎建立
slim = pruning.Autoslim(model, inputs=torch.randn(
1, 3, 224, 224), compression_ratio=0.5)
if mod == 'fpgm':
config = {
'layer_compression_ratio': None,
'norm_rate': 1.0, 'prune_shortcut': 1,
'dist_type': 'l1', 'pruning_func': 'fpgm'
}
elif mod == 'l1':
config = {
'layer_compression_ratio': None,
'norm_rate': 1.0, 'prune_shortcut': 1,
'global_pruning': False, 'pruning_func': 'l1'
}
slim.base_prunging(config)
flops_new, params_new = pruning.get_model_complexity_info(
model, (3, 224, 224), as_strings=True, print_per_layer_stat=False)
print('\n-[INFO] after pruning flops: ' + flops_new)
print('-[INFO] after pruning params: ' + params_new)
python prune_resnet18_cifar10.py --mode train --round 0
python prune_resnet18_cifar10.py --mode prune --round 1 --total_epochs 60
python cifar100_prune.py --mode train --round 2 --total_epochs 10 --batch_size 512
感谢以下仓库: