Skip to content

Commit

Permalink
[FRONTEND][TENSORFLOW] support multiply outputs (apache#4980)
Browse files Browse the repository at this point in the history
* [FRONTEND][TENSORFLOW] support multiply outputs

* [TENSORFLOW][TEST] add tf_testing.AddShapesToGraphDef test

* update frontend test

* retrigger CI
  • Loading branch information
vv1133 authored and zhiics committed Apr 17, 2020
1 parent a8287db commit 48e262d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
9 changes: 7 additions & 2 deletions python/tvm/relay/testing/tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def ProcessGraphDefParam(graph_def):
return graph_def


def convert_to_list(x):
if not isinstance(x, list):
x = [x]
return x

def AddShapesToGraphDef(session, out_node):
""" Add shapes attribute to nodes of the graph.
Input graph here is the default graph in context.
Expand All @@ -74,7 +79,7 @@ def AddShapesToGraphDef(session, out_node):
----------
session : tf.Session
Tensorflow session
out_node : String
out_node : String or List
Final output node of the graph.
Returns
Expand All @@ -87,7 +92,7 @@ def AddShapesToGraphDef(session, out_node):
graph_def = tf_compat_v1.graph_util.convert_variables_to_constants(
session,
session.graph.as_graph_def(add_shapes=True),
[out_node],
convert_to_list(out_node),
)
return graph_def

Expand Down
6 changes: 1 addition & 5 deletions tests/python/frontend/tensorflow/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,11 +171,7 @@ def name_without_num(name):
with tf.Session() as sess:
if init_global_variables:
sess.run(variables.global_variables_initializer())
final_graph_def = tf.graph_util.convert_variables_to_constants(
sess,
sess.graph.as_graph_def(add_shapes=True),
out_node,
)
final_graph_def = tf_testing.AddShapesToGraphDef(sess, out_node)
tf_output = run_tf_graph(sess, in_data, in_name, out_name)

for device in ["llvm", "cuda"]:
Expand Down

0 comments on commit 48e262d

Please sign in to comment.