-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add davit * fix mixup config * convert scripts * lint * test * test * Add checkpoint links. Co-authored-by: mzr1996 <[email protected]>
- Loading branch information
Showing
17 changed files
with
1,304 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
# dataset settings | ||
dataset_type = 'ImageNet' | ||
data_preprocessor = dict( | ||
num_classes=1000, | ||
# RGB format normalization parameters | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
# convert image from BGR to RGB | ||
to_rgb=True, | ||
) | ||
|
||
bgr_mean = data_preprocessor['mean'][::-1] | ||
bgr_std = data_preprocessor['std'][::-1] | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='RandomResizedCrop', | ||
scale=224, | ||
backend='pillow', | ||
interpolation='bicubic'), | ||
dict(type='RandomFlip', prob=0.5, direction='horizontal'), | ||
dict( | ||
type='RandAugment', | ||
policies='timm_increasing', | ||
num_policies=2, | ||
total_level=10, | ||
magnitude_level=9, | ||
magnitude_std=0.5, | ||
hparams=dict( | ||
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')), | ||
dict( | ||
type='RandomErasing', | ||
erase_prob=0.25, | ||
mode='rand', | ||
min_area_ratio=0.02, | ||
max_area_ratio=1 / 3, | ||
fill_color=bgr_mean, | ||
fill_std=bgr_std), | ||
dict(type='PackClsInputs'), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='ResizeEdge', | ||
scale=236, | ||
edge='short', | ||
backend='pillow', | ||
interpolation='bicubic'), | ||
dict(type='CenterCrop', crop_size=224), | ||
dict(type='PackClsInputs'), | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=64, | ||
num_workers=5, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/imagenet', | ||
ann_file='meta/train.txt', | ||
data_prefix='train', | ||
pipeline=train_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
persistent_workers=True, | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=64, | ||
num_workers=5, | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root='data/imagenet', | ||
ann_file='meta/val.txt', | ||
data_prefix='val', | ||
pipeline=test_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
persistent_workers=True, | ||
) | ||
val_evaluator = dict(type='Accuracy', topk=(1, 5)) | ||
|
||
# If you want standard test, please manually configure the test dataset | ||
test_dataloader = val_dataloader | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict( | ||
type='DaViT', arch='base', out_indices=(3, ), drop_path_rate=0.4), | ||
neck=dict(type='GlobalAveragePooling'), | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=1024, | ||
loss=dict( | ||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), | ||
), | ||
train_cfg=dict(augments=[ | ||
dict(type='Mixup', alpha=0.8), | ||
dict(type='CutMix', alpha=1.0) | ||
])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict( | ||
type='DaViT', arch='small', out_indices=(3, ), drop_path_rate=0.2), | ||
neck=dict(type='GlobalAveragePooling'), | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=768, | ||
loss=dict( | ||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), | ||
), | ||
train_cfg=dict(augments=[ | ||
dict(type='Mixup', alpha=0.8), | ||
dict(type='CutMix', alpha=1.0) | ||
])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
model = dict( | ||
type='ImageClassifier', | ||
backbone=dict( | ||
type='DaViT', arch='t', out_indices=(3, ), drop_path_rate=0.1), | ||
neck=dict(type='GlobalAveragePooling'), | ||
head=dict( | ||
type='LinearClsHead', | ||
num_classes=1000, | ||
in_channels=768, | ||
loss=dict( | ||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'), | ||
), | ||
train_cfg=dict(augments=[ | ||
dict(type='Mixup', alpha=0.8), | ||
dict(type='CutMix', alpha=1.0) | ||
])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
# DaViT | ||
|
||
> [DaViT: Dual Attention Vision Transformers](https://arxiv.org/abs/2204.03645v1) | ||
<!-- [ALGORITHM] --> | ||
|
||
## Abstract | ||
|
||
In this work, we introduce Dual Attention Vision Transformers (DaViT), a simple yet effective vision transformer architecture that is able to capture global context while maintaining computational efficiency. We propose approaching the problem from an orthogonal angle: exploiting self-attention mechanisms with both "spatial tokens" and "channel tokens". With spatial tokens, the spatial dimension defines the token scope, and the channel dimension defines the token feature dimension. With channel tokens, we have the inverse: the channel dimension defines the token scope, and the spatial dimension defines the token feature dimension. We further group tokens along the sequence direction for both spatial and channel tokens to maintain the linear complexity of the entire model. We show that these two self-attentions complement each other: (i) since each channel token contains an abstract representation of the entire image, the channel attention naturally captures global interactions and representations by taking all spatial positions into account when computing attention scores between channels; (ii) the spatial attention refines the local representations by performing fine-grained interactions across spatial locations, which in turn helps the global information modeling in channel attention. Extensive experiments show our DaViT achieves state-of-the-art performance on four different tasks with efficient computations. Without extra data, DaViT-Tiny, DaViT-Small, and DaViT-Base achieve 82.8%, 84.2%, and 84.6% top-1 accuracy on ImageNet-1K with 28.3M, 49.7M, and 87.9M parameters, respectively. When we further scale up DaViT with 1.5B weakly supervised image and text pairs, DaViT-Gaint reaches 90.4% top-1 accuracy on ImageNet-1K. | ||
|
||
<div align=center> | ||
<img src="https://user-images.githubusercontent.com/24734142/196125065-e232409b-f710-4729-b657-4e5f9158f2d1.png" width="90%"/> | ||
</div> | ||
|
||
## Results and models | ||
|
||
### ImageNet-1k | ||
|
||
| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download | | ||
| :-------: | :----------: | :--------: | :-------: | :------: | :-------: | :-------: | :------------------------------------: | :----------------------------------------------------------------------------------------------: | | ||
| DaViT-T\* | From scratch | 224x224 | 28.36 | 4.54 | 82.24 | 96.13 | [config](./davit-tiny_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/davit/davit-tiny_3rdparty_in1k_20221116-700fdf7d.pth) | | ||
| DaViT-S\* | From scratch | 224x224 | 49.74 | 8.79 | 83.61 | 96.75 | [config](./davit-small_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/davit/davit-small_3rdparty_in1k_20221116-51a849a6.pth) | | ||
| DaViT-B\* | From scratch | 224x224 | 87.95 | 15.5 | 84.09 | 96.82 | [config](./davit-base_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/davit/davit-base_3rdparty_in1k_20221116-19e0d956.pth) | | ||
|
||
*Models with * are converted from the [official repo](https://github.com/dingmyu/davit). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.* | ||
|
||
Note: Inference accuracy is a bit lower than paper result because of inference code for classification doesn't exist. | ||
|
||
## Citation | ||
|
||
``` | ||
@inproceedings{ding2022davit, | ||
title={DaViT: Dual Attention Vision Transformer}, | ||
author={Ding, Mingyu and Xiao, Bin and Codella, Noel and Luo, Ping and Wang, Jingdong and Yuan, Lu}, | ||
booktitle={ECCV}, | ||
year={2022}, | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
_base_ = [ | ||
'../_base_/models/davit/davit-base.py', | ||
'../_base_/datasets/imagenet_bs256_davit_224.py', | ||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py', | ||
'../_base_/default_runtime.py' | ||
] | ||
|
||
# data settings | ||
train_dataloader = dict(batch_size=256) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
_base_ = [ | ||
'../_base_/models/davit/davit-small.py', | ||
'../_base_/datasets/imagenet_bs256_davit_224.py', | ||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py', | ||
'../_base_/default_runtime.py' | ||
] | ||
|
||
# data settings | ||
train_dataloader = dict(batch_size=256) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
_base_ = [ | ||
'../_base_/models/davit/davit-tiny.py', | ||
'../_base_/datasets/imagenet_bs256_davit_224.py', | ||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py', | ||
'../_base_/default_runtime.py' | ||
] | ||
|
||
# data settings | ||
train_dataloader = dict(batch_size=256) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
Collections: | ||
- Name: DaViT | ||
Metadata: | ||
Architecture: | ||
- GELU | ||
- Layer Normalization | ||
- Multi-Head Attention | ||
- Scaled Dot-Product Attention | ||
Paper: | ||
URL: https://arxiv.org/abs/2204.03645v1 | ||
Title: 'DaViT: Dual Attention Vision Transformers' | ||
README: configs/davit/README.md | ||
Code: | ||
URL: https://github.com/open-mmlab/mmclassification/blob/v1.0.0rc3/mmcls/models/backbones/davit.py | ||
Version: v1.0.0rc3 | ||
|
||
Models: | ||
- Name: davit-tiny_3rdparty_in1k | ||
In Collection: DaViT | ||
Metadata: | ||
FLOPs: 4539698688 | ||
Parameters: 28360168 | ||
Training Data: | ||
- ImageNet-1k | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Task: Image Classification | ||
Metrics: | ||
Top 1 Accuracy: 82.24 | ||
Top 5 Accuracy: 96.13 | ||
Weights: https://download.openmmlab.com/mmclassification/v0/davit/davit-tiny_3rdparty_in1k_20221116-700fdf7d.pth | ||
Converted From: | ||
Weights: https://drive.google.com/file/d/1RSpi3lxKaloOL5-or20HuG975tbPwxRZ/view?usp=sharing | ||
Code: https://github.com/dingmyu/davit/blob/main/mmdet/mmdet/models/backbones/davit.py#L355 | ||
Config: configs/davit/davit-tiny_4xb256_in1k.py | ||
- Name: davit-small_3rdparty_in1k | ||
In Collection: DaViT | ||
Metadata: | ||
FLOPs: 8799942144 | ||
Parameters: 49745896 | ||
Training Data: | ||
- ImageNet-1k | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Task: Image Classification | ||
Metrics: | ||
Top 1 Accuracy: 83.61 | ||
Top 5 Accuracy: 96.75 | ||
Weights: https://download.openmmlab.com/mmclassification/v0/davit/davit-small_3rdparty_in1k_20221116-51a849a6.pth | ||
Converted From: | ||
Weights: https://drive.google.com/file/d/1q976ruj45mt0RhO9oxhOo6EP_cmj4ahQ/view?usp=sharing | ||
Code: https://github.com/dingmyu/davit/blob/main/mmdet/mmdet/models/backbones/davit.py#L355 | ||
Config: configs/davit/davit-small_4xb256_in1k.py | ||
- Name: davit-base_3rdparty_in1k | ||
In Collection: DaViT | ||
Metadata: | ||
FLOPs: 15509702656 | ||
Parameters: 87954408 | ||
Training Data: | ||
- ImageNet-1k | ||
Results: | ||
- Dataset: ImageNet-1k | ||
Task: Image Classification | ||
Metrics: | ||
Top 1 Accuracy: 84.09 | ||
Top 5 Accuracy: 96.82 | ||
Weights: https://download.openmmlab.com/mmclassification/v0/davit/davit-base_3rdparty_in1k_20221116-19e0d956.pth | ||
Converted From: | ||
Weights: https://drive.google.com/file/d/1u9sDBEueB-YFuLigvcwf4b2YyA4MIVsZ/view?usp=sharing | ||
Code: https://github.com/dingmyu/davit/blob/main/mmdet/mmdet/models/backbones/davit.py#L355 | ||
Config: configs/davit/davit-base_4xb256_in1k.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,6 +65,7 @@ Backbones | |
Conformer | ||
ConvMixer | ||
ConvNeXt | ||
DaViT | ||
DeiT3 | ||
DenseNet | ||
DistilledVisionTransformer | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.