Skip to content

Commit

Permalink
[ONNX] Update slice to infer attributes when not graph inputs (apache…
Browse files Browse the repository at this point in the history
…#6276)

* Update ONNX Slice converter to infer slice attributes when necessary.

* Linting
  • Loading branch information
csullivan authored and Trevor Morris committed Aug 26, 2020
1 parent 3fb8a70 commit 60e761c
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 32 deletions.
25 changes: 14 additions & 11 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1050,21 +1050,24 @@ def _impl_v1(cls, inputs, attr, params):

@classmethod
def _impl_v10(cls, inputs, attr, params):
starts = params[get_name(inputs[1])].asnumpy()
ends = params[get_name(inputs[2])].asnumpy()

# Update the starts and ends according to axes if required.
attrs = {'starts' : inputs[1], 'ends' : inputs[2]}
if len(inputs) >= 4:
axes = params[get_name(inputs[3])].asnumpy()
attrs['axes'] = inputs[3]
attrs = {k : (v, get_name(v)) for (k, v) in attrs.items()}
attrs = {k : params[v[1]].asnumpy() if v[1] in params else
infer_value_simulated(v[0], params).asnumpy()
for (k, v) in attrs.items()}

if max(axes + 1) != len(axes):
# Update the starts and ends according to axes if required.
if 'axes' in attrs:
if max(attrs['axes'] + 1) != len(attrs['axes']):
new_starts, new_ends, _ = cls._common(
starts, ends, axes)
starts = new_starts
ends = new_ends
attrs['starts'], attrs['ends'], attrs['axes'])
attrs['starts'] = new_starts
attrs['ends'] = new_ends
return _op.strided_slice(inputs[0],
begin=_expr.const(starts, dtype="int64"),
end=_expr.const(ends, dtype="int64"))
begin=_expr.const(attrs['starts'], dtype="int64"),
end=_expr.const(attrs['ends'], dtype="int64"))


class Gather(OnnxOpConverter):
Expand Down
84 changes: 63 additions & 21 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,14 +465,10 @@ def _test_slice_iteration_v1(indata, outdata, starts, ends, axes=None):

tvm.testing.assert_allclose(outdata, tvm_out)


def _test_slice_iteration_v10(indata, outdata, starts, ends, axes=None):
if isinstance(starts, int):
starts = (starts, )
if isinstance(ends, int):
ends = (ends, )
if isinstance(axes, int):
axes = (axes, )
def _test_slice_iteration_v10(indata, outdata, **attrs):
starts = attrs['starts']
ends = attrs['ends']
axes = None if 'axes' not in attrs else attrs['axes']
starts = np.asarray(starts)
ends = np.asarray(ends)
inputs = [
Expand All @@ -488,21 +484,59 @@ def _test_slice_iteration_v10(indata, outdata, starts, ends, axes=None):
starts),
helper.make_tensor("ends", TensorProto.INT64, list(ends.shape), ends)
]
nodes = []

if 'add_noop_to_input_attrs' in attrs:
def add_noop_to_input_attr(attr_name, attr):
output_name = attr_name+"_output"

ref_shape = list(np.array(attr).shape)
ref_shape.insert(0, 1)
ref_shape = tuple(ref_shape)
ref_array = np.array(ref_shape)
ref_node = onnx.helper.make_node('Constant',
inputs=[],
outputs=['ref_in_'+attr_name],
value=onnx.helper.make_tensor(name='const_tensor__1_'+attr_name,
data_type=onnx.TensorProto.INT64,
dims=ref_array.shape,
vals=ref_array.flatten().astype(int)))
in_shape = np.array(attr).shape
in_array = np.array(in_shape)
ref_node2 = onnx.helper.make_node('Constant',
inputs=[],
outputs=['input_shape_'+attr_name],
value=onnx.helper.make_tensor(name='const_tensor__2_'+attr_name,
data_type=onnx.TensorProto.INT64,
dims=in_array.shape,
vals=in_array.flatten().astype(int)))

reshape1_node = helper.make_node("Reshape", [attr_name, "ref_in_"+attr_name], ["reshape_"+attr_name])
reshape2_node = helper.make_node("Reshape", ["reshape_"+attr_name, "input_shape_"+attr_name], [output_name])
return [ref_node, ref_node2, reshape1_node, reshape2_node]

slice_inputs = []
for attr_name in ["starts", "ends", "axes"]:
if attr_name == "axes" and not axes:
continue
if "add_noop_to_input_attrs" in attrs and attr_name in attrs["add_noop_to_input_attrs"]:
nodes.extend(add_noop_to_input_attr(attr_name, attrs[attr_name]))
slice_inputs.append(attr_name + "_output")
else:
slice_inputs.append(attr_name)

if axes:
axes = np.asarray(axes)
y = helper.make_node("Slice", ["data", "starts", "ends", "axes"],
["out"])
inputs.append(
helper.make_tensor_value_info("axes", TensorProto.INT32,
list(axes.shape)))
initializer.append(
helper.make_tensor("axes", TensorProto.INT32, list(axes.shape),
axes))
else:
y = helper.make_node("Slice", ["data", "starts", "ends"], ["out"])
y = helper.make_node("Slice", ["data", *slice_inputs], ["out"])

graph = helper.make_graph([y],
nodes.append(y)
graph = helper.make_graph(nodes,
'slice_test',
inputs=inputs,
outputs=[
Expand All @@ -527,15 +561,23 @@ def _test_slice_iteration_v10(indata, outdata, starts, ends, axes=None):

def test_slice():
x = np.random.randn(20, 10, 5).astype(np.float32)
_test_slice_iteration_v1(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1))
_test_slice_iteration_v1(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4))
_test_slice_iteration_v1(x, x[:, 1:1000], (1), (1000), (1))
_test_slice_iteration_v1(x, x[:, 0:-1], (0), (-1), (1))
_test_slice_iteration_v10(x, x[0:3, 0:10], (0, 0), (3, 10), (0, 1))
_test_slice_iteration_v10(x, x[:, :, 3:4], (0, 0, 3), (20, 10, 4))
_test_slice_iteration_v10(x, x[:, 1:1000], (1), (1000), (1))
_test_slice_iteration_v1(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1))
_test_slice_iteration_v1(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4))
_test_slice_iteration_v1(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,))
_test_slice_iteration_v1(x, x[:, 0:-1], starts=(0,), ends=(-1,), axes=(1,))
_test_slice_iteration_v10(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1))
_test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4))
_test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,))
_test_slice_iteration_v10(x, x[:, 0:-1], starts=(0,), ends=(-1,), axes=(1,))
_test_slice_iteration_v10(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1), add_noop_to_input_attrs=["starts"])
_test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4), add_noop_to_input_attrs=["ends"])
_test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,), add_noop_to_input_attrs=["axes"])
_test_slice_iteration_v10(x, x[:, 0:-1], starts=(0,), ends=(-1,), axes=(1,), add_noop_to_input_attrs=["starts", "ends"])
_test_slice_iteration_v10(x, x[0:3, 0:10], starts=(0, 0), ends=(3, 10), axes=(0, 1), add_noop_to_input_attrs=["ends", "axes"])
_test_slice_iteration_v10(x, x[:, :, 3:4], starts=(0, 0, 3), ends=(20, 10, 4), add_noop_to_input_attrs=["starts", "axes"])
_test_slice_iteration_v10(x, x[:, 1:1000], starts=(1,), ends=(1000,), axes=(1,), add_noop_to_input_attrs=["starts", "ends", "axes"])
x = np.random.randn(1, 1, 1, 128).astype(np.float32)
_test_slice_iteration_v10(x, x, (0, 0), (9223372036854775807, 9223372036854775807), (0, 3))
_test_slice_iteration_v10(x, x, starts=(0, 0), ends=(9223372036854775807, 9223372036854775807), axes=(0, 3))


def _test_onnx_op_elementwise(inshape, outfunc, npargs, dtype, opname, kwargs):
Expand Down

0 comments on commit 60e761c

Please sign in to comment.