Skip to content

Commit

Permalink
support axes argument in topi cpp strided slice
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jun 2, 2021
1 parent cb96228 commit bd5ae6c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 7 deletions.
11 changes: 9 additions & 2 deletions python/tvm/topi/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def reverse_sequence(a, seq_lengths, seq_axis=1, batch_axis=0):
return cpp.reverse_sequence(a, seq_lengths, seq_axis, batch_axis)


def strided_slice(a, begin, end, strides=None, slice_mode="end"):
def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end"):
"""Slice of an array.
Parameters
Expand All @@ -189,6 +189,10 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"):
in that case, the input tensor will be reversed
in that particular axis.
axes : list of int, optional
Axes along which slicing is applied. When it is specified, begin, end
strides, and axes need to a list of integers of the same length.
slice_mode : str, optional
The slice mode [end, size].
end - The ending indices for the slice [default].
Expand All @@ -205,6 +209,7 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"):
or isinstance(end, tvm.te.Tensor)
or isinstance(strides, tvm.te.Tensor)
):
assert axes is None, "axes argument is not supported by dynamic strided slice yet."
if not isinstance(begin, tvm.te.Tensor):
begin = const_vector(begin)
if not isinstance(end, tvm.te.Tensor):
Expand All @@ -216,7 +221,9 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"):
return cpp.dynamic_strided_slice(a, begin, end, strides)
if strides is None:
strides = []
return cpp.strided_slice(a, begin, end, strides, slice_mode)
if axes is None:
axes = []
return cpp.strided_slice(a, begin, end, strides, axes, slice_mode)


@tvm.te.tag_scope(tag=tag.INJECTIVE + ",strided_set")
Expand Down
9 changes: 7 additions & 2 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,17 @@ TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue*
Array<PrimExpr> begin = args[1];
Array<PrimExpr> end = args[2];
Array<PrimExpr> strides = args[3];
std::string slice_mode = args[4];
if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides)) {
Array<Integer> begin_static = args[1];
Array<Integer> end_static = args[2];
Array<Integer> strides_static = args[3];
*rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode);
Array<Integer> axes = args[4];
std::string slice_mode = args[5];
if (axes.size() > 0) {
*rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, slice_mode);
} else {
*rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode);
}
} else {
*rv = dynamic_strided_slice(x, begin, end, strides);
}
Expand Down
9 changes: 6 additions & 3 deletions tests/python/topi/python/test_topi_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,12 @@ def check_device(target):
check_device(target)


def verify_strided_slice(in_shape, begin, end, strides=None):
def verify_strided_slice(in_shape, begin, end, strides=None, axes=None):
A = te.placeholder(shape=in_shape, name="A")
strides = [1, 1, 1] if strides is None else strides
B = topi.strided_slice(A, begin, end, strides) + 1
if axes:
strides = [strides[axis] for axis in axes]
B = topi.strided_slice(A, begin, end, strides, axes) + 1

def check_device(target):
dev = tvm.device(target, 0)
Expand All @@ -414,7 +416,7 @@ def check_device(target):

foo = tvm.build(s, [A, B], target, name="stride_slice")
x_np = np.random.uniform(size=in_shape).astype(A.dtype)
out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) + 1
out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides, axes) + 1
data_nd = tvm.nd.array(x_np, dev)
out_nd = tvm.nd.empty(out_npy.shape, device=dev, dtype=A.dtype)
foo(data_nd, out_nd)
Expand Down Expand Up @@ -819,6 +821,7 @@ def test_strided_slice():
verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3])
verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3])
verify_strided_slice((3, 4, 3), [0, 0, 0], [None, None, None])
verify_strided_slice((3, 4, 3), [0], [2], None, axes=[1])


@tvm.testing.uses_gpu
Expand Down

0 comments on commit bd5ae6c

Please sign in to comment.