Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent fbb099c commit 053eee2
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
16 changes: 16 additions & 0 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -736,6 +736,7 @@ inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape, co
const Array<Integer>& strides,
const Array<Integer>& axes,
const std::string& slice_mode) {
ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());
std::vector<int64_t> begin_vec, end_vec, strides_vec;
std::tie(begin_vec, end_vec, strides_vec) = ToVec(begin, end, strides, slice_mode);
auto begin_canonicalized =
Expand All @@ -744,6 +745,21 @@ inline Array<PrimExpr> StridedSliceOutputShape(const Array<PrimExpr>& ishape, co
begin_canonicalized);
}


/*!
* \brief strided_slice of a tensor
*
* \param x The input tensor
* \param begin The indices to begin with in the slicing
* \param end Indicies indicating end of the slice
* \param strides Specifies the stride values, it can be negative
* in that case, the input tensor will be reversed in that particular axis
* \param slice_mode Specifies the slice mode
* \param name The name of the operation
* \param tag The tag to mark the operation
*
* \return A Tensor whose op member is the split operation
*/
inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
const Array<Integer>& axes, std::string slice_mode = "end",
Expand Down
26 changes: 21 additions & 5 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2449,20 +2449,36 @@ bool StridedSliceRel(const Array<Type>& types, int num_inputs, const Attrs& attr
ICHECK(param->end) << "strided_slice recieved invalid end " << param->end;
ICHECK(param->strides) << "strided_slice recieved invalid strides " << param->strides;

auto begin = param->begin.value();
auto end = param->end.value();
auto strides = param->strides.value();

const size_t src_tensor_dim = static_cast<size_t>(data->shape.size());
Array<Integer> axes;
if (param->axes) {
axes = param->axes.value();
ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size())
<< "axes, begin, end, and strides must have the same length";
} else {
for (size_t i = 0; i < src_tensor_dim; ++i) axes.push_back(i);

const IntImm one = IntImm(DataType::Int(64), 1);
const IntImm zero = IntImm(DataType::Int(64), 0);
const IntImm max_range = IntImm(DataType::Int(64), std::numeric_limits<int64_t>::max());

for (size_t i = strides.size(); i < src_tensor_dim; ++i) {
strides.push_back(one);
}
for (size_t i = begin.size(); i < src_tensor_dim; ++i) {
begin.push_back(topi::GetConstInt(strides[i]) > 0 ? zero : max_range);
}
for (size_t i = end.size(); i < src_tensor_dim; ++i) {
end.push_back(topi::GetConstInt(strides[i]) < 0 ? zero : max_range);
}
}
auto begin = param->begin.value();
auto end = param->end.value();
auto strides = param->strides.value();
ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size())
<< "Axes, begin, end, and strides must have the same length";
auto oshape =
topi::StridedSliceOutputShape(data->shape, begin, end, strides, axes, param->slice_mode);
LOG(INFO) << "oshape: " << oshape;
reporter->Assign(types[1], TensorType(oshape, data->dtype));
return true;
}
Expand Down

0 comments on commit 053eee2

Please sign in to comment.