From 053eee2e6f58749af0b68cca52fd530afc0f6454 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 24 May 2021 17:14:44 +0900 Subject: [PATCH] fix --- include/tvm/topi/transform.h | 16 ++++++++++++++++ src/relay/op/tensor/transform.cc | 26 +++++++++++++++++++++----- 2 files changed, 37 insertions(+), 5 deletions(-) diff --git a/include/tvm/topi/transform.h b/include/tvm/topi/transform.h index 7dd5df79da53..fddbb926d978 100644 --- a/include/tvm/topi/transform.h +++ b/include/tvm/topi/transform.h @@ -736,6 +736,7 @@ inline Array StridedSliceOutputShape(const Array& ishape, co const Array& strides, const Array& axes, const std::string& slice_mode) { + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()); std::vector begin_vec, end_vec, strides_vec; std::tie(begin_vec, end_vec, strides_vec) = ToVec(begin, end, strides, slice_mode); auto begin_canonicalized = @@ -744,6 +745,21 @@ inline Array StridedSliceOutputShape(const Array& ishape, co begin_canonicalized); } + +/*! + * \brief strided_slice of a tensor + * + * \param x The input tensor + * \param begin The indices to begin with in the slicing + * \param end Indicies indicating end of the slice + * \param strides Specifies the stride values, it can be negative + * in that case, the input tensor will be reversed in that particular axis + * \param slice_mode Specifies the slice mode + * \param name The name of the operation + * \param tag The tag to mark the operation + * + * \return A Tensor whose op member is the split operation + */ inline Tensor strided_slice_with_axes(const Tensor& x, const Array& begin, const Array& end, const Array& strides, const Array& axes, std::string slice_mode = "end", diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index b35863bc8066..1a70b915f9bf 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -2449,20 +2449,36 @@ bool StridedSliceRel(const Array& types, int num_inputs, const Attrs& attr ICHECK(param->end) << "strided_slice recieved invalid end " << param->end; ICHECK(param->strides) << "strided_slice recieved invalid strides " << param->strides; + auto begin = param->begin.value(); + auto end = param->end.value(); + auto strides = param->strides.value(); + const size_t src_tensor_dim = static_cast(data->shape.size()); Array axes; if (param->axes) { axes = param->axes.value(); + ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()) + << "axes, begin, end, and strides must have the same length"; } else { for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i); + + const IntImm one = IntImm(DataType::Int(64), 1); + const IntImm zero = IntImm(DataType::Int(64), 0); + const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits::max()); + + for (size_t i = strides.size(); i < src_tensor_dim; ++i) { + strides.push_back(one); + } + for (size_t i = begin.size(); i < src_tensor_dim; ++i) { + begin.push_back(topi::GetConstInt(strides[i]) > 0 ? zero : max_range); + } + for (size_t i = end.size(); i < src_tensor_dim; ++i) { + end.push_back(topi::GetConstInt(strides[i]) < 0 ? zero : max_range); + } } - auto begin = param->begin.value(); - auto end = param->end.value(); - auto strides = param->strides.value(); - ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size()) - << "Axes, begin, end, and strides must have the same length"; auto oshape = topi::StridedSliceOutputShape(data->shape, begin, end, strides, axes, param->slice_mode); + LOG(INFO) << "oshape: " << oshape; reporter->Assign(types[1], TensorType(oshape, data->dtype)); return true; }