Skip to content

Commit

Permalink
migrating inlined topi compute to topi/transform.h
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 27, 2021
1 parent 763ac37 commit 150e945
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 90 deletions.
100 changes: 100 additions & 0 deletions include/tvm/topi/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -725,6 +725,106 @@ inline Tensor strided_slice(const Tensor& x, const Array<PrimExpr>& begin,
name, tag);
}

inline Tensor strided_slice(const Tensor& x, const Array<Integer>& begin, const Array<Integer>& end,
const Array<Integer>& strides, std::string slice_mode = "end",
std::string name = "T_strided_slice", std::string tag = kInjective) {
Array<PrimExpr> begin_expr, end_expr, strides_expr;
for (size_t i = 0; i < begin.size(); ++i) {
begin_expr.push_back(begin[i]);
}
for (size_t i = 0; i < end.size(); ++i) {
end_expr.push_back(end[i]);
}
for (size_t i = 0; i < strides.size(); ++i) {
strides_expr.push_back(strides[i]);
}
return strided_slice(x, begin_expr, end_expr, strides_expr, slice_mode, name, tag);
}

inline Tensor strided_slice_dynamic_input(const Tensor& input, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
std::string slice_mode = "end",
std::string name = "T_strided_slice_dynamic_input",
std::string tag = kInjective) {
size_t src_tensor_dim = input->shape.size();
ICHECK(begin.size() == src_tensor_dim)
<< "for dynamic inputs, len(begin) must equal the input dimension";
Array<PrimExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(tvm::tir::Var("dim"));
}
Array<PrimExpr> begin_expr, end_expr, strides_expr;
for (size_t i = 0; i < src_tensor_dim; ++i) {
int64_t begin_i = begin[i]->value;
if (begin_i < 0) {
begin_i += topi::detail::GetConstInt(input->shape[i]);
}
begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i));
strides_expr.push_back(
tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()),
(i < strides.size() ? strides[i]->value : 1)));
}
return te::compute(
out_shape,
[&](const Array<tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]);
}
return input(real_indices);
},
std::string{"T_strided_slice_dynamic_input"}, std::string{topi::kInjective});
}

inline Tensor strided_slice_with_axes(const Tensor& input, const Array<Integer>& begin,
const Array<Integer>& end, const Array<Integer>& strides,
const Array<Integer>& axes, std::string slice_mode = "end",
std::string name = "T_strided_slice_dynamic_input",
std::string tag = kInjective) {
size_t src_tensor_dim = input->shape.size();

ICHECK(axes.size() <= src_tensor_dim);
ICHECK(axes.size() == begin.size() && axes.size() == end.size() && axes.size() == strides.size());

Array<PrimExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(input->shape[i]);
}
Array<PrimExpr> begin_expr;
for (size_t i = 0; i < axes.size(); ++i) {
auto idim = input->shape[axes[i]];
auto b = tvm::if_then_else(begin[i] < 0, begin[i] + idim, begin[i]);
auto e = tvm::if_then_else(end[i] < 0, end[i] + idim, end[i]);
auto s = strides[i]->value;
PrimExpr range;
if (s < 0) {
b = tvm::min(b, idim - 1);
e = tvm::if_then_else(e < -1, -1, e);
range = b - e;
s = -s;
} else {
b = tvm::if_then_else(b < 0, 0, b);
e = tvm::min(e, idim);
range = e - b;
}
PrimExpr odim = indexdiv(range + tvm::PrimExpr(static_cast<int32_t>(s - 1)), s);
out_shape.Set(axes[i], cast(out_shape[i].dtype(), odim));
begin_expr.push_back(b);
}
return te::compute(
out_shape,
[&](const Array<tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) real_indices.push_back(indices[i]);
for (size_t i = 0; i < axes.size(); ++i) {
PrimExpr ind = indices[axes[i]] * strides[i] + begin_expr[i];
real_indices.Set(axes[i], ind);
}
return input(real_indices);
},
std::string{"T_strided_slice_with_axes"}, std::string{topi::kInjective});
}

/*!
* \brief Split a tensor into a number of sub-tensors
*
Expand Down
98 changes: 9 additions & 89 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2943,99 +2943,19 @@ Array<te::Tensor> StridedSliceCompute(const Attrs& attrs, const Array<te::Tensor
const Type& out_type) {
const StridedSliceAttrs* param = attrs.as<StridedSliceAttrs>();
ICHECK(param != nullptr);
Array<Integer> begin, end, strides;
Array<PrimExpr> begin_expr, end_expr, strides_expr;
begin = param->begin.value();
end = param->end.value();
strides = param->strides.value();
ICHECK(param->begin && param->end && param->strides);
Array<Integer> begin = param->begin.value();
Array<Integer> end = param->end.value();
Array<Integer> strides = param->strides.value();
if (param->axes) {
auto axes = param->axes.value();
auto input = inputs[0];
size_t src_tensor_dim = input->shape.size();

ICHECK(axes.size() <= src_tensor_dim);
ICHECK(axes.size() == begin.size() && axes.size() == end.size() &&
axes.size() == strides.size());

Array<IndexExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(input->shape[i]);
}
Array<PrimExpr> begin_expr;
for (size_t i = 0; i < axes.size(); ++i) {
auto idim = input->shape[axes[i]];
auto b = tvm::if_then_else(begin[i] < 0, begin[i] + idim, begin[i]);
auto e = tvm::if_then_else(end[i] < 0, end[i] + idim, end[i]);
auto s = strides[i]->value;
PrimExpr range;
if (s < 0) {
b = tvm::min(b, idim - 1);
e = tvm::if_then_else(e < -1, -1, e);
range = b - e;
s = -s;
} else {
b = tvm::if_then_else(b < 0, 0, b);
e = tvm::min(e, idim);
range = e - b;
}
PrimExpr odim = indexdiv(range + tvm::PrimExpr(static_cast<int32_t>(s - 1)), s);
out_shape.Set(axes[i], cast(out_shape[i].dtype(), odim));
begin_expr.push_back(b);
}
return Array<te::Tensor>{te::compute(
out_shape,
[&](const Array<tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) real_indices.push_back(indices[i]);
for (size_t i = 0; i < axes.size(); ++i) {
PrimExpr ind = indices[axes[i]] * strides[i] + begin_expr[i];
real_indices.Set(axes[i], ind);
}
return input(real_indices);
},
std::string{"T_strided_slice_with_axes"}, std::string{topi::kInjective})};
return Array<te::Tensor>{
topi::strided_slice_with_axes(inputs[0], begin, end, strides, axes, param->slice_mode)};
} else if (IsDynamic(out_type)) {
auto input = inputs[0];
size_t src_tensor_dim = input->shape.size();
ICHECK(begin.size() == src_tensor_dim)
<< "for dynamic inputs, len(begin) must equal the input dimension";
Array<IndexExpr> out_shape;
for (size_t i = 0; i < src_tensor_dim; ++i) {
out_shape.push_back(tvm::tir::Var("dim"));
}
for (size_t i = 0; i < src_tensor_dim; ++i) {
int64_t begin_i = begin[i]->value;
if (begin_i < 0) {
begin_i += topi::detail::GetConstInt(input->shape[i]);
}
begin_expr.push_back(tir::make_const(begin[0].dtype(), begin_i));
strides_expr.push_back(
tir::make_const((strides.size() != 0 ? strides[0].dtype() : begin[0].dtype()),
(i < strides.size() ? strides[i]->value : 1)));
}
return Array<te::Tensor>{te::compute(
out_shape,
[&](const Array<tir::Var>& indices) {
Array<PrimExpr> real_indices;
for (size_t i = 0; i < src_tensor_dim; ++i) {
real_indices.push_back(indices[i] * strides_expr[i] + begin_expr[i]);
}
return input(real_indices);
},
std::string{"T_strided_slice_dynamic"}, std::string{topi::kInjective})};
} else {
for (size_t i = 0; i < begin.size(); ++i) {
begin_expr.push_back(begin[i]);
}
for (size_t i = 0; i < end.size(); ++i) {
end_expr.push_back(end[i]);
}
for (size_t i = 0; i < strides.size(); ++i) {
strides_expr.push_back(strides[i]);
}
return Array<te::Tensor>{
topi::strided_slice_dynamic_input(inputs[0], begin, end, strides, param->slice_mode)};
}
return Array<te::Tensor>{
topi::strided_slice(inputs[0], begin_expr, end_expr, strides_expr, param->slice_mode)};
return Array<te::Tensor>{topi::strided_slice(inputs[0], begin, end, strides, param->slice_mode)};
}

// Positional relay function to create StridedSlice operator used by frontend FFI.
Expand Down
5 changes: 4 additions & 1 deletion src/topi/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ TVM_REGISTER_GLOBAL("topi.einsum").set_body([](TVMArgs args, TVMRetValue* rv) {
});

TVM_REGISTER_GLOBAL("topi.strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) {
*rv = strided_slice(args[0], args[1], args[2], args[3], args[4]);
Array<PrimExpr> begin = args[1];
Array<PrimExpr> end = args[2];
Array<PrimExpr> strides = args[3];
*rv = strided_slice(args[0], begin, end, strides, args[4]);
});

TVM_REGISTER_GLOBAL("topi.dynamic_strided_slice").set_body([](TVMArgs args, TVMRetValue* rv) {
Expand Down

0 comments on commit 150e945

Please sign in to comment.