Skip to content

Commit

Permalink
[bugfix] Fix the behavior of IsHorizontalFuse (#49)
Browse files Browse the repository at this point in the history
* init

* upd

* upd
  • Loading branch information
yzh119 committed Oct 9, 2022
1 parent a3f59ba commit 843e3a3
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 19 deletions.
4 changes: 2 additions & 2 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ int CheckCompleteBlockErrorCode(const ScheduleState& self, const StmtSRef& block
// Cond 2. Dominant: the block is the only writer of its output,
// dominating the reader of its output buffers
if (!IsDominantBlock(self, scope_root_sref, block_sref)) {
if (!IsHorizontalFuse(self)) {
if (!IsHorizontalFuse(self, block_sref)) {
return 2;
}
}
Expand Down Expand Up @@ -358,7 +358,7 @@ int CheckReductionBlockErrorCode(const ScheduleState& self, const StmtSRef& bloc
// Cond 4. Dominant: the block is the only writer of its output, dominating the reader of its
// output buffers.
if (!IsDominantBlock(self, scope_root_sref, block_sref)) {
if (!IsHorizontalFuse(self)) {
if (!IsHorizontalFuse(self, block_sref)) {
return 4;
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -928,7 +928,7 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_bu
info.annotations = block->annotations;

// Step 3. Check the only writer block.
if (!IsHorizontalFuse(self)) {
if (!IsHorizontalFuse(self, block_sref)) {
ICHECK_EQ(block_sref.get(), GetOnlyWriteBlock(self, scope_sref, write_buffer).get());
}

Expand Down Expand Up @@ -1072,6 +1072,7 @@ StmtSRef ReverseCacheRead(ScheduleState self, const StmtSRef& block_sref, int re
StmtSRef ReverseCacheWrite(ScheduleState self, const StmtSRef& block_sref, int write_buffer_index,
const String& storage_scope) {
LOG(FATAL) << "Not implemented yet.";
return StmtSRef();
}

/******** Instruction Registration ********/
Expand Down
13 changes: 7 additions & 6 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,12 @@ class NotInSameScopeError : public ScheduleError {
* \throws ScheduleError if there is no such insertion point found
*/
template <bool require_all_producers_visited, bool require_all_consumers_visited>
int FindInsertionPoint(
const ScheduleState& self, const Array<Stmt>& subtrees, const Array<StmtSRef>& producer_srefs,
const Array<StmtSRef>& consumer_srefs,
std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize) {
int FindInsertionPoint(const ScheduleState& self, const Array<Stmt>& subtrees,
const Array<StmtSRef>& producer_srefs, const Array<StmtSRef>& consumer_srefs,
std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize,
bool horizontal_fuse) {
ProducerConsumerSplit split =
ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize);
bool horizontal_fuse = IsHorizontalFuse(self);
// Step 1. Check if all the producers are visited in the subtrees, if required to
if (require_all_producers_visited && !horizontal_fuse) {
int num_producers = producer_srefs.size();
Expand Down Expand Up @@ -582,11 +581,13 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
// Check condition 5): all the required block are under the given loop
std::unordered_map<const BlockNode*, const BlockRealizeNode*> block2realize;
block2realize.reserve(self->block_info.size());
bool is_horizontal_fuse = IsHorizontalFuse(self, block_sref);
int insert_position = FindInsertionPoint<!is_compute_at, is_compute_at>(
/*self=*/self,
/*subtrees=*/AsArray(loop->body),
/*producer_srefs=*/producer_srefs,
/*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize);
/*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize,
/*horizontal_fuse=*/is_horizontal_fuse);
// Step 4. Calculate the region provided by a single execution instance of `block`,
// as well as the region required by dependent blocks under `loop`.
// Here is the definition of `provide` and `require`:
Expand Down
9 changes: 5 additions & 4 deletions src/tir/schedule/primitive/loop_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
const PrimExpr& factor = factors[i];
Var var = loop->loop_var.copy_with_suffix("_" + std::to_string(i));
// NOTE(zihao): fix the bug in hybrid spmm.
if (!is_one(factor) || IsHorizontalFuse(self)) substitute_value = substitute_value * factor + var;
if (!is_one(factor) || IsHorizontalFuse(self, loop_sref))
substitute_value = substitute_value * factor + var;
analyzer.Bind(var, Range::FromMinExtent(0, factor));
new_loop_vars.emplace_back(std::move(var));
}
Expand All @@ -444,9 +445,9 @@ Array<StmtSRef> Split(ScheduleState self, const StmtSRef& loop_sref,
for (int i = n - 1; i >= 0; i--) {
new_stmt = For(new_loop_vars[i], 0, factors[i], ForKind::kSerial, new_stmt);
}
if (!IsHorizontalFuse(self)) {
new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(std::move(new_stmt), GetLoops(loop_sref),
opaque_block_reuse.CopyOnWrite());
if (!IsHorizontalFuse(self, loop_sref)) {
new_stmt = IterMapSimplifyBlockBinding::SimplifyBindings(
std::move(new_stmt), GetLoops(loop_sref), opaque_block_reuse.CopyOnWrite());
}
self->Replace(loop_sref, new_stmt, opaque_block_reuse);
Array<StmtSRef> result_srefs;
Expand Down
11 changes: 5 additions & 6 deletions src/tir/schedule/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -449,12 +449,11 @@ inline String BufferIndexType2Str(BufferIndexType buffer_index_type) {

/******** Whether it's horizontal fuse function. ********/

inline bool IsHorizontalFuse(const ScheduleState& state) {
IRModule mod = state->mod;
return mod->functions.Get(mod->GetGlobalVar("main"))
.value()
->attrs->dict.Get("horizontal_fuse")
.defined();
inline bool IsHorizontalFuse(const ScheduleState& state, const StmtSRef& sref) {
const StmtSRef& root_block_sref = GetSRefTreeRoot(sref);
const PrimFuncNode* func = GetRootPrimFunc(state->mod, root_block_sref->stmt, nullptr);
CHECK(func != nullptr) << "The given sref does not resides in the schedule state's module.";
return func->attrs.HasNonzeroAttr("horizontal_fuse");
}

} // namespace tir
Expand Down

0 comments on commit 843e3a3

Please sign in to comment.