Skip to content

Commit

Permalink
[Tools] Support respliting data_batch with tag (open-mmlab#7641)
Browse files Browse the repository at this point in the history
* support respliting data_batch with tag

* add citations

* add a unit test

* fix lint
  • Loading branch information
Czm369 authored and SakiRinn committed Mar 17, 2023
1 parent 697fab2 commit 1ee91e7
Show file tree
Hide file tree
Showing 3 changed files with 142 additions and 1 deletion.
3 changes: 2 additions & 1 deletion mmdet/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from .logger import get_caller_name, get_root_logger, log_img_scale
from .misc import find_latest_checkpoint, update_data_root
from .setup_env import setup_multi_processes
from .split_batch import split_batch

__all__ = [
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
'update_data_root', 'setup_multi_processes', 'get_caller_name',
'log_img_scale', 'compat_cfg'
'log_img_scale', 'compat_cfg', 'split_batch'
]
45 changes: 45 additions & 0 deletions mmdet/utils/split_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch


def split_batch(img, img_metas, kwargs):
"""Split data_batch by tags.
Code is modified from
<https://github.com/microsoft/SoftTeacher/blob/main/ssod/utils/structure_utils.py> # noqa: E501
Args:
img (Tensor): of shape (N, C, H, W) encoding input images.
Typically these should be mean centered and std scaled.
img_metas (list[dict]): List of image info dict where each dict
has: 'img_shape', 'scale_factor', 'flip', and may also contain
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
For details on the values of these keys, see
:class:`mmdet.datasets.pipelines.Collect`.
kwargs (dict): Specific to concrete implementation.
Returns:
data_groups (dict): a dict that data_batch splited by tags,
such as 'sup', 'unsup_teacher', and 'unsup_student'.
"""

# only stack img in the batch
def fuse_list(obj_list, obj):
return torch.stack(obj_list) if isinstance(obj,
torch.Tensor) else obj_list

# select data with tag from data_batch
def select_group(data_batch, current_tag):
group_flag = [tag == current_tag for tag in data_batch['tag']]
return {
k: fuse_list([vv for vv, gf in zip(v, group_flag) if gf], v)
for k, v in data_batch.items()
}

kwargs.update({'img': img, 'img_metas': img_metas})
kwargs.update({'tag': [meta['tag'] for meta in img_metas]})
tags = list(set(kwargs['tag']))
data_groups = {tag: select_group(kwargs, tag) for tag in tags}
for tag, group in data_groups.items():
group.pop('tag')
return data_groups
95 changes: 95 additions & 0 deletions tests/test_utils/test_split_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from copy import deepcopy

import mmcv
import numpy as np
import torch

from mmdet.utils import split_batch


def test_split_batch():
img_root = osp.join(osp.dirname(__file__), '../data/color.jpg')
img = mmcv.imread(img_root, 'color')
h, w, _ = img.shape
gt_bboxes = np.array([[0.2 * w, 0.2 * h, 0.4 * w, 0.4 * h],
[0.6 * w, 0.6 * h, 0.8 * w, 0.8 * h]],
dtype=np.float32)
gt_lables = np.ones(gt_bboxes.shape[0], dtype=np.int64)

img = torch.tensor(img).permute(2, 0, 1)
meta = dict()
meta['filename'] = img_root
meta['ori_shape'] = img.shape
meta['img_shape'] = img.shape
meta['img_norm_cfg'] = {
'mean': np.array([103.53, 116.28, 123.675], dtype=np.float32),
'std': np.array([1., 1., 1.], dtype=np.float32),
'to_rgb': False
}
meta['pad_shape'] = img.shape
# For example, tag include sup, unsup_teacher and unsup_student,
# in order to distinguish the difference between the three groups of data,
# the scale_factor of sup is [0.5, 0.5, 0.5, 0.5]
# the scale_factor of unsup_teacher is [1.0, 1.0, 1.0, 1.0]
# the scale_factor of unsup_student is [2.0, 2.0, 2.0, 2.0]
imgs = img.unsqueeze(0).repeat(9, 1, 1, 1)
img_metas = []
tags = [
'sup', 'unsup_teacher', 'unsup_student', 'unsup_teacher',
'unsup_student', 'unsup_teacher', 'unsup_student', 'unsup_teacher',
'unsup_student'
]
for tag in tags:
img_meta = deepcopy(meta)
if tag == 'sup':
img_meta['scale_factor'] = [0.5, 0.5, 0.5, 0.5]
img_meta['tag'] = 'sup'
elif tag == 'unsup_teacher':
img_meta['scale_factor'] = [1.0, 1.0, 1.0, 1.0]
img_meta['tag'] = 'unsup_teacher'
elif tag == 'unsup_student':
img_meta['scale_factor'] = [2.0, 2.0, 2.0, 2.0]
img_meta['tag'] = 'unsup_student'
else:
continue
img_metas.append(img_meta)
kwargs = dict()
kwargs['gt_bboxes'] = [torch.tensor(gt_bboxes)] + [torch.zeros(0, 4)] * 8
kwargs['gt_lables'] = [torch.tensor(gt_lables)] + [torch.zeros(0, )] * 8
data_groups = split_batch(imgs, img_metas, kwargs)
assert set(data_groups.keys()) == set(tags)
assert data_groups['sup']['img'].shape == (1, 3, h, w)
assert data_groups['unsup_teacher']['img'].shape == (4, 3, h, w)
assert data_groups['unsup_student']['img'].shape == (4, 3, h, w)
# the scale_factor of sup is [0.5, 0.5, 0.5, 0.5]
assert data_groups['sup']['img_metas'][0]['scale_factor'] == [
0.5, 0.5, 0.5, 0.5
]
# the scale_factor of unsup_teacher is [1.0, 1.0, 1.0, 1.0]
assert data_groups['unsup_teacher']['img_metas'][0]['scale_factor'] == [
1.0, 1.0, 1.0, 1.0
]
assert data_groups['unsup_teacher']['img_metas'][1]['scale_factor'] == [
1.0, 1.0, 1.0, 1.0
]
assert data_groups['unsup_teacher']['img_metas'][2]['scale_factor'] == [
1.0, 1.0, 1.0, 1.0
]
assert data_groups['unsup_teacher']['img_metas'][3]['scale_factor'] == [
1.0, 1.0, 1.0, 1.0
]
# the scale_factor of unsup_student is [2.0, 2.0, 2.0, 2.0]
assert data_groups['unsup_student']['img_metas'][0]['scale_factor'] == [
2.0, 2.0, 2.0, 2.0
]
assert data_groups['unsup_student']['img_metas'][1]['scale_factor'] == [
2.0, 2.0, 2.0, 2.0
]
assert data_groups['unsup_student']['img_metas'][2]['scale_factor'] == [
2.0, 2.0, 2.0, 2.0
]
assert data_groups['unsup_student']['img_metas'][3]['scale_factor'] == [
2.0, 2.0, 2.0, 2.0
]

0 comments on commit 1ee91e7

Please sign in to comment.