From 150e945290cbc595bd370dcae7e96e24597fbf04 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 21 May 2021 04:08:37 +0900 Subject: [PATCH] migrating inlined topi compute to topi/transform.h --- include/tvm/topi/transform.h | 100 +++++++++++++++++++++++++++++++ src/relay/op/tensor/transform.cc | 98 +++--------------------------- src/topi/transform.cc | 5 +- 3 files changed, 113 insertions(+), 90 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 36acc7376c7c..a81ac691dadd 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -725,6 +725,106 @@ inline Tensor strided_slice(const Tensor& x, const Array& begin, name, tag); } +inline Tensor strided_slice(const Tensor& x, const Array& begin, const Array& end, + const Array& strides, std::string slice_mode = "end", + std::string name = "T_strided_slice", std::string tag = kInjective) { + Array begin_expr, end_expr, strides_expr; + for (size_t i = 0; i < begin.size(); ++i) { + begin_expr.push_back(begin[i]); + } + for (size_t i = 0; i < end.size(); ++i) { + end_expr.push_back(end[i]); + } + for (size_t i = 0; i < strides.size(); ++i) { + strides_expr.push_back(strides[i]); + } + return strided_slice(x, begin_expr, end_expr, strides_expr, slice_mode, name, tag); +} + +inline Tensor strided_slice_dynamic_input(const Tensor& input, const Array& begin, + const Array& end, const Array& strides, + std::string slice_mode = "end", + std::string name = "T_strided_slice_dynamic_input", + std::string tag = kInjective) { + size_t src_tensor_dim = input->shape.size(); + ICHECK(begin.size() == src_tensor_dim) + << "for dynamic inputs, len(begin) must equal the input dimension"; + Array out_shape; + for (size_t i = 0; i < src_tensor_dim; ++i) { + out_shape.push_back(tvm::tir::Var("dim")); + } + Array begin_expr, end_expr, strides_expr; + for (size_t i = 0; i < src_tensor_dim; ++i) { + int64_t begin_i = begin[i]->value; + if (begin_i < 0) { + begin_i += topi::detail::GetConstInt(input->shape[i]); + } + begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i)); + strides_expr.push_back( + tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), + (i < strides.size() ? strides[i]->value : 1))); + } + return te::compute( + out_shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < src_tensor_dim; ++i) { + real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); + } + return input(real_indices); + }, + std::string{"T_strided_slice_dynamic_input"}, std::string{topi::kInjective}); +} + +inline Tensor strided_slice_with_axes(const Tensor& input, const Array& begin, + const Array& end, const Array& strides, + const Array& axes, std::string slice_mode = "end", + std::string name = "T_strided_slice_dynamic_input", + std::string tag = kInjective) { + size_t src_tensor_dim = input->shape.size(); + + ICHECK(axes.size() <= src_tensor_dim); + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); + + Array out_shape; + for (size_t i = 0; i < src_tensor_dim; ++i) { + out_shape.push_back(input->shape[i]); + } + Array begin_expr; + for (size_t i = 0; i < axes.size(); ++i) { + auto idim = input->shape[axes[i]]; + auto b = tvm::if_then_else(begin[i] < 0, begin[i] + idim, begin[i]); + auto e = tvm::if_then_else(end[i] < 0, end[i] + idim, end[i]); + auto s = strides[i]->value; + PrimExpr range; + if (s < 0) { + b = tvm::min(b, idim - 1); + e = tvm::if_then_else(e < -1, -1, e); + range = b - e; + s = -s; + } else { + b = tvm::if_then_else(b < 0, 0, b); + e = tvm::min(e, idim); + range = e - b; + } + PrimExpr odim = indexdiv(range + tvm::PrimExpr(static_cast(s - 1)), s); + out_shape.Set(axes[i], cast(out_shape[i].dtype(), odim)); + begin_expr.push_back(b); + } + return te::compute( + out_shape, + [&](const Array& indices) { + Array real_indices; + for (size_t i = 0; i < src_tensor_dim; ++i) real_indices.push_back(indices[i]); + for (size_t i = 0; i < axes.size(); ++i) { + PrimExpr ind = indices[axes[i]] * strides[i] + begin_expr[i]; + real_indices.Set(axes[i], ind); + } + return input(real_indices); + }, + std::string{"T_strided_slice_with_axes"}, std::string{topi::kInjective}); +} + /*! * \brief Split a tensor into a number of sub-tensors * diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index 1279e0acde9f..8c012ecb47d3 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2943,99 +2943,19 @@ Array StridedSliceCompute(const Attrs& attrs, const Array(); ICHECK(param != nullptr); - Array begin, end, strides; - Array begin_expr, end_expr, strides_expr; - begin = param->begin.value(); - end = param->end.value(); - strides = param->strides.value(); + ICHECK(param->begin && param->end && param->strides); + Array begin = param->begin.value(); + Array end = param->end.value(); + Array strides = param->strides.value(); if (param->axes) { auto axes = param->axes.value(); - auto input = inputs[0]; - size_t src_tensor_dim = input->shape.size(); - - ICHECK(axes.size() <= src_tensor_dim); - ICHECK(axes.size() == begin.size() && axes.size() == end.size() && - axes.size() == strides.size()); - - Array out_shape; - for (size_t i = 0; i < src_tensor_dim; ++i) { - out_shape.push_back(input->shape[i]); - } - Array begin_expr; - for (size_t i = 0; i < axes.size(); ++i) { - auto idim = input->shape[axes[i]]; - auto b = tvm::if_then_else(begin[i] < 0, begin[i] + idim, begin[i]); - auto e = tvm::if_then_else(end[i] < 0, end[i] + idim, end[i]); - auto s = strides[i]->value; - PrimExpr range; - if (s < 0) { - b = tvm::min(b, idim - 1); - e = tvm::if_then_else(e < -1, -1, e); - range = b - e; - s = -s; - } else { - b = tvm::if_then_else(b < 0, 0, b); - e = tvm::min(e, idim); - range = e - b; - } - PrimExpr odim = indexdiv(range + tvm::PrimExpr(static_cast(s - 1)), s); - out_shape.Set(axes[i], cast(out_shape[i].dtype(), odim)); - begin_expr.push_back(b); - } - return Array{te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) real_indices.push_back(indices[i]); - for (size_t i = 0; i < axes.size(); ++i) { - PrimExpr ind = indices[axes[i]] * strides[i] + begin_expr[i]; - real_indices.Set(axes[i], ind); - } - return input(real_indices); - }, - std::string{"T_strided_slice_with_axes"}, std::string{topi::kInjective})}; + return Array{ + topi::strided_slice_with_axes(inputs[0], begin, end, strides, axes, param->slice_mode)}; } else if (IsDynamic(out_type)) { - auto input = inputs[0]; - size_t src_tensor_dim = input->shape.size(); - ICHECK(begin.size() == src_tensor_dim) - << "for dynamic inputs, len(begin) must equal the input dimension"; - Array out_shape; - for (size_t i = 0; i < src_tensor_dim; ++i) { - out_shape.push_back(tvm::tir::Var("dim")); - } - for (size_t i = 0; i < src_tensor_dim; ++i) { - int64_t begin_i = begin[i]->value; - if (begin_i < 0) { - begin_i += topi::detail::GetConstInt(input->shape[i]); - } - begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i)); - strides_expr.push_back( - tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()), - (i < strides.size() ? strides[i]->value : 1))); - } - return Array{te::compute( - out_shape, - [&](const Array& indices) { - Array real_indices; - for (size_t i = 0; i < src_tensor_dim; ++i) { - real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]); - } - return input(real_indices); - }, - std::string{"T_strided_slice_dynamic"}, std::string{topi::kInjective})}; - } else { - for (size_t i = 0; i < begin.size(); ++i) { - begin_expr.push_back(begin[i]); - } - for (size_t i = 0; i < end.size(); ++i) { - end_expr.push_back(end[i]); - } - for (size_t i = 0; i < strides.size(); ++i) { - strides_expr.push_back(strides[i]); - } + return Array{ + topi::strided_slice_dynamic_input(inputs[0], begin, end, strides, param->slice_mode)}; } - return Array{ - topi::strided_slice(inputs[0], begin_expr, end_expr, strides_expr, param->slice_mode)}; + return Array{topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)}; } // Positional relay function to create StridedSlice operator used by frontend FFI. diff --git a/src/topi/transform.cc b/src/topi/transform.cc index 0bce3bbc7f53..dfea643217d6 100644 --- a/src/topi/transform.cc +++ b/src/topi/transform.cc @@ -174,7 +174,10 @@ TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) { }); TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) { - *rv = strided_slice(args[0], args[1], args[2], args[3], args[4]); + Array begin = args[1]; + Array end = args[2]; + Array strides = args[3]; + *rv = strided_slice(args[0], begin, end, strides, args[4]); }); TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) {