From 277dc469da786860f0a3b171ac660d604d3acc19 Mon Sep 17 00:00:00 2001 From: zhengdi Date: Tue, 3 Mar 2020 18:39:42 +0800 Subject: [PATCH 1/2] [FRONTEND][TENSORFLOW] support multiply outputs --- python/tvm/relay/testing/tf.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index 1dbbf14b41f3..dfe5801b39cf 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -66,6 +66,11 @@ def ProcessGraphDefParam(graph_def): return graph_def +def convert_to_list(x): + if not isinstance(x, list): + x = [x] + return x + def AddShapesToGraphDef(session, out_node): """ Add shapes attribute to nodes of the graph. Input graph here is the default graph in context. @@ -74,7 +79,7 @@ def AddShapesToGraphDef(session, out_node): ---------- session : tf.Session Tensorflow session - out_node : String + out_node : String or List Final output node of the graph. Returns @@ -87,7 +92,7 @@ def AddShapesToGraphDef(session, out_node): graph_def = tf_compat_v1.graph_util.convert_variables_to_constants( session, session.graph.as_graph_def(add_shapes=True), - [out_node], + convert_to_list(out_node), ) return graph_def From 24e2b52790a77e71b380f4ac498af263504f7a6d Mon Sep 17 00:00:00 2001 From: zhengdi Date: Thu, 5 Mar 2020 09:32:25 +0800 Subject: [PATCH 2/2] [TENSORFLOW][TEST] add tf_testing.AddShapesToGraphDef test * update frontend test * retrigger CI --- tests/python/frontend/tensorflow/test_forward.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tests/python/frontend/tensorflow/test_forward.py b/tests/python/frontend/tensorflow/test_forward.py index 42408b706111..fae39e6b0183 100644 --- a/tests/python/frontend/tensorflow/test_forward.py +++ b/tests/python/frontend/tensorflow/test_forward.py @@ -171,11 +171,7 @@ def name_without_num(name): with tf.Session() as sess: if init_global_variables: sess.run(variables.global_variables_initializer()) - final_graph_def = tf.graph_util.convert_variables_to_constants( - sess, - sess.graph.as_graph_def(add_shapes=True), - out_node, - ) + final_graph_def = tf_testing.AddShapesToGraphDef(sess, out_node) tf_output = run_tf_graph(sess, in_data, in_name, out_name) for device in ["llvm", "cuda"]: