Skip to content

Commit

Permalink
restore another slice variant
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent 36aa777 commit d2538ae
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 6 deletions.
26 changes: 26 additions & 0 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,32 @@ inline te::Tensor dynamic_strided_slice(const te::Tensor& x, const te::Tensor& b
name, tag);
}

inline Tensor dynamic_strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
const Array<PrimExpr>& end, const Array<PrimExpr>& strides,
std::string slice_mode = "end",
std::string name = "T_strided_slice",
std::string tag = kInjective) {
size_t src_tensor_dim = static_cast<size_t>(x->shape.size());
ICHECK_EQ(begin.size(), src_tensor_dim);
ICHECK_EQ(end.size(), src_tensor_dim);
ICHECK_EQ(strides.size(), src_tensor_dim);

Array<PrimExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(indexdiv(end[i] - begin[i], strides[i]));
}
return te::compute(
out_shape,
[&](const Array<tvm::tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i] * strides[i] + begin[i]);
}
return x(real_indices);
},
name, tag);
}

inline Tensor strided_slice_dynamic_input(const Tensor& input, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
std::string slice_mode = "end",
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -1359,7 +1359,7 @@ def has_static_axes():
else:
strides_np = steps.data.asnumpy().astype("int64")

return _op.strided_slice(inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np))
# return _op.strided_slice(inputs[0], list(begin_np), list(end_np), list(strides_np), axes=list(axes_np))

# Update the starts and ends according to axes if required.
if axes is not None:
Expand Down
1 change: 1 addition & 0 deletions python/tvm/topi/cuda/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,6 +962,7 @@ def topk(data, k=1, axis=-1, ret_type="both", is_ascend=False, dtype="int64"):
end.append(dshape[i])
if ret_type == "both":
values_out, indices_out = output
print("end:", end, k)
values_out = strided_slice(values_out, beg, end, strides)
indices_out = strided_slice(indices_out, beg, end, strides)
output = [values_out, indices_out]
Expand Down
26 changes: 21 additions & 5 deletions src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,30 @@ TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) {
});

TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) {
Array<Integer> begin = args[1];
Array<Integer> end = args[2];
Array<Integer> strides = args[3];
*rv = strided_slice(args[0], begin, end, strides, args[4]);
Tensor x = args[0];
Array<PrimExpr> begin = args[1];
Array<PrimExpr> end = args[2];
Array<PrimExpr> strides = args[3];
std::string slice_mode = args[4];
if (IsConstIntArray(begin) && IsConstIntArray(end) && IsConstIntArray(strides)) {
Array<Integer> begin_static = args[1];
Array<Integer> end_static = args[2];
Array<Integer> strides_static = args[3];
if (IsConstIntArray(x->shape)) {
*rv = strided_slice(x, begin_static, end_static, strides_static, slice_mode);
} else {
*rv = strided_slice_dynamic_input(x, begin_static, end_static, strides_static, slice_mode);
}
} else {
*rv = dynamic_strided_slice(x, begin, end, strides, slice_mode);
}
});

TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = dynamic_strided_slice(args[0], args[1], args[2], args[3]);
te::Tensor begin = args[1];
te::Tensor end = args[2];
te::Tensor strides = args[3];
*rv = dynamic_strided_slice(args[0], begin, end, strides);
});

TVM_REGISTER_GLOBAL("topi.one_hot").set_body([](TVMArgs args, TVMRetValue* rv) {
Expand Down

0 comments on commit d2538ae

Please sign in to comment.