Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature]: Add GQA #1585

Merged
merged 9 commits into from
May 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions configs/_base_/datasets/gqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# data settings

data_preprocessor = dict(
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
interpolation='bicubic',
backend='pillow'),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(480, 480),
interpolation='bicubic',
backend='pillow'),
dict(
type='CleanCaption',
keys=['question'],
),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
),
]

train_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='GQA',
data_root='data/gqa',
data_prefix='images',
ann_file='annotations/train_balanced_questions.json',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
drop_last=True,
)

val_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='GQA',
data_root='data/gqa',
data_prefix='images',
ann_file='annotations/testdev_balanced_questions.json',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='GQAAcc')

test_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='GQA',
data_root='data/gqa',
data_prefix='images',
ann_file='annotations/testdev_balanced_questions.json',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
test_evaluator = val_evaluator
87 changes: 87 additions & 0 deletions configs/blip2/blip2-opt2.7b_8xb16_gqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
_base_ = [
'../_base_/datasets/gqa.py',
'../_base_/default_runtime.py',
]

# model settings
model = dict(
type='Blip2VQA',
tokenizer=dict(
type='AutoTokenizer', name_or_path='facebook/opt-2.7b',
use_fast=False),
vision_backbone=dict(
type='BEiTViT',
# eva-g without the final layer
arch=dict(
embed_dims=1408,
num_layers=39,
num_heads=16,
feedforward_channels=6144,
),
img_size=364,
patch_size=14,
out_indices=-2,
layer_scale_init_value=0.0,
use_abs_pos_emb=True,
use_rel_pos_bias=False,
frozen_stages=39,
final_norm=False,
use_shared_rel_pos_bias=False,
out_type='raw'),
text_backbone=dict(
type='OPTForCausalLM', name_or_path='facebook/opt-2.7b'),
multimodal_backbone=dict(
type='Qformer',
model_style='bert-base-uncased',
vision_model_width=1408,
add_cross_attention=True,
cross_attention_freq=2,
num_query_token=32),
vision_neck=dict(
type='LinearClsHead',
in_channels=768,
num_classes=2560,
),
prompt='Question: {} Short Answer:',
max_txt_len=10)

# data settings
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', scale=224),
dict(type='PackInputs', algorithm_keys=['question', 'gt_answer']),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(224, 224),
interpolation='bicubic',
backend='pillow'),
dict(
type='CleanCaption',
keys=['question'],
),
dict(type='PackInputs', algorithm_keys=['question', 'gt_answer']),
]

train_dataloader = dict(dataset=dict(pipeline=train_pipeline))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline))
test_dataloader = val_dataloader

# schedule settings
optim_wrapper = dict(optimizer=dict(type='AdamW', lr=1e-5, weight_decay=0.05))

param_scheduler = [
dict(
type='CosineAnnealingLR',
by_epoch=True,
begin=0,
end=10,
)
]

train_cfg = dict(max_epochs=10)
val_cfg = dict()
test_cfg = dict()
2 changes: 2 additions & 0 deletions mmpretrain/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .coco_retrieval import COCORetrieval
from .coco_vqa import COCOVQA
from .flamingo import FlamingoEvalCOCOCaption, FlamingoEvalCOCOVQA
from .gqa_dataset import GQA
from .refcoco import RefCOCO
from .scienceqa import ScienceQA
from .visual_genome import VisualGenomeQA
Expand All @@ -51,4 +52,5 @@
'RefCOCO',
'VisualGenomeQA',
'ScienceQA',
'GQA',
])
70 changes: 70 additions & 0 deletions mmpretrain/datasets/gqa_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from typing import List

import mmengine
from mmengine.dataset import BaseDataset

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class GQA(BaseDataset):
"""GQA dataset.

YuanLiuuuuuu marked this conversation as resolved.
Show resolved Hide resolved
We use the annotation file from LAVIS, and you can download all annotation files from following links: # noqa: E501

train:
https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/train_balanced_questions.json # noqa: E501
val:
https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/testdev_balanced_questions.json # noqa: E501
test:
https://storage.googleapis.com/sfr-vision-language-research/LAVIS/datasets/gqa/test_balanced_questions.json # noqa: E501

and images from the official website:
https://cs.stanford.edu/people/dorarad/gqa/index.html

Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to an empty string.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""

def __init__(self,
data_root: str,
data_prefix: str,
ann_file: str = '',
**kwarg):
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)

def load_data_list(self) -> List[dict]:
"""Load data list."""
annotations = mmengine.load(self.ann_file)

data_list = []
for ann in annotations:
# ann example
# {
# 'question': "Is it overcast?",
# 'answer': 'no,
# 'image_id': n161313.jpg,
# 'question_id': 262148000,
# ....
# }
data_info = dict()
data_info['img_path'] = osp.join(self.data_prefix['img_path'],
ann['image'])
data_info['question'] = ann['question']
data_info['gt_answer'] = ann['answer']

data_list.append(data_info)

return data_list
3 changes: 2 additions & 1 deletion mmpretrain/evaluation/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .caption import COCOCaption
from .gqa import GQAAcc
from .multi_label import AveragePrecision, MultiLabelMetric
from .multi_task import MultiTasksMetric
from .retrieval import RetrievalRecall
Expand All @@ -13,5 +14,5 @@
'Accuracy', 'SingleLabelMetric', 'MultiLabelMetric', 'AveragePrecision',
'MultiTasksMetric', 'VOCAveragePrecision', 'VOCMultiLabelMetric',
'ConfusionMatrix', 'RetrievalRecall', 'VQAAcc', 'ReportVQA', 'COCOCaption',
'VisualGroundingMetric', 'ScienceQAMetric'
'VisualGroundingMetric', 'ScienceQAMetric', 'GQAAcc'
]
78 changes: 78 additions & 0 deletions mmpretrain/evaluation/metrics/gqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional

from mmengine.evaluator import BaseMetric

from mmpretrain.evaluation.metrics.vqa import (_process_digit_article,
_process_punctuation)
from mmpretrain.registry import METRICS


@METRICS.register_module()
class GQAAcc(BaseMetric):
"""GQA Acc metric.

Compute GQA accuracy.

Args:
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
prefix (str, optional): The prefix that will be added in the metric
names to disambiguate homonymous metrics of different evaluators.
If prefix is not provided in the argument, self.default_prefix
will be used instead. Should be modified according to the
`retrieval_type` for unambiguous results. Defaults to TR.
"""
default_prefix = 'GQA'

def __init__(self,
collect_device: str = 'cpu',
prefix: Optional[str] = None) -> None:
super().__init__(collect_device=collect_device, prefix=prefix)

def process(self, data_batch, data_samples) -> None:
"""Process one batch of data samples.

The processed results should be stored in ``self.results``, which will
be used to computed the metrics when all batches have been processed.

Args:
data_batch: A batch of data from the dataloader.
data_samples (Sequence[dict]): A batch of outputs from the model.
"""
for sample in data_samples:
gt_answer = sample.get('gt_answer')
result = {
'pred_answer': sample.get('pred_answer'),
'gt_answer': gt_answer
}

self.results.append(result)

def compute_metrics(self, results: List) -> dict:
"""Compute the metrics from processed results.

Args:
results (dict): The processed results of each batch.

Returns:
Dict: The computed metrics. The keys are the names of the metrics,
and the values are corresponding results.
"""
acc = []
for result in results:
pred_answer = self._process_answer(result['pred_answer'])
gt_answer = self._process_answer(result['gt_answer'])
gqa_acc = 1 if pred_answer == gt_answer else 0
acc.append(gqa_acc)

accuracy = sum(acc) / len(acc)

metrics = {'acc': accuracy}
return metrics

def _process_answer(self, answer) -> str:
answer = _process_punctuation(answer)
answer = _process_digit_article(answer)
return answer
6 changes: 3 additions & 3 deletions mmpretrain/evaluation/metrics/scienceqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,9 @@ def process(self, data_batch, data_samples) -> None:
result['answer'] = data_sample.get('gt_answer')
hint = data_sample.get('hint')
has_image = data_sample.get('has_image', False)
result[
'no_context'] = True if not has_image and hint is None else False # noqa
result['has_text'] = True if hint is not None else False
result['no_context'] = True if not has_image and len(
hint) == 0 else False # noqa
result['has_text'] = True if len(hint) > 0 else False
result['has_image'] = has_image

# Save the result to `self.results`.
Expand Down
30 changes: 30 additions & 0 deletions tests/test_evaluation/test_metrics/test_gqa.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.evaluator import Evaluator

from mmpretrain.structures import DataSample


class TestScienceQAMetric:

def test_evaluate(self):
meta_info = {
'pred_answer': 'dog',
'gt_answer': 'dog',
}
data_sample = DataSample(metainfo=meta_info)
data_samples = [data_sample for _ in range(10)]
evaluator = Evaluator(dict(type='mmpretrain.GQAAcc'))
evaluator.process(data_samples)
res = evaluator.evaluate(4)
assert res['GQA/acc'] == 1.0

meta_info = {
'pred_answer': 'dog',
'gt_answer': 'cat',
}
data_sample = DataSample(metainfo=meta_info)
data_samples = [data_sample for _ in range(10)]
evaluator = Evaluator(dict(type='mmpretrain.GQAAcc'))
evaluator.process(data_samples)
res = evaluator.evaluate(4)
assert res['GQA/acc'] == 0.0