Skip to content

Commit

Permalink
* Review comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Oct 17, 2018
1 parent 3e38f99 commit f3490b2
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 10 deletions.
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
TVM_ATTR_FIELD(indices_or_sections)
.describe("Number of outputs to be splitted");
TVM_ATTR_FIELD(axis).set_lower_bound(0).set_default(1)
TVM_ATTR_FIELD(axis).set_default(0)
.describe("the axis to be splitted.");
TVM_ATTR_FIELD(equal_split).set_default(false)
.describe("Is it equal split of input");
Expand Down
24 changes: 15 additions & 9 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -712,43 +712,49 @@ bool SplitRel(const Array<Type>& types,
CHECK_EQ(types.size(), 2);
const auto* data = types[0].as<TensorTypeNode>();
CHECK(data != nullptr);
CHECK_NE(data->shape.size(), 0);
CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
const auto param = attrs.as<SplitAttrs>();
CHECK(param != nullptr);

auto axis = param->axis;
if (axis < 0) {
axis += data->shape.size();
}
CHECK_LT(axis, data->shape.size())
<< "axis should be within the input dimension range.";
CHECK_GT(axis, 0)
<< "axis should be within the input dimension range.";

if (param->equal_split) {
const auto num_outputs = as_const_int(param->indices_or_sections[0]);
CHECK_LT(param->axis, data->shape.size());
// CHECK(reporter->Assert(data->shape[param->axis] %
// CHECK(reporter->Assert(data->shape[axis] %
// param->indices_or_sections[0] == make_zero(Int(64))))
// << "indices_or_sections need to be able to divide input.shape[axis]";

std::vector<Type> fields;
for (int i = 0; i < *num_outputs; ++i) {
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
oshape[param->axis] /= param->indices_or_sections[0];
oshape[axis] /= param->indices_or_sections[0];
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
}
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
} else {
const auto num_outputs = param->indices_or_sections.size() + 1;
CHECK_LT(param->axis, data->shape.size());
auto begin = make_zero(Int(32));
std::vector<Type> fields;
for (uint i = 0; i < num_outputs - 1; ++i) {
// CHECK(reporter->Assert(param->indices_or_sections[i] > begin))
// << "indices_or_sections need to be a sorted ascending list";
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
oshape[param->axis] = param->indices_or_sections[i] - begin;
oshape[axis] = param->indices_or_sections[i] - begin;
begin = param->indices_or_sections[i];
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
}
// CHECK(reporter->Assert(begin < data->shape[param->axis]))
// CHECK(reporter->Assert(begin < data->shape[axis]))
// << "The sum of sections must match the input.shape[axis]";
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
oshape[param->axis] = data->shape[param->axis] - begin;
oshape[axis] = data->shape[axis] - begin;
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);

Expand Down

0 comments on commit f3490b2

Please sign in to comment.