Skip to content

Commit

Permalink
Merge 292c69f into 7581b76
Browse files Browse the repository at this point in the history
  • Loading branch information
yyk-wew authored Jun 15, 2023
2 parents 7581b76 + 292c69f commit 5a76048
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 1 deletion.
80 changes: 80 additions & 0 deletions configs/_base_/datasets/vizwiz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# data settings

data_preprocessor = dict(
mean=[122.770938, 116.7460125, 104.09373615],
std=[68.5005327, 66.6321579, 70.32316305],
to_rgb=True,
)

train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
scale=384,
interpolation='bicubic',
backend='pillow'),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
),
]

test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
scale=(480, 480),
interpolation='bicubic',
backend='pillow'),
dict(
type='CleanCaption',
keys=['question'],
),
dict(
type='PackInputs',
algorithm_keys=['question', 'gt_answer', 'gt_answer_weight'],
meta_keys=['question_id', 'image_id'],
),
]

train_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='VizWiz',
data_root='data/vizwiz/Images',
data_prefix='',
ann_file='Annotations/train.json',
pipeline=train_pipeline),
sampler=dict(type='DefaultSampler', shuffle=True),
persistent_workers=True,
drop_last=True,
)

val_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='VizWiz',
data_root='data/vizwiz/Images',
data_prefix='',
ann_file='Annotations/val.json',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
persistent_workers=True,
)
val_evaluator = dict(type='VizWizAcc')

test_dataloader = dict(
batch_size=16,
num_workers=8,
dataset=dict(
type='VizWiz',
data_root='data/vizwiz/Images',
data_prefix='',
ann_file='Annotations/test.json',
pipeline=test_pipeline),
sampler=dict(type='DefaultSampler', shuffle=False),
)
test_evaluator = dict(type='ReportVQA', file_path='vqa_test.json')
3 changes: 2 additions & 1 deletion mmpretrain/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@
from .textvqa import TextVQA
from .visual_genome import VisualGenomeQA
from .vsr import VSR
from .vizwiz import VizWiz

__all__.extend([
'COCOCaption', 'COCORetrieval', 'COCOVQA', 'FlamingoEvalCOCOCaption',
'FlamingoEvalCOCOVQA', 'OCRVQA', 'RefCOCO', 'VisualGenomeQA',
'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', 'VSR'
'ScienceQA', 'NoCaps', 'GQA', 'TextVQA', 'VSR', 'VizWiz'
])
112 changes: 112 additions & 0 deletions mmpretrain/datasets/vizwiz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) OpenMMLab. All rights reserved.
from collections import Counter
from typing import List

import mmengine
from mmengine.dataset import BaseDataset

from mmpretrain.registry import DATASETS


@DATASETS.register_module()
class VizWiz(BaseDataset):
"""VizWiz dataset.
Args:
data_root (str): The root directory for ``data_prefix``, ``ann_file``
and ``question_file``.
data_prefix (str): The directory of images.
ann_file (str, optional): Annotation file path for training and
validation. Defaults to an empty string.
**kwargs: Other keyword arguments in :class:`BaseDataset`.
"""

def __init__(self,
data_root: str,
data_prefix: str,
ann_file: str = '',
**kwarg):
super().__init__(
data_root=data_root,
data_prefix=dict(img_path=data_prefix),
ann_file=ann_file,
**kwarg,
)

def load_data_list(self) -> List[dict]:
"""Load data list."""
annotations = mmengine.load(self.ann_file)

data_list = []
for ann in annotations:
# {
# "image": "VizWiz_val_00000001.jpg",
# "question": "Can you tell me what this medicine is please?",
# "answers": [
# {
# "answer": "no",
# "answer_confidence": "yes"
# },
# {
# "answer": "unanswerable",
# "answer_confidence": "yes"
# },
# {
# "answer": "night time",
# "answer_confidence": "maybe"
# },
# {
# "answer": "unanswerable",
# "answer_confidence": "yes"
# },
# {
# "answer": "night time",
# "answer_confidence": "maybe"
# },
# {
# "answer": "night time cold medicine",
# "answer_confidence": "maybe"
# },
# {
# "answer": "night time",
# "answer_confidence": "maybe"
# },
# {
# "answer": "night time",
# "answer_confidence": "maybe"
# },
# {
# "answer": "night time",
# "answer_confidence": "maybe"
# },
# {
# "answer": "night time medicine",
# "answer_confidence": "yes"
# }
# ],
# "answer_type": "other",
# "answerable": 1
# },
data_info = dict()
data_info['question'] = ann['question']
data_info['img_path'] = mmengine.join_path(
self.data_prefix['img_path'], ann['image'])

if 'answerable' not in ann:
data_list.append(data_info)
else:
if ann['answerable'] == 1:
# add answer_weight & answer_count, delete duplicate answer
answers = []
for item in ann.pop('answers'):
if item['answer_confidence'] == 'yes' and item[
'answer'] != 'unanswerable':
answers.append(item['answer'])
count = Counter(answers)
answer_weight = [i / len(answers) for i in count.values()]
data_info['gt_answer'] = list(count.keys())
data_info['gt_answer_weight'] = answer_weight
# data_info.update(ann)
data_list.append(data_info)

return data_list

0 comments on commit 5a76048

Please sign in to comment.