Skip to content

Commit

Permalink
* review comments. Use NodeRef and keep numpy consistency.
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Oct 23, 2018
1 parent 1e48380 commit 889baf2
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 39 deletions.
11 changes: 6 additions & 5 deletions include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,18 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
}; // struct SqueezeAttrs

struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
Array<IndexExpr> indices_or_sections;
NodeRef indices_or_sections;
int axis;
bool equal_split;

TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
TVM_ATTR_FIELD(indices_or_sections)
.describe("Number of outputs to be splitted");
.describe("Indices or sections to split into. Accepts an int or a tuple"
"If indices_or_sections is an integer, the input will be divided equally"
"along given axis. If such a split is not possible, an error is raised."
"If indices_or_sections is a tuple of sorted integers,"
"the entries indicate where along axis the array is split.");
TVM_ATTR_FIELD(axis).set_default(0)
.describe("the axis to be splitted.");
TVM_ATTR_FIELD(equal_split).set_default(false)
.describe("Is it equal split of input");
}
};

Expand Down
8 changes: 5 additions & 3 deletions python/tvm/relay/op/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,14 +300,16 @@ def split(data, indices_or_sections, axis=0):
indices_or_sections : int or tuple of int
Indices or sections to split into. Accepts an int or a tuple
axis : int, optional
axis : int, optional
The axis over which to split.
Returns
-------
ret : relay.Tuple([relay.Expr, relay.Expr])
The computed result.
"""
ret_size = indices_or_sections if isinstance(indices_or_sections, int)
else len(indices_or_sections)+1
if isinstance(indices_or_sections, int):
ret_size = indices_or_sections
else:
ret_size = len(indices_or_sections) + 1
return TupleWrapper(_make.split(data, indices_or_sections, axis), ret_size)
48 changes: 26 additions & 22 deletions src/relay/op/tensor/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
#include <tvm/relay/op.h>
#include <tvm/relay/attrs/transform.h>
#include <tvm/ir_operator.h>
#include <tvm/ir.h>
#include <vector>
#include "../op_common.h"


namespace tvm {
namespace relay {
using ir::IntImm;

/* relay.expand_dims */

Expand Down Expand Up @@ -793,7 +795,6 @@ bool SplitRel(const Array<Type>& types,
CHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
const auto param = attrs.as<SplitAttrs>();
CHECK(param != nullptr);

auto axis = param->axis;
if (axis < 0) {
axis += data->shape.size();
Expand All @@ -803,29 +804,31 @@ bool SplitRel(const Array<Type>& types,
CHECK_GT(axis, 0)
<< "axis should be within the input dimension range.";

if (param->equal_split) {
const auto num_outputs = as_const_int(param->indices_or_sections[0]);
if (param->indices_or_sections.as<IntImm>()) {
const auto sections = make_const(Int(32),
param->indices_or_sections.as<IntImm>()->value);
CHECK(reporter->Assert(data->shape[axis] %
param->indices_or_sections[0] == make_zero(Int(64))))
sections == make_zero(Int(64))))
<< "indices_or_sections need to be able to divide input.shape[axis]";
std::vector<Type> fields;
for (int i = 0; i < *num_outputs; ++i) {
for (int i = 0; i < *as_const_int(sections); ++i) {
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
oshape[axis] /= param->indices_or_sections[0];
oshape[axis] /= sections;
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
}
reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
} else {
const auto num_outputs = param->indices_or_sections.size() + 1;
auto begin = make_zero(Int(32));
auto indices = param->indices_or_sections.as<ArrayNode>()->data;
const auto num_outputs = indices.size() + 1;
auto begin = IndexExpr(make_zero(Int(32)));
std::vector<Type> fields;
for (uint i = 0; i < num_outputs - 1; ++i) {
CHECK(reporter->Assert(param->indices_or_sections[i] > begin))
CHECK(reporter->Assert(IndexExpr(indices[i]) > begin))
<< "indices_or_sections need to be a sorted ascending list";
std::vector<IndexExpr>&& oshape = AsVector(data->shape);
oshape[axis] = param->indices_or_sections[i] - begin;
begin = param->indices_or_sections[i];
oshape[axis] = IndexExpr(indices[i]) - begin;
begin = IndexExpr(indices[i]);
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);
}
Expand All @@ -835,38 +838,39 @@ bool SplitRel(const Array<Type>& types,
oshape[axis] = data->shape[axis] - begin;
auto vec_type = TensorTypeNode::make(oshape, data->dtype);
fields.push_back(vec_type);

reporter->Assign(types[1], TupleTypeNode::make(Array<Type>(fields)));
}
return true;
}

Expr MakeSplit(Expr data,
Array<IndexExpr> indices_or_sections,
int axis,
bool equal_split) {
NodeRef indices_or_sections,
int axis) {
auto attrs = make_node<SplitAttrs>();
attrs->axis = axis;
attrs->indices_or_sections = std::move(indices_or_sections);
attrs->equal_split = equal_split;
static const Op& op = Op::Get("split");
return CallNode::make(op, {data}, Attrs(attrs), {});
}

TVM_REGISTER_API("relay.op._make.split")
.set_body([](const TVMArgs& args, TVMRetValue* rv) {
runtime::detail::unpack_call<Expr, 4>(MakeSplit, args, rv);
if (args.type_codes[1] == kDLInt) {
*rv = MakeSplit(args[0], make_const(Int(64), int64_t(args[1])), args[2]);
} else {
*rv = MakeSplit(args[0], args[1], args[2]);
}
});

RELAY_REGISTER_OP("split")
.describe(R"code(Splits an array along a particular axis into multiple sub-arrays.
While equal_split is true `indices_or_sections` should be of size 1 and it indicates
number of sections to solit into and the dimension along given axis should be a
multiple of indices_or_section[0].
Indices or sections to split into. Accepts an int or a tuple
If indices_or_sections is an integer, the input will be divided equally
along given axis. If such a split is not possible, an error is raised.
With equal_split being false indices_or_section ia an ascending ordered list with in 0
and dimention of given axis. Here the input is split at the given indices.
If indices_or_sections is a tuple of sorted integers,
the entries indicate where along axis the array is split.
)code" TVM_ADD_FILELINE)
.set_attrs_type_key("relay.attrs.SplitAttrs")
Expand Down
16 changes: 7 additions & 9 deletions tests/python/relay/test_op_level3.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,38 +100,36 @@ def verify_take(dshape, indices_shape, oshape, axis=None):
verify_take((d1, d2, d3, d4), (d5, d6), (d1, d2, d5, d6, d4), -2)

def test_split_infer_type():
def verify_split(dshape, indices_or_sections, ret_type, axis=None, equal_split=True):
def verify_split(dshape, indices_or_sections, ret_type, axis=None):
x = relay.var("x", relay.ty.TensorType(dshape, "float32"))
y = relay.split(x, indices_or_sections, axis=axis, equal_split=equal_split)
y = relay.split(x, indices_or_sections, axis=axis)
yy = relay.ir_pass.infer_type(y)
assert yy.checked_type == ret_type

d1, d2, d3, d4 = tvm.var("d1"), tvm.var("d2"), tvm.var("d3"), tvm.var("d4")
axis = tvm.var("axis")
verify_split((5, 5, 2, 2), (5,),
verify_split((5, 5, 2, 2), 5,
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((5, 1, 2, 2), "float32"),
relay.ty.TensorType((5, 1, 2, 2), "float32"),
relay.ty.TensorType((5, 1, 2, 2), "float32"),
relay.ty.TensorType((5, 1, 2, 2), "float32"),
relay.ty.TensorType((5, 1, 2, 2), "float32")])),
axis=1, equal_split=True)

verify_split((d1, d2, d3, d4), (4,),
axis=1)
verify_split((d1, d2, d3, d4), 4,
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32"),
relay.ty.TensorType((d1, d2, d3/4, d4), "float32")])),
axis=2, equal_split=True)

axis=2)
verify_split((d1, d2, d3, d4), (2, 4, 7),
relay.ty.TupleType(tvm.convert([
relay.ty.TensorType((d1, 2, d3, d4), "float32"),
relay.ty.TensorType((d1, 2, d3, d4), "float32"),
relay.ty.TensorType((d1, 3, d3, d4), "float32"),
relay.ty.TensorType((d1, (d2-7), d3, d4), "float32")])),
axis=1, equal_split=False)
axis=1)

def test_full():
# default settings: match input dtype
Expand Down

0 comments on commit 889baf2

Please sign in to comment.