diff --git a/tf2onnx/onnx_opset/controlflow.py b/tf2onnx/onnx_opset/controlflow.py index b6dd5a14b..b244bd3f1 100644 --- a/tf2onnx/onnx_opset/controlflow.py +++ b/tf2onnx/onnx_opset/controlflow.py @@ -381,29 +381,32 @@ def version_7(cls, ctx, node, **kwargs): # may be removed from output_names below output_names = node.output.copy() - # Make maximum_iterations int64 and replace -1(tf) with maxsize(onnx). If the const node has no other + # Make maximum_iterations int64. If the const node has no other # consumers, modify it in place. Otherwise, make a new const node and leave the original unchanged. # if maximum_iterations is not const,should add an cast node(cast to int64) maximum_iterations_name = node.input[1] if node.inputs[1].is_const(): maximum_iterations = node.inputs[1].get_tensor_value() - if maximum_iterations == -1: - maximum_iterations = np.iinfo(np.int64).max - consumers = ctx.find_output_consumers(maximum_iterations_name) - external_consumers = [c for c in consumers if c != node and c.type != 'TensorListReserve'] - if len(external_consumers) == 0: - ctx.remove_node(node.inputs[1].name) + # maximum_iterations with -1(tf) means it doesn't set the maximum count. + # For onnx Loop op optional input `M`(int64), represents a maximum trip-count. Set empty string to skip. + if maximum_iterations != -1: + consumers = ctx.find_output_consumers(maximum_iterations_name) + external_consumers = [c for c in consumers if c != node and c.type != 'TensorListReserve'] + if len(external_consumers) == 0: + ctx.remove_node(node.inputs[1].name) + else: + maximum_iterations_name = utils.make_name(node.inputs[1].name) + ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64)) + ctx.replace_input(node, node.input[1], maximum_iterations_name, 1) + maximum_iterations_m = maximum_iterations_name else: - maximum_iterations_name = utils.make_name(node.inputs[1].name) - ctx.make_const(maximum_iterations_name, np.array(maximum_iterations, dtype=np.int64)) - ctx.replace_input(node, node.input[1], maximum_iterations_name, 1) - maximum_iterations_int64 = maximum_iterations_name + maximum_iterations_m = "" else: cast_inputs = [maximum_iterations_name] attr = {"to": onnx_pb.TensorProto.INT64} cast_name = node.name + "_cast" cast_node = ctx.make_node("Cast", cast_inputs, attr, name=cast_name) - maximum_iterations_int64 = cast_node.output[0] + maximum_iterations_m = cast_node.output[0] cond_name = node.get_attr_str("cond") cond_graph = find_function(cond_name) @@ -427,7 +430,7 @@ def version_7(cls, ctx, node, **kwargs): cond_input_to_state_var[cond_graph.input_names[idx]] = maximum_iterations_name continue if idx < 2: - # skip [0,1] loop_counter, max_iterations + # skip [0,1] loop_counter, max_iterations continue n = node.inputs[idx] if n.type in ["TensorListReserve", "TensorListResize"]: @@ -511,7 +514,7 @@ def version_7(cls, ctx, node, **kwargs): output_names = output_names[2:] branches = {"body": body} - loop_node = ctx.make_node("Loop", [maximum_iterations_int64, cond_outputs[0]] + loop_vars, + loop_node = ctx.make_node("Loop", [maximum_iterations_m, cond_outputs[0]] + loop_vars, output_count=len(output_shapes), name=node.name + "_loop", shapes=output_shapes, dtypes=output_dtypes, skip_conversion=True, branches=branches)