Skip to content

Commit

Permalink
[Relay][Frontend][ONNX] Allow importing models with malformed Loop no…
Browse files Browse the repository at this point in the history
…des. (#8475)

* Snapshot

* Undo comments.

* Add testing for malformed loop nodes.

* Format oops.
  • Loading branch information
Josh Fromm authored Jul 15, 2021
1 parent bd88ee2 commit ce15ca6
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
26 changes: 13 additions & 13 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2742,24 +2742,24 @@ def get_var(name, val, scan=False):
loop_var_names = [v.name_hint for v in loop_vars]

num_scan_outputs = len(body.output) - (1 + num_deps)
# TODO (jwfromm) Test with strided slice once type unifier for this case is fixed.
if num_scan_outputs != 0 and "Slice" in [n.op_type for n in body.node]:
warnings.warn(
"""
Using scan outputs in a loop with strided slice
currently may cause errors during compilation.
"""
)

# Construct variables and intial empty tensors for any scan outputs.
# To do this, we'll figure out the output shapes of the body subgraph by importing
# it and doing type inference.
scan_output_vars = []
scan_output_init = []
if num_scan_outputs > 0:
with subgraph_scope:
loop_outputs = subgraph_scope.from_onnx(
body, graph_scope.opset, get_output_expr=True
)
loop_outputs = _expr.TupleWrapper(loop_outputs, len(body.output))

for i in range(num_scan_outputs):
name, shape, dtype, _ = get_info(body.output[i + 1 + num_deps])
if dtype is None:
dtype = infer_type(loop_deps[i]).checked_type.dtype
if dtype == "float":
dtype = "float32"
name, _, _, _ = get_info(body.output[i + 1 + num_deps])
output_node = infer_type(loop_outputs[i + 1 + num_deps])
shape = get_const_tuple(output_node.checked_type.shape)
dtype = output_node.checked_type.dtype
scan_output_vars.append(
_expr.var(name, shape=([_ty.Any()] * (len(shape) + 1)), dtype=dtype)
)
Expand Down
13 changes: 11 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -4043,7 +4043,7 @@ def verify_count_loop():
verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True)


def verify_tensor_loop():
def verify_tensor_loop(shapeless_output=False):
y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [3, 3, 3, 3])
y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [3, 3, 3, 3])
scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [3, 3, 3, 3])
Expand Down Expand Up @@ -4076,6 +4076,13 @@ def verify_tensor_loop():

trip_count = np.array(5).astype(np.int64)
cond = np.array(1).astype(bool)

# Allow testing of malformed nodes since pytorch likes to create these.
if shapeless_output:
scan_shape = None
else:
scan_shape = [5, 3, 3, 3, 3]

loop_graph = onnx.helper.make_graph(
[loop_node],
"loop_outer",
Expand All @@ -4086,7 +4093,7 @@ def verify_tensor_loop():
],
outputs=[
onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]),
onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 3, 3, 3, 3]),
onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, scan_shape),
],
)
loop_model = onnx.helper.make_model(loop_graph)
Expand All @@ -4106,6 +4113,8 @@ def test_loop():
verify_count_loop()
# Test a loop that uses an array output.
verify_tensor_loop()
# Test a loop that is malformed and has no output shape defined.
verify_tensor_loop(shapeless_output=True)


def verify_if(cond_array, num_outputs):
Expand Down

0 comments on commit ce15ca6

Please sign in to comment.