Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix #1

Merged
merged 1 commit into from
May 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 35 additions & 39 deletions src/meta_schedule/schedule_rule/multi_level_tiling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -290,8 +290,6 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional<Array<Str
}

Optional<LoopRV> MultiLevelTilingNode::TransformWithTensorIntrin(State& state, const String& intrin_name) const {
// Optional<tir::TensorizeInfo> opt_tensorize_info = GetTensorizeLoopMapping(
// sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc);
BlockRV block_rv = state.block_rv;
Optional<tir::LayoutInfo> opt_layout_info =
GetTensorizeLayoutInfo(state.sch->state(), state.sch->GetSRef(block_rv),
Expand All @@ -310,62 +308,60 @@ Optional<LoopRV> MultiLevelTilingNode::TransformWithTensorIntrin(State& state, c
for (size_t i = 0; i < block->writes.size(); ++i) {
buffers[block->writes[i]->buffer] = std::move(std::make_pair(i, false));
}

// Reindex buffers and insert reindex stage
state.tensor_core_reindex_store = state.sch->ReIndex(block_rv, 0, true);
state.tensor_core_reindex_A = state.sch->ReIndex(block_rv, 0, false);
state.tensor_core_reindex_B = state.sch->ReIndex(block_rv, 1, false);
state.sch->TransformBlockLayout(state.tensor_core_reindex_store.value(), info->mapping);
state.sch->TransformBlockLayout(state.tensor_core_reindex_A.value(), info->mapping);
state.sch->TransformBlockLayout(state.tensor_core_reindex_B.value(), info->mapping);
state.sch->TransformBlockLayout(state.block_rv, info->mapping);

size_t offset = info->mapping->final_indices.size() - info->rhs_iters.size();
// Transform the layout of reindex buffers accordingly
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> unmapped_vars;
std::unordered_map<tir::Var, tir::Var, ObjectPtrHash, ObjectPtrEqual> representer_map;
std::unordered_map<tir::Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> tgt_iter_map;

size_t offset = info->mapping->final_indices.size() - info->rhs_iters.size();
ICHECK_EQ(info->lhs_iters.size(), info->mapping->initial_indices.size());
for (size_t i = 0; i < info->lhs_iters.size(); ++i) {
representer_map[info->lhs_iters[i]->var] = info->mapping->initial_indices[i];
}
for (size_t i = 0; i < offset; ++i) {
const tir::VarNode* var_ptr = info->mapping->final_indices[i].as<tir::VarNode>();
ICHECK(var_ptr != nullptr);
unmapped_vars.insert(Downcast<tir::Var>(info->mapping->final_indices[i]));
}
for (size_t i = offset; i < info->mapping->final_indices.size(); ++i) {
tgt_iter_map[info->rhs_iters[i - offset]->var] = info->mapping->final_indices[i];
}

for (const auto& it : buffers) {
// organize the mappings for buffer layout transformation
const tir::Buffer& rhs_buffer = info->lhs_buffer_map[it.first];
std::vector<tir::Var> new_representers;
std::vector<PrimExpr> new_tgt_iters;
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> covered;
auto collect = [&](const ObjectRef& obj) -> bool {
if (const tir::VarNode* var = obj.as<tir::VarNode>()) {
covered.insert(GetRef<tir::Var>(var));
}
return true;
};
// new target iters
for (const PrimExpr& index : info->lhs_indices_map[it.first]) {
tir::PreOrderVisit(index, collect);
}
for (size_t i = 0; i < offset; ++i) {
if (covered.count(info->lhs_iters[i]->var)) {
covered.insert(info->mapping->initial_indices[i]);
new_tgt_iters.push_back(info->mapping->final_indices[i]);
std::vector<tir::Var> sub_representers;
std::vector<PrimExpr> sub_target_iters;
// Refresh block sref and handler
block_sref = state.sch->GetSRef(state.block_rv);
block = TVM_SREF_TO_BLOCK(block, block_sref);
const tir::BufferRegion& region = it.second.second ? block->reads[it.second.first] : block->writes[it.second.first];
for (const Range& range : region->region) {
ICHECK(tir::is_one(range->extent));
const tir::VarNode* var_ptr = range->min.as<tir::VarNode>();
ICHECK(var_ptr != nullptr);
sub_representers.push_back(representer_map[GetRef<tir::Var>(var_ptr)]);

if (unmapped_vars.find(GetRef<tir::Var>(var_ptr)) != unmapped_vars.end()) {
sub_target_iters.push_back(GetRef<tir::Var>(var_ptr));
}
}
for (size_t i = 0; i < info->rhs_indices_map[rhs_buffer].size(); ++i) {
const tir::VarNode* var = info->rhs_indices_map[rhs_buffer][i].as<tir::VarNode>();
ICHECK(var != nullptr);
new_tgt_iters.push_back(tgt_iter_map[GetRef<tir::Var>(var)]);
tir::PreOrderVisit(new_tgt_iters.back(), collect);
}
// new representers
for (const auto& representer : info->mapping->initial_indices) {
if (covered.count(representer)) {
new_representers.push_back(representer);
}
sub_target_iters.push_back(tgt_iter_map[GetRef<tir::Var>(var)]);
}
LOG(INFO) << "TransformaLayout " << it.second.first << it.first << " " << rhs_buffer;
state.sch->TransformLayout(state.block_rv, it.second.first,
it.second.second ? tir::BufferIndexType::kRead : tir::BufferIndexType::kWrite,
tir::IndexMap(new_representers, new_tgt_iters));
LOG(INFO) << "OK";
tir::IndexMap(sub_representers, sub_target_iters));
}
// Transform the layout of current block and reindex blocks
state.sch->TransformBlockLayout(state.tensor_core_reindex_store.value(), info->mapping);
state.sch->TransformBlockLayout(state.tensor_core_reindex_A.value(), info->mapping);
state.sch->TransformBlockLayout(state.tensor_core_reindex_B.value(), info->mapping);
state.sch->TransformBlockLayout(state.block_rv, info->mapping);

Array<LoopRV> loops = state.sch->GetLoops(state.block_rv);
return loops[loops.size() - info->rhs_iters.size()];
Expand Down
3 changes: 1 addition & 2 deletions src/tir/schedule/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -767,12 +767,11 @@ class LayoutInfoNode : public Object {
public:
IndexMap mapping;
Map<Buffer, Buffer> lhs_buffer_map;
Map<Buffer, Array<PrimExpr>> lhs_indices_map, rhs_indices_map;
Map<Buffer, Array<PrimExpr>> rhs_indices_map;
Array<IterVar> lhs_iters, rhs_iters;

void VisitAttrs(AttrVisitor* v) {
v->Visit("mapping", &mapping);
v->Visit("lhs_indices_map", &lhs_indices_map);
v->Visit("rhs_indices_map", &rhs_indices_map);
v->Visit("lhs_iters", &lhs_iters);
v->Visit("rhs_iters", &rhs_iters);
Expand Down
1 change: 0 additions & 1 deletion src/tir/schedule/analysis/analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1283,7 +1283,6 @@ Optional<LayoutInfo> GetTensorizeLayoutInfo(const tir::ScheduleState& self,
// Only using 1 layout now
ret->mapping = std::move(proposer.mappings_[0]);
ret->lhs_buffer_map = std::move(proposer.lhs_buffer_map_);
ret->lhs_indices_map = std::move(extractor.lhs_buffer_indices_map_);
ret->rhs_indices_map = std::move(extractor.rhs_buffer_indices_map_);
ret->lhs_iters = std::move(extractor.lhs_iters_);
ret->rhs_iters = std::move(extractor.rhs_iters_);
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ class ReverseComputeInliner : public BaseInliner {
// Failure: no BufferLoad from the `inlined_buffer_`
return false;
}
int n_vars = GetNumUndefinedNonpointerVars(GetRef<Stmt>(inlined_store_));
// int n_vars = GetNumUndefinedNonpointerVars(GetRef<Stmt>(inlined_store_));

//LOG(INFO) << "C " << n_vars;
//LOG(INFO) << "Store: " << GetRef<Stmt>(inlined_store_);
Expand Down