Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] add quant algo Learned Step Size Quantization #346

Merged
merged 54 commits into from
Nov 11, 2022
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
34a2d51
update
Oct 9, 2022
6659a34
Fix a bug in make_divisible. (#333)
LKJacky Oct 24, 2022
a6a337b
[Fix] Fix counter mapping bug (#331)
gaoyang07 Oct 24, 2022
31052ea
[Docs]Add MMYOLO projects link (#334)
kitecats Oct 24, 2022
972fd8e
[Features]Support `MethodInputsRecorder` and `FunctionInputsRecorder`…
HIT-cwh Oct 24, 2022
55c5499
updated
Oct 24, 2022
8f57a52
retina loss & predict & tesnor DONE
Oct 25, 2022
8c7cdb3
[Feature] Add deit-base (#332)
HIT-cwh Oct 25, 2022
1e8f886
[Feature]Feature map visualization (#293)
HIT-cwh Oct 26, 2022
db32b32
[Feature] Add kd examples (#305)
HIT-cwh Oct 26, 2022
5eaa225
[Doc] add documents about pruning. (#313)
LKJacky Oct 26, 2022
86c6153
[Feature] PyTorch version of `PKD: General Distillation Framework for…
HIT-cwh Oct 26, 2022
0fb34d3
for RFC
fpshuang Oct 27, 2022
7e533cb
Customed FX initialize
fpshuang Oct 31, 2022
1180cda
add UT init
fpshuang Oct 31, 2022
d37829e
[Refactor] Refactor Mutables and Mutators (#324)
pppppM Nov 1, 2022
d90c786
[Fix] Update readme (#341)
HIT-cwh Nov 1, 2022
9c567e4
Bump version to 1.0.0rc1 (#338)
pppppM Nov 1, 2022
fe13c44
init demo
Aug 2, 2022
1f449b4
add customer_tracer
Aug 3, 2022
4c72514
add quantizer
Aug 4, 2022
45d40aa
add fake_quant, loop, config
Aug 5, 2022
fa06c01
remove CPatcher in custome_tracer
Aug 8, 2022
71ea8e8
demo_try
Aug 8, 2022
f82ec7d
init version
Aug 11, 2022
10f59a5
modified base.py
Aug 26, 2022
4f90606
pre-rebase
Sep 20, 2022
aca5023
wip of adaround series
Sep 22, 2022
619bb20
adaround experiment
Sep 29, 2022
b1bfd1e
trasfer to s2
Sep 30, 2022
b75b685
update api
Sep 30, 2022
515325a
point at sub_reconstruction
Sep 30, 2022
376034a
pre-checkout
Oct 20, 2022
0ed8622
export onnx
Oct 24, 2022
3241f29
add customtracer
Oct 26, 2022
8ac6576
fix lint
Nov 2, 2022
e3d5933
move custom tracer
Nov 2, 2022
902d9d4
fix import
Nov 3, 2022
97ccb95
TDO: UTs
fpshuang Nov 3, 2022
fe55319
merge fx Tracer & quant init
fpshuang Nov 3, 2022
1191b29
Successfully RUN
fpshuang Nov 8, 2022
f98702e
update loop
fpshuang Nov 8, 2022
04c2835
update loop docstrings
fpshuang Nov 9, 2022
a7727de
update quantizer docstrings
fpshuang Nov 9, 2022
7a8f719
update qscheme docstrings
fpshuang Nov 9, 2022
8a17020
update qobserver docstrings
fpshuang Nov 9, 2022
8745593
update tracer docstrings
fpshuang Nov 9, 2022
c73e477
update UTs init
fpshuang Nov 9, 2022
f29694b
update UTs init
fpshuang Nov 9, 2022
af12bac
fix review comments
fpshuang Nov 10, 2022
8ea1812
merge dev-1.x
fpshuang Nov 10, 2022
f5865fe
fix CI
fpshuang Nov 10, 2022
44bee5c
fix UTs
fpshuang Nov 10, 2022
9a7f4a7
update torch requirements
fpshuang Nov 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ This project is released under the [Apache 2.0 license](LICENSE).
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab detection toolbox and benchmark.
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab's next-generation platform for general 3D object detection.
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab rotated object detection toolbox and benchmark.
- [MMYOLO](https://github.com/open-mmlab/mmyolo): OpenMMLab YOLO series toolbox and benchmark.
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab semantic segmentation toolbox and benchmark.
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab text detection, recognition, and understanding toolbox.
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab pose estimation toolbox and benchmark.
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ MMRazor 是一款由来自不同高校和企业的研发人员共同参与贡献
- [MMDetection](https://github.com/open-mmlab/mmdetection): OpenMMLab 目标检测工具箱
- [MMDetection3D](https://github.com/open-mmlab/mmdetection3d): OpenMMLab 新一代通用 3D 目标检测平台
- [MMRotate](https://github.com/open-mmlab/mmrotate): OpenMMLab 旋转框检测工具箱与测试基准
- [MMYOLO](https://github.com/open-mmlab/mmyolo): OpenMMLab YOLO 系列工具箱与测试基准
- [MMSegmentation](https://github.com/open-mmlab/mmsegmentation): OpenMMLab 语义分割工具箱
- [MMOCR](https://github.com/open-mmlab/mmocr): OpenMMLab 全流程文字检测识别理解工具箱
- [MMPose](https://github.com/open-mmlab/mmpose): OpenMMLab 姿态估计工具箱
Expand Down
45 changes: 45 additions & 0 deletions configs/distill/mmcls/deit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# DeiT

> [](https://arxiv.org/abs/2012.12877)
> Training data-efficient image transformers & distillation through attention

<!-- [ALGORITHM] -->

## Abstract

Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models.

<div align=center>
<img src="https://user-images.githubusercontent.com/26739999/143225703-c287c29e-82c9-4c85-a366-dfae30d198cd.png" width="40%"/>
</div>

## Results and models

### Classification

| Dataset | Model | Teacher | Top-1 (%) | Top-5 (%) | Configs | Download |
| -------- | --------- | ----------- | --------- | --------- | ------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| ImageNet | Deit-base | RegNety-160 | 83.24 | 96.33 | [config](deit-base_regnety160_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.json?versionId=CAEQThiBgIDGos20oBgiIGVlNDgyM2M2ZTk5MzQyYjFhNTgwNGIzMjllZjg3YmZm) |

```{warning}
Before training, please first install `timm`.

pip install timm
or
git clone https://github.com/rwightman/pytorch-image-models
cd pytorch-image-models && pip install -e .
```

## Citation

```
@InProceedings{pmlr-v139-touvron21a,
title = {Training data-efficient image transformers &amp; distillation through attention},
author = {Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and Jegou, Herve},
booktitle = {International Conference on Machine Learning},
pages = {10347--10357},
year = {2021},
volume = {139},
month = {July}
}
```
64 changes: 64 additions & 0 deletions configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
_base_ = ['mmcls::deit/deit-base_pt-16xb64_in1k.py']

# student settings
student = _base_.model
student.backbone.type = 'DistilledVisionTransformer'
student.head = dict(
type='mmrazor.DeiTClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='mmcls.LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
loss_weight=0.5))

data_preprocessor = dict(
type='mmcls.ClsDataPreprocessor', batch_augments=student.train_cfg)

# teacher settings
checkpoint_path = 'https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth' # noqa: E501
teacher = dict(
_scope_='mmcls',
type='ImageClassifier',
backbone=dict(
type='TIMMBackbone', model_name='regnety_160', pretrained=True),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=3024,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
loss_weight=0.5),
topk=(1, 5),
init_cfg=dict(
type='Pretrained', checkpoint=checkpoint_path, prefix='head.')))

model = dict(
_scope_='mmrazor',
_delete_=True,
type='SingleTeacherDistill',
architecture=student,
teacher=teacher,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.layers.head_dist')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_distill=dict(
type='CrossEntropyLoss',
loss_weight=0.5,
)),
loss_forward_mappings=dict(
loss_distill=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
34 changes: 34 additions & 0 deletions configs/distill/mmcls/deit/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
Collections:
- Name: DEIT
Metadata:
Training Data:
- ImageNet-1k
Paper:
URL: https://arxiv.org/abs/2012.12877
Title: Training data-efficient image transformers & distillation through attention
README: configs/distill/mmcls/deit/README.md

Models:
- Name: deit-base_regnety160_pt-16xb64_in1k
In Collection: DEIT
Metadata:
Student:
Config: mmcls::deit/deit-base_pt-16xb64_in1k.py
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth
Metrics:
Top 1 Accuracy: 81.76
Top 5 Accuracy: 95.81
Teacher:
Config: mmrazor::distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
Weights: https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth
Metrics:
Top 1 Accuracy: 82.83
Top 5 Accuracy: 96.42
Results:
- Task: Classification
Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.24
Top 5 Accuracy: 96.33
Weights: https://download.openmmlab.com/mmrazor/v1/deit/deit-base/deit-base_regnety160_pt-16xb64_in1k_20221011_113403-a67bf475.pth?versionId=CAEQThiBgMCFteW0oBgiIDdmMWY2NGRiOGY1YzRmZWZiOTExMzQ2NjNlMjk2Nzcz
Config: configs/distill/mmcls/deit/deit-base_regnety160_pt-16xb64_in1k.py
8 changes: 5 additions & 3 deletions configs/distill/mmcls/kd/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@ A very simple way to improve the performance of almost any machine learning algo

### Classification

| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :------: | :------: | :----------------------------------------------------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :-------------------------------------------------------: | :------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| logits | ImageNet | [resnet34](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet34_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet18_8xb32_in1k.py) | 71.54 | 73.62 | 69.90 | [config](./wsld_cls_head_resnet34_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k/wsld_cls_head_resnet34_resnet18_8xb32_in1k_acc-71.54_20211222-91f28cf6.pth?versionId=CAEQHxiBgMC6memK7xciIGMzMDFlYTA4YzhlYTRiMTNiZWU0YTVhY2I5NjVkMjY2) \| [log](https://openmmlab-share.oss-cn-hangzhou.aliyuncs.com/mmrazor/v0.1/distill/wsld/wsld_cls_head_resnet34_resnet18_8xb32_in1k/wsld_cls_head_resnet34_resnet18_8xb32_in1k_20211221_181516.log.json?versionId=CAEQHxiBgIDLmemK7xciIGNkM2FiN2Y4N2E5YjRhNDE4NDVlNmExNDczZDIxN2E5) |
| Location | Dataset | Teacher | Student | Acc | Acc(T) | Acc(S) | Config | Download |
| :------: | :------: | :-----------------------------------------------------------------------------------------------------------: | :--------------------------------------------------------------------------------------------------------------------------------: | :---: | :----: | :----: | :------------------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| logits | ImageNet | [resnet34](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet34_8xb32_in1k.py) | [resnet18](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet18_8xb32_in1k.py) | 71.81 | 73.62 | 69.90 | [config](./kd_logits_resnet34_resnet18_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/kd/kl_r18_w3/kd_logits_resnet34_resnet18_8xb32_in1k_w3_20221011_181115-5c6a834d.pth?versionId=CAEQThiBgID1_Me0oBgiIDE3NTk3MDgxZmU2YjRlMjVhMzg1ZTQwMmRhNmYyNGU2) \| [log](https://download.openmmlab.com/mmrazor/v1/kd/kl_r18_w3/kd_logits_resnet34_resnet18_8xb32_in1k_w3_20221011_181115-5c6a834d.json?versionId=CAEQThiBgMDx_se0oBgiIDQxNTM2MWZjZGRhNjRhZDZiZTIzY2Y0NDU3NDA4ODBl) |
| logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet50_8xb32_in1k.py) | [mobilenet-v2](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/mobilenet_v2/mobilenet-v2_8xb32_in1k.py) | 73.56 | 76.55 | 71.86 | [config](./kd_logits_resnet50_mobilenet-v2_8xb32_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/kd/kl_mbv2_w3t1/kd_logits_resnet50_mobilenet-v2_8xb32_in1k_20221025_212407-6ea9e2a5.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/kd/kl_mbv2_w3t1/kd_logits_resnet50_mobilenet-v2_8xb32_in1k_20221025_212407-6ea9e2a5.json) |
| logits | ImageNet | [resnet50](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/resnet/resnet50_8xb32_in1k.py) | [shufflenet-v2](https://github.com/open-mmlab/mmclassification/blob/dev-1.x/configs/shufflenet_v2/shufflenet-v2-1x_16xb64_in1k.py) | 70.87 | 76.55 | 69.55 | [config](./kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k.py) | [teacher](https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth) \|[model](https://download.openmmlab.com/mmrazor/v1/kd/kl_shuffle_w3t1/kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k_20221025_224424-5d748c1b.pth) \| [log](https://download.openmmlab.com/mmrazor/v1/kd/kl_shuffle_w3t1/kd_logits_resnet50_shufflenet-v2-1x_16xb64_in1k_20221025_224424-5d748c1b.json) |

## Citation

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
'mmcls::_base_/default_runtime.py'
]

teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet34_8xb32_in1k_20210831-f257d4e6.pth' # noqa: E501

model = dict(
_scope_='mmrazor',
type='SingleTeacherDistill',
Expand All @@ -17,16 +19,16 @@
architecture=dict(
cfg_path='mmcls::resnet/resnet18_8xb32_in1k.py', pretrained=False),
teacher=dict(
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=True),
teacher_ckpt='resnet34_8xb32_in1k_20210831-f257d4e6.pth',
cfg_path='mmcls::resnet/resnet34_8xb32_in1k.py', pretrained=False),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_kl=dict(type='KLDivergence', tau=1, loss_weight=5)),
loss_kl=dict(type='KLDivergence', tau=1, loss_weight=3)),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
_base_ = ['mmcls::mobilenet_v2/mobilenet-v2_8xb32_in1k.py']

student = _base_.model

teacher_ckpt = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth' # noqa: E501

model = dict(
_scope_='mmrazor',
_delete_=True,
type='SingleTeacherDistill',
data_preprocessor=dict(
type='ImgDataPreprocessor',
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
bgr_to_rgb=True),
architecture=student,
teacher=dict(
cfg_path='mmcls::resnet/resnet50_8xb32_in1k.py', pretrained=False),
teacher_ckpt=teacher_ckpt,
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
teacher_recorders=dict(
fc=dict(type='ModuleOutputs', source='head.fc')),
distill_losses=dict(
loss_kl=dict(type='KLDivergence', tau=1, loss_weight=3)),
loss_forward_mappings=dict(
loss_kl=dict(
preds_S=dict(from_student=True, recorder='fc'),
preds_T=dict(from_student=False, recorder='fc')))))

find_unused_parameters = True

val_cfg = dict(_delete_=True, type='mmrazor.SingleTeacherDistillValLoop')
Loading