Skip to content

Commit

Permalink
Change Loop op with maximum iterations input M equals to empty string (
Browse files Browse the repository at this point in the history
…#1971)

* make Loop op with maximum iterations M equal to empty string to match onnx spec

Signed-off-by: Deyu Huang <[email protected]>
  • Loading branch information
hwangdeyu authored Jun 17, 2022
1 parent 89c4c5c commit b027bb2
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions tf2onnx/onnx_opset/controlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,29 +381,32 @@ def version_7(cls, ctx, node, **kwargs):
# may be removed from output_names below
output_names = node.output.copy()

# Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other
# Make maximum_iterations int64. If the const node has no other
# consumers, modify it in place. Otherwise, make a new const node and leave the original unchanged.
# if maximum_iterations is not const,should add an cast node(cast to int64)
maximum_iterations_name = node.input[1]
if node.inputs[1].is_const():
maximum_iterations = node.inputs[1].get_tensor_value()
if maximum_iterations == -1:
maximum_iterations = np.iinfo(np.int64).max
consumers = ctx.find_output_consumers(maximum_iterations_name)
external_consumers = [c for c in consumers if c != node and c.type != 'TensorListReserve']
if len(external_consumers) == 0:
ctx.remove_node(node.inputs[1].name)
# maximum_iterations with -1(tf) means it doesn't set the maximum count.
# For onnx Loop op optional input `M`(int64), represents a maximum trip-count. Set empty string to skip.
if maximum_iterations != -1:
consumers = ctx.find_output_consumers(maximum_iterations_name)
external_consumers = [c for c in consumers if c != node and c.type != 'TensorListReserve']
if len(external_consumers) == 0:
ctx.remove_node(node.inputs[1].name)
else:
maximum_iterations_name = utils.make_name(node.inputs[1].name)
ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64))
ctx.replace_input(node, node.input[1], maximum_iterations_name, 1)
maximum_iterations_m = maximum_iterations_name
else:
maximum_iterations_name = utils.make_name(node.inputs[1].name)
ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64))
ctx.replace_input(node, node.input[1], maximum_iterations_name, 1)
maximum_iterations_int64 = maximum_iterations_name
maximum_iterations_m = ""
else:
cast_inputs = [maximum_iterations_name]
attr = {"to": onnx_pb.TensorProto.INT64}
cast_name = node.name + "_cast"
cast_node = ctx.make_node("Cast", cast_inputs, attr, name=cast_name)
maximum_iterations_int64 = cast_node.output[0]
maximum_iterations_m = cast_node.output[0]

cond_name = node.get_attr_str("cond")
cond_graph = find_function(cond_name)
Expand All @@ -427,7 +430,7 @@ def version_7(cls, ctx, node, **kwargs):
cond_input_to_state_var[cond_graph.input_names[idx]] = maximum_iterations_name
continue
if idx < 2:
# skip [0,1] loop_counter, max_iterations
# skip [0,1] loop_counter, max_iterations
continue
n = node.inputs[idx]
if n.type in ["TensorListReserve", "TensorListResize"]:
Expand Down Expand Up @@ -511,7 +514,7 @@ def version_7(cls, ctx, node, **kwargs):
output_names = output_names[2:]

branches = {"body": body}
loop_node = ctx.make_node("Loop", [maximum_iterations_int64, cond_outputs[0]] + loop_vars,
loop_node = ctx.make_node("Loop", [maximum_iterations_m, cond_outputs[0]] + loop_vars,
output_count=len(output_shapes), name=node.name + "_loop",
shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True,
branches=branches)
Expand Down

0 comments on commit b027bb2

Please sign in to comment.