Skip to content

Commit

Permalink
[TE] Add LegalizeInvalidAttach to legalize the compute_at location af…
Browse files Browse the repository at this point in the history
…ter split or fuse (apache#5917)

* Add LegalizeInvalidAttach

* lint & typo

* lint & typo

* address comment

* fix lint
  • Loading branch information
merrymercy authored and Trevor Morris committed Jun 30, 2020
1 parent 19ccf75 commit 317af72
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 2 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ tvm_t.*
.python_history
.pytest_cache
.local
cmake-build-debug

# Visual Studio Code
.vscode
Expand Down
79 changes: 77 additions & 2 deletions src/te/schedule/schedule_dataflow_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ Tensor Schedule::cache_write(const Tensor& tensor, const std::string& scope) {
}
}

void RebaseNonZeroMinLoop(const Schedule& sch) {
void RebaseNonZeroMinLoop(ScheduleNode* sch) {
std::unordered_map<IterVar, IterVar> rebase_map;
for (Stage s : sch->stages) {
if (s->attach_type == kInlinedAlready) continue;
Expand Down Expand Up @@ -614,10 +614,85 @@ void InjectInline(ScheduleNode* sch) {
}
}

void LegalizeInvalidAttach(ScheduleNode* sch) {
// Legalize the compute_at location if the target iterator of compute_at is split or fused.
// Case 1: If the target of compute_at is split,
// we will move the compute_at location to the inner iterator.
// Case 2: If the target of compute_at is fused,
// we will move the compute_at location to the newly fused iterator.
// Note that case 2 can only happen if the target of compute_at
// is the innermost operand of fuse operation.

// Map an old invalid attach point to its new valid attach point
std::unordered_map<IterVar, IterVar> replace_map;

for (Stage stage : sch->stages) {
for (Stage s = stage; s.defined();) {
// The following logic is simiar to the `CreateAttachPath` in `src/te/schedule/graph.h`,
// because we follow the validation check in that function to legalize the attach.
Stage spec = s.GetAttachSpec();
if (spec->attach_type != kScope) {
break;
}
bool start_attach = false;
IterVar attach_ivar = spec->attach_ivar;
s = spec->attach_stage;
CHECK(attach_ivar.defined());
CHECK(s.defined());

for (size_t i = s->leaf_iter_vars.size(); i != 0; --i) {
IterVar iv = s->leaf_iter_vars[i - 1];
if (!start_attach && iv.same_as(attach_ivar)) {
start_attach = true;
break;
}
}

if (!start_attach) {
IterVar new_attach_ivar = attach_ivar;
bool updated = true;
// recursively update the relations
while (updated) {
updated = false;
for (const auto& rel : s->relations) {
if (const FuseNode* r = rel.as<FuseNode>()) {
if (new_attach_ivar.same_as(r->inner)) {
new_attach_ivar = r->fused;
updated = true;
}
} else if (const SplitNode* r = rel.as<SplitNode>()) {
if (new_attach_ivar.same_as(r->parent)) {
new_attach_ivar = r->inner;
updated = true;
}
}
}
replace_map[attach_ivar] = new_attach_ivar;
}
}
}
}

// remap the parent relation
for (Stage s : sch->stages) {
if (s->attach_type != kScope) continue;
if (replace_map.count(s->attach_ivar)) {
s->attach_ivar = replace_map.at(s->attach_ivar);
}
}
for (Stage s : sch->groups) {
if (s->attach_type != kScope) continue;
if (replace_map.count(s->attach_ivar)) {
s->attach_ivar = replace_map.at(s->attach_ivar);
}
}
}

Schedule Schedule::normalize() {
Schedule sn = copy();
InjectInline(sn.operator->());
RebaseNonZeroMinLoop(sn);
RebaseNonZeroMinLoop(sn.operator->());
LegalizeInvalidAttach(sn.operator->());
return sn;
}

Expand Down
20 changes: 20 additions & 0 deletions tests/python/unittest/test_te_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,25 @@ def intrin_func(ins, outs, sp):
assert str(stmt.body.body.value.args[3]) == "(i: int32*i)"
assert str(stmt.body.body.value.args[4]) == "(i: int32 + j: int32)"

def test_legalize_invalid_attach():
A = te.compute((10, 10), lambda i, j: 1.0, name='A')
B = te.compute((10, 10), lambda i, j: A[i][j], name='B')

# Case 1: Split an axis which is the target of a compute_at
s = te.create_schedule([B.op])
s[A].compute_at(s[B], B.op.axis[1])
s[B].split(B.op.axis[1], 2)

stmt = tvm.lower(s, [A, B], simple_mode=True)['main'].body
assert isinstance(stmt.body.body, tvm.tir.stmt.For)

# Case 2: Fuse an axis which is the target of a compute_at
s = te.create_schedule([B.op])
s[A].compute_at(s[B], B.op.axis[1])
s[B].fuse(B.op.axis[0], B.op.axis[1])
stmt = tvm.lower(s, [A, B], simple_mode=True)['main'].body
assert isinstance(stmt, tvm.tir.stmt.For)

if __name__ == "__main__":
test_singleton()
test_pragma()
Expand All @@ -305,3 +324,4 @@ def intrin_func(ins, outs, sp):
test_fuse_with_out_of_order_axis_with_reorder()
test_vectorize()
test_vectorize_commreduce()
test_legalize_invalid_attach()

0 comments on commit 317af72

Please sign in to comment.