From d2538ae980a1c731b646467c615d742efeb65e25 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 24 May 2021 07:56:05 +0900 Subject: [PATCH] restore another slice variant --- include/tvm/topi/transform.h | 26 ++++++++++++++++++++++++++ python/tvm/relay/frontend/onnx.py | 2 +- python/tvm/topi/cuda/sort.py | 1 + src/topi/transform.cc | 26 +++++++++++++++++++++----- 4 files changed, 49 insertions(+), 6 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 28daff9da139..5c3734b6e4c0 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -593,6 +593,32 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b name, tag); } +inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begin, + const Array& end, const Array& strides, + std::string slice_mode = "end", + std::string name = "T_strided_slice", + std::string tag = kInjective) { + size_t src_tensor_dim = static_cast(x->shape.size()); + ICHECK_EQ(begin.size(), src_tensor_dim); + ICHECK_EQ(end.size(), src_tensor_dim); + ICHECK_EQ(strides.size(), src_tensor_dim); + + Array out_shape; + for (size_t i = 0; i < src_tensor_dim; ++i) { + out_shape.push_back(indexdiv(end[i] - begin[i], strides[i])); + } + return te::compute( + out_shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i] * strides[i] + begin[i]); + } + return x(real_indices); + }, + name, tag); +} + inline Tensor strided_slice_dynamic_input(const Tensor& input, const Array& begin, const Array& end, const Array& strides, std::string slice_mode = "end", diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 40958116517f..6671abb2e263 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1359,7 +1359,7 @@ def has_static_axes(): else: strides_np = steps.data.asnumpy().astype("int64") - return _op.strided_slice(inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np)) + # return _op.strided_slice(inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np)) # Update the starts and ends according to axes if required. if axes is not None: diff --git a/python/tvm/topi/cuda/sort.py b/python/tvm/topi/cuda/sort.py index 25cc7a4e2cfb..a9ad55c72c81 100644 --- a/python/tvm/topi/cuda/sort.py +++ b/python/tvm/topi/cuda/sort.py @@ -962,6 +962,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"): end.append(dshape[i]) if ret_type == "both": values_out, indices_out = output + print("end:", end, k) values_out = strided_slice(values_out, beg, end, strides) indices_out = strided_slice(indices_out, beg, end, strides) output = [values_out, indices_out] diff --git a/src/topi/transform.cc b/src/topi/transform.cc index dd7962bdb1cf..e30daf3f3503 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -174,14 +174,30 @@ TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { - Array begin = args[1]; - Array end = args[2]; - Array strides = args[3]; - *rv = strided_slice(args[0], begin, end, strides, args[4]); + Tensor x = args[0]; + Array begin = args[1]; + Array end = args[2]; + Array strides = args[3]; + std::string slice_mode = args[4]; + if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides)) { + Array begin_static = args[1]; + Array end_static = args[2]; + Array strides_static = args[3]; + if (IsConstIntArray(x->shape)) { + *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); + } else { + *rv = strided_slice_dynamic_input(x, begin_static, end_static, strides_static, slice_mode); + } + } else { + *rv = dynamic_strided_slice(x, begin, end, strides, slice_mode); + } }); TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = dynamic_strided_slice(args[0], args[1], args[2], args[3]); + te::Tensor begin = args[1]; + te::Tensor end = args[2]; + te::Tensor strides = args[3]; + *rv = dynamic_strided_slice(args[0], begin, end, strides); }); TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) {