Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support GLIP #1308

Merged
merged 6 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ Results and models are available in the [model zoo](https://mmpretrain.readthedo
<li><a href="configs/xcit">XCiT</a></li>
<li><a href="configs/levit">LeViT</a></li>
<li><a href="configs/riformer">RIFormer</a></li>
<li><a href="configs/glip">GLIP</a></li>
</ul>
</td>
<td>
Expand Down
1 change: 1 addition & 0 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ mim install -e .
<li><a href="configs/xcit">XCiT</a></li>
<li><a href="configs/levit">LeViT</a></li>
<li><a href="configs/riformer">RIFormer</a></li>
<li><a href="configs/glip">GLIP</a></li>
</ul>
</td>
<td>
Expand Down
57 changes: 57 additions & 0 deletions configs/glip/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# GLIP

> [Grounded Language-Image Pre-training](https://arxiv.org/abs/2112.03857)

<!-- [ALGORITHM] -->

## Abstract

This paper presents a grounded language-image pre-training (GLIP) model for learning object-level, language-aware, and semantic-rich visual representations. GLIP unifies object detection and phrase grounding for pre-training. The unification brings two benefits: 1) it allows GLIP to learn from both detection and grounding data to improve both tasks and bootstrap a good grounding model; 2) GLIP can leverage massive image-text pairs by generating grounding boxes in a self-training fashion, making the learned representation semantic-rich. In our experiments, we pre-train GLIP on 27M grounding data, including 3M human-annotated and 24M web-crawled image-text pairs. The learned representations demonstrate strong zero-shot and few-shot transferability to various object-level recognition tasks. 1) When directly evaluated on COCO and LVIS (without seeing any images in COCO during pre-training), GLIP achieves 49.8 AP and 26.9 AP, respectively, surpassing many supervised baselines. 2) After fine-tuned on COCO, GLIP achieves 60.8 AP on val and 61.5 AP on test-dev, surpassing prior SoTA. 3) When transferred to 13 downstream object detection tasks, a 1-shot GLIP rivals with a fully-supervised Dynamic Head.

<div align="center">
<img src="https://github.com/microsoft/GLIP/blob/main/docs/lead.png" width="70%"/>
</div>

## How to use it?

<!-- [TABS-BEGIN] -->

**Use the model**

```python
import torch
from mmpretrain import get_model
model = get_model('swin-t_glip-pre_3rdparty', pretrained=True)
inputs = torch.rand(1, 3, 224, 224)
out = model(inputs)
print(type(out))
# To extract features.
feats = model.extract_feat(inputs)
print(type(feats))
```

<!-- [TABS-END] -->

## Results and models

### Pre-trained models

The pre-trained models are used to fine-tune, and therefore don't have evaluation results.

| Model | Pretrain | resolution | Download |
| :------------------------------------------ | :------------------------: | :--------: | :-------------------------------------------------------------------------------------------------------------------: |
| GLIP-T (`swin-t_glip-pre_3rdparty`)\* | O365,GoldG,CC3M,SBU | 224x224 | [model](https://download.openmmlab.com/mmclassification/v1/glip/swin-t_glip-pre_3rdparty_20230413-d85813b5.pth) |
| GLIP-L (`swin-l_glip-pre_3rdparty_384px`)\* | FourODs,GoldG,CC3M+12M,SBU | 384x384 | [model](https://download.openmmlab.com/mmclassification/v1/glip/swin-l_glip-pre_3rdparty_384px_20230413-04b198e8.pth) |

*Models with * are converted from the [official repo](https://github.com/microsoft/GLIP).*

## Citation

```bibtex
@inproceedings{li2021grounded,
title={Grounded Language-Image Pre-training},
author={Liunian Harold Li* and Pengchuan Zhang* and Haotian Zhang* and Jianwei Yang and Chunyuan Li and Yiwu Zhong and Lijuan Wang and Lu Yuan and Lei Zhang and Jenq-Neng Hwang and Kai-Wei Chang and Jianfeng Gao},
year={2022},
booktitle={CVPR},
}
```
18 changes: 18 additions & 0 deletions configs/glip/glip-l_headless.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer',
arch='large',
img_size=384,
out_indices=(1, 2, 3), # original weight is for detection
stage_cfgs=dict(block_cfgs=dict(window_size=12))),
neck=None,
head=None)

data_preprocessor = dict(
# RGB format normalization parameters
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
# convert image from BGR to RGB
to_rgb=False,
)
18 changes: 18 additions & 0 deletions configs/glip/glip-t_headless.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
model = dict(
type='ImageClassifier',
backbone=dict(
type='SwinTransformer',
arch='tiny',
img_size=224,
out_indices=(1, 2, 3), # original weight is for detection
),
neck=None,
head=None)

data_preprocessor = dict(
# RGB format normalization parameters
mean=[103.53, 116.28, 123.675],
std=[57.375, 57.12, 58.395],
# convert image from BGR to RGB
to_rgb=False,
)
49 changes: 49 additions & 0 deletions configs/glip/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
Collections:
- Name: GLIP
Metadata:
Training Techniques:
- AdamW
- Weight Decay
Architecture:
- Shift Window Multihead Self Attention
Paper:
URL: https://arxiv.org/abs/2112.03857
Title: "Grounded Language-Image Pre-training"
README: configs/glip/README.md
Code:
URL: https://github.com/open-mmlab/mmpretrain/blob/main/mmpretrain/models/backbones/vit.py
Version: v1.0.0rc8

Models:
- Name: swin-t_glip-pre_3rdparty
In Collection: GLIP
Metadata:
FLOPs: 4508464128
Parameters: 29056354
Training Data:
- O365
- GoldG
- CC3M
- SBU
Results: null
Weights: https://download.openmmlab.com/mmclassification/v1/glip/swin-t_glip-pre_3rdparty_20230413-d85813b5.pth
Converted From:
Weights: https://penzhanwu2bbs.blob.core.windows.net/data/GLIPv1_Open/models/glip_tiny_model_o365_goldg_cc_sbu.pth
Code: https://github.com/microsoft/GLIP
Config: configs/glip/glip-t_headless.py
- Name: swin-l_glip-pre_3rdparty_384px
In Collection: GLIP
Metadata:
FLOPs: 104080343040
Parameters: 196735516
Training Data:
- FourODs
- GoldG
- CC3M+12M
- SBU
Results: null
Weights: https://download.openmmlab.com/mmclassification/v1/glip/swin-l_glip-pre_3rdparty_384px_20230413-04b198e8.pth
Converted From:
Weights: https://penzhanwu2bbs.blob.core.windows.net/data/GLIPv1_Open/models/glip_large_model.pth
Code: https://github.com/microsoft/GLIP
Config: configs/glip/glip-l_headless.py
1 change: 1 addition & 0 deletions model-index.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,4 @@ Import:
- configs/milan/metafile.yml
- configs/riformer/metafile.yml
- configs/sam/metafile.yml
- configs/glip/metafile.yml
76 changes: 76 additions & 0 deletions tools/model_converters/glip_to_mmpretrain.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import os.path as osp
from collections import OrderedDict

import mmengine
import torch
from mmengine.runner import CheckpointLoader


def convert_glip(ckpt):

def correct_unfold_reduction_order(x):
out_channel, in_channel = x.shape
x = x.reshape(out_channel, 4, in_channel // 4)
x = x[:, [0, 2, 1, 3], :].transpose(1,
2).reshape(out_channel, in_channel)
return x

def correct_unfold_norm_order(x):
in_channel = x.shape[0]
x = x.reshape(4, in_channel // 4)
x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel)
return x

new_ckpt = OrderedDict()

for k, v in list(ckpt.items()):
if 'language_backbone' in k or 'backbone' not in k or 'fpn' in k:
continue
new_v = v
new_k = k.replace('body.', '')
new_k = new_k.replace('module.', '')
if new_k.startswith('backbone.layers'):
new_k = new_k.replace('backbone.layers', 'backbone.stages')
if 'mlp' in new_k:
new_k = new_k.replace('mlp.fc1', 'ffn.layers.0.0')
new_k = new_k.replace('mlp.fc2', 'ffn.layers.1')
elif 'attn' in new_k:
new_k = new_k.replace('attn', 'attn.w_msa')
elif 'patch_embed' in k:
new_k = new_k.replace('proj', 'projection')
elif 'downsample' in new_k:
if 'reduction.' in k:
new_v = correct_unfold_reduction_order(new_v)
elif 'norm.' in k:
new_v = correct_unfold_norm_order(new_v)

new_ckpt[new_k] = new_v
return new_ckpt


def main():
parser = argparse.ArgumentParser(
description='Convert keys in pretrained glip models to mmcls style.')
parser.add_argument('src', help='src model path or url')
# The dst path must be a full path of the new checkpoint.
parser.add_argument('dst', help='save path')
args = parser.parse_args()

checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu')

if 'model' in checkpoint:
state_dict = checkpoint['model']
else:
state_dict = checkpoint

weight = convert_glip(state_dict)
mmengine.mkdir_or_exist(osp.dirname(args.dst))
torch.save(weight, args.dst)

print('Done!!')


if __name__ == '__main__':
main()