From 1ff3c6012a9f531ac4b6b4a5a665a1f9fbd68c86 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Mon, 17 Aug 2020 15:22:51 -0700 Subject: [PATCH] [ONNX] Update slice to infer attributes when not graph inputs (#6276) * Update ONNX Slice converter to infer slice attributes when necessary. * Linting --- python/tvm/relay/frontend/onnx.py | 25 ++++--- tests/python/frontend/onnx/test_forward.py | 84 ++++++++++++++++------ 2 files changed, 77 insertions(+), 32 deletions(-) diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index f54a145882a9..bc44431df3eb 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1050,21 +1050,24 @@ def _impl_v1(cls, inputs, attr, params): @classmethod def _impl_v10(cls, inputs, attr, params): - starts = params[get_name(inputs[1])].asnumpy() - ends = params[get_name(inputs[2])].asnumpy() - - # Update the starts and ends according to axes if required. + attrs = {'starts' : inputs[1], 'ends' : inputs[2]} if len(inputs) >= 4: - axes = params[get_name(inputs[3])].asnumpy() + attrs['axes'] = inputs[3] + attrs = {k : (v, get_name(v)) for (k, v) in attrs.items()} + attrs = {k : params[v[1]].asnumpy() if v[1] in params else + infer_value_simulated(v[0], params).asnumpy() + for (k, v) in attrs.items()} - if max(axes + 1) != len(axes): + # Update the starts and ends according to axes if required. + if 'axes' in attrs: + if max(attrs['axes'] + 1) != len(attrs['axes']): new_starts, new_ends, _ = cls._common( - starts, ends, axes) - starts = new_starts - ends = new_ends + attrs['starts'], attrs['ends'], attrs['axes']) + attrs['starts'] = new_starts + attrs['ends'] = new_ends return _op.strided_slice(inputs[0], - begin=_expr.const(starts, dtype="int64"), - end=_expr.const(ends, dtype="int64")) + begin=_expr.const(attrs['starts'], dtype="int64"), + end=_expr.const(attrs['ends'], dtype="int64")) class Gather(OnnxOpConverter): diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index c376c9aa78ea..c09580e57301 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -465,14 +465,10 @@ def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None): tvm.testing.assert_allclose(outdata, tvm_out) - -def _test_slice_iteration_v10(indata, outdata, starts, ends, axes=None): - if isinstance(starts, int): - starts = (starts, ) - if isinstance(ends, int): - ends = (ends, ) - if isinstance(axes, int): - axes = (axes, ) +def _test_slice_iteration_v10(indata, outdata, **attrs): + starts = attrs['starts'] + ends = attrs['ends'] + axes = None if 'axes' not in attrs else attrs['axes'] starts = np.asarray(starts) ends = np.asarray(ends) inputs = [ @@ -488,21 +484,59 @@ def _test_slice_iteration_v10(indata, outdata, starts, ends, axes=None): starts), helper.make_tensor("ends", TensorProto.INT64, list(ends.shape), ends) ] + nodes = [] + + if 'add_noop_to_input_attrs' in attrs: + def add_noop_to_input_attr(attr_name, attr): + output_name = attr_name+"_output" + + ref_shape = list(np.array(attr).shape) + ref_shape.insert(0, 1) + ref_shape = tuple(ref_shape) + ref_array = np.array(ref_shape) + ref_node = onnx.helper.make_node('Constant', + inputs=[], + outputs=['ref_in_'+attr_name], + value=onnx.helper.make_tensor(name='const_tensor__1_'+attr_name, + data_type=onnx.TensorProto.INT64, + dims=ref_array.shape, + vals=ref_array.flatten().astype(int))) + in_shape = np.array(attr).shape + in_array = np.array(in_shape) + ref_node2 = onnx.helper.make_node('Constant', + inputs=[], + outputs=['input_shape_'+attr_name], + value=onnx.helper.make_tensor(name='const_tensor__2_'+attr_name, + data_type=onnx.TensorProto.INT64, + dims=in_array.shape, + vals=in_array.flatten().astype(int))) + + reshape1_node = helper.make_node("Reshape", [attr_name, "ref_in_"+attr_name], ["reshape_"+attr_name]) + reshape2_node = helper.make_node("Reshape", ["reshape_"+attr_name, "input_shape_"+attr_name], [output_name]) + return [ref_node, ref_node2, reshape1_node, reshape2_node] + + slice_inputs = [] + for attr_name in ["starts", "ends", "axes"]: + if attr_name == "axes" and not axes: + continue + if "add_noop_to_input_attrs" in attrs and attr_name in attrs["add_noop_to_input_attrs"]: + nodes.extend(add_noop_to_input_attr(attr_name, attrs[attr_name])) + slice_inputs.append(attr_name + "_output") + else: + slice_inputs.append(attr_name) if axes: axes = np.asarray(axes) - y = helper.make_node("Slice", ["data", "starts", "ends", "axes"], - ["out"]) inputs.append( helper.make_tensor_value_info("axes", TensorProto.INT32, list(axes.shape))) initializer.append( helper.make_tensor("axes", TensorProto.INT32, list(axes.shape), axes)) - else: - y = helper.make_node("Slice", ["data", "starts", "ends"], ["out"]) + y = helper.make_node("Slice", ["data", *slice_inputs], ["out"]) - graph = helper.make_graph([y], + nodes.append(y) + graph = helper.make_graph(nodes, 'slice_test', inputs=inputs, outputs=[ @@ -527,15 +561,23 @@ def _test_slice_iteration_v10(indata, outdata, starts, ends, axes=None): def test_slice(): x = np.random.randn(20, 10, 5).astype(np.float32) - _test_slice_iteration_v1(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1)) - _test_slice_iteration_v1(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4)) - _test_slice_iteration_v1(x, x[:, 1:1000], (1), (1000), (1)) - _test_slice_iteration_v1(x, x[:, 0:-1], (0), (-1), (1)) - _test_slice_iteration_v10(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1)) - _test_slice_iteration_v10(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4)) - _test_slice_iteration_v10(x, x[:, 1:1000], (1), (1000), (1)) + _test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1)) + _test_slice_iteration_v1(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4)) + _test_slice_iteration_v1(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,)) + _test_slice_iteration_v1(x, x[:, 0:-1], starts=(0,), ends=(-1,), axes=(1,)) + _test_slice_iteration_v10(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1)) + _test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4)) + _test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,)) + _test_slice_iteration_v10(x, x[:, 0:-1], starts=(0,), ends=(-1,), axes=(1,)) + _test_slice_iteration_v10(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1), add_noop_to_input_attrs=["starts"]) + _test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4), add_noop_to_input_attrs=["ends"]) + _test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,), add_noop_to_input_attrs=["axes"]) + _test_slice_iteration_v10(x, x[:, 0:-1], starts=(0,), ends=(-1,), axes=(1,), add_noop_to_input_attrs=["starts", "ends"]) + _test_slice_iteration_v10(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1), add_noop_to_input_attrs=["ends", "axes"]) + _test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4), add_noop_to_input_attrs=["starts", "axes"]) + _test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,), add_noop_to_input_attrs=["starts", "ends", "axes"]) x = np.random.randn(1, 1, 1, 128).astype(np.float32) - _test_slice_iteration_v10(x, x, (0, 0), (9223372036854775807, 9223372036854775807), (0, 3)) + _test_slice_iteration_v10(x, x, starts=(0, 0), ends=(9223372036854775807, 9223372036854775807), axes=(0, 3)) def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):