From c492737826261a84b37138a215ac4fe7170a8c72 Mon Sep 17 00:00:00 2001 From: libing4752 Date: Thu, 4 Jan 2018 05:37:56 +0800 Subject: [PATCH] modified schedule_dataflow_rewrite.cc to fix Stale Tensor during Dataflow Rewrite #738 (#747) * modified schedule_dataflow_rewrite.cc to fix losing tensor problem * modified schedule_dataflow_rewrite.cc for lint scan * modified schedule_dataflow_rewrite.cc for lint scan * using tensor's value_index to index output of stage op --- src/schedule/schedule_dataflow_rewrite.cc | 4 +++- .../unittest/test_schedule_schedule_ops.py | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index d1a69ecf0203..b58df9d0481f 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -86,7 +86,9 @@ Tensor Schedule::cache_read(const Tensor& tensor, return tensor(Array(i.begin(), i.end())); }, os.str()); std::unordered_map vsub; - vsub[tensor] = cache; + Stage s = operator[](tensor->op); + Tensor sugar_tensor = s->op.output(tensor->value_index); + vsub[sugar_tensor] = cache; std::unordered_map vmap; for (Operation op : readers) { diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index a85db2a23e86..03b8dbf48c8c 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -182,6 +182,25 @@ def test_schedule_cache(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) +def test_schedule_middle_cache(): + m = tvm.var('m') + n = tvm.var('n') + A = tvm.placeholder((m, n), name='A') + B = tvm.placeholder((m, n), name='B') + + C = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='C') + D = tvm.compute((m, n), lambda i, j: C(i , j) , name='D') + + s = tvm.create_schedule(D.op) + AA = s.cache_read(A, "local", readers=[C]) + BB = s.cache_read(B, "local", readers=[C]) + CC = s.cache_read(C, "local", readers=[D]) + DD = s.cache_write(D, "local") + #s[AA].compute_at(s[CC], CC.op.axis[0]) + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + def test_schedule_cache_relayout1(): m = tvm.var('m') @@ -231,6 +250,7 @@ def test_schedule_cache_relayout3(): if __name__ == "__main__": + test_schedule_middle_cache() test_inline_multi_reduce() test_schedule_cache_relayout3() test_schedule_cache_relayout2()