diff --git a/src/schedule/schedule_dataflow_rewrite.cc b/src/schedule/schedule_dataflow_rewrite.cc index d1a69ecf0203b..b58df9d0481fa 100644 --- a/src/schedule/schedule_dataflow_rewrite.cc +++ b/src/schedule/schedule_dataflow_rewrite.cc @@ -86,7 +86,9 @@ Tensor Schedule::cache_read(const Tensor& tensor, return tensor(Array(i.begin(), i.end())); }, os.str()); std::unordered_map vsub; - vsub[tensor] = cache; + Stage s = operator[](tensor->op); + Tensor sugar_tensor = s->op.output(tensor->value_index); + vsub[sugar_tensor] = cache; std::unordered_map vmap; for (Operation op : readers) { diff --git a/tests/python/unittest/test_schedule_schedule_ops.py b/tests/python/unittest/test_schedule_schedule_ops.py index a85db2a23e86c..03b8dbf48c8c9 100644 --- a/tests/python/unittest/test_schedule_schedule_ops.py +++ b/tests/python/unittest/test_schedule_schedule_ops.py @@ -182,6 +182,25 @@ def test_schedule_cache(): bounds = tvm.schedule.InferBound(s) stmt = tvm.schedule.ScheduleOps(s, bounds) +def test_schedule_middle_cache(): + m = tvm.var('m') + n = tvm.var('n') + A = tvm.placeholder((m, n), name='A') + B = tvm.placeholder((m, n), name='B') + + C = tvm.compute((m, n), lambda i, j: A(i, j) * B(i, j), name='C') + D = tvm.compute((m, n), lambda i, j: C(i , j) , name='D') + + s = tvm.create_schedule(D.op) + AA = s.cache_read(A, "local", readers=[C]) + BB = s.cache_read(B, "local", readers=[C]) + CC = s.cache_read(C, "local", readers=[D]) + DD = s.cache_write(D, "local") + #s[AA].compute_at(s[CC], CC.op.axis[0]) + bounds = tvm.schedule.InferBound(s) + stmt = tvm.schedule.ScheduleOps(s, bounds) + + def test_schedule_cache_relayout1(): m = tvm.var('m') @@ -231,6 +250,7 @@ def test_schedule_cache_relayout3(): if __name__ == "__main__": + test_schedule_middle_cache() test_inline_multi_reduce() test_schedule_cache_relayout3() test_schedule_cache_relayout2()