diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 095124794b15..fe0130ff997c 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -634,9 +634,9 @@ inline int64_t CanonicalizeIndex(int64_t index, int64_t extent, int64_t stride) return std::min(std::max(index, begin_range), end_range); } -inline std::tuple, std::vector, std::vector> -ToVec(const Array& begin, const Array& end, - const Array& strides, std::string slice_mode) { +inline std::tuple, std::vector, std::vector> ToVec( + const Array& begin, const Array& end, const Array& strides, + std::string slice_mode) { std::vector stride_vec(strides.size(), 1); if (slice_mode == "end") { for (size_t i = 0; i < strides.size(); ++i) { @@ -673,11 +673,10 @@ ToVec(const Array& begin, const Array& end, return std::make_tuple(begin_vec, end_vec, stride_vec); } - -inline Array StridedSliceCanonicalizeBegin(const Tensor& x, const std::vector& begin, - const std::vector& strides, - const Array& axes, - DataType dtype, +inline Array StridedSliceCanonicalizeBegin(const Tensor& x, + const std::vector& begin, + const std::vector& strides, + const Array& axes, DataType dtype, std::string slice_mode = "end") { Array begin_expr; for (size_t i = 0; i < axes.size(); ++i) { @@ -688,7 +687,7 @@ inline Array StridedSliceCanonicalizeBegin(const Tensor& x, const std: } else { auto idim = x->shape[axes[i]]; auto b_expr = make_const(dtype, begin[i]); - auto b = tvm::if_then_else(begin[i] < 0, b_expr + idim, b_expr); + PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr; auto s = strides[i]; if (s < 0) { b = tvm::min(b, idim - 1); @@ -705,7 +704,7 @@ inline Array StridedSliceOutputShape(const Tensor& x, const std::vecto const std::vector& end, const std::vector& strides, const Array& axes, std::string slice_mode, - const Array& begin_expr) { + const Array& begin_canonicalized) { size_t src_tensor_dim = x->shape.size(); Array out_shape; for (size_t i = 0; i < src_tensor_dim; ++i) { @@ -715,14 +714,14 @@ inline Array StridedSliceOutputShape(const Tensor& x, const std::vecto for (size_t i = 0; i < axes.size(); ++i) { if (x->shape[axes[i]]->IsInstance()) { const int64_t dim_i = GetConstInt(x->shape[axes[i]]); - int64_t begin_i = GetConstInt(begin_expr[i]); + ICHECK(begin_canonicalized[i]->IsInstance()); + int64_t begin_i = GetConstInt(begin_canonicalized[i]); int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]); int interval = std::abs(end_i - begin_i); int slice_size = static_cast((interval + std::abs(strides[i]) - 1) / std::abs(strides[i])); ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i)) - << ": Input [Begin=" << begin[i] << ", End=" << end[i] - << "] is invalid for axis=" << i; + << ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i; out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size))); } else { out_shape.Set(axes[i], tvm::tir::Var("dim", out_shape[i]->dtype)); @@ -732,37 +731,35 @@ inline Array StridedSliceOutputShape(const Tensor& x, const std::vecto return out_shape; } -// inline Array StridedSliceOutputShape(const Tensor& x, const Array& begin, -// const Array& end, -// const Array& strides, -// const Array& axes, std::string slice_mode) { -// Array begin_expr = StridedSliceCanonicalizeBegin(x, begin, strides, axes, slice_mode); -// return StridedSliceOutputShape(x, begin, end, strides, axes, slice_mode, begin_expr); -// } +inline Array StridedSliceOutputShape(const Tensor& x, const Array& begin, + const Array& end, + const Array& strides, + const Array& axes, + const std::string& slice_mode) { + std::vector begin_vec, end_vec, strides_vec; + std::tie(begin_vec, end_vec, strides_vec) = ToVec(begin, end, strides, slice_mode); + auto begin_canonicalized = + StridedSliceCanonicalizeBegin(x, begin_vec, strides_vec, axes, begin[0]->dtype, slice_mode); + return StridedSliceOutputShape(x, begin_vec, end_vec, strides_vec, axes, slice_mode, + begin_canonicalized); +} inline Tensor strided_slice_with_axes(const Tensor& x, const Array& begin, const Array& end, const Array& strides, const Array& axes, std::string slice_mode = "end", std::string name = "T_strided_slice_with_axes", std::string tag = kInjective) { - size_t src_tensor_dim = x->shape.size(); - + const size_t src_tensor_dim = x->shape.size(); ICHECK(axes.size() <= src_tensor_dim); ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); std::vector begin_vec, end_vec, strides_vec; std::tie(begin_vec, end_vec, strides_vec) = ToVec(begin, end, strides, slice_mode); - Array begin_expr = StridedSliceCanonicalizeBegin(x, begin_vec, strides_vec, axes, begin[0]->dtype, slice_mode); - Array out_shape = + + auto begin_expr = + StridedSliceCanonicalizeBegin(x, begin_vec, strides_vec, axes, begin[0]->dtype, slice_mode); + auto out_shape = StridedSliceOutputShape(x, begin_vec, end_vec, strides_vec, axes, slice_mode, begin_expr); - Array strides_expr; - for (size_t i = 0; i < axes.size(); ++i) { - if (slice_mode == "end") { - strides_expr.push_back(make_const(strides[i].dtype(), GetConstInt(strides[i]))); - } else { - strides_expr.push_back(make_const(strides[i].dtype(), 1)); - } - } return te::compute( out_shape, @@ -770,7 +767,8 @@ inline Tensor strided_slice_with_axes(const Tensor& x, const Array& beg Array real_indices; for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]); for (size_t i = 0; i < axes.size(); ++i) { - PrimExpr ind = indices[axes[i]] * strides_expr[i] + begin_expr[i]; + auto stride = make_const(strides[i].dtype(), strides_vec[i]); + PrimExpr ind = indices[axes[i]] * stride + begin_expr[i]; real_indices.Set(axes[i], ind); } return x(real_indices);