-
Notifications
You must be signed in to change notification settings - Fork 2.6k
/
voc_aug.py
87 lines (69 loc) · 2.85 KB
/
voc_aug.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import argparse
import os.path as osp
from functools import partial
import mmcv
import numpy as np
from PIL import Image
from scipy.io import loadmat
AUG_LEN = 10582
def convert_mat(mat_file, in_dir, out_dir):
data = loadmat(osp.join(in_dir, mat_file))
mask = data['GTcls'][0]['Segmentation'][0].astype(np.uint8)
seg_filename = osp.join(out_dir, mat_file.replace('.mat', '.png'))
Image.fromarray(mask).save(seg_filename, 'PNG')
def generate_aug_list(merged_list, excluded_list):
return list(set(merged_list) - set(excluded_list))
def parse_args():
parser = argparse.ArgumentParser(
description='Convert PASCAL VOC annotations to mmsegmentation format')
parser.add_argument('devkit_path', help='pascal voc devkit path')
parser.add_argument('aug_path', help='pascal voc aug path')
parser.add_argument('-o', '--out_dir', help='output path')
parser.add_argument(
'--nproc', default=1, type=int, help='number of process')
args = parser.parse_args()
return args
def main():
args = parse_args()
devkit_path = args.devkit_path
aug_path = args.aug_path
nproc = args.nproc
if args.out_dir is None:
out_dir = osp.join(devkit_path, 'VOC2012', 'SegmentationClassAug')
else:
out_dir = args.out_dir
mmcv.mkdir_or_exist(out_dir)
in_dir = osp.join(aug_path, 'dataset', 'cls')
mmcv.track_parallel_progress(
partial(convert_mat, in_dir=in_dir, out_dir=out_dir),
list(mmcv.scandir(in_dir, suffix='.mat')),
nproc=nproc)
with open(osp.join(aug_path, 'dataset', 'trainval.txt')) as f:
full_aug_list = [line.strip() for line in f]
with open(
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
'train.txt')) as f:
ori_train_list = [line.strip() for line in f]
with open(
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
'val.txt')) as f:
val_list = [line.strip() for line in f]
aug_train_list = generate_aug_list(ori_train_list + full_aug_list,
val_list)
assert len(aug_train_list) == AUG_LEN, 'len(aug_train_list) != {}'.format(
AUG_LEN)
with open(
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation',
'trainaug.txt'), 'w') as f:
f.writelines(line + '\n' for line in aug_train_list)
aug_list = generate_aug_list(full_aug_list, ori_train_list + val_list)
assert len(aug_list) == AUG_LEN - len(
ori_train_list), 'len(aug_list) != {}'.format(AUG_LEN -
len(ori_train_list))
with open(
osp.join(devkit_path, 'VOC2012/ImageSets/Segmentation', 'aug.txt'),
'w') as f:
f.writelines(line + '\n' for line in aug_list)
print('Done!')
if __name__ == '__main__':
main()