diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 819994296639..6409f70e22e2 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -290,8 +290,6 @@ ScheduleRule ScheduleRule::MultiLevelTiling(String structure, Optional MultiLevelTilingNode::TransformWithTensorIntrin(State& state, const String& intrin_name) const { - // Optional opt_tensorize_info = GetTensorizeLoopMapping( - // sch->state(), sch->GetSRef(block_rv), tir::TensorIntrin::Get(intrin_name)->desc); BlockRV block_rv = state.block_rv; Optional opt_layout_info = GetTensorizeLayoutInfo(state.sch->state(), state.sch->GetSRef(block_rv), @@ -310,62 +308,60 @@ Optional MultiLevelTilingNode::TransformWithTensorIntrin(State& state, c for (size_t i = 0; i < block->writes.size(); ++i) { buffers[block->writes[i]->buffer] = std::move(std::make_pair(i, false)); } - + // Reindex buffers and insert reindex stage state.tensor_core_reindex_store = state.sch->ReIndex(block_rv, 0, true); state.tensor_core_reindex_A = state.sch->ReIndex(block_rv, 0, false); state.tensor_core_reindex_B = state.sch->ReIndex(block_rv, 1, false); - state.sch->TransformBlockLayout(state.tensor_core_reindex_store.value(), info->mapping); - state.sch->TransformBlockLayout(state.tensor_core_reindex_A.value(), info->mapping); - state.sch->TransformBlockLayout(state.tensor_core_reindex_B.value(), info->mapping); - state.sch->TransformBlockLayout(state.block_rv, info->mapping); - - size_t offset = info->mapping->final_indices.size() - info->rhs_iters.size(); + // Transform the layout of reindex buffers accordingly + std::unordered_set unmapped_vars; + std::unordered_map representer_map; std::unordered_map tgt_iter_map; - + size_t offset = info->mapping->final_indices.size() - info->rhs_iters.size(); + ICHECK_EQ(info->lhs_iters.size(), info->mapping->initial_indices.size()); + for (size_t i = 0; i < info->lhs_iters.size(); ++i) { + representer_map[info->lhs_iters[i]->var] = info->mapping->initial_indices[i]; + } + for (size_t i = 0; i < offset; ++i) { + const tir::VarNode* var_ptr = info->mapping->final_indices[i].as(); + ICHECK(var_ptr != nullptr); + unmapped_vars.insert(Downcast(info->mapping->final_indices[i])); + } for (size_t i = offset; i < info->mapping->final_indices.size(); ++i) { tgt_iter_map[info->rhs_iters[i - offset]->var] = info->mapping->final_indices[i]; } - for (const auto& it : buffers) { // organize the mappings for buffer layout transformation const tir::Buffer& rhs_buffer = info->lhs_buffer_map[it.first]; - std::vector new_representers; - std::vector new_tgt_iters; - std::unordered_set covered; - auto collect = [&](const ObjectRef& obj) -> bool { - if (const tir::VarNode* var = obj.as()) { - covered.insert(GetRef(var)); - } - return true; - }; - // new target iters - for (const PrimExpr& index : info->lhs_indices_map[it.first]) { - tir::PreOrderVisit(index, collect); - } - for (size_t i = 0; i < offset; ++i) { - if (covered.count(info->lhs_iters[i]->var)) { - covered.insert(info->mapping->initial_indices[i]); - new_tgt_iters.push_back(info->mapping->final_indices[i]); + std::vector sub_representers; + std::vector sub_target_iters; + // Refresh block sref and handler + block_sref = state.sch->GetSRef(state.block_rv); + block = TVM_SREF_TO_BLOCK(block, block_sref); + const tir::BufferRegion& region = it.second.second ? block->reads[it.second.first] : block->writes[it.second.first]; + for (const Range& range : region->region) { + ICHECK(tir::is_one(range->extent)); + const tir::VarNode* var_ptr = range->min.as(); + ICHECK(var_ptr != nullptr); + sub_representers.push_back(representer_map[GetRef(var_ptr)]); + + if (unmapped_vars.find(GetRef(var_ptr)) != unmapped_vars.end()) { + sub_target_iters.push_back(GetRef(var_ptr)); } } for (size_t i = 0; i < info->rhs_indices_map[rhs_buffer].size(); ++i) { const tir::VarNode* var = info->rhs_indices_map[rhs_buffer][i].as(); ICHECK(var != nullptr); - new_tgt_iters.push_back(tgt_iter_map[GetRef(var)]); - tir::PreOrderVisit(new_tgt_iters.back(), collect); - } - // new representers - for (const auto& representer : info->mapping->initial_indices) { - if (covered.count(representer)) { - new_representers.push_back(representer); - } + sub_target_iters.push_back(tgt_iter_map[GetRef(var)]); } - LOG(INFO) << "TransformaLayout " << it.second.first << it.first << " " << rhs_buffer; state.sch->TransformLayout(state.block_rv, it.second.first, it.second.second ? tir::BufferIndexType::kRead : tir::BufferIndexType::kWrite, - tir::IndexMap(new_representers, new_tgt_iters)); - LOG(INFO) << "OK"; + tir::IndexMap(sub_representers, sub_target_iters)); } + // Transform the layout of current block and reindex blocks + state.sch->TransformBlockLayout(state.tensor_core_reindex_store.value(), info->mapping); + state.sch->TransformBlockLayout(state.tensor_core_reindex_A.value(), info->mapping); + state.sch->TransformBlockLayout(state.tensor_core_reindex_B.value(), info->mapping); + state.sch->TransformBlockLayout(state.block_rv, info->mapping); Array loops = state.sch->GetLoops(state.block_rv); return loops[loops.size() - info->rhs_iters.size()]; diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index c7717fcd4355..81542cf04f1d 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -767,12 +767,11 @@ class LayoutInfoNode : public Object { public: IndexMap mapping; Map lhs_buffer_map; - Map> lhs_indices_map, rhs_indices_map; + Map> rhs_indices_map; Array lhs_iters, rhs_iters; void VisitAttrs(AttrVisitor* v) { v->Visit("mapping", &mapping); - v->Visit("lhs_indices_map", &lhs_indices_map); v->Visit("rhs_indices_map", &rhs_indices_map); v->Visit("lhs_iters", &lhs_iters); v->Visit("rhs_iters", &rhs_iters); diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 69b9094b8a10..72349c34785b 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -1283,7 +1283,6 @@ Optional GetTensorizeLayoutInfo(const tir::ScheduleState& self, // Only using 1 layout now ret->mapping = std::move(proposer.mappings_[0]); ret->lhs_buffer_map = std::move(proposer.lhs_buffer_map_); - ret->lhs_indices_map = std::move(extractor.lhs_buffer_indices_map_); ret->rhs_indices_map = std::move(extractor.rhs_buffer_indices_map_); ret->lhs_iters = std::move(extractor.lhs_iters_); ret->rhs_iters = std::move(extractor.rhs_iters_); diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index d2e4bad1f904..7324c98180f2 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -547,7 +547,7 @@ class ReverseComputeInliner : public BaseInliner { // Failure: no BufferLoad from the `inlined_buffer_` return false; } - int n_vars = GetNumUndefinedNonpointerVars(GetRef(inlined_store_)); + // int n_vars = GetNumUndefinedNonpointerVars(GetRef(inlined_store_)); //LOG(INFO) << "C " << n_vars; //LOG(INFO) << "Store: " << GetRef(inlined_store_);