diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 59b9185106f3..9392ee6fcaf3 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -2742,24 +2742,24 @@ def get_var(name, val, scan=False): loop_var_names = [v.name_hint for v in loop_vars] num_scan_outputs = len(body.output) - (1 + num_deps) - # TODO (jwfromm) Test with strided slice once type unifier for this case is fixed. - if num_scan_outputs != 0 and "Slice" in [n.op_type for n in body.node]: - warnings.warn( - """ - Using scan outputs in a loop with strided slice - currently may cause errors during compilation. - """ - ) # Construct variables and intial empty tensors for any scan outputs. + # To do this, we'll figure out the output shapes of the body subgraph by importing + # it and doing type inference. scan_output_vars = [] scan_output_init = [] + if num_scan_outputs > 0: + with subgraph_scope: + loop_outputs = subgraph_scope.from_onnx( + body, graph_scope.opset, get_output_expr=True + ) + loop_outputs = _expr.TupleWrapper(loop_outputs, len(body.output)) + for i in range(num_scan_outputs): - name, shape, dtype, _ = get_info(body.output[i + 1 + num_deps]) - if dtype is None: - dtype = infer_type(loop_deps[i]).checked_type.dtype - if dtype == "float": - dtype = "float32" + name, _, _, _ = get_info(body.output[i + 1 + num_deps]) + output_node = infer_type(loop_outputs[i + 1 + num_deps]) + shape = get_const_tuple(output_node.checked_type.shape) + dtype = output_node.checked_type.dtype scan_output_vars.append( _expr.var(name, shape=([_ty.Any()] * (len(shape) + 1)), dtype=dtype) ) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index f7bae5da79e1..bf7899551df9 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -4043,7 +4043,7 @@ def verify_count_loop(): verify_with_ort_with_inputs(loop_model, input_vals, use_vm=True, freeze_params=True) -def verify_tensor_loop(): +def verify_tensor_loop(shapeless_output=False): y_in = helper.make_tensor_value_info("y_in", TensorProto.FLOAT, [3, 3, 3, 3]) y_out = helper.make_tensor_value_info("y_out", TensorProto.FLOAT, [3, 3, 3, 3]) scan_out = helper.make_tensor_value_info("scan_out", TensorProto.FLOAT, [3, 3, 3, 3]) @@ -4076,6 +4076,13 @@ def verify_tensor_loop(): trip_count = np.array(5).astype(np.int64) cond = np.array(1).astype(bool) + + # Allow testing of malformed nodes since pytorch likes to create these. + if shapeless_output: + scan_shape = None + else: + scan_shape = [5, 3, 3, 3, 3] + loop_graph = onnx.helper.make_graph( [loop_node], "loop_outer", @@ -4086,7 +4093,7 @@ def verify_tensor_loop(): ], outputs=[ onnx.helper.make_tensor_value_info("res_y", onnx.TensorProto.FLOAT, [3, 3, 3, 3]), - onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, [5, 3, 3, 3, 3]), + onnx.helper.make_tensor_value_info("res_scan", onnx.TensorProto.FLOAT, scan_shape), ], ) loop_model = onnx.helper.make_model(loop_graph) @@ -4106,6 +4113,8 @@ def test_loop(): verify_count_loop() # Test a loop that uses an array output. verify_tensor_loop() + # Test a loop that is malformed and has no output shape defined. + verify_tensor_loop(shapeless_output=True) def verify_if(cond_array, num_outputs):