Skip to content

Commit

Permalink
Merge pull request #5 from ZihengJiang/master
Browse files Browse the repository at this point in the history
Add tile operation
  • Loading branch information
tqchen authored Dec 11, 2016
2 parents 0c72ca9 + 0e99e8a commit f650216
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 0 deletions.
3 changes: 3 additions & 0 deletions include/tvm/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ class Schedule : public NodeRef {
* \return reference to self.
*/
Schedule& reorder(const Array<IterVar>& order); // NOLINT(*)
Schedule& tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer,
IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor); // NOLINT(*)
};

/*!
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,8 @@ def reorder(self, *args):
The order to be ordered
"""
_function_internal._ScheduleReorder(self, args)

def tile(self, x_parent, y_parent, x_factor, y_factor):
x_outer, y_outer, x_inner, y_inner = _function_internal._ScheduleTile(
self, x_parent, y_parent, x_factor, y_factor)
return x_outer, y_outer, x_inner, y_inner
8 changes: 8 additions & 0 deletions src/c_api/c_api_lang.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,5 +151,13 @@ TVM_REGISTER_API(_ScheduleReorder)
.reorder(args.at(1));
});

TVM_REGISTER_API(_ScheduleTile)
.set_body([](const ArgStack& args, RetValue *ret) {
IterVar x_outer, y_outer, x_inner, y_inner;
args.at(0).operator Schedule()
.tile(args.at(1), args.at(2), &x_outer, &y_outer,
&x_inner, &y_inner, args.at(3), args.at(4));
*ret = Array<IterVar>({x_outer, y_outer, x_inner, y_inner});
});

} // namespace tvm
10 changes: 10 additions & 0 deletions src/lang/schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,16 @@ Schedule& Schedule::reorder(const Array<IterVar>& order) { // NOLINT(*)
return *this;
}

Schedule& Schedule::tile(IterVar x_parent, IterVar y_parent, IterVar* p_x_outer,
IterVar* p_y_outer, IterVar* p_x_inner, IterVar* p_y_inner,
Expr x_factor, Expr y_factor) { // NOLINT(*)

split(x_parent, p_x_outer, p_x_inner, x_factor);
split(y_parent, p_y_outer, p_y_inner, y_factor);
reorder(Array<IterVar>({*p_x_inner, *p_y_inner, *p_x_outer, *p_y_outer}));
return *this;
}

IterVarRelation SplitNode::make(
IterVar parent, IterVar outer,
IterVar inner, Expr factor) {
Expand Down
10 changes: 10 additions & 0 deletions tests/python/test_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,18 @@ def test_reorder():
sch_T.reorder(*order)
assert tuple(sch_T.leaf_iter_vars) == order

def test_tile():
m = tvm.Var('m')
n = tvm.Var('n')
A = tvm.placeholder((m, n), name='A')
T = tvm.compute((m, n), lambda i, j: A[i, j])

sch_T = tvm.Schedule(T.op, scope="shared")
xo, yo, xi, yi = sch_T.tile(T.op.dim_var[0], T.op.dim_var[1], x_factor=10, y_factor=5)
assert tuple(sch_T.leaf_iter_vars) == (xi, yi, xo, yo)

if __name__ == "__main__":
test_schedule_create()
test_reorder()
test_tile()

0 comments on commit f650216

Please sign in to comment.