-
Notifications
You must be signed in to change notification settings - Fork 556
/
prepare_train_data.py
107 lines (95 loc) · 4.48 KB
/
prepare_train_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from __future__ import division
import argparse
import scipy.misc
import numpy as np
from glob import glob
from joblib import Parallel, delayed
import os
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_dir", type=str, required=True, help="where the dataset is stored")
parser.add_argument("--dataset_name", type=str, required=True, choices=["kitti_raw_eigen", "kitti_raw_stereo", "kitti_odom", "cityscapes"])
parser.add_argument("--dump_root", type=str, required=True, help="Where to dump the data")
parser.add_argument("--seq_length", type=int, required=True, help="Length of each training sequence")
parser.add_argument("--img_height", type=int, default=128, help="image height")
parser.add_argument("--img_width", type=int, default=416, help="image width")
parser.add_argument("--num_threads", type=int, default=4, help="number of threads to use")
args = parser.parse_args()
def concat_image_seq(seq):
for i, im in enumerate(seq):
if i == 0:
res = im
else:
res = np.hstack((res, im))
return res
def dump_example(n, args):
if n % 2000 == 0:
print('Progress %d/%d....' % (n, data_loader.num_train))
example = data_loader.get_train_example_with_idx(n)
if example == False:
return
image_seq = concat_image_seq(example['image_seq'])
intrinsics = example['intrinsics']
fx = intrinsics[0, 0]
fy = intrinsics[1, 1]
cx = intrinsics[0, 2]
cy = intrinsics[1, 2]
dump_dir = os.path.join(args.dump_root, example['folder_name'])
# if not os.path.isdir(dump_dir):
# os.makedirs(dump_dir, exist_ok=True)
try:
os.makedirs(dump_dir)
except OSError:
if not os.path.isdir(dump_dir):
raise
dump_img_file = dump_dir + '/%s.jpg' % example['file_name']
scipy.misc.imsave(dump_img_file, image_seq.astype(np.uint8))
dump_cam_file = dump_dir + '/%s_cam.txt' % example['file_name']
with open(dump_cam_file, 'w') as f:
f.write('%f,0.,%f,0.,%f,%f,0.,0.,1.' % (fx, cx, fy, cy))
def main():
if not os.path.exists(args.dump_root):
os.makedirs(args.dump_root)
global data_loader
if args.dataset_name == 'kitti_odom':
from kitti.kitti_odom_loader import kitti_odom_loader
data_loader = kitti_odom_loader(args.dataset_dir,
img_height=args.img_height,
img_width=args.img_width,
seq_length=args.seq_length)
if args.dataset_name == 'kitti_raw_eigen':
from kitti.kitti_raw_loader import kitti_raw_loader
data_loader = kitti_raw_loader(args.dataset_dir,
split='eigen',
img_height=args.img_height,
img_width=args.img_width,
seq_length=args.seq_length)
if args.dataset_name == 'kitti_raw_stereo':
from kitti.kitti_raw_loader import kitti_raw_loader
data_loader = kitti_raw_loader(args.dataset_dir,
split='stereo',
img_height=args.img_height,
img_width=args.img_width,
seq_length=args.seq_length)
if args.dataset_name == 'cityscapes':
from cityscapes.cityscapes_loader import cityscapes_loader
data_loader = cityscapes_loader(args.dataset_dir,
img_height=args.img_height,
img_width=args.img_width,
seq_length=args.seq_length)
Parallel(n_jobs=args.num_threads)(delayed(dump_example)(n, args) for n in range(data_loader.num_train))
# Split into train/val
np.random.seed(8964)
subfolders = os.listdir(args.dump_root)
with open(args.dump_root + 'train.txt', 'w') as tf:
with open(args.dump_root + 'val.txt', 'w') as vf:
for s in subfolders:
if not os.path.isdir(args.dump_root + '/%s' % s):
continue
imfiles = glob(os.path.join(args.dump_root, s, '*.jpg'))
frame_ids = [os.path.basename(fi).split('.')[0] for fi in imfiles]
for frame in frame_ids:
if np.random.random() < 0.1:
vf.write('%s %s\n' % (s, frame))
else:
tf.write('%s %s\n' % (s, frame))
main()