Skip to content

Commit

Permalink
working
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent b357c2f commit ecfe3cd
Showing 1 changed file with 31 additions and 33 deletions.
64 changes: 31 additions & 33 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -634,9 +634,9 @@ inline int64_t CanonicalizeIndex(int64_t index, int64_t extent, int64_t stride)
return std::min(std::max(index, begin_range), end_range);
}

inline std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>>
ToVec(const Array<Integer>& begin, const Array<Integer>& end,
const Array<Integer>& strides, std::string slice_mode) {
inline std::tuple<std::vector<int64_t>, std::vector<int64_t>, std::vector<int64_t>> ToVec(
const Array<Integer>& begin, const Array<Integer>& end, const Array<Integer>& strides,
std::string slice_mode) {
std::vector<int64_t> stride_vec(strides.size(), 1);
if (slice_mode == "end") {
for (size_t i = 0; i < strides.size(); ++i) {
Expand Down Expand Up @@ -673,11 +673,10 @@ ToVec(const Array<Integer>& begin, const Array<Integer>& end,
return std::make_tuple(begin_vec, end_vec, stride_vec);
}


inline Array<PrimExpr> StridedSliceCanonicalizeBegin(const Tensor& x, const std::vector<int64_t>& begin,
const std::vector<int64_t>& strides,
const Array<Integer>& axes,
DataType dtype,
inline Array<PrimExpr> StridedSliceCanonicalizeBegin(const Tensor& x,
const std::vector<int64_t>& begin,
const std::vector<int64_t>& strides,
const Array<Integer>& axes, DataType dtype,
std::string slice_mode = "end") {
Array<PrimExpr> begin_expr;
for (size_t i = 0; i < axes.size(); ++i) {
Expand All @@ -688,7 +687,7 @@ inline Array<PrimExpr> StridedSliceCanonicalizeBegin(const Tensor& x, const std:
} else {
auto idim = x->shape[axes[i]];
auto b_expr = make_const(dtype, begin[i]);
auto b = tvm::if_then_else(begin[i] < 0, b_expr + idim, b_expr);
PrimExpr b = begin[i] < 0 ? b_expr + idim : b_expr;
auto s = strides[i];
if (s < 0) {
b = tvm::min(b, idim - 1);
Expand All @@ -705,7 +704,7 @@ inline Array<PrimExpr> StridedSliceOutputShape(const Tensor& x, const std::vecto
const std::vector<int64_t>& end,
const std::vector<int64_t>& strides,
const Array<Integer>& axes, std::string slice_mode,
const Array<PrimExpr>& begin_expr) {
const Array<PrimExpr>& begin_canonicalized) {
size_t src_tensor_dim = x->shape.size();
Array<PrimExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
Expand All @@ -715,14 +714,14 @@ inline Array<PrimExpr> StridedSliceOutputShape(const Tensor& x, const std::vecto
for (size_t i = 0; i < axes.size(); ++i) {
if (x->shape[axes[i]]->IsInstance<tvm::IntImmNode>()) {
const int64_t dim_i = GetConstInt(x->shape[axes[i]]);
int64_t begin_i = GetConstInt(begin_expr[i]);
ICHECK(begin_canonicalized[i]->IsInstance<tvm::IntImmNode>());
int64_t begin_i = GetConstInt(begin_canonicalized[i]);
int64_t end_i = CanonicalizeIndex(end[i], dim_i, strides[i]);
int interval = std::abs(end_i - begin_i);
int slice_size =
static_cast<int>((interval + std::abs(strides[i]) - 1) / std::abs(strides[i]));
ICHECK(strides[i] < 0 ? (end_i <= begin_i) : (begin_i <= end_i))
<< ": Input [Begin=" << begin[i] << ", End=" << end[i]
<< "] is invalid for axis=" << i;
<< ": Input [Begin=" << begin[i] << ", End=" << end[i] << "] is invalid for axis=" << i;
out_shape.Set(axes[i], cast(out_shape[i].dtype(), PrimExpr(slice_size)));
} else {
out_shape.Set(axes[i], tvm::tir::Var("dim", out_shape[i]->dtype));
Expand All @@ -732,45 +731,44 @@ inline Array<PrimExpr> StridedSliceOutputShape(const Tensor& x, const std::vecto
return out_shape;
}

// inline Array<PrimExpr> StridedSliceOutputShape(const Tensor& x, const Array<Integer>& begin,
// const Array<Integer>& end,
// const Array<Integer>& strides,
// const Array<Integer>& axes, std::string slice_mode) {
// Array<PrimExpr> begin_expr = StridedSliceCanonicalizeBegin(x, begin, strides, axes, slice_mode);
// return StridedSliceOutputShape(x, begin, end, strides, axes, slice_mode, begin_expr);
// }
inline Array<PrimExpr> StridedSliceOutputShape(const Tensor& x, const Array<Integer>& begin,
const Array<Integer>& end,
const Array<Integer>& strides,
const Array<Integer>& axes,
const std::string& slice_mode) {
std::vector<int64_t> begin_vec, end_vec, strides_vec;
std::tie(begin_vec, end_vec, strides_vec) = ToVec(begin, end, strides, slice_mode);
auto begin_canonicalized =
StridedSliceCanonicalizeBegin(x, begin_vec, strides_vec, axes, begin[0]->dtype, slice_mode);
return StridedSliceOutputShape(x, begin_vec, end_vec, strides_vec, axes, slice_mode,
begin_canonicalized);
}

inline Tensor strided_slice_with_axes(const Tensor& x, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
const Array<Integer>& axes, std::string slice_mode = "end",
std::string name = "T_strided_slice_with_axes",
std::string tag = kInjective) {
size_t src_tensor_dim = x->shape.size();

const size_t src_tensor_dim = x->shape.size();
ICHECK(axes.size() <= src_tensor_dim);
ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());

std::vector<int64_t> begin_vec, end_vec, strides_vec;
std::tie(begin_vec, end_vec, strides_vec) = ToVec(begin, end, strides, slice_mode);
Array<PrimExpr> begin_expr = StridedSliceCanonicalizeBegin(x, begin_vec, strides_vec, axes, begin[0]->dtype, slice_mode);
Array<PrimExpr> out_shape =

auto begin_expr =
StridedSliceCanonicalizeBegin(x, begin_vec, strides_vec, axes, begin[0]->dtype, slice_mode);
auto out_shape =
StridedSliceOutputShape(x, begin_vec, end_vec, strides_vec, axes, slice_mode, begin_expr);
Array<PrimExpr> strides_expr;
for (size_t i = 0; i < axes.size(); ++i) {
if (slice_mode == "end") {
strides_expr.push_back(make_const(strides[i].dtype(), GetConstInt(strides[i])));
} else {
strides_expr.push_back(make_const(strides[i].dtype(), 1));
}
}

return te::compute(
out_shape,
[&](const Array<tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < out_shape.size(); ++i) real_indices.push_back(indices[i]);
for (size_t i = 0; i < axes.size(); ++i) {
PrimExpr ind = indices[axes[i]] * strides_expr[i] + begin_expr[i];
auto stride = make_const(strides[i].dtype(), strides_vec[i]);
PrimExpr ind = indices[axes[i]] * stride + begin_expr[i];
real_indices.Set(axes[i], ind);
}
return x(real_indices);
Expand Down

0 comments on commit ecfe3cd

Please sign in to comment.