From b6c38208f8ff907253331ddaa0d39552e3f419e2 Mon Sep 17 00:00:00 2001 From: zhengdi Date: Sun, 8 Mar 2020 14:27:56 +0800 Subject: [PATCH] [FRONTEND][TENSORFLOW] support multiply outputs (#4980) * [FRONTEND][TENSORFLOW] support multiply outputs * [TENSORFLOW][TEST] add tf_testing.AddShapesToGraphDef test * update frontend test * retrigger CI --- python/tvm/relay/testing/tf.py | 9 +++++++-- tests/python/frontend/tensorflow/test_forward.py | 6 +----- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index 1dbbf14b41f3..dfe5801b39cf 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -66,6 +66,11 @@ def ProcessGraphDefParam(graph_def): return graph_def +def convert_to_list(x): + if not isinstance(x, list): + x = [x] + return x + def AddShapesToGraphDef(session, out_node): """ Add shapes attribute to nodes of the graph. Input graph here is the default graph in context. @@ -74,7 +79,7 @@ def AddShapesToGraphDef(session, out_node): ---------- session : tf.Session Tensorflow session - out_node : String + out_node : String or List Final output node of the graph. Returns @@ -87,7 +92,7 @@ def AddShapesToGraphDef(session, out_node): graph_def = tf_compat_v1.graph_util.convert_variables_to_constants( session, session.graph.as_graph_def(add_shapes=True), - [out_node], + convert_to_list(out_node), ) return graph_def diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 31c548044846..3d033d368c19 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -171,11 +171,7 @@ def name_without_num(name): with tf.Session() as sess: if init_global_variables: sess.run(variables.global_variables_initializer()) - final_graph_def = tf.graph_util.convert_variables_to_constants( - sess, - sess.graph.as_graph_def(add_shapes=True), - out_node, - ) + final_graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) tf_output = run_tf_graph(sess, in_data, in_name, out_name) for device in ["llvm", "cuda"]: