From cbe3dcad5e3f9358af8d2d79e2880f92718a5d0b Mon Sep 17 00:00:00 2001 From: masahi Date: Thu, 3 Jun 2021 06:33:00 +0900 Subject: [PATCH] [Relay, TOPI] Refactor strided_slice and add axes argument (#8165) * Initial import commit 667011f10320918e4dcd47ac2b57fe49849e5440 Author: Masahiro Masuda Date: Thu May 27 16:28:57 2021 +0900 Squashed commit of the following: commit 95242d86ea5de96925430c0a74b6e91e299fb5ab Author: Masahiro Masuda Date: Thu May 27 15:45:19 2021 +0900 Add function attribute for shape func for profiling commit b8ede24ff987eb152bde7cc15afce004a88aeb5f Author: Masahiro Masuda Date: Thu May 27 10:21:06 2021 +0900 layout transform support complete commit 5782b7070288eb0de122f5dab91b38c26166a7d7 Author: Masahiro Masuda Date: Thu May 27 08:31:11 2021 +0900 support layout transform part1 commit e94aa6b2a916607234c89eddcd07afdfa8085786 Author: Masahiro Masuda Date: Mon May 24 19:47:46 2021 +0900 moved utilities to its own file commit 8bf88913b9bc02730120a0695138ed3fb8ed49ae Author: Masahiro Masuda Date: Mon May 24 17:39:50 2021 +0900 fix format commit e89d599d6a10021167feb241483693260aa535f2 Author: Masahiro Masuda Date: Mon May 24 17:33:02 2021 +0900 ToVec -> ConvertToVec commit 001982ce1419504f1f0e1d116d57dd34f0180008 Author: Masahiro Masuda Date: Mon May 24 17:26:56 2021 +0900 format commit fae57f9bd67b29880a7552b5149d03120924cdac Author: Masahiro Masuda Date: Mon May 24 17:24:35 2021 +0900 use Any for relay type rel path commit 053eee2e6f58749af0b68cca52fd530afc0f6454 Author: Masahiro Masuda Date: Mon May 24 17:14:44 2021 +0900 fix commit fbb099c8e66caf846c773e180c66a2b336bd64a3 Author: Masahiro Masuda Date: Mon May 24 16:39:37 2021 +0900 refactor type rel commit ecfe3cd43e3968375505d5393959ec19da4b5c01 Author: Masahiro Masuda Date: Mon May 24 16:23:47 2021 +0900 working commit b357c2f8825603d8ba9ee2424a7f572e12c29852 Author: Masahiro Masuda Date: Mon May 24 16:07:07 2021 +0900 refactoring output shape calc commit f69ef407cf003c7977a4564185949d3f6b5c0219 Author: Masahiro Masuda Date: Mon May 24 14:23:36 2021 +0900 bug fix end param init commit a5611c9a1f243f4b9a56539e7e8a15661374920c Author: Masahiro Masuda Date: Mon May 24 13:42:31 2021 +0900 fix test shape commit e79931a264f0d8ed63a333ec4ab10a72cff22a84 Author: Masahiro Masuda Date: Mon May 24 13:42:03 2021 +0900 dyn slice tests left as todo now work commit 7db4cea31378eed85dfae1cb03fb5a97394f7fe3 Author: Masahiro Masuda Date: Mon May 24 13:36:30 2021 +0900 remove dynamic input specific op commit 510bce6a181604e5eb3f2bd1951ae035a4090700 Author: Masahiro Masuda Date: Mon May 24 12:52:30 2021 +0900 refactoring dynamic slice commit 1b3969ade9ee98651b8157ecab1c675410a84ee5 Author: Masahiro Masuda Date: Mon May 24 09:06:46 2021 +0900 fix slice axes dispatch commit 9a795606fb71ec08cefe5bfa904f1ab32c18da4b Author: Masahiro Masuda Date: Mon May 24 08:32:54 2021 +0900 refactor compute commit 80442f86bbf9f0582823d5903021e3bae61a4662 Author: Masahiro Masuda Date: Mon May 24 08:11:18 2021 +0900 fixed output shape, refactored version working commit d2538ae980a1c731b646467c615d742efeb65e25 Author: Masahiro Masuda Date: Mon May 24 07:56:05 2021 +0900 restore another slice variant commit 36aa777eacd8426a850d08b528e9addcd36a4894 Author: Masahiro Masuda Date: Mon May 24 06:41:50 2021 +0900 refactoring slice with axes commit 32698b74df211829777e5493e82bf7425364acb4 Author: Masahiro Masuda Date: Sat May 22 13:11:01 2021 +0900 fix axes null check commit 54fb723d23d351551b75d879198aafb1eac2dede Author: Masahiro Masuda Date: Sat May 22 12:52:18 2021 +0900 Revert "[Bugfix][Vulkan] Call VulkanDeviceAPI destructor on program exit (#7997)" This reverts commit 58c3413a30e5b03208b6281651d38ee02c44f9c1. commit 37eaf579d47190bc42ad64f9ac34c93a9dac3ce5 Author: Masahiro Masuda Date: Sat May 22 04:30:37 2021 +0900 remove wip layout transform support for slice with axes commit 9bcb2ada60fadadd1f29a6d09e6b4fc5104efd3f Author: Masahiro Masuda Date: Fri May 21 18:01:59 2021 +0900 fix pylint commit 7063a09ef1b98849e98194e8a9e47455cd1b5fa3 Author: Masahiro Masuda Date: Fri May 21 17:57:03 2021 +0900 minor fix commit 96c9231b5b2cbf2f36b4096d54f1f5ac4033d361 Author: Masahiro Masuda Date: Fri May 21 17:54:16 2021 +0900 support dynamic scatter nd commit d4a4db8a8b518b1ef9e6abacfa23a9e1b76fd1b0 Author: Masahiro Masuda Date: Fri May 21 17:33:19 2021 +0900 gather_dim -> num_indices_per_tuple commit a489375f0b31948a13e41f5967960305453c7049 Author: Masahiro Masuda Date: Fri May 21 17:23:46 2021 +0900 add dynamic gather_nd test commit 533854a006c16359842451b8690cb8639b47635d Author: Masahiro Masuda Date: Fri May 21 17:18:26 2021 +0900 refactor gather_nd ref funcs commit 36a4501a151070760559f6ce4cfa574202b4d0c8 Author: Masahiro Masuda Date: Fri May 21 14:36:34 2021 +0900 add gather_nd shape func commit 1853c35d883e501e484d5f74adb3081f916761d5 Author: Masahiro Masuda Date: Sat May 22 04:20:39 2021 +0900 add eyelike support commit 150e945290cbc595bd370dcae7e96e24597fbf04 Author: Masahiro Masuda Date: Fri May 21 04:08:37 2021 +0900 migrating inlined topi compute to topi/transform.h commit 763ac37f725c2cb89a3221621b69da0e6ac39ed8 Author: Masahiro Masuda Date: Fri May 21 03:45:37 2021 +0900 strided slice with axes support * fix bad merge * fix cpplint * fix pylint * more cpplint fix * fix compiler warning * add doc * add tests * typo fixed * support axes argument in topi cpp strided slice * Properly test axes argument in relay tests * fix bad merge (revert vm change) * fix tests --- include/tvm/relay/attrs/transform.h | 4 + include/tvm/topi/detail/strided_slice.h | 156 ++++++++ include/tvm/topi/nn.h | 2 +- include/tvm/topi/transform.h | 279 +++++++------- python/tvm/relay/frontend/onnx.py | 25 +- python/tvm/relay/op/_transform.py | 56 ++- python/tvm/relay/op/transform.py | 12 +- .../tvm/topi/testing/strided_slice_python.py | 21 +- python/tvm/topi/transform.py | 11 +- src/relay/op/make_op.h | 3 +- src/relay/op/tensor/transform.cc | 350 ++++++++---------- src/topi/transform.cc | 24 +- tests/python/relay/test_any.py | 2 +- tests/python/relay/test_op_level4.py | 87 ++++- .../python/relay/test_pass_alter_op_layout.py | 56 +++ .../relay/test_pass_convert_op_layout.py | 50 ++- .../python/topi/python/test_topi_transform.py | 9 +- 17 files changed, 790 insertions(+), 357 deletions(-) create mode 100644 include/tvm/topi/detail/strided_slice.h diff --git a/include/tvm/relay/attrs/transform.h b/include/tvm/relay/attrs/transform.h index 027b3fe1df5f..69a9c64a4588 100644 --- a/include/tvm/relay/attrs/transform.h +++ b/include/tvm/relay/attrs/transform.h @@ -310,6 +310,7 @@ struct StridedSliceAttrs : public tvm::AttrsNode { Optional> end; Optional> strides; std::string slice_mode; + Optional> axes; TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") { TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive"); @@ -324,6 +325,9 @@ struct StridedSliceAttrs : public tvm::AttrsNode { "size - The input strides will be ignored, input end in this mode indicates the size" "of a slice starting at the location specified by begin. If end[i] is -1," "all remaining elements in that dimension are included in the slice"); + TVM_ATTR_FIELD(axes).describe( + "Axes along which slicing is applied. When it is specified, the length of begin, end, " + "strides, and axes must be equal."); } }; diff --git a/include/tvm/topi/detail/strided_slice.h b/include/tvm/topi/detail/strided_slice.h new file mode 100644 index 000000000000..da76022c552b --- /dev/null +++ b/include/tvm/topi/detail/strided_slice.h @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file strided_slice.h + * \brief Utility functions for strided_slice op + */ +#ifndef TVM_TOPI_DETAIL_STRIDED_SLICE_H_ +#define TVM_TOPI_DETAIL_STRIDED_SLICE_H_ + +#include + +#include +#include +#include +#include +#include + +#include "constant_utils.h" + +namespace tvm { +namespace topi { +namespace detail { + +using namespace tvm::te; + +inline int64_t CanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) { + int64_t begin_range = stride < 0 ? -1 : 0; + int64_t end_range = stride < 0 ? extent - 1 : extent; + if (index < 0) { + index += extent; + } + return std::min(std::max(index, begin_range), end_range); +} + +inline std::tuple, std::vector, std::vector> ConvertToVec( + const Array& begin, const Array& end, const Array& strides, + std::string slice_mode) { + std::vector stride_vec(strides.size(), 1); + if (slice_mode == "end") { + for (size_t i = 0; i < strides.size(); ++i) { + ICHECK(strides[i].defined()); + stride_vec[i] = GetConstInt(strides[i]); + } + } + const int64_t max_range = std::numeric_limits::max(); + std::vector begin_vec; + for (size_t i = 0; i < begin.size(); ++i) { + if (!begin[i].defined()) { + // value=None + begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + } else { + begin_vec.push_back(GetConstInt(begin[i])); + } + } + std::vector end_vec; + for (size_t i = 0; i < end.size(); ++i) { + // allow end to be None + if (!end[i].defined()) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else if (slice_mode == "size") { + int64_t end_val = GetConstInt(end[i]); + if (end_val < 0) { + end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + } else { + end_vec.push_back(begin_vec[i] + end_val); + } + } else { + end_vec.push_back(GetConstInt(end[i])); + } + } + return std::make_tuple(begin_vec, end_vec, stride_vec); +} + +inline Array StridedSliceCanonicalizeBegin(const Array& ishape, + const std::vector& begin, + const std::vector& strides, + const Array& axes, DataType dtype, + std::string slice_mode = "end") { + Array begin_expr; + for (size_t i = 0; i < axes.size(); ++i) { + if (ishape[axes[i]]->IsInstance()) { + int64_t dim_i = GetConstInt(ishape[axes[i]]); + int64_t begin_i = CanonicalizeIndex(begin[i], dim_i, strides[i]); + begin_expr.push_back(make_const(dtype, begin_i)); + } else { + auto idim = ishape[axes[i]]; + auto b_expr = make_const(dtype, begin[i]); + PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr; + auto s = strides[i]; + if (s < 0) { + b = tvm::min(b, idim - 1); + } else { + b = tvm::if_then_else(b < 0, 0, b); + } + begin_expr.push_back(b); + } + } + return begin_expr; +} + +inline Array StridedSliceOutputShape(const Array& ishape, + const std::vector& begin, + const std::vector& end, + const std::vector& strides, + const Array& axes, std::string slice_mode, + const Array& begin_canonicalized, + bool use_any = false) { + const size_t src_tensor_dim = ishape.size(); + Array out_shape; + for (size_t i = 0; i < src_tensor_dim; ++i) { + out_shape.push_back(ishape[i]); + } + + for (size_t i = 0; i < axes.size(); ++i) { + if (ishape[axes[i]]->IsInstance()) { + const int64_t dim_i = GetConstInt(ishape[axes[i]]); + ICHECK(begin_canonicalized[i]->IsInstance()); + int64_t begin_i = GetConstInt(begin_canonicalized[i]); + int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]); + int interval = std::abs(end_i - begin_i); + int slice_size = + static_cast((interval + std::abs(strides[i]) - 1) / std::abs(strides[i])); + ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) + << ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i; + out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size))); + } else if (use_any) { + out_shape.Set(axes[i], tvm::tir::Any()); + } else { + out_shape.Set(axes[i], tvm::tir::Var("dim", out_shape[i]->dtype)); + } + } + + return out_shape; +} + +} // namespace detail +} // namespace topi +} // namespace tvm +#endif // TVM_TOPI_DETAIL_STRIDED_SLICE_H_ diff --git a/include/tvm/topi/nn.h b/include/tvm/topi/nn.h index 29c3156ab5d6..d3328c59afb4 100644 --- a/include/tvm/topi/nn.h +++ b/include/tvm/topi/nn.h @@ -619,7 +619,7 @@ inline tvm::te::Tensor batch_to_space_nd(const tvm::te::Tensor& data, out = reshape(out, r_p_shape); // Crop the start and end of dimensions of out - Array begin_idx, end_idx, strides; + Array begin_idx, end_idx, strides; for (size_t i = 0; i < r_p_shape.size(); ++i) { strides.push_back(Integer(1)); if (i > 0 && i <= num_block_dims) { diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 36acc7376c7c..8d1a49a4cc5f 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -27,8 +27,10 @@ #include #include #include +#include #include #include +#include #include #include @@ -39,8 +41,6 @@ #include #include -#include "detail/broadcast.h" - namespace tvm { namespace topi { @@ -551,7 +551,7 @@ inline Array split(const Tensor& x, Array split_indices, int a } /*! - * \brief strided_slice of a tensor with dynamic begin/end/stride + * \brief strided_slice of a tensor where begin/end/stride can be mixed static and dynamic * * \param x The input tensor * \param begin The indices to begin with in the slicing @@ -561,31 +561,45 @@ inline Array split(const Tensor& x, Array split_indices, int a * \param name The name of the operation * \param tag The tag to mark the operation * - * \return A Tensor whose op member is the split operation + * \return A Tensor whose op member is the dynamic_strided_slice operation */ -inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& begin, - const te::Tensor& end, const te::Tensor& strides, - std::string name = "T_strided_slice_dynamic", - std::string tag = topi::kInjective) { - int64_t src_tensor_dim = x->shape.size(); +inline Tensor dynamic_strided_slice(const Tensor& x, const Array& begin, + const Array& end, const Array& strides, + std::string name = "T_dynamic_strided_slice", + std::string tag = kInjective) { + const size_t src_tensor_dim = x->shape.size(); + ICHECK_LE(begin.size(), src_tensor_dim); + ICHECK_LE(end.size(), src_tensor_dim); + ICHECK_LE(strides.size(), src_tensor_dim); + ICHECK_EQ(begin.size(), end.size()); + ICHECK_EQ(begin.size(), strides.size()); + + const size_t num_slice_axes = begin.size(); Array out_shape; - const int64_t num_dynamic_axes = begin->shape[0].as()->value; - for (int64_t i = 0; i < num_dynamic_axes; ++i) { - out_shape.push_back(tvm::tir::Var("dim")); + + for (size_t i = 0; i < num_slice_axes; ++i) { + auto d = indexdiv(end[i] - begin[i], strides[i]); + if (d->IsInstance()) { + // Preserve static dimension if possible + out_shape.push_back(d); + } else { + out_shape.push_back(tvm::tir::Var("dim")); + } } - for (int64_t i = num_dynamic_axes; i < src_tensor_dim; ++i) { + + for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) { out_shape.push_back(x->shape[i]); } + return te::compute( out_shape, [&](const Array& indices) { Array real_indices; - // dynamic slicing - for (int32_t i = 0; i < num_dynamic_axes; ++i) { - real_indices.push_back(indices[i] * strides(i) + tvm::min(begin(i), x->shape[i] - 1)); + for (size_t i = 0; i < num_slice_axes; ++i) { + real_indices.push_back(indices[i] * strides[i] + tvm::min(begin[i], x->shape[i] - 1)); } // keep input dim - for (int32_t i = num_dynamic_axes; i < src_tensor_dim; ++i) { + for (size_t i = num_slice_axes; i < src_tensor_dim; ++i) { real_indices.push_back(indices[i]); } return x(real_indices); @@ -594,137 +608,152 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b } /*! - * \brief strided_slice of a tensor + * \brief strided_slice of a tensor with dynamic begin/end/stride * * \param x The input tensor * \param begin The indices to begin with in the slicing * \param end Indicies indicating end of the slice * \param strides Specifies the stride values, it can be negative * in that case, the input tensor will be reversed in that particular axis - * \param slice_mode Specifies the slice mode * \param name The name of the operation * \param tag The tag to mark the operation * - * \return A Tensor whose op member is the split operation + * \return A Tensor whose op member is the dynamic_strided_slice operation */ -inline Tensor strided_slice(const Tensor& x, const Array& begin, - const Array& end, const Array& strides, - std::string slice_mode = "end", std::string name = "T_strided_slice", - std::string tag = kInjective) { - size_t src_tensor_dim = static_cast(x->shape.size()); - // Quick path for dynamic shape strided slice. - // This is for ease of use to dynamice strided slice in topi. - bool is_static = IsConstIntArray(x->shape); - is_static &= IsConstIntArray(begin); - is_static &= IsConstIntArray(end); - is_static &= IsConstIntArray(strides); - - Array out_shape; - if (!is_static) { - ICHECK_EQ(strides.size(), src_tensor_dim); - for (size_t i = 0; i < src_tensor_dim; ++i) { - out_shape.push_back(indexdiv(end[i] - begin[i], strides[i])); - } - return te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides[i] + begin[i]); - } - return x(real_indices); - }, - name, tag); - } - - // Setup the ranges. - // NOTE: this code duplicates the shape inference logic relay.op - // Consider to refactor in the future. - std::vector stride_vec(src_tensor_dim, 1); - for (size_t i = 0; i < strides.size(); ++i) { - ICHECK(strides[i].defined()); - stride_vec[i] = GetConstInt(strides[i]); - } - - const int64_t max_range = std::numeric_limits::max(); +inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& begin, + const te::Tensor& end, const te::Tensor& strides, + std::string name = "T_strided_slice_dynamic", + std::string tag = topi::kInjective) { + const int64_t num_dynamic_axes = begin->shape[0].as()->value; + ICHECK_EQ(end->shape[0].as()->value, num_dynamic_axes); + ICHECK_EQ(strides->shape[0].as()->value, num_dynamic_axes); - std::vector begin_vec; - for (size_t i = 0; i < begin.size(); ++i) { - if (!begin[i].defined()) { - // value=None - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } else { - begin_vec.push_back(GetConstInt(begin[i])); - } - } - for (size_t i = begin_vec.size(); i < src_tensor_dim; ++i) { - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); + Array begin_expr, end_expr, strides_expr; + for (int64_t i = 0; i < num_dynamic_axes; ++i) { + auto i64_ind = IntImm(DataType::Int(64), i); + begin_expr.push_back(begin(i64_ind)); + end_expr.push_back(end(i64_ind)); + strides_expr.push_back(strides(i64_ind)); } + return dynamic_strided_slice(x, begin_expr, end_expr, strides_expr, name, tag); +} - std::vector end_vec; - for (size_t i = 0; i < end.size(); ++i) { - // allow end to be None +/*! + * \brief Calcluate the output shape of strided_slice, the entry point for Relay type relation + * + * \param ishape The input tensor shape + * \param begin The indices to begin with in the slicing + * \param end Indicies indicating end of the slice + * \param strides Specifies the stride values, it can be negative + * in that case, the input tensor will be reversed in that particular axis + * \param axes Axes along which slicing is applied. When it is specified, the length of begin, end, + * strides, and axes argument must be equal + * \param slice_mode Specifies the slice mode + * + * \return The output shape of strided_slice using the arguments above + */ +inline Array StridedSliceOutputShape( + const Array& ishape, const Array& begin, const Array& end, + const Array& strides, const Array& axes, const std::string& slice_mode) { + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); + std::vector begin_vec, end_vec, strides_vec; + std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); + auto begin_canonicalized = StridedSliceCanonicalizeBegin(ishape, begin_vec, strides_vec, axes, + begin[0]->dtype, slice_mode); + return StridedSliceOutputShape(ishape, begin_vec, end_vec, strides_vec, axes, slice_mode, + begin_canonicalized, true); +} - if (!end[i].defined()) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } else if (slice_mode == "size") { - int64_t end_val = GetConstInt(end[i]); - if (end_val < 0) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } else { - end_vec.push_back(begin_vec[i] + end_val); - } - } else { - end_vec.push_back(GetConstInt(end[i])); - } - } - for (size_t i = end_vec.size(); i < src_tensor_dim; ++i) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } - // Compute - Array begin_expr; - Array strides_expr; - - for (size_t i = 0; i < src_tensor_dim; ++i) { - int64_t begin_range = stride_vec[i] < 0 ? -1 : 0; - int64_t dim_i = GetConstInt(x->shape[i]); - int64_t end_range = stride_vec[i] < 0 ? dim_i - 1 : dim_i; - // transform negative indices to positive value, clips on the correct range - auto index_canonicalization = [dim_i, begin_range, end_range](int64_t index) { - if (index < 0) { - index += dim_i; - } - return std::min(std::max(index, begin_range), end_range); - }; - - int64_t begin_i = index_canonicalization(begin_vec[i]); - int64_t end_i = index_canonicalization(end_vec[i]); - - int interval = std::abs(end_i - begin_i); - int slice_size = - static_cast((interval + std::abs(stride_vec[i]) - 1) / std::abs(stride_vec[i])); - ICHECK(stride_vec[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) - << ": Input [Begin=" << begin_vec[i] << ", End=" << end_vec[i] - << "] is invalid for axis=" << i; - - begin_expr.push_back(make_const(begin[0].dtype(), begin_i)); - strides_expr.push_back( - make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), stride_vec[i])); - out_shape.push_back(slice_size); - } +/*! + * \brief strided_slice of a tensor + * + * \param x The input tensor + * \param begin The indices to begin with in the slicing + * \param end Indicies indicating end of the slice + * \param strides Specifies the stride values, it can be negative + * in that case, the input tensor will be reversed in that particular axis + * \param axes Axes along which slicing is applied. When it is specified, the length of begin, end, + * strides, and axes argument must be equal + * \param slice_mode Specifies the slice mode + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the sstrided_slice operation + */ +inline Tensor strided_slice_with_axes(const Tensor& x, const Array& begin, + const Array& end, const Array& strides, + const Array& axes, std::string slice_mode = "end", + std::string name = "T_strided_slice_with_axes", + std::string tag = kInjective) { + const size_t src_tensor_dim = x->shape.size(); + ICHECK(axes.size() <= src_tensor_dim); + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); + + std::vector begin_vec, end_vec, strides_vec; + std::tie(begin_vec, end_vec, strides_vec) = ConvertToVec(begin, end, strides, slice_mode); + + auto begin_expr = StridedSliceCanonicalizeBegin(x->shape, begin_vec, strides_vec, axes, + begin[0]->dtype, slice_mode); + auto out_shape = StridedSliceOutputShape(x->shape, begin_vec, end_vec, strides_vec, axes, + slice_mode, begin_expr); - return compute( + return te::compute( out_shape, - [&](const Array& indices) { + [&](const Array& indices) { Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); + for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); + for (size_t i = 0; i < axes.size(); ++i) { + auto stride = make_const(strides[i].dtype(), strides_vec[i]); + PrimExpr ind = indices[axes[i]] * stride + begin_expr[i]; + real_indices.Set(axes[i], ind); } return x(real_indices); }, name, tag); } +/*! + * \brief strided_slice of a tensor + * + * \param x The input tensor + * \param begin The indices to begin with in the slicing + * \param end Indicies indicating end of the slice + * \param strides Specifies the stride values, it can be negative + * in that case, the input tensor will be reversed in that particular axis + * \param slice_mode Specifies the slice mode + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the strided_slice operation + */ +inline Tensor strided_slice(const Tensor& x, const Array& begin, const Array& end, + const Array& strides, std::string slice_mode = "end", + std::string name = "T_strided_slice", std::string tag = kInjective) { + size_t src_tensor_dim = static_cast(x->shape.size()); + Array axes; + for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i); + Array begin_full(begin); + Array end_full(end); + Array strides_full(strides); + + const IntImm one = IntImm(DataType::Int(64), 1); + const IntImm zero = IntImm(DataType::Int(64), 0); + const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits::max()); + + for (size_t i = strides.size(); i < src_tensor_dim; ++i) { + strides_full.push_back(one); + } + for (size_t i = begin.size(); i < src_tensor_dim; ++i) { + begin_full.push_back(GetConstInt(strides_full[i]) > 0 ? zero : max_range); + } + for (size_t i = end.size(); i < src_tensor_dim; ++i) { + end_full.push_back(GetConstInt(strides_full[i]) < 0 ? zero : max_range); + } + + return strided_slice_with_axes(x, begin_full, end_full, strides_full, axes, slice_mode, name, + tag); +} + /*! * \brief Split a tensor into a number of sub-tensors * diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 49c744edcf3f..09ff6b7de5b5 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -1341,7 +1341,30 @@ def _impl_v10(cls, inputs, attr, params): axes = inputs[3] steps = inputs[4] - data_rank = len(infer_shape(inputs[0])) + ishape = infer_shape(inputs[0]) + data_rank = len(ishape) + + def has_static_axes(): + return ( + isinstance(axes, _expr.Constant) + and isinstance(starts, _expr.Constant) + and isinstance(ends, _expr.Constant) + and (steps is None or isinstance(steps, _expr.Constant)) + ) + + if axes is not None and has_static_axes(): + axes_np = axes.data.asnumpy().astype("int64") + begin_np = starts.data.asnumpy().astype("int64") + end_np = ends.data.asnumpy().astype("int64") + if steps is None: + strides_np = np.ones_like(begin_np).astype("int64") + else: + strides_np = steps.data.asnumpy().astype("int64") + + if all([isinstance(ishape[i], int) for i in axes_np]): + return _op.strided_slice( + inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np) + ) # Update the starts and ends according to axes if required. if axes is not None: diff --git a/python/tvm/relay/op/_transform.py b/python/tvm/relay/op/_transform.py index fee3eacf1aec..f87b5ed0b8ef 100644 --- a/python/tvm/relay/op/_transform.py +++ b/python/tvm/relay/op/_transform.py @@ -244,15 +244,67 @@ def _strided_slice_shape_func_input_shape(data_shape, begin, end, strides, slice return out +@script +def _strided_slice_shape_func_with_axes(data_shape, begin, end, strides, slice_mode, axes): + ndim = data_shape.shape[0] + out = output_tensor((ndim,), "int64") + for i in const_range(ndim): + out[i] = data_shape[i] + + for i in const_range(len(axes)): + cbegin = int64(0) + cend = int64(data_shape[axes[i]]) + cstride = int64(1) + if len(strides) > i: + cstride = int64(strides[i]) + if len(begin) > i: + cbegin = int64(begin[i]) + if cbegin < 0: + cbegin += int64(data_shape[axes[i]]) + if len(end) <= i: + cend = int64(data_shape[axes[i]]) + elif slice_mode != 0: + cstride = int64(1) + if end[i] < 0: + cend = int64(data_shape[axes[i]]) + else: + cend = cbegin + int64(end[i]) + else: + if end[i] > data_shape[i]: + cend = int64(data_shape[axes[i]]) + elif end[i] < -data_shape[i]: + cend = int64(-1) + else: + cend = int64(end[i]) + if cend < 0: + cend += int64(data_shape[axes[i]]) + assert cstride != 0, "Strides can't be zero." + if cstride < 0: + slice_range = cbegin - cend + step = -cstride + else: + slice_range = cend - cbegin + step = cstride + + out[axes[i]] = int64(ceil_div(slice_range, step)) + return out + + @_reg.register_shape_func("strided_slice", False) def strided_slice_shape_func(attrs, inputs, _): """ Shape func for strided_slice """ slice_mode = convert(0 if attrs.slice_mode == "end" else 1) + if attrs.axes is None: + return [ + _strided_slice_shape_func_input_shape( + inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode + ) + ] return [ - _strided_slice_shape_func_input_shape( - inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode + _strided_slice_shape_func_with_axes( + inputs[0], attrs.begin, attrs.end, attrs.strides, slice_mode, attrs.axes ) ] diff --git a/python/tvm/relay/op/transform.py b/python/tvm/relay/op/transform.py index 71ecc0076285..80913e5f0cbd 100644 --- a/python/tvm/relay/op/transform.py +++ b/python/tvm/relay/op/transform.py @@ -867,7 +867,7 @@ def split(data, indices_or_sections, axis=0): return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size) -def strided_slice(data, begin, end, strides=None, slice_mode="end"): +def strided_slice(data, begin, end, strides=None, axes=None, slice_mode="end"): """Strided slice of an array. Parameters @@ -885,6 +885,12 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): Specifies the stride values, it can be negative in that case, the input tensor will be reversed in that particular axis. + axes : Tuple[int] or List[int], optional + Axes along which slicing is applied. When it is specified, the length of begin, end, + strides, and axes must be equal. Moreover, begin, end, strides, and axes must be + static (cannot be relay.Expr). Axes argument for dynamic parameter slicing is + not supported yet. + slice_mode : str, optional The slice mode [end, size]. end: The ending indices for the slice [default]. @@ -916,8 +922,10 @@ def strided_slice(data, begin, end, strides=None, slice_mode="end"): ishape_slice = slice_like(ishape, begin) begin = _make.where(begin < cast_like(const(0), begin), begin + ishape_slice, begin) begin = _make.where(begin >= ishape_slice, ishape_slice, begin) + # TODO(masahi): Support axes argument in dynamic strided slice + assert axes is None, "Axes argument for dynamic parameter slicing is not supported yet." return _dyn_make.strided_slice(data, begin, end, strides, slice_mode) - return _make.strided_slice(data, begin, end, strides, slice_mode) + return _make.strided_slice(data, begin, end, strides, slice_mode, axes) def strided_set(data, v, begin, end, strides=None): diff --git a/python/tvm/topi/testing/strided_slice_python.py b/python/tvm/topi/testing/strided_slice_python.py index 30466c785778..3843d0996777 100644 --- a/python/tvm/topi/testing/strided_slice_python.py +++ b/python/tvm/topi/testing/strided_slice_python.py @@ -17,7 +17,7 @@ """strided_slice/set in python""" -def strided_slice_python(data, begin, end, strides, slice_mode="end"): +def strided_slice_python(data, begin, end, strides, slice_mode="end", axes=None): """Python version of strided slice operator. Parameters @@ -41,6 +41,8 @@ def strided_slice_python(data, begin, end, strides, slice_mode="end"): the sizeof a slice starting at the location specified by begin. If end[i] is -1, all remaining elements in that dimension are included in the slice. + axes : list, optional + Axes along which slicing is applied Returns ------- @@ -48,6 +50,22 @@ def strided_slice_python(data, begin, end, strides, slice_mode="end"): The sliced result. """ strides = [] if strides is None else strides + if axes is not None: + rank = len(data.shape) + new_begin = [0] * rank + new_end = [data.shape[i] for i in range(rank)] + new_strides = [1] * rank + + for i, axis in enumerate(axes): + new_begin[axis] = begin[i] + new_end[axis] = end[i] + if len(strides) > i: + new_strides[axis] = strides[i] + + begin = new_begin + end = new_end + strides = new_strides + slices = [] for i in range(len(data.shape)): new_stride = None @@ -66,6 +84,7 @@ def strided_slice_python(data, begin, end, strides, slice_mode="end"): new_end = end[i] slices.append(slice(new_begin, new_end, new_stride)) + return data[tuple(slices)] diff --git a/python/tvm/topi/transform.py b/python/tvm/topi/transform.py index df30ff775f60..b4d0167be2b1 100644 --- a/python/tvm/topi/transform.py +++ b/python/tvm/topi/transform.py @@ -170,7 +170,7 @@ def reverse_sequence(a, seq_lengths, seq_axis=1, batch_axis=0): return cpp.reverse_sequence(a, seq_lengths, seq_axis, batch_axis) -def strided_slice(a, begin, end, strides=None, slice_mode="end"): +def strided_slice(a, begin, end, strides=None, axes=None, slice_mode="end"): """Slice of an array. Parameters @@ -189,6 +189,10 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"): in that case, the input tensor will be reversed in that particular axis. + axes : list of int, optional + Axes along which slicing is applied. When it is specified, begin, end + strides, and axes need to a list of integers of the same length. + slice_mode : str, optional The slice mode [end, size]. end - The ending indices for the slice [default]. @@ -205,6 +209,7 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"): or isinstance(end, tvm.te.Tensor) or isinstance(strides, tvm.te.Tensor) ): + assert axes is None, "axes argument is not supported by dynamic strided slice yet." if not isinstance(begin, tvm.te.Tensor): begin = const_vector(begin) if not isinstance(end, tvm.te.Tensor): @@ -216,7 +221,9 @@ def strided_slice(a, begin, end, strides=None, slice_mode="end"): return cpp.dynamic_strided_slice(a, begin, end, strides) if strides is None: strides = [] - return cpp.strided_slice(a, begin, end, strides, slice_mode) + if axes is None: + axes = [] + return cpp.strided_slice(a, begin, end, strides, axes, slice_mode) @tvm.te.tag_scope(tag=tag.INJECTIVE + ",strided_set") diff --git a/src/relay/op/make_op.h b/src/relay/op/make_op.h index bbfef5883e3d..81de4bc90ad7 100644 --- a/src/relay/op/make_op.h +++ b/src/relay/op/make_op.h @@ -78,7 +78,8 @@ Expr MakeStack(Expr data, int axis); Expr MakeTranspose(Expr data, Array axes); Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides, - String slice_mode); + String slice_mode, + Optional> axes = NullValue>()); Expr MakeTile(Expr data, Array reps); diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index d6caefbb4e2c..9361e1996796 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2445,99 +2445,40 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr return false; } - auto dshape = data->shape; - int64_t num_axis = dshape.size(); - - // calculate output shape - std::vector oshape(num_axis); - if (param->begin && param->end && param->strides) { - // stride will be set as 1 if slice mode is enabled - std::vector stride_vec(num_axis, 1); - if (param->slice_mode == "end") { - for (size_t i = 0; i < param->strides.value().size(); ++i) { - ICHECK(param->strides.value()[i].defined()); - stride_vec[i] = param->strides.value()[i]->value; - } - } - const int64_t max_range = std::numeric_limits::max(); - std::vector begin_vec; - for (size_t i = 0; i < param->begin.value().size(); ++i) { - if (!param->begin.value()[i].defined()) { - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } else { - begin_vec.push_back(param->begin.value()[i]->value); - } - } - for (int64_t i = begin_vec.size(); i < num_axis; ++i) { - begin_vec.push_back(stride_vec[i] > 0 ? 0 : max_range); - } + ICHECK(param->begin) << "strided_slice recieved invalid begin " << param->begin; + ICHECK(param->end) << "strided_slice recieved invalid end " << param->end; + ICHECK(param->strides) << "strided_slice recieved invalid strides " << param->strides; + + auto begin = param->begin.value(); + auto end = param->end.value(); + auto strides = param->strides.value(); + + const size_t src_tensor_dim = static_cast(data->shape.size()); + Array axes; + if (param->axes) { + axes = param->axes.value(); + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && + axes.size() == strides.size()) + << "axes, begin, end, and strides must have the same length"; + } else { + for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i); - std::vector end_vec; - for (size_t i = 0; i < param->end.value().size(); ++i) { - // allow end to be None - if (!param->end.value()[i].defined()) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); - } else if (param->slice_mode == "size") { - if (param->end.value()[i]->value < 0) { - end_vec.push_back(max_range); - } else { - end_vec.push_back(begin_vec[i] + param->end.value()[i]->value); - } - } else if (param->slice_mode == "end") { - end_vec.push_back(param->end.value()[i]->value); - } else { - LOG(FATAL) << "Unsupported slice mode: " << param->slice_mode; - } + const IntImm one = IntImm(DataType::Int(64), 1); + const IntImm zero = IntImm(DataType::Int(64), 0); + const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits::max()); + + for (size_t i = strides.size(); i < src_tensor_dim; ++i) { + strides.push_back(one); } - for (int64_t i = end_vec.size(); i < num_axis; ++i) { - end_vec.push_back(stride_vec[i] < 0 ? 0 : max_range); + for (size_t i = begin.size(); i < src_tensor_dim; ++i) { + begin.push_back(topi::GetConstInt(strides[i]) > 0 ? zero : max_range); } - - for (int64_t i = 0; i < num_axis; ++i) { - int64_t stride_v = stride_vec[i]; - int64_t begin_v = begin_vec[i]; - int64_t end_v = end_vec[i]; - - if ((stride_v == 1 && begin_v == 0 && end_v == max_range) || - (stride_v == -1 && begin_v == max_range && end_v == 0)) { - // Quick path, do not slice this dimension. - oshape[i] = dshape[i]; - continue; - } - // Normal path, require the shape to be concrete integer. - // Require concrete integer as symbolic inference of min/max - // can get complicated and not very helpful. - const int64_t* p_dim_size = tir::as_const_int(dshape[i]); - if (!p_dim_size) { - oshape[i] = dshape[i]; - continue; - } - int64_t dim_size = p_dim_size[0]; - begin_v = (begin_v < 0) ? dim_size + begin_v : begin_v; - end_v = (end_v < 0) ? dim_size + end_v : end_v; - - int64_t slice_range, step; - if (stride_v < 0) { - if (end_v < -1) end_v = -1; - ICHECK_LE(end_v, begin_v) << "strided_slice get empty slice at axis " << i; - begin_v = std::min(dim_size - 1, begin_v); - slice_range = begin_v - end_v; - step = -stride_v; - } else { - if (begin_v < 0) begin_v = 0; - ICHECK_GE(stride_v, 0); - ICHECK_LE(begin_v, end_v) << "strided_slice get invalid slice at axis " << i; - end_v = std::min(dim_size, end_v); - slice_range = end_v - begin_v; - step = stride_v; - } - oshape[i] = tir::make_const(dshape[i].dtype(), (slice_range + step - 1) / step); + for (size_t i = end.size(); i < src_tensor_dim; ++i) { + end.push_back(topi::GetConstInt(strides[i]) < 0 ? zero : max_range); } - } else { - ICHECK(param->begin) << "strided_slice recieved invalid begin " << param->begin; - ICHECK(param->end) << "strided_slice recieved invalid end " << param->end; - ICHECK(param->strides) << "strided_slice recieved invalid strides " << param->strides; } + auto oshape = + topi::StridedSliceOutputShape(data->shape, begin, end, strides, axes, param->slice_mode); reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; } @@ -2596,78 +2537,130 @@ Array> StridedSliceInferCorrectLayout(const Attrs& attrs, // Not support NHW4c -> NCHW return {{Layout::Undef()}, {Layout::Undef()}}; } else { - for (size_t i = 0; i < new_layout_name.size(); ++i) { - auto index = layout.IndexOf(new_layout[i]); - if (index == -1) { - return {{Layout::Undef()}, {Layout::Undef()}}; + if (params->axes) { + auto axes = params->axes.value(); + Array new_axes; + + for (size_t i = 0; i < axes.size(); ++i) { + auto old_idx = axes[i]; + auto new_idx = new_layout.IndexOf(layout[old_idx]); + new_begin.push_back(begin[i]); + new_end.push_back(end[i]); + new_strides.push_back(strides[i]); + new_axes.push_back(new_idx); } + params->axes = new_axes; - size_t new_index = static_cast(index); - int64_t bg, ed, st; - if (strides.defined() && new_index < strides.size() && strides[new_index].defined()) { - st = strides[new_index]->value; - } else { - st = 1; - } - if (new_index < begin.size() && begin[new_index].defined()) { - bg = begin[new_index]->value; - } else { - bg = 0; - } - if (new_index < end.size() && end[new_index].defined()) { - ed = end[new_index]->value; - } else { - ed = shape[new_index].as()->value; - } + } else { + for (size_t i = 0; i < new_layout_name.size(); ++i) { + auto index = layout.IndexOf(new_layout[i]); + if (index == -1) { + return {{Layout::Undef()}, {Layout::Undef()}}; + } - new_begin.push_back(IntImm(begin[0]->dtype, bg)); - new_end.push_back(IntImm(end[0]->dtype, ed)); - new_strides.push_back(IntImm(strides[0]->dtype, st)); + size_t new_index = static_cast(index); + int64_t bg, ed, st; + if (strides.defined() && new_index < strides.size() && strides[new_index].defined()) { + st = strides[new_index]->value; + } else { + st = 1; + } + if (new_index < begin.size() && begin[new_index].defined()) { + bg = begin[new_index]->value; + } else { + bg = 0; + } + if (new_index < end.size() && end[new_index].defined()) { + ed = end[new_index]->value; + } else { + ed = shape[new_index].as()->value; + } + + new_begin.push_back(IntImm(begin[0]->dtype, bg)); + new_end.push_back(IntImm(end[0]->dtype, ed)); + new_strides.push_back(IntImm(strides[0]->dtype, st)); + } } + params->begin = new_begin; params->end = new_end; params->strides = new_strides; layout = new_layout; } } else { - for (size_t i = 0; i < begin.size(); i++) { - const LayoutAxis& axis = layout[i]; - if (!axis.IsPrimal()) { - // original layout that contains splitted axes is not supported - return {{Layout::Undef()}, {Layout::Undef()}}; - } - auto factor = new_layout.FactorOf(axis); - if (factor == -1) { - new_begin.push_back(IntImm(begin[i]->dtype, begin[i])); - new_end.push_back(IntImm(end[i]->dtype, end[i])); - } else { - if (strides.defined() && i < strides.size()) { - auto stride = strides[i]; - // arbitrary stride is not supported - if (stride.defined() && stride->value != 1) { + if (params->axes) { + auto axes = params->axes.value(); + Array new_axes; + + for (size_t i = 0; i < axes.size(); ++i) { + auto old_idx = axes[i]; + auto new_idx = new_layout.IndexOf(layout[old_idx]); + new_axes.push_back(new_idx); + + const LayoutAxis& axis = layout[old_idx]; + if (!axis.IsPrimal()) { + // original layout that contains splitted axes is not supported + return {{Layout::Undef()}, {Layout::Undef()}}; + } + + auto factor = new_layout.FactorOf(axis); + + if (factor == -1) { + new_begin.push_back(begin[i]); + new_end.push_back(end[i]); + } else { + int64_t bg = begin[i]; + int64_t ed = end[i]; + if (bg % factor || ed % factor) { + // transform to original layout return {{Layout::Undef()}, {Layout::Undef()}}; } + new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); + new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); } - int64_t bg = begin[i].defined() ? begin[i]->value : 0; - int64_t ed; - if (!end[i].defined()) { - ed = shape[i].as()->value; - } else if (params->slice_mode == "size") { - if (end[i]->value < 0) { + } + params->axes = new_axes; + + } else { + for (size_t i = 0; i < begin.size(); i++) { + const LayoutAxis& axis = layout[i]; + if (!axis.IsPrimal()) { + // original layout that contains splitted axes is not supported + return {{Layout::Undef()}, {Layout::Undef()}}; + } + auto factor = new_layout.FactorOf(axis); + if (factor == -1) { + new_begin.push_back(IntImm(begin[i]->dtype, begin[i])); + new_end.push_back(IntImm(end[i]->dtype, end[i])); + } else { + if (strides.defined() && i < strides.size()) { + auto stride = strides[i]; + // arbitrary stride is not supported + if (stride.defined() && stride->value != 1) { + return {{Layout::Undef()}, {Layout::Undef()}}; + } + } + int64_t bg = begin[i].defined() ? begin[i]->value : 0; + int64_t ed; + if (!end[i].defined()) { ed = shape[i].as()->value; + } else if (params->slice_mode == "size") { + if (end[i]->value < 0) { + ed = shape[i].as()->value; + } else { + ed = bg + end[i]->value; + } } else { - ed = bg + end[i]->value; + ed = end[i]->value; } - } else { - ed = end[i]->value; - } - if (bg % factor || ed % factor) { - // transform to original layout - return {{Layout::Undef()}, {Layout::Undef()}}; + if (bg % factor || ed % factor) { + // transform to original layout + return {{Layout::Undef()}, {Layout::Undef()}}; + } + new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); + new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); } - new_begin.push_back(IntImm(begin[0]->dtype, (bg / factor))); - new_end.push_back(IntImm(end[0]->dtype, (ed / factor))); } } @@ -2683,63 +2676,27 @@ Array StridedSliceCompute(const Attrs& attrs, const Array(); ICHECK(param != nullptr); - Array begin, end, strides; - Array begin_expr, end_expr, strides_expr; - begin = param->begin.value(); - end = param->end.value(); - strides = param->strides.value(); - if (IsDynamic(out_type)) { - auto input = inputs[0]; - size_t src_tensor_dim = input->shape.size(); - ICHECK(begin.size() == src_tensor_dim) - << "for dynamic inputs, len(begin) must equal the input dimension"; - Array out_shape; - for (size_t i = 0; i < src_tensor_dim; ++i) { - out_shape.push_back(tvm::tir::Var("dim")); - } - for (size_t i = 0; i < src_tensor_dim; ++i) { - int64_t begin_i = begin[i]->value; - if (begin_i < 0) { - begin_i += topi::detail::GetConstInt(input->shape[i]); - } - begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i)); - strides_expr.push_back( - tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), - (i < strides.size() ? strides[i]->value : 1))); - } - return Array{te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); - } - return input(real_indices); - }, - std::string{"T_strided_slice_dynamic"}, std::string{topi::kInjective})}; - } else { - for (size_t i = 0; i < begin.size(); ++i) { - begin_expr.push_back(begin[i]); - } - for (size_t i = 0; i < end.size(); ++i) { - end_expr.push_back(end[i]); - } - for (size_t i = 0; i < strides.size(); ++i) { - strides_expr.push_back(strides[i]); - } + ICHECK(param->begin && param->end && param->strides); + Array begin = param->begin.value(); + Array end = param->end.value(); + Array strides = param->strides.value(); + if (param->axes) { + auto axes = param->axes.value(); + return Array{ + topi::strided_slice_with_axes(inputs[0], begin, end, strides, axes, param->slice_mode)}; } - return Array{ - topi::strided_slice(inputs[0], begin_expr, end_expr, strides_expr, param->slice_mode)}; + return Array{topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)}; } // Positional relay function to create StridedSlice operator used by frontend FFI. Expr MakeStridedSlice(Expr data, Array begin, Array end, Array strides, - String slice_mode) { + String slice_mode, Optional> axes) { auto attrs = make_object(); attrs->begin = std::move(begin); attrs->end = std::move(end); attrs->strides = std::move(strides); attrs->slice_mode = slice_mode; + attrs->axes = std::move(axes); static const Op& op = Op::Get("strided_slice"); return Call(op, {data}, Attrs(attrs), {}); } @@ -3057,16 +3014,21 @@ Array SliceLikeCompute(const Attrs& attrs, const Array& ICHECK(param != nullptr); Array src_shape = inputs[0]->shape; Array target_shape = inputs[1]->shape; - Array begin_idx, end_idx, strides; + Array begin_idx, end_idx, strides; for (size_t i = 0; i < src_shape.size(); ++i) { begin_idx.push_back(0); strides.push_back(1); } - end_idx = Array(src_shape); + for (auto s : src_shape) { + ICHECK(s->IsInstance()) << "slice_like does not support dynamic input shape"; + end_idx.push_back(topi::GetConstInt(s)); + } if (!param->axes.defined()) { for (size_t i = 0; i < src_shape.size(); ++i) { if (i < target_shape.size()) { - end_idx.Set(i, target_shape[i]); + ICHECK(target_shape[i]->IsInstance()) + << "slice_like does not support dynamic output shape"; + end_idx.Set(i, topi::GetConstInt(target_shape[i])); ICHECK_LE(topi::GetConstInt(end_idx[i]), topi::GetConstInt(src_shape[i])) << "End index of axis " << i << " exceeds input shape: " << topi::GetConstInt(end_idx[i]) << " vs " @@ -3078,7 +3040,9 @@ Array SliceLikeCompute(const Attrs& attrs, const Array& if (axis < 0) { axis = static_cast(src_shape.size()) + axis; } - end_idx.Set(axis, target_shape[axis]); + ICHECK(target_shape[axis]->IsInstance()) + << "slice_like does not support dynamic output shape"; + end_idx.Set(axis, topi::GetConstInt(target_shape[axis])); ICHECK_LE(topi::GetConstInt(end_idx[axis]), topi::GetConstInt(src_shape[axis])) << "End index of axis " << axis << " exceeds input shape: " << topi::GetConstInt(end_idx[axis]) << " vs " diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 0bce3bbc7f53..db54d5a99a91 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -174,11 +174,31 @@ TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = strided_slice(args[0], args[1], args[2], args[3], args[4]); + Tensor x = args[0]; + Array begin = args[1]; + Array end = args[2]; + Array strides = args[3]; + if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides)) { + Array begin_static = args[1]; + Array end_static = args[2]; + Array strides_static = args[3]; + Array axes = args[4]; + std::string slice_mode = args[5]; + if (axes.size() > 0) { + *rv = strided_slice_with_axes(x, begin_static, end_static, strides_static, axes, slice_mode); + } else { + *rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode); + } + } else { + *rv = dynamic_strided_slice(x, begin, end, strides); + } }); TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = dynamic_strided_slice(args[0], args[1], args[2], args[3]); + te::Tensor begin = args[1]; + te::Tensor end = args[2]; + te::Tensor strides = args[3]; + *rv = dynamic_strided_slice(args[0], begin, end, strides); }); TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) { diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py index 8016e435618a..74b8ec51e1fa 100644 --- a/tests/python/relay/test_any.py +++ b/tests/python/relay/test_any.py @@ -1032,7 +1032,7 @@ def verify_any_strided_slice( mod = tvm.IRModule() data = relay.var("data", shape=data_shape, dtype="float32") if const_attrs: - data = relay.var("data", shape=data_np_shape, dtype="float32") + data = relay.var("data", shape=data_shape, dtype="float32") begin = relay.const(np_begin) end = relay.const(np_end) strides = relay.const(np_strides) diff --git a/tests/python/relay/test_op_level4.py b/tests/python/relay/test_op_level4.py index c49e3de62662..c4d26a1811b1 100644 --- a/tests/python/relay/test_op_level4.py +++ b/tests/python/relay/test_op_level4.py @@ -379,7 +379,17 @@ def test_mean_var_std(): @tvm.testing.uses_gpu def test_strided_slice(): - def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, dtype="int32"): + def verify( + dshape, + begin, + end, + strides, + output, + axes=None, + slice_mode="end", + test_ref=True, + dtype="int32", + ): x = relay.var("x", relay.TensorType(dshape, "float32")) ndim = len(dshape) begin = begin if begin else [0] * ndim @@ -387,12 +397,21 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, # target numpy result x_data = np.random.uniform(size=dshape).astype("float32") - ref_res = tvm.topi.testing.strided_slice_python(x_data, begin, end, strides, slice_mode) + ref_res = tvm.topi.testing.strided_slice_python( + x_data, + begin, + end, + strides, + slice_mode, + axes=axes, + ) if strides: - z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode) + z = relay.strided_slice( + x, begin=begin, end=end, strides=strides, axes=axes, slice_mode=slice_mode + ) else: - z = relay.strided_slice(x, begin=begin, end=end, slice_mode=slice_mode) + z = relay.strided_slice(x, begin=begin, end=end, axes=axes, slice_mode=slice_mode) func = relay.Function([x], z) func = run_infer_type(func) @@ -436,24 +455,43 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False ) verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True) + verify((3, 4, 3), [1], [4], None, None, axes=[1]) @tvm.testing.uses_gpu def test_dyn_strided_slice(): - def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, dtype="int32"): + def verify( + dshape, + begin, + end, + strides, + output, + axes=None, + ishape=None, + slice_mode="end", + test_ref=True, + dtype="int32", + ): ndim = len(dshape) begin = begin if begin else [0] * ndim end = end if end else list(dshape) # target numpy result x_data = np.random.uniform(size=dshape).astype("float32") - ref_res = tvm.topi.testing.strided_slice_python(x_data, begin, end, strides, slice_mode) + ref_res = tvm.topi.testing.strided_slice_python( + x_data, begin, end, strides, slice_mode, axes=axes + ) - x = relay.var("x", relay.TensorType((relay.Any(),) * ndim, "float32")) + if ishape is None: + ishape = (relay.Any(),) * ndim + + x = relay.var("x", relay.TensorType(ishape, "float32")) if strides: - z = relay.strided_slice(x, begin=begin, end=end, strides=strides, slice_mode=slice_mode) + z = relay.strided_slice( + x, begin=begin, end=end, strides=strides, axes=axes, slice_mode=slice_mode + ) else: - z = relay.strided_slice(x, begin=begin, end=end, slice_mode=slice_mode) + z = relay.strided_slice(x, begin=begin, end=end, axes=axes, slice_mode=slice_mode) func = relay.Function([x], z) func = run_infer_type(func) @@ -483,13 +521,21 @@ def verify(dshape, begin, end, strides, output, slice_mode="end", test_ref=True, verify((3, 4, 3), [1, 1, 0], [4, 1000, 3], None, (2, 3, 3)) verify((3, 4, 3), [1, 1, 0], [4, 4, 4], None, (2, 3, 3)) verify((3, 4, 3), [1, 1, 0], [4, 4, 3], None, (2, 3, 3)) - # TODO(mbrookhart): fix static strided_slice with dynamic input and negative begin - # verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) - # verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) + verify((3, 4, 3), [1, -1, 0], [4, -5, 3], [2, -1, 1], (1, 4, 3)) + verify((3, 4, 3), [1, -1, 0], [2, -3, 3], [1, -1, 1], (1, 2, 3)) verify( (3, 4, 3), [1, 0, 0], [3, -1, 3], [1, 1, 1], (2, 4, 3), slice_mode="size", test_ref=False ) verify((3, 4, 3), [1, 0, 0], [-1, 2, 3], [1, 1, 1], (2, 2, 3), slice_mode="size", test_ref=True) + verify( + (3, 4, 3, 2), + [1, 0], + [3, 1], + [1, 1], + None, + axes=[1, 3], + ishape=(relay.Any(), 4, relay.Any(), 2), + ) @tvm.testing.uses_gpu @@ -534,11 +580,12 @@ def verify(dshape, begin, end, strides, vshape, test_ref=True): if __name__ == "__main__": test_strided_slice() - test_strided_set() - test_binary_op() - test_cmp_type() - test_binary_int_broadcast_1() - test_binary_int_broadcast_2() - test_where() - test_reduce_functions() - test_mean_var_std() + test_dyn_strided_slice() + # test_strided_set() + # test_binary_op() + # test_cmp_type() + # test_binary_int_broadcast_1() + # test_binary_int_broadcast_2() + # test_where() + # test_reduce_functions() + # test_mean_var_std() diff --git a/tests/python/relay/test_pass_alter_op_layout.py b/tests/python/relay/test_pass_alter_op_layout.py index 3031c55379ae..5c2793c607a9 100644 --- a/tests/python/relay/test_pass_alter_op_layout.py +++ b/tests/python/relay/test_pass_alter_op_layout.py @@ -770,6 +770,61 @@ def expected(): ) +@tvm.testing.uses_gpu +def test_alter_layout_strided_slice_axes_nhwc(): + """Test rewriting strided_slice with axes during alter_iop_layout""" + + def before(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + y = relay.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.strided_slice(y, begin=[0, 16], end=[1, 32], strides=[1, 1], axes=[0, 3]) + y = relay.Function(analysis.free_vars(y), y) + return y + + def alter_conv2d(attrs, inputs, tinfos, out_type): + data, weight = inputs + new_attrs = dict(attrs) + new_attrs["data_layout"] = "NHWC4c" + return relay.nn.conv2d(data, weight, **new_attrs) + + def expected(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + x = relay.layout_transform(x, "NHWC", "NHWC4c") + y = relay.op.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC4c", + kernel_layout="HWIO", + ) + y = relay.strided_slice(y, begin=[0, 4], end=[1, 8], strides=[1, 1], axes=[0, 3]) + y = relay.layout_transform(y, "NHWC4c", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + with TempOpAttr("nn.conv2d", "FTVMAlterOpLayout", alter_conv2d): + a = run_opt_pass(before(), transform.AlterOpLayout()) + b = run_opt_pass(expected(), transform.InferType()) + + mod_before = tvm.IRModule() + mod_new = tvm.IRModule() + mod_before["main"] = a + mod_new["main"] = b + assert tvm.ir.structural_equal(mod_before, mod_new) + + def test_alter_layout_depthwise_conv2d(): """Test depthwise_conv2d operator""" @@ -1298,3 +1353,4 @@ def expected(): test_alter_layout_nhwc_int8_aarch64() test_alter_op_with_global_var() test_alter_op_dense() + test_alter_layout_strided_slice_axes_nhwc() diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py index dd2dc979a731..4710d50ea8e4 100644 --- a/tests/python/relay/test_pass_convert_op_layout.py +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -735,7 +735,7 @@ def expected(): def test_conv_bn_convert_layout(): - """ Check that layout transforms are propagated through bn. """ + """Check that layout transforms are propagated through bn.""" def before(): x = relay.var("x", shape=(1, 56, 56, 64)) @@ -1097,7 +1097,7 @@ def expected(): def test_conv_convert_kernel_layout(): - """ Check that convolution kernel layout is correctly transformed. """ + """Check that convolution kernel layout is correctly transformed.""" def before(): x = relay.var("x", shape=(1, 56, 56, 64)) @@ -1235,6 +1235,49 @@ def expected(): assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) +def test_conv_strided_slice_axes_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + y = relay.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NHWC", + kernel_layout="HWIO", + ) + y = relay.strided_slice(y, begin=[0, 16], end=[1, 33], strides=[1, 1], axes=[0, 3]) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 28, 28, 32)) + weight = relay.var("weight", shape=(3, 3, 32, 32)) + weight = relay.layout_transform(weight, "HWIO", "OIHW") + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d( + x, + weight, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout="NCHW", + kernel_layout="OIHW", + ) + y = relay.strided_slice(y, begin=[0, 16], end=[1, 33], strides=[1, 1], axes=[0, 1]) + + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + a = run_opt_pass(before(), transform.ConvertLayout({"nn.conv2d": ["NCHW", "default"]})) + b = run_opt_pass(expected(), transform.InferType()) + + assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a) + + def test_conv_roi_pool_convert_layout(): def before(): x = relay.var("x", shape=(1, 64, 56, 56)) @@ -1289,7 +1332,7 @@ def expected(): def test_default_keyword(): - """ Check that the default keyword selects correct TVM default layout. """ + """Check that the default keyword selects correct TVM default layout.""" def before(): x = relay.var("x", shape=(1, 64, 56, 56)) @@ -1784,3 +1827,4 @@ def expected(): test_convert_with_config() test_conv_squeeze_convert_layout() test_conv_reduce_convert_layout() + test_conv_strided_slice_axes_convert_layout() diff --git a/tests/python/topi/python/test_topi_transform.py b/tests/python/topi/python/test_topi_transform.py index 20172f07fd9e..ddde2e20e754 100644 --- a/tests/python/topi/python/test_topi_transform.py +++ b/tests/python/topi/python/test_topi_transform.py @@ -398,10 +398,12 @@ def check_device(target): check_device(target) -def verify_strided_slice(in_shape, begin, end, strides=None): +def verify_strided_slice(in_shape, begin, end, strides=None, axes=None): A = te.placeholder(shape=in_shape, name="A") strides = [1, 1, 1] if strides is None else strides - B = topi.strided_slice(A, begin, end, strides) + 1 + if axes: + strides = [strides[axis] for axis in axes] + B = topi.strided_slice(A, begin, end, strides, axes) + 1 def check_device(target): dev = tvm.device(target, 0) @@ -414,7 +416,7 @@ def check_device(target): foo = tvm.build(s, [A, B], target, name="stride_slice") x_np = np.random.uniform(size=in_shape).astype(A.dtype) - out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides) + 1 + out_npy = tvm.topi.testing.strided_slice_python(x_np, begin, end, strides, axes=axes) + 1 data_nd = tvm.nd.array(x_np, dev) out_nd = tvm.nd.empty(out_npy.shape, device=dev, dtype=A.dtype) foo(data_nd, out_nd) @@ -819,6 +821,7 @@ def test_strided_slice(): verify_strided_slice((3, 4, 3), [1, 1, 0], [4, 4, 3]) verify_strided_slice((3, 4, 3), [0, 2, 0], [1, 2, 3]) verify_strided_slice((3, 4, 3), [0, 0, 0], [None, None, None]) + verify_strided_slice((3, 4, 3), [0], [2], None, axes=[1]) @tvm.testing.uses_gpu