Skip to content

Commit

Permalink
[Feature] Support DaViT. (#1105)
Browse files Browse the repository at this point in the history
* add davit

* fix mixup config

* convert scripts

* lint

* test

* test

* Add checkpoint links.

Co-authored-by: mzr1996 <[email protected]>
  • Loading branch information
okotaku and mzr1996 authored Nov 16, 2022
1 parent 992d13e commit c4f3883
Show file tree
Hide file tree
Showing 17 changed files with 1,304 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit)
- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
- [x] [DaViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/davit)

</details>

Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ mim install -e .
- [x] [MViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mvit)
- [x] [HorNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/hornet)
- [x] [MobileViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/mobilevit)
- [x] [DaViT](https://github.com/open-mmlab/mmclassification/tree/1.x/configs/davit)

</details>

Expand Down
84 changes: 84 additions & 0 deletions configs/_base_/datasets/imagenet_bs256_davit_224.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
# convert image from BGR to RGB
to_rgb=True,
)

bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackClsInputs'),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=236,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
]

train_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
)

val_dataloader = dict(
batch_size=64,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator
16 changes: 16 additions & 0 deletions configs/_base_/models/davit/davit-base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='DaViT', arch='base', out_indices=(3, ), drop_path_rate=0.4),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1024,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))
16 changes: 16 additions & 0 deletions configs/_base_/models/davit/davit-small.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='DaViT', arch='small', out_indices=(3, ), drop_path_rate=0.2),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))
16 changes: 16 additions & 0 deletions configs/_base_/models/davit/davit-tiny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='DaViT', arch='t', out_indices=(3, ), drop_path_rate=0.1),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]))
38 changes: 38 additions & 0 deletions configs/davit/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# DaViT

> [DaViT: Dual Attention Vision Transformers](https://arxiv.org/abs/2204.03645v1)
<!-- [ALGORITHM] -->

## Abstract

In this work, we introduce Dual Attention Vision Transformers (DaViT), a simple yet effective vision transformer architecture that is able to capture global context while maintaining computational efficiency. We propose approaching the problem from an orthogonal angle: exploiting self-attention mechanisms with both "spatial tokens" and "channel tokens". With spatial tokens, the spatial dimension defines the token scope, and the channel dimension defines the token feature dimension. With channel tokens, we have the inverse: the channel dimension defines the token scope, and the spatial dimension defines the token feature dimension. We further group tokens along the sequence direction for both spatial and channel tokens to maintain the linear complexity of the entire model. We show that these two self-attentions complement each other: (i) since each channel token contains an abstract representation of the entire image, the channel attention naturally captures global interactions and representations by taking all spatial positions into account when computing attention scores between channels; (ii) the spatial attention refines the local representations by performing fine-grained interactions across spatial locations, which in turn helps the global information modeling in channel attention. Extensive experiments show our DaViT achieves state-of-the-art performance on four different tasks with efficient computations. Without extra data, DaViT-Tiny, DaViT-Small, and DaViT-Base achieve 82.8%, 84.2%, and 84.6% top-1 accuracy on ImageNet-1K with 28.3M, 49.7M, and 87.9M parameters, respectively. When we further scale up DaViT with 1.5B weakly supervised image and text pairs, DaViT-Gaint reaches 90.4% top-1 accuracy on ImageNet-1K.

<div align=center>
<img src="https://user-images.githubusercontent.com/24734142/196125065-e232409b-f710-4729-b657-4e5f9158f2d1.png" width="90%"/>
</div>

## Results and models

### ImageNet-1k

| Model | Pretrain | resolution | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
| :-------: | :----------: | :--------: | :-------: | :------: | :-------: | :-------: | :------------------------------------: | :----------------------------------------------------------------------------------------------: |
| DaViT-T\* | From scratch | 224x224 | 28.36 | 4.54 | 82.24 | 96.13 | [config](./davit-tiny_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/davit/davit-tiny_3rdparty_in1k_20221116-700fdf7d.pth) |
| DaViT-S\* | From scratch | 224x224 | 49.74 | 8.79 | 83.61 | 96.75 | [config](./davit-small_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/davit/davit-small_3rdparty_in1k_20221116-51a849a6.pth) |
| DaViT-B\* | From scratch | 224x224 | 87.95 | 15.5 | 84.09 | 96.82 | [config](./davit-base_4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/davit/davit-base_3rdparty_in1k_20221116-19e0d956.pth) |

*Models with * are converted from the [official repo](https://github.com/dingmyu/davit). The config files of these models are only for validation. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*

Note: Inference accuracy is a bit lower than paper result because of inference code for classification doesn't exist.

## Citation

```
@inproceedings{ding2022davit,
title={DaViT: Dual Attention Vision Transformer},
author={Ding, Mingyu and Xiao, Bin and Codella, Noel and Luo, Ping and Wang, Jingdong and Yuan, Lu},
booktitle={ECCV},
year={2022},
}
```
9 changes: 9 additions & 0 deletions configs/davit/davit-base_4xb256_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/davit/davit-base.py',
'../_base_/datasets/imagenet_bs256_davit_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py'
]

# data settings
train_dataloader = dict(batch_size=256)
9 changes: 9 additions & 0 deletions configs/davit/davit-small_4xb256_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/davit/davit-small.py',
'../_base_/datasets/imagenet_bs256_davit_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py'
]

# data settings
train_dataloader = dict(batch_size=256)
9 changes: 9 additions & 0 deletions configs/davit/davit-tiny_4xb256_in1k.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
_base_ = [
'../_base_/models/davit/davit-tiny.py',
'../_base_/datasets/imagenet_bs256_davit_224.py',
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
'../_base_/default_runtime.py'
]

# data settings
train_dataloader = dict(batch_size=256)
71 changes: 71 additions & 0 deletions configs/davit/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
Collections:
- Name: DaViT
Metadata:
Architecture:
- GELU
- Layer Normalization
- Multi-Head Attention
- Scaled Dot-Product Attention
Paper:
URL: https://arxiv.org/abs/2204.03645v1
Title: 'DaViT: Dual Attention Vision Transformers'
README: configs/davit/README.md
Code:
URL: https://github.com/open-mmlab/mmclassification/blob/v1.0.0rc3/mmcls/models/backbones/davit.py
Version: v1.0.0rc3

Models:
- Name: davit-tiny_3rdparty_in1k
In Collection: DaViT
Metadata:
FLOPs: 4539698688
Parameters: 28360168
Training Data:
- ImageNet-1k
Results:
- Dataset: ImageNet-1k
Task: Image Classification
Metrics:
Top 1 Accuracy: 82.24
Top 5 Accuracy: 96.13
Weights: https://download.openmmlab.com/mmclassification/v0/davit/davit-tiny_3rdparty_in1k_20221116-700fdf7d.pth
Converted From:
Weights: https://drive.google.com/file/d/1RSpi3lxKaloOL5-or20HuG975tbPwxRZ/view?usp=sharing
Code: https://github.com/dingmyu/davit/blob/main/mmdet/mmdet/models/backbones/davit.py#L355
Config: configs/davit/davit-tiny_4xb256_in1k.py
- Name: davit-small_3rdparty_in1k
In Collection: DaViT
Metadata:
FLOPs: 8799942144
Parameters: 49745896
Training Data:
- ImageNet-1k
Results:
- Dataset: ImageNet-1k
Task: Image Classification
Metrics:
Top 1 Accuracy: 83.61
Top 5 Accuracy: 96.75
Weights: https://download.openmmlab.com/mmclassification/v0/davit/davit-small_3rdparty_in1k_20221116-51a849a6.pth
Converted From:
Weights: https://drive.google.com/file/d/1q976ruj45mt0RhO9oxhOo6EP_cmj4ahQ/view?usp=sharing
Code: https://github.com/dingmyu/davit/blob/main/mmdet/mmdet/models/backbones/davit.py#L355
Config: configs/davit/davit-small_4xb256_in1k.py
- Name: davit-base_3rdparty_in1k
In Collection: DaViT
Metadata:
FLOPs: 15509702656
Parameters: 87954408
Training Data:
- ImageNet-1k
Results:
- Dataset: ImageNet-1k
Task: Image Classification
Metrics:
Top 1 Accuracy: 84.09
Top 5 Accuracy: 96.82
Weights: https://download.openmmlab.com/mmclassification/v0/davit/davit-base_3rdparty_in1k_20221116-19e0d956.pth
Converted From:
Weights: https://drive.google.com/file/d/1u9sDBEueB-YFuLigvcwf4b2YyA4MIVsZ/view?usp=sharing
Code: https://github.com/dingmyu/davit/blob/main/mmdet/mmdet/models/backbones/davit.py#L355
Config: configs/davit/davit-base_4xb256_in1k.py
1 change: 1 addition & 0 deletions docs/en/api/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Backbones
Conformer
ConvMixer
ConvNeXt
DaViT
DeiT3
DenseNet
DistilledVisionTransformer
Expand Down
2 changes: 2 additions & 0 deletions mmcls/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from .convmixer import ConvMixer
from .convnext import ConvNeXt
from .cspnet import CSPDarkNet, CSPNet, CSPResNet, CSPResNeXt
from .davit import DaViT
from .deit import DistilledVisionTransformer
from .deit3 import DeiT3
from .densenet import DenseNet
Expand Down Expand Up @@ -93,4 +94,5 @@
'DeiT3',
'HorNet',
'MobileViT',
'DaViT',
]
Loading

0 comments on commit c4f3883

Please sign in to comment.