Skip to content

Commit

Permalink
rebase & bugfix
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Jun 17, 2021
1 parent d1b9bc1 commit dfb7112
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 43 deletions.
2 changes: 1 addition & 1 deletion src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ CommReducer::CommReducer(Array<Var> lhs, Array<Var> rhs, Array<PrimExpr> result,
data_ = std::move(node);
}

`bool CommReducer::MatchReducer(const CommReducer& reducer, const PrimExpr& identity,
bool CommReducer::MatchReducer(const CommReducer& reducer, const PrimExpr& identity,
const PrimExpr& combiner, Optional<PrimExpr>& lhs,
Optional<PrimExpr>& rhs) {
ExprDeepEqual equal;
Expand Down
2 changes: 2 additions & 0 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#include <tvm/tir/stmt_functor.h>
#include <tvm/arith/iter_affine_map.h>

#include "../../runtime/thread_storage_scope.h"

namespace tvm {
namespace tir {

Expand Down
36 changes: 1 addition & 35 deletions src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ bool IsReductionBlock(const ScheduleState& self, const StmtSRef& block_sref,
}
return true;
});
return affected;
return !affected;
}

/******** Binding ********/
Expand Down Expand Up @@ -523,40 +523,6 @@ BufferRegion SubstituteBufferRegion(const BufferRegion& buffer_region,
return BufferRegion(new_buffer_region);
}

/******** Block Information Update ********/

void UpdateScope(ScheduleState self, const StmtSRef& block_sref) {
BlockScope scope(tir::GetChildBlocks(self, block_sref));
// The caller is responsible for correcting the flags
bool affine_binding = false;
bool region_cover = false;
self->block_info[block_sref] =
BlockInfo(std::move(scope), affine_binding, region_cover);
}

void UpdateAffineFlag(ScheduleState self, const StmtSRef& block_sref) {
if (block_sref->parent == nullptr) {
ICHECK(self->block_info.count(block_sref));
self->block_info[block_sref].affine_binding = true;
return;
}
BlockRealize realize = GetBlockRealize(block_sref);
const auto* block = TVM_SREF_TO_BLOCK(block, block_sref);
Map<Var, Range> loop_var_ranges;
for (StmtSRefNode* loop_sref = block_sref->parent; loop_sref != nullptr;
loop_sref = loop_sref->parent) {
if (const auto* loop = loop_sref->StmtAs<ForNode>()) {
loop_var_ranges.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
} else {
break;
}
}
ICHECK(self->block_info.count(block_sref));
arith::Analyzer analyzer;
self->block_info[block_sref].affine_binding =
IsAffineBinding(realize, loop_var_ranges, &analyzer);
}

/******** Pattern Matcher ********/

void PatternMatcher::VisitExpr_(const VarNode* op) {
Expand Down
12 changes: 5 additions & 7 deletions src/tir/schedule/primitives/reduction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis)
BlockRealize block_realize = GetBlockRealize(block_sref);
Block block = block_realize->block;
Optional<StmtSRef> scope_root = GetScopeRoot(block_sref);
// Todo: comment out the two lines below out after Junru's PR getting merged
// CHECK(IsReductionBlock(self, block_sref, scope_root.value()))
// << "ValueError: We can only do rfactor for loops of a reduction block";
CHECK(IsReductionBlock(self, block_sref, scope_root.value()))
<< "ValueError: We can only do rfactor for loops of a reduction block";

// Collect the information of the reduction.
// Get the `init` identity and the `update` combiner of the reduction.
Expand Down Expand Up @@ -390,10 +389,9 @@ StmtSRef RFactor(ScheduleState self, const StmtSRef& loop_sref, int factor_axis)
self->Replace(scope_root.value(), new_scope_block, {{scope_block, new_scope_block}});
// Update scope information.
StmtSRef rf_block_sref = self->stmt2ref.at(rf_block.get());
UpdateScope(self, scope_root.value());
UpdateAffineFlag(self, scope_root.value());
UpdateAffineFlag(self, rf_block_sref);
// Todo: in which cases should we call UpdateScope & UpdateAffineFlag?
self->block_info[rf_block_sref].affine_binding = true;
self->block_info[rf_block_sref].region_cover = true;
self->block_info[rf_block_sref].scope->stage_pipeline = true;
return rf_block_sref;
}

Expand Down

0 comments on commit dfb7112

Please sign in to comment.