diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index df30ff775f60..b4d0167be2b1 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -170,7 +170,7 @@ def reverse_sequence(a, seq_lengths, seq_axis=1, batch_axis=0): return cpp.reverse_sequence(a, seq_lengths, seq_axis, batch_axis) -def strided_slice(a, begin, end, strides=None, slice_mode="end"): +def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end"): """Slice of an array. Parameters @@ -189,6 +189,10 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"): in that case, the input tensor will be reversed in that particular axis. + axes : list of int, optional + Axes along which slicing is applied. When it is specified, begin, end + strides, and axes need to a list of integers of the same length. + slice_mode : str, optional The slice mode [end, size]. end - The ending indices for the slice [default]. @@ -205,6 +209,7 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"): or isinstance(end, tvm.te.Tensor) or isinstance(strides, tvm.te.Tensor) ): + assert axes is None, "axes argument is not supported by dynamic strided slice yet." if not isinstance(begin, tvm.te.Tensor): begin = const_vector(begin) if not isinstance(end, tvm.te.Tensor): @@ -216,7 +221,9 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"): return cpp.dynamic_strided_slice(a, begin, end, strides) if strides is None: strides = [] - return cpp.strided_slice(a, begin, end, strides, slice_mode) + if axes is None: + axes = [] + return cpp.strided_slice(a, begin, end, strides, axes, slice_mode) @tvm.te.tag_scope(tag=tag.INJECTIVE + ",strided_set") diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 7c6e491dcc26..db54d5a99a91 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -178,12 +178,17 @@ TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* Array begin = args[1]; Array end = args[2]; Array strides = args[3]; - std::string slice_mode = args[4]; if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides)) { Array begin_static = args[1]; Array end_static = args[2]; Array strides_static = args[3]; - *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); + Array axes = args[4]; + std::string slice_mode = args[5]; + if (axes.size() > 0) { + *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, slice_mode); + } else { + *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); + } } else { *rv = dynamic_strided_slice(x, begin, end, strides); } diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 20172f07fd9e..c98478885a10 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -398,10 +398,12 @@ def check_device(target): check_device(target) -def verify_strided_slice(in_shape, begin, end, strides=None): +def verify_strided_slice(in_shape, begin, end, strides=None, axes=None): A = te.placeholder(shape=in_shape, name="A") strides = [1, 1, 1] if strides is None else strides - B = topi.strided_slice(A, begin, end, strides) + 1 + if axes: + strides = [strides[axis] for axis in axes] + B = topi.strided_slice(A, begin, end, strides, axes) + 1 def check_device(target): dev = tvm.device(target, 0) @@ -414,7 +416,7 @@ def check_device(target): foo = tvm.build(s, [A, B], target, name="stride_slice") x_np = np.random.uniform(size=in_shape).astype(A.dtype) - out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) + 1 + out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides, axes) + 1 data_nd = tvm.nd.array(x_np, dev) out_nd = tvm.nd.empty(out_npy.shape, device=dev, dtype=A.dtype) foo(data_nd, out_nd) @@ -819,6 +821,7 @@ def test_strided_slice(): verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3]) verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3]) verify_strided_slice((3, 4, 3), [0, 0, 0], [None, None, None]) + verify_strided_slice((3, 4, 3), [0], [2], None, axes=[1]) @tvm.testing.uses_gpu