Skip to content

Commit

Permalink
[Feature] Add GQA dataset. (#1585)
Browse files Browse the repository at this point in the history
* [Feature]: Add GQA dataset

* [Feature]: Add GQA

* [Feature]: Add GQA UT

* [Fix]: Fix hint

* [Feature]: Add BLIP2 GQA

* [Fix]: Fix lint

* [Feature]: Update anno link

* [Fix]: Update docstring

* [Feature]: Update all links
  • Loading branch information
YuanLiuuuuuu authored May 23, 2023
1 parent 4dd8a86 commit 46a523e
Show file tree
Hide file tree
Showing 8 changed files with 353 additions and 4 deletions.
81 changes: 81 additions & 0 deletions configs/_base_/datasets/gqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# data settings

data_preprocessor = dict(
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
interpolation='bicubic',
backend='pillow'),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(480, 480),
interpolation='bicubic',
backend='pillow'),
dict(
type='CleanCaption',
keys=['question'],
),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
),
]

train_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='GQA',
data_root='data/gqa',
data_prefix='images',
ann_file='annotations/train_balanced_questions.json',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
drop_last=True,
)

val_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='GQA',
data_root='data/gqa',
data_prefix='images',
ann_file='annotations/testdev_balanced_questions.json',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='GQAAcc')

test_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='GQA',
data_root='data/gqa',
data_prefix='images',
ann_file='annotations/testdev_balanced_questions.json',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
test_evaluator = val_evaluator
87 changes: 87 additions & 0 deletions configs/blip2/blip2-opt2.7b_8xb16_gqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
_base_ = [
'../_base_/datasets/gqa.py',
'../_base_/default_runtime.py',
]

# model settings
model = dict(
type='Blip2VQA',
tokenizer=dict(
type='AutoTokenizer', name_or_path='facebook/opt-2.7b',
use_fast=False),
vision_backbone=dict(
type='BEiTViT',
# eva-g without the final layer
arch=dict(
embed_dims=1408,
num_layers=39,
num_heads=16,
feedforward_channels=6144,
),
img_size=364,
patch_size=14,
out_indices=-2,
layer_scale_init_value=0.0,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
frozen_stages=39,
final_norm=False,
use_shared_rel_pos_bias=False,
out_type='raw'),
text_backbone=dict(
type='OPTForCausalLM', name_or_path='facebook/opt-2.7b'),
multimodal_backbone=dict(
type='Qformer',
model_style='bert-base-uncased',
vision_model_width=1408,
add_cross_attention=True,
cross_attention_freq=2,
num_query_token=32),
vision_neck=dict(
type='LinearClsHead',
in_channels=768,
num_classes=2560,
),
prompt='Question: {} Short Answer:',
max_txt_len=10)

# data settings
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224),
dict(type='PackInputs', algorithm_keys=['question', 'gt_answer']),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(224, 224),
interpolation='bicubic',
backend='pillow'),
dict(
type='CleanCaption',
keys=['question'],
),
dict(type='PackInputs', algorithm_keys=['question', 'gt_answer']),
]

train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

# schedule settings
optim_wrapper = dict(optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.05))

param_scheduler = [
dict(
type='CosineAnnealingLR',
by_epoch=True,
begin=0,
end=10,
)
]

train_cfg = dict(max_epochs=10)
val_cfg = dict()
test_cfg = dict()
2 changes: 2 additions & 0 deletions mmpretrain/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .coco_retrieval import COCORetrieval
from .coco_vqa import COCOVQA
from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA
from .gqa_dataset import GQA
from .refcoco import RefCOCO
from .scienceqa import ScienceQA
from .visual_genome import VisualGenomeQA
Expand All @@ -51,4 +52,5 @@
'RefCOCO',
'VisualGenomeQA',
'ScienceQA',
'GQA',
])
70 changes: 70 additions & 0 deletions mmpretrain/datasets/gqa_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List

import mmengine
from mmengine.dataset import BaseDataset

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class GQA(BaseDataset):
"""GQA dataset.
We use the annotation file from LAVIS, and you can download all annotation files from following links: # noqa: E501
train:
https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json # noqa: E501
val:
https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/testdev_balanced_questions.json # noqa: E501
test:
https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json # noqa: E501
and images from the official website:
https://cs.stanford.edu/people/dorarad/gqa/index.html
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to an empty string.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""

def __init__(self,
data_root: str,
data_prefix: str,
ann_file: str = '',
**kwarg):
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)

def load_data_list(self) -> List[dict]:
"""Load data list."""
annotations = mmengine.load(self.ann_file)

data_list = []
for ann in annotations:
# ann example
# {
# 'question': "Is it overcast?",
# 'answer': 'no,
# 'image_id': n161313.jpg,
# 'question_id': 262148000,
# ....
# }
data_info = dict()
data_info['img_path'] = osp.join(self.data_prefix['img_path'],
ann['image'])
data_info['question'] = ann['question']
data_info['gt_answer'] = ann['answer']

data_list.append(data_info)

return data_list
3 changes: 2 additions & 1 deletion mmpretrain/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .caption import COCOCaption
from .gqa import GQAAcc
from .multi_label import AveragePrecision, MultiLabelMetric
from .multi_task import MultiTasksMetric
from .retrieval import RetrievalRecall
Expand All @@ -13,5 +14,5 @@
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption',
'VisualGroundingMetric', 'ScienceQAMetric'
'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc'
]
78 changes: 78 additions & 0 deletions mmpretrain/evaluation/metrics/gqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional

from mmengine.evaluator import BaseMetric

from mmpretrain.evaluation.metrics.vqa import (_process_digit_article,
_process_punctuation)
from mmpretrain.registry import METRICS


@METRICS.register_module()
class GQAAcc(BaseMetric):
"""GQA Acc metric.
Compute GQA accuracy.
Args:
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
"""
default_prefix = 'GQA'

def __init__(self,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)

def process(self, data_batch, data_samples) -> None:
"""Process one batch of data samples.
The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.
Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for sample in data_samples:
gt_answer = sample.get('gt_answer')
result = {
'pred_answer': sample.get('pred_answer'),
'gt_answer': gt_answer
}

self.results.append(result)

def compute_metrics(self, results: List) -> dict:
"""Compute the metrics from processed results.
Args:
results (dict): The processed results of each batch.
Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
acc = []
for result in results:
pred_answer = self._process_answer(result['pred_answer'])
gt_answer = self._process_answer(result['gt_answer'])
gqa_acc = 1 if pred_answer == gt_answer else 0
acc.append(gqa_acc)

accuracy = sum(acc) / len(acc)

metrics = {'acc': accuracy}
return metrics

def _process_answer(self, answer) -> str:
answer = _process_punctuation(answer)
answer = _process_digit_article(answer)
return answer
6 changes: 3 additions & 3 deletions mmpretrain/evaluation/metrics/scienceqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def process(self, data_batch, data_samples) -> None:
result['answer'] = data_sample.get('gt_answer')
hint = data_sample.get('hint')
has_image = data_sample.get('has_image', False)
result[
'no_context'] = True if not has_image and hint is None else False # noqa
result['has_text'] = True if hint is not None else False
result['no_context'] = True if not has_image and len(
hint) == 0 else False # noqa
result['has_text'] = True if len(hint) > 0 else False
result['has_image'] = has_image

# Save the result to `self.results`.
Expand Down
30 changes: 30 additions & 0 deletions tests/test_evaluation/test_metrics/test_gqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.evaluator import Evaluator

from mmpretrain.structures import DataSample


class TestScienceQAMetric:

def test_evaluate(self):
meta_info = {
'pred_answer': 'dog',
'gt_answer': 'dog',
}
data_sample = DataSample(metainfo=meta_info)
data_samples = [data_sample for _ in range(10)]
evaluator = Evaluator(dict(type='mmpretrain.GQAAcc'))
evaluator.process(data_samples)
res = evaluator.evaluate(4)
assert res['GQA/acc'] == 1.0

meta_info = {
'pred_answer': 'dog',
'gt_answer': 'cat',
}
data_sample = DataSample(metainfo=meta_info)
data_samples = [data_sample for _ in range(10)]
evaluator = Evaluator(dict(type='mmpretrain.GQAAcc'))
evaluator.process(data_samples)
res = evaluator.evaluate(4)
assert res['GQA/acc'] == 0.0

0 comments on commit 46a523e

Please sign in to comment.