Skip to content

Commit

Permalink
modified schedule_dataflow_rewrite.cc to fix Stale Tensor during Data…
Browse files Browse the repository at this point in the history
…flow Rewrite apache#738 (apache#747)

* modified schedule_dataflow_rewrite.cc to fix losing tensor problem

* modified schedule_dataflow_rewrite.cc for lint scan

* modified schedule_dataflow_rewrite.cc for lint scan

* using tensor's value_index to index output of stage op
  • Loading branch information
libing4752 authored and tqchen committed Jan 3, 2018
1 parent d8be197 commit c492737
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/schedule/schedule_dataflow_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ Tensor Schedule::cache_read(const Tensor& tensor,
return tensor(Array<Expr>(i.begin(), i.end()));
}, os.str());
std::unordered_map<Tensor, Tensor> vsub;
vsub[tensor] = cache;
Stage s = operator[](tensor->op);
Tensor sugar_tensor = s->op.output(tensor->value_index);
vsub[sugar_tensor] = cache;

std::unordered_map<Tensor, Tensor> vmap;
for (Operation op : readers) {
Expand Down
20 changes: 20 additions & 0 deletions tests/python/unittest/test_schedule_schedule_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,25 @@ def test_schedule_cache():
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)

def test_schedule_middle_cache():
m = tvm.var('m')
n = tvm.var('n')
A = tvm.placeholder((m, n), name='A')
B = tvm.placeholder((m, n), name='B')

C = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='C')
D = tvm.compute((m, n), lambda i, j: C(i , j) , name='D')

s = tvm.create_schedule(D.op)
AA = s.cache_read(A, "local", readers=[C])
BB = s.cache_read(B, "local", readers=[C])
CC = s.cache_read(C, "local", readers=[D])
DD = s.cache_write(D, "local")
#s[AA].compute_at(s[CC], CC.op.axis[0])
bounds = tvm.schedule.InferBound(s)
stmt = tvm.schedule.ScheduleOps(s, bounds)



def test_schedule_cache_relayout1():
m = tvm.var('m')
Expand Down Expand Up @@ -231,6 +250,7 @@ def test_schedule_cache_relayout3():


if __name__ == "__main__":
test_schedule_middle_cache()
test_inline_multi_reduce()
test_schedule_cache_relayout3()
test_schedule_cache_relayout2()
Expand Down

0 comments on commit c492737

Please sign in to comment.