forked from bluerythem/pcn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_util.py
55 lines (45 loc) · 2.16 KB
/
data_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Author: Wentao Yuan ([email protected]) 05/31/2018
import numpy as np
import tensorflow as tf
from tensorpack import dataflow
def resample_pcd(pcd, n):
"""Drop or duplicate points so that pcd has exactly n points"""
idx = np.random.permutation(pcd.shape[0])
if idx.shape[0] < n:
idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])])
return pcd[idx[:n]]
class PreprocessData(dataflow.ProxyDataFlow):
def __init__(self, ds, input_size, output_size):
super(PreprocessData, self).__init__(ds)
self.input_size = input_size
self.output_size = output_size
def get_data(self):
for id, input, gt in self.ds.get_data():
input = resample_pcd(input, self.input_size)
gt = resample_pcd(gt, self.output_size)
yield id, input, gt
def lmdb_dataflow(lmdb_path, batch_size, input_size, output_size, is_training, test_speed=False):
df = dataflow.LMDBSerializer.load(lmdb_path, shuffle=False)
size = df.size()
if is_training:
df = dataflow.LocallyShuffleData(df, buffer_size=2000)
df = dataflow.PrefetchData(df, nr_prefetch=500, nr_proc=1)
df = PreprocessData(df, input_size, output_size)
if is_training:
df = dataflow.PrefetchDataZMQ(df, nr_proc=8)
df = dataflow.BatchData(df, batch_size, use_list=True)
df = dataflow.RepeatedData(df, -1)
if test_speed:
dataflow.TestDataSpeed(df, size=1000).start()
df.reset_state()
return df, size
def get_queued_data(generator, dtypes, shapes, queue_capacity=10):
assert len(dtypes) == len(shapes), 'dtypes and shapes must have the same length'
queue = tf.FIFOQueue(queue_capacity, dtypes, shapes)
placeholders = [tf.placeholder(dtype, shape) for dtype, shape in zip(dtypes, shapes)]
enqueue_op = queue.enqueue(placeholders)
close_op = queue.close(cancel_pending_enqueues=True)
feed_fn = lambda: {placeholder: value for placeholder, value in zip(placeholders, next(generator))}
queue_runner = tf.contrib.training.FeedingQueueRunner(queue, [enqueue_op], close_op, feed_fns=[feed_fn])
tf.train.add_queue_runner(queue_runner)
return queue.dequeue()