-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* [Feature]: Add GQA dataset * [Feature]: Add GQA * [Feature]: Add GQA UT * [Fix]: Fix hint * [Feature]: Add BLIP2 GQA * [Fix]: Fix lint * [Feature]: Update anno link * [Fix]: Update docstring * [Feature]: Update all links
- Loading branch information
1 parent
4dd8a86
commit 46a523e
Showing
8 changed files
with
353 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
# data settings | ||
|
||
data_preprocessor = dict( | ||
mean=[122.770938, 116.7460125, 104.09373615], | ||
std=[68.5005327, 66.6321579, 70.32316305], | ||
to_rgb=True, | ||
) | ||
|
||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='RandomResizedCrop', | ||
scale=384, | ||
interpolation='bicubic', | ||
backend='pillow'), | ||
dict( | ||
type='PackInputs', | ||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], | ||
meta_keys=['question_id', 'image_id'], | ||
), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='Resize', | ||
scale=(480, 480), | ||
interpolation='bicubic', | ||
backend='pillow'), | ||
dict( | ||
type='CleanCaption', | ||
keys=['question'], | ||
), | ||
dict( | ||
type='PackInputs', | ||
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'], | ||
meta_keys=['question_id', 'image_id'], | ||
), | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=16, | ||
num_workers=8, | ||
dataset=dict( | ||
type='GQA', | ||
data_root='data/gqa', | ||
data_prefix='images', | ||
ann_file='annotations/train_balanced_questions.json', | ||
pipeline=test_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
persistent_workers=True, | ||
drop_last=True, | ||
) | ||
|
||
val_dataloader = dict( | ||
batch_size=16, | ||
num_workers=8, | ||
dataset=dict( | ||
type='GQA', | ||
data_root='data/gqa', | ||
data_prefix='images', | ||
ann_file='annotations/testdev_balanced_questions.json', | ||
pipeline=test_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
persistent_workers=True, | ||
) | ||
val_evaluator = dict(type='GQAAcc') | ||
|
||
test_dataloader = dict( | ||
batch_size=16, | ||
num_workers=8, | ||
dataset=dict( | ||
type='GQA', | ||
data_root='data/gqa', | ||
data_prefix='images', | ||
ann_file='annotations/testdev_balanced_questions.json', | ||
pipeline=test_pipeline), | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
persistent_workers=True, | ||
) | ||
test_evaluator = val_evaluator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
_base_ = [ | ||
'../_base_/datasets/gqa.py', | ||
'../_base_/default_runtime.py', | ||
] | ||
|
||
# model settings | ||
model = dict( | ||
type='Blip2VQA', | ||
tokenizer=dict( | ||
type='AutoTokenizer', name_or_path='facebook/opt-2.7b', | ||
use_fast=False), | ||
vision_backbone=dict( | ||
type='BEiTViT', | ||
# eva-g without the final layer | ||
arch=dict( | ||
embed_dims=1408, | ||
num_layers=39, | ||
num_heads=16, | ||
feedforward_channels=6144, | ||
), | ||
img_size=364, | ||
patch_size=14, | ||
out_indices=-2, | ||
layer_scale_init_value=0.0, | ||
use_abs_pos_emb=True, | ||
use_rel_pos_bias=False, | ||
frozen_stages=39, | ||
final_norm=False, | ||
use_shared_rel_pos_bias=False, | ||
out_type='raw'), | ||
text_backbone=dict( | ||
type='OPTForCausalLM', name_or_path='facebook/opt-2.7b'), | ||
multimodal_backbone=dict( | ||
type='Qformer', | ||
model_style='bert-base-uncased', | ||
vision_model_width=1408, | ||
add_cross_attention=True, | ||
cross_attention_freq=2, | ||
num_query_token=32), | ||
vision_neck=dict( | ||
type='LinearClsHead', | ||
in_channels=768, | ||
num_classes=2560, | ||
), | ||
prompt='Question: {} Short Answer:', | ||
max_txt_len=10) | ||
|
||
# data settings | ||
train_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict(type='RandomResizedCrop', scale=224), | ||
dict(type='PackInputs', algorithm_keys=['question', 'gt_answer']), | ||
] | ||
|
||
test_pipeline = [ | ||
dict(type='LoadImageFromFile'), | ||
dict( | ||
type='Resize', | ||
scale=(224, 224), | ||
interpolation='bicubic', | ||
backend='pillow'), | ||
dict( | ||
type='CleanCaption', | ||
keys=['question'], | ||
), | ||
dict(type='PackInputs', algorithm_keys=['question', 'gt_answer']), | ||
] | ||
|
||
train_dataloader = dict(dataset=dict(pipeline=train_pipeline)) | ||
val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) | ||
test_dataloader = val_dataloader | ||
|
||
# schedule settings | ||
optim_wrapper = dict(optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.05)) | ||
|
||
param_scheduler = [ | ||
dict( | ||
type='CosineAnnealingLR', | ||
by_epoch=True, | ||
begin=0, | ||
end=10, | ||
) | ||
] | ||
|
||
train_cfg = dict(max_epochs=10) | ||
val_cfg = dict() | ||
test_cfg = dict() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import os.path as osp | ||
from typing import List | ||
|
||
import mmengine | ||
from mmengine.dataset import BaseDataset | ||
|
||
from mmpretrain.registry import DATASETS | ||
|
||
|
||
@DATASETS.register_module() | ||
class GQA(BaseDataset): | ||
"""GQA dataset. | ||
We use the annotation file from LAVIS, and you can download all annotation files from following links: # noqa: E501 | ||
train: | ||
https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json # noqa: E501 | ||
val: | ||
https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/testdev_balanced_questions.json # noqa: E501 | ||
test: | ||
https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json # noqa: E501 | ||
and images from the official website: | ||
https://cs.stanford.edu/people/dorarad/gqa/index.html | ||
Args: | ||
data_root (str): The root directory for ``data_prefix``, ``ann_file`` | ||
and ``question_file``. | ||
data_prefix (str): The directory of images. | ||
ann_file (str, optional): Annotation file path for training and | ||
validation. Defaults to an empty string. | ||
**kwargs: Other keyword arguments in :class:`BaseDataset`. | ||
""" | ||
|
||
def __init__(self, | ||
data_root: str, | ||
data_prefix: str, | ||
ann_file: str = '', | ||
**kwarg): | ||
super().__init__( | ||
data_root=data_root, | ||
data_prefix=dict(img_path=data_prefix), | ||
ann_file=ann_file, | ||
**kwarg, | ||
) | ||
|
||
def load_data_list(self) -> List[dict]: | ||
"""Load data list.""" | ||
annotations = mmengine.load(self.ann_file) | ||
|
||
data_list = [] | ||
for ann in annotations: | ||
# ann example | ||
# { | ||
# 'question': "Is it overcast?", | ||
# 'answer': 'no, | ||
# 'image_id': n161313.jpg, | ||
# 'question_id': 262148000, | ||
# .... | ||
# } | ||
data_info = dict() | ||
data_info['img_path'] = osp.join(self.data_prefix['img_path'], | ||
ann['image']) | ||
data_info['question'] = ann['question'] | ||
data_info['gt_answer'] = ann['answer'] | ||
|
||
data_list.append(data_info) | ||
|
||
return data_list |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from typing import List, Optional | ||
|
||
from mmengine.evaluator import BaseMetric | ||
|
||
from mmpretrain.evaluation.metrics.vqa import (_process_digit_article, | ||
_process_punctuation) | ||
from mmpretrain.registry import METRICS | ||
|
||
|
||
@METRICS.register_module() | ||
class GQAAcc(BaseMetric): | ||
"""GQA Acc metric. | ||
Compute GQA accuracy. | ||
Args: | ||
collect_device (str): Device name used for collecting results from | ||
different ranks during distributed training. Must be 'cpu' or | ||
'gpu'. Defaults to 'cpu'. | ||
prefix (str, optional): The prefix that will be added in the metric | ||
names to disambiguate homonymous metrics of different evaluators. | ||
If prefix is not provided in the argument, self.default_prefix | ||
will be used instead. Should be modified according to the | ||
`retrieval_type` for unambiguous results. Defaults to TR. | ||
""" | ||
default_prefix = 'GQA' | ||
|
||
def __init__(self, | ||
collect_device: str = 'cpu', | ||
prefix: Optional[str] = None) -> None: | ||
super().__init__(collect_device=collect_device, prefix=prefix) | ||
|
||
def process(self, data_batch, data_samples) -> None: | ||
"""Process one batch of data samples. | ||
The processed results should be stored in ``self.results``, which will | ||
be used to computed the metrics when all batches have been processed. | ||
Args: | ||
data_batch: A batch of data from the dataloader. | ||
data_samples (Sequence[dict]): A batch of outputs from the model. | ||
""" | ||
for sample in data_samples: | ||
gt_answer = sample.get('gt_answer') | ||
result = { | ||
'pred_answer': sample.get('pred_answer'), | ||
'gt_answer': gt_answer | ||
} | ||
|
||
self.results.append(result) | ||
|
||
def compute_metrics(self, results: List) -> dict: | ||
"""Compute the metrics from processed results. | ||
Args: | ||
results (dict): The processed results of each batch. | ||
Returns: | ||
Dict: The computed metrics. The keys are the names of the metrics, | ||
and the values are corresponding results. | ||
""" | ||
acc = [] | ||
for result in results: | ||
pred_answer = self._process_answer(result['pred_answer']) | ||
gt_answer = self._process_answer(result['gt_answer']) | ||
gqa_acc = 1 if pred_answer == gt_answer else 0 | ||
acc.append(gqa_acc) | ||
|
||
accuracy = sum(acc) / len(acc) | ||
|
||
metrics = {'acc': accuracy} | ||
return metrics | ||
|
||
def _process_answer(self, answer) -> str: | ||
answer = _process_punctuation(answer) | ||
answer = _process_digit_article(answer) | ||
return answer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from mmengine.evaluator import Evaluator | ||
|
||
from mmpretrain.structures import DataSample | ||
|
||
|
||
class TestScienceQAMetric: | ||
|
||
def test_evaluate(self): | ||
meta_info = { | ||
'pred_answer': 'dog', | ||
'gt_answer': 'dog', | ||
} | ||
data_sample = DataSample(metainfo=meta_info) | ||
data_samples = [data_sample for _ in range(10)] | ||
evaluator = Evaluator(dict(type='mmpretrain.GQAAcc')) | ||
evaluator.process(data_samples) | ||
res = evaluator.evaluate(4) | ||
assert res['GQA/acc'] == 1.0 | ||
|
||
meta_info = { | ||
'pred_answer': 'dog', | ||
'gt_answer': 'cat', | ||
} | ||
data_sample = DataSample(metainfo=meta_info) | ||
data_samples = [data_sample for _ in range(10)] | ||
evaluator = Evaluator(dict(type='mmpretrain.GQAAcc')) | ||
evaluator.process(data_samples) | ||
res = evaluator.evaluate(4) | ||
assert res['GQA/acc'] == 0.0 |