forked from bluerythem/pcn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
demo.py
59 lines (48 loc) · 1.97 KB
/
demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
# Author: Wentao Yuan ([email protected]) 05/31/2018
import argparse
import importlib
import models
import numpy as np
import tensorflow as tf
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from open3d import *
from pdb import set_trace as st
def plot_pcd(ax, pcd):
ax.scatter(pcd[:, 0], pcd[:, 1], pcd[:, 2], zdir='y', c=pcd[:, 0], s=0.5, cmap='Reds', vmin=-1, vmax=0.5)
ax.set_axis_off()
ax.set_xlim(-0.3, 0.3)
ax.set_ylim(-0.3, 0.3)
ax.set_zlim(-0.3, 0.3)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--input_path', default='demo_data/car.pcd')
parser.add_argument('--model_type', default='pcn_cd')
parser.add_argument('--checkpoint', default='data/trained_models/pcn_cd')
parser.add_argument('--num_gt_points', type=int, default=16384)
parser.add_argument('--num_samp', type=int)
args = parser.parse_args()
inputs = tf.placeholder(tf.float32, (1, None, 3))
gt = tf.placeholder(tf.float32, (1, args.num_gt_points, 3))
model_module = importlib.import_module('.%s' % args.model_type, 'models')
model = model_module.Model(inputs, gt, tf.constant(1.0))
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.allow_soft_placement = True
sess = tf.Session(config=config)
saver = tf.train.Saver()
saver.restore(sess, args.checkpoint)
partial = read_point_cloud(args.input_path)
partial = np.array(partial.points)
partial_samp = np.random.choice(np.arange(partial.shape[0]), args.num_samp)
partial = partial[partial_samp]
complete = sess.run(model.outputs, feed_dict={inputs: [partial]})[0]
fig = plt.figure(figsize=(8, 4))
ax = fig.add_subplot(121, projection='3d')
plot_pcd(ax, partial)
ax.set_title('Input')
ax = fig.add_subplot(122, projection='3d')
plot_pcd(ax, complete)
ax.set_title('Output')
plt.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0)
plt.show()