Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] [CodeCamp #63] Add VIG Backbone #1304

Merged
merged 30 commits into from
Jan 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
19e72e3
添加vig源文件
szwlh-c Jan 4, 2023
bbdc383
某些模块修改到mmcls风格
szwlh-c Jan 4, 2023
2ca8607
修改到mmcls风格
szwlh-c Jan 4, 2023
5f678a4
修改
szwlh-c Jan 4, 2023
eff1b76
添加VIG模型及源文件
szwlh-c Jan 6, 2023
3b6068e
update model file
szwlh-c Jan 6, 2023
7e57f8c
update model file and config
szwlh-c Jan 6, 2023
f6ec7af
change class name and some variable name
szwlh-c Jan 10, 2023
7835c79
change class name and some variable name
szwlh-c Jan 10, 2023
91232b8
update
szwlh-c Jan 10, 2023
229f4c9
update
szwlh-c Jan 10, 2023
6c5e4ec
change nn.BatchNorm to mmcv.cnn.build_norm_layer
szwlh-c Jan 10, 2023
04f9baf
update
szwlh-c Jan 10, 2023
b7bc7ed
change nn.Seq to mmcls
szwlh-c Jan 10, 2023
d0e946c
change backbone to stage_blocks
szwlh-c Jan 11, 2023
99f2c64
add vig_head
szwlh-c Jan 11, 2023
8a23717
update
szwlh-c Jan 11, 2023
62d350b
update config file
szwlh-c Jan 11, 2023
6426925
update
szwlh-c Jan 12, 2023
f8b7433
add readme and metafile
szwlh-c Jan 12, 2023
6b7e7ca
Update Branch Merge branch 'dev-1.x' into add_new_net
szwlh-c Jan 12, 2023
ea6d2e9
update model-index.yml
szwlh-c Jan 12, 2023
6749b73
update model file
szwlh-c Jan 12, 2023
f19ecb9
rename config file and add docstring
szwlh-c Jan 16, 2023
451d858
variable rename
szwlh-c Jan 16, 2023
ac45ecc
update readme and metafile
szwlh-c Jan 16, 2023
6e52ed9
update readme
szwlh-c Jan 16, 2023
db4aee4
update
szwlh-c Jan 16, 2023
344317c
Update VIG backbone implementation and docs.
mzr1996 Jan 17, 2023
dbd10cc
Fix configs.
mzr1996 Jan 17, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .dev_scripts/benchmark_regression/1-benchmark_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,19 @@ def inference(config_file, checkpoint, work_dir, args, exp_name):
data = default_collate([data] * args.batch_size)
resolution = tuple(data['inputs'].shape[-2:])
model = Runner.from_cfg(cfg).model
load_checkpoint(model, checkpoint, map_location='cpu')
model.eval()
forward = model.val_step
else:
# For configs only for get model.
model = init_model(cfg)
load_checkpoint(model, checkpoint, map_location='cpu')
model.eval()
data = torch.empty(1, 3, 224, 224).to(model.data_preprocessor.device)
resolution = (224, 224)
forward = model.extract_feat

if checkpoint is not None:
load_checkpoint(model, checkpoint, map_location='cpu')

# forward the model
result = {'resolution': resolution}
with torch.no_grad():
Expand Down
82 changes: 82 additions & 0 deletions configs/_base_/datasets/imagenet_bs128_vig_224.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# dataset settings
dataset_type = 'ImageNet'
data_preprocessor = dict(
num_classes=1000,
# RGB format normalization parameters
mean=[127.5, 127.5, 127.5],
std=[127.5, 127.5, 127.5],
# convert image from BGR to RGB
to_rgb=True,
)

bgr_mean = data_preprocessor['mean'][::-1]
bgr_std = data_preprocessor['std'][::-1]

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies='timm_increasing',
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in bgr_mean], interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=bgr_mean,
fill_std=bgr_std),
dict(type='PackClsInputs'),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='ResizeEdge',
scale=248,
edge='short',
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='PackClsInputs'),
]

train_dataloader = dict(
batch_size=128,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/train.txt',
data_prefix='train',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
)

val_dataloader = dict(
batch_size=128,
num_workers=5,
dataset=dict(
type=dataset_type,
data_root='data/imagenet',
ann_file='meta/val.txt',
data_prefix='val',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, 5))

# If you want standard test, please manually configure the test dataset
test_dataloader = val_dataloader
test_evaluator = val_evaluator
32 changes: 32 additions & 0 deletions configs/_base_/models/vig/pyramid_vig_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='PyramidVig',
arch='base',
k=9,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN'),
graph_conv_type='mr',
graph_conv_bias=True,
epsilon=0.2,
use_stochastic=False,
drop_path=0.1,
norm_eval=False,
frozen_stages=0),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='VigClsHead',
num_classes=1000,
in_channels=1024,
hidden_dim=1024,
act_cfg=dict(type='GELU'),
dropout=0.,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)
32 changes: 32 additions & 0 deletions configs/_base_/models/vig/pyramid_vig_medium.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='PyramidVig',
arch='medium',
k=9,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN'),
graph_conv_type='mr',
graph_conv_bias=True,
epsilon=0.2,
use_stochastic=False,
drop_path=0.1,
norm_eval=False,
frozen_stages=0),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='VigClsHead',
num_classes=1000,
in_channels=768,
hidden_dim=1024,
act_cfg=dict(type='GELU'),
dropout=0.,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)
32 changes: 32 additions & 0 deletions configs/_base_/models/vig/pyramid_vig_small.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='PyramidVig',
arch='small',
k=9,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN'),
graph_conv_type='mr',
graph_conv_bias=True,
epsilon=0.2,
use_stochastic=False,
drop_path=0.1,
norm_eval=False,
frozen_stages=0),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='VigClsHead',
num_classes=1000,
in_channels=640,
hidden_dim=1024,
act_cfg=dict(type='GELU'),
dropout=0.,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)
32 changes: 32 additions & 0 deletions configs/_base_/models/vig/pyramid_vig_tiny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='PyramidVig',
arch='tiny',
k=9,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN'),
graph_conv_type='mr',
graph_conv_bias=True,
epsilon=0.2,
use_stochastic=False,
drop_path=0.1,
norm_eval=False,
frozen_stages=0),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='VigClsHead',
num_classes=1000,
in_channels=384,
hidden_dim=1024,
act_cfg=dict(type='GELU'),
dropout=0.,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)
33 changes: 33 additions & 0 deletions configs/_base_/models/vig/vig_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='Vig',
arch='base',
k=9,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN'),
graph_conv_type='mr',
graph_conv_bias=True,
epsilon=0.2,
use_dilation=True,
use_stochastic=False,
drop_path=0.1,
relative_pos=False,
norm_eval=False,
frozen_stages=0),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='VigClsHead',
num_classes=1000,
in_channels=640,
hidden_dim=1024,
act_cfg=dict(type='GELU'),
dropout=0.,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)
33 changes: 33 additions & 0 deletions configs/_base_/models/vig/vig_small.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='Vig',
arch='small',
k=9,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN'),
graph_conv_type='mr',
graph_conv_bias=True,
epsilon=0.2,
use_dilation=True,
use_stochastic=False,
drop_path=0.1,
relative_pos=False,
norm_eval=False,
frozen_stages=0),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='VigClsHead',
num_classes=1000,
in_channels=320,
hidden_dim=1024,
act_cfg=dict(type='GELU'),
dropout=0.,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)
33 changes: 33 additions & 0 deletions configs/_base_/models/vig/vig_tiny.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='Vig',
arch='tiny',
k=9,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='BN'),
graph_conv_type='mr',
graph_conv_bias=True,
epsilon=0.2,
use_dilation=True,
use_stochastic=False,
drop_path=0.1,
relative_pos=False,
norm_eval=False,
frozen_stages=0),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='VigClsHead',
num_classes=1000,
in_channels=192,
hidden_dim=1024,
act_cfg=dict(type='GELU'),
dropout=0.,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
),
train_cfg=dict(augments=[
dict(type='Mixup', alpha=0.8),
dict(type='CutMix', alpha=1.0)
]),
)
40 changes: 40 additions & 0 deletions configs/vig/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# VIG

> [Vision GNN: An Image is Worth Graph of Nodes](https://arxiv.org/abs/2206.00272)

<!-- [ALGORITHM] -->

## Abstract

Network architecture plays a key role in the deep learning-based computer vision system. The widely-used convolutional neural network and transformer treat the image as a grid or sequence structure, which is not flexible to capture irregular and complex objects. In this paper, we propose to represent the image as a graph structure and introduce a new Vision GNN (ViG) architecture to extract graph-level feature for visual tasks. We first split the image to a number of patches which are viewed as nodes, and construct a graph by connecting the nearest neighbors. Based on the graph representation of images, we build our ViG model to transform and exchange information among all the nodes. ViG consists of two basic modules: Grapher module with graph convolution for aggregating and updating graph information, and FFN module with two linear layers for node feature transformation. Both isotropic and pyramid architectures of ViG are built with different model sizes. Extensive experiments on image recognition and object detection tasks demonstrate the superiority of our ViG architecture. We hope this pioneering study of GNN on general visual tasks will provide useful inspiration and experience for future research.

<div align=center>
<img src="https://user-images.githubusercontent.com/26739999/212789461-f085e4da-9ce9-435f-93c0-e1b84d10b79f.png" width="50%"/>
</div>

## Results and models

### ImageNet-1k

| Model | Pretrain | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
| :-------------------------: | :----------: | :-------: | :------: | :-------: | :-------: | :------------------------------------: | :--------------------------------------------------------------------------------------: |
| vig-tiny_3rdparty_in1k\* | From scratch | 7.18 | 1.31 | 74.40 | 92.34 | [config](./vig-tiny_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/vig/vig-tiny_3rdparty_in1k_20230117-6414c684.pth) |
| vig-small_3rdparty_in1k\* | From scratch | 22.75 | 4.54 | 80.61 | 95.28 | [config](./vig-small_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/vig/vig-small_3rdparty_in1k_20230117-5338bf3b.pth) |
| vig-base_3rdparty_in1k\* | From scratch | 20.68 | 17.68 | 82.64 | 96.04 | [config](./vig-base_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/vig/vig-base_3rdparty_in1k_20230117-92f6f12f.pth) |
| pvig-tiny_3rdparty_in1k\* | From scratch | 9.46 | 1.71 | 78.38 | 94.38 | [config](./pvig-tiny_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/vig/pvig-tiny_3rdparty_in1k_20230117-eb77347d.pth) |
| pvig-small_3rdparty_in1k\* | From scratch | 29.02 | 4.57 | 82.00 | 95.97 | [config](./pvig-small_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/vig/pvig-small_3rdparty_in1k_20230117-9433dc96.pth) |
| pvig-medium_3rdparty_in1k\* | From scratch | 51.68 | 8.89 | 83.12 | 96.35 | [config](./pvig-medium_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/vig/pvig-medium_3rdparty_in1k_20230117-21057a6d.pth) |
| pvig-base_3rdparty_in1k\* | From scratch | 95.21 | 16.86 | 83.59 | 96.52 | [config](./pvig-base_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/vig/pvig-base_3rdparty_in1k_20230117-dbab3c85.pth) |

*Models with * are converted from the [official repo](https://github.com/huawei-noah/Efficient-AI-Backbones/tree/master/vig_pytorch). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*

## Citation

```bibtex
@inproceedings{han2022vig,
title={Vision GNN: An Image is Worth Graph of Nodes},
author={Kai Han and Yunhe Wang and Jianyuan Guo and Yehui Tang and Enhua Wu},
booktitle={NeurIPS},
year={2022}
}
```
Loading