From 843e3a3f4bb9fda62fec731ecc5cd31ae67b0c24 Mon Sep 17 00:00:00 2001 From: Zihao Ye Date: Sun, 9 Oct 2022 03:21:20 -0700 Subject: [PATCH] [bugfix] Fix the behavior of IsHorizontalFuse (#49) * init * upd * upd --- src/tir/schedule/analysis/analysis.cc | 4 ++-- src/tir/schedule/primitive/cache_read_write.cc | 3 ++- src/tir/schedule/primitive/compute_at.cc | 13 +++++++------ src/tir/schedule/primitive/loop_transformation.cc | 9 +++++---- src/tir/schedule/utils.h | 11 +++++------ 5 files changed, 21 insertions(+), 19 deletions(-) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 3e0cf652f619..459fea2a2393 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -256,7 +256,7 @@ int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block // Cond 2. Dominant: the block is the only writer of its output, // dominating the reader of its output buffers if (!IsDominantBlock(self, scope_root_sref, block_sref)) { - if (!IsHorizontalFuse(self)) { + if (!IsHorizontalFuse(self, block_sref)) { return 2; } } @@ -358,7 +358,7 @@ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& bloc // Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its // output buffers. if (!IsDominantBlock(self, scope_root_sref, block_sref)) { - if (!IsHorizontalFuse(self)) { + if (!IsHorizontalFuse(self, block_sref)) { return 4; } } diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index 39322a56988c..df99b9f037f1 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -928,7 +928,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu info.annotations = block->annotations; // Step 3. Check the only writer block. - if (!IsHorizontalFuse(self)) { + if (!IsHorizontalFuse(self, block_sref)) { ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get()); } @@ -1072,6 +1072,7 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index, const String& storage_scope) { LOG(FATAL) << "Not implemented yet."; + return StmtSRef(); } /******** Instruction Registration ********/ diff --git a/src/tir/schedule/primitive/compute_at.cc b/src/tir/schedule/primitive/compute_at.cc index 782fe0078a73..5853f9f47b11 100644 --- a/src/tir/schedule/primitive/compute_at.cc +++ b/src/tir/schedule/primitive/compute_at.cc @@ -134,13 +134,12 @@ class NotInSameScopeError : public ScheduleError { * \throws ScheduleError if there is no such insertion point found */ template -int FindInsertionPoint( - const ScheduleState& self, const Array& subtrees, const Array& producer_srefs, - const Array& consumer_srefs, - std::unordered_map* block2realize) { +int FindInsertionPoint(const ScheduleState& self, const Array& subtrees, + const Array& producer_srefs, const Array& consumer_srefs, + std::unordered_map* block2realize, + bool horizontal_fuse) { ProducerConsumerSplit split = ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize); - bool horizontal_fuse = IsHorizontalFuse(self); // Step 1. Check if all the producers are visited in the subtrees, if required to if (require_all_producers_visited && !horizontal_fuse) { int num_producers = producer_srefs.size(); @@ -582,11 +581,13 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s // Check condition 5): all the required block are under the given loop std::unordered_map block2realize; block2realize.reserve(self->block_info.size()); + bool is_horizontal_fuse = IsHorizontalFuse(self, block_sref); int insert_position = FindInsertionPoint( /*self=*/self, /*subtrees=*/AsArray(loop->body), /*producer_srefs=*/producer_srefs, - /*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize); + /*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize, + /*horizontal_fuse=*/is_horizontal_fuse); // Step 4. Calculate the region provided by a single execution instance of `block`, // as well as the region required by dependent blocks under `loop`. // Here is the definition of `provide` and `require`: diff --git a/src/tir/schedule/primitive/loop_transformation.cc b/src/tir/schedule/primitive/loop_transformation.cc index 2d8b5568e7a5..7aedbe9e0668 100644 --- a/src/tir/schedule/primitive/loop_transformation.cc +++ b/src/tir/schedule/primitive/loop_transformation.cc @@ -420,7 +420,8 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, const PrimExpr& factor = factors[i]; Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i)); // NOTE(zihao): fix the bug in hybrid spmm. - if (!is_one(factor) || IsHorizontalFuse(self)) substitute_value = substitute_value * factor + var; + if (!is_one(factor) || IsHorizontalFuse(self, loop_sref)) + substitute_value = substitute_value * factor + var; analyzer.Bind(var, Range::FromMinExtent(0, factor)); new_loop_vars.emplace_back(std::move(var)); } @@ -444,9 +445,9 @@ Array Split(ScheduleState self, const StmtSRef& loop_sref, for (int i = n - 1; i >= 0; i--) { new_stmt = For(new_loop_vars[i], 0, factors[i], ForKind::kSerial, new_stmt); } - if (!IsHorizontalFuse(self)) { - new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref), - opaque_block_reuse.CopyOnWrite()); + if (!IsHorizontalFuse(self, loop_sref)) { + new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings( + std::move(new_stmt), GetLoops(loop_sref), opaque_block_reuse.CopyOnWrite()); } self->Replace(loop_sref, new_stmt, opaque_block_reuse); Array result_srefs; diff --git a/src/tir/schedule/utils.h b/src/tir/schedule/utils.h index 4699ad484ec8..a1bc57d00cee 100644 --- a/src/tir/schedule/utils.h +++ b/src/tir/schedule/utils.h @@ -449,12 +449,11 @@ inline String BufferIndexType2Str(BufferIndexType buffer_index_type) { /******** Whether it's horizontal fuse function. ********/ -inline bool IsHorizontalFuse(const ScheduleState& state) { - IRModule mod = state->mod; - return mod->functions.Get(mod->GetGlobalVar("main")) - .value() - ->attrs->dict.Get("horizontal_fuse") - .defined(); +inline bool IsHorizontalFuse(const ScheduleState& state, const StmtSRef& sref) { + const StmtSRef& root_block_sref = GetSRefTreeRoot(sref); + const PrimFuncNode* func = GetRootPrimFunc(state->mod, root_block_sref->stmt, nullptr); + CHECK(func != nullptr) << "The given sref does not resides in the schedule state's module."; + return func->attrs.HasNonzeroAttr("horizontal_fuse"); } } // namespace tir