Skip to content

Commit

Permalink
strided slice with axes support
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent 2cde3dc commit 763ac37
Show file tree
Hide file tree
Showing 6 changed files with 431 additions and 48 deletions.
2 changes: 2 additions & 0 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,7 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
Optional<Array<Integer>> end;
Optional<Array<Integer>> strides;
std::string slice_mode;
Optional<Array<Integer>> axes;

TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive");
Expand All @@ -317,6 +318,7 @@ struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
"size - The input strides will be ignored, input end in this mode indicates the size"
"of a slice starting at the location specified by begin. If end[i] is -1,"
"all remaining elements in that dimension are included in the slice");
TVM_ATTR_FIELD(axes).describe("TODO");
}
};

Expand Down
18 changes: 18 additions & 0 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,24 @@ def _impl_v10(cls, inputs, attr, params):

data_rank = len(infer_shape(inputs[0]))

def has_static_axes():
return (isinstance(axes, _expr.Constant) and
isinstance(starts, _expr.Constant) and
isinstance(ends, _expr.Constant) and
(steps is None or isinstance(steps, _expr.Constant)))

# Update the starts and ends according to axes if required.
if axes is not None and has_static_axes():
axes_np = axes.data.asnumpy().astype("int64")
begin_np = starts.data.asnumpy().astype("int64")
end_np = ends.data.asnumpy().astype("int64")
if steps is None:
strides_np = np.ones_like(begin_np).astype("int64")
else:
strides_np = steps.data.asnumpy().astype("int64")

return _op.strided_slice(inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np))

# Update the starts and ends according to axes if required.
if axes is not None:
data_shape = shape_of(inputs[0], dtype=infer_type(ends).checked_type.dtype)
Expand Down
59 changes: 56 additions & 3 deletions python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,16 +244,69 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice
return out


@script
def _strided_slice_shape_func_with_axes(data_shape, begin, end, strides, slice_mode, axes):
ndim = data_shape.shape[0]
out = output_tensor((ndim,), "int64")
for i in const_range(ndim):
out[i] = data_shape[i]

for i in const_range(len(axes)):
axis = int64(axes[i])
cbegin = int64(0)
cend = int64(data_shape[axis])
cstride = int64(1)
if len(strides) > i:
cstride = int64(strides[i])
if len(begin) > i:
cbegin = int64(begin[i])
if cbegin < 0:
cbegin += int64(data_shape[axis])
if len(end) <= i:
cend = int64(data_shape[axis])
elif slice_mode != 0:
cstride = int64(1)
if end[i] < 0:
cend = int64(data_shape[axis])
else:
cend = cbegin + int64(end[i])
else:
if end[i] > data_shape[i]:
cend = int64(data_shape[axis])
elif end[i] < -data_shape[i]:
cend = int64(-1)
else:
cend = int64(end[i])
if cend < 0:
cend += int64(data_shape[axis])
assert cstride != 0, "Strides can't be zero."
if cstride < 0:
slice_range = cbegin - cend
step = -cstride
else:
slice_range = cend - cbegin
step = cstride

out[axis] = int64(ceil_div(slice_range, step))
return out


@_reg.register_shape_func("strided_slice", False)
def strided_slice_shape_func(attrs, inputs, _):
"""
Shape func for strided_slice
"""
slice_mode = convert(0 if attrs.slice_mode == "end" else 1)
if len(attrs.axes) == 0:
return [
_strided_slice_shape_func_input_shape(
inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode
)
]
return [
_strided_slice_shape_func_input_shape(
inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode
)
_strided_slice_shape_func_with_axes(
inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode, attrs.axes
)
]


Expand Down
7 changes: 5 additions & 2 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,7 +867,7 @@ def split(data, indices_or_sections, axis=0):
return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size)


def strided_slice(data, begin, end, strides=None, slice_mode="end"):
def strided_slice(data, begin, end, strides=None, slice_mode="end", axes=None):
"""Strided slice of an array.
Parameters
Expand All @@ -892,6 +892,9 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
the size of a slice starting at the location specified by begin. If end[i]
is -1, all remaining elements in that dimension are included in the slice.
axes : List[int]
TODO
Returns
-------
ret : relay.Expr
Expand All @@ -917,7 +920,7 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"):
begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin)
begin = _make.where(begin >= ishape_slice, ishape_slice, begin)
return _dyn_make.strided_slice(data, begin, end, strides, slice_mode)
return _make.strided_slice(data, begin, end, strides, slice_mode)
return _make.strided_slice(data, begin, end, strides, slice_mode, axes)


def strided_set(data, v, begin, end, strides=None):
Expand Down
2 changes: 1 addition & 1 deletion src/relay/op/make_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Expr MakeStack(Expr data, int axis);
Expr MakeTranspose(Expr data, Array<Integer> axes);

Expr MakeStridedSlice(Expr data, Array<Integer> begin, Array<Integer> end, Array<Integer> strides,
String slice_mode);
String slice_mode, Optional<Array<Integer>> axes=NullValue<Array<Integer>>());

Expr MakeTile(Expr data, Array<Integer> reps);

Expand Down
Loading

0 comments on commit 763ac37

Please sign in to comment.