diff --git a/include/tvm/meta_schedule/postproc.h b/include/tvm/meta_schedule/postproc.h index f382cd219e0c..3a489555c443 100644 --- a/include/tvm/meta_schedule/postproc.h +++ b/include/tvm/meta_schedule/postproc.h @@ -118,12 +118,23 @@ class Postproc : public runtime::ObjectRef { PyPostprocNode::FInitializeWithTuneContext f_initialize_with_tune_context, // PyPostprocNode::FApply f_apply, // PyPostprocNode::FAsString f_as_string); + /*! + * \brief Create a postprocessor that checks if all loops are static + * \return The postprocessor created + */ + TVM_DLL static Postproc DisallowDynamicLoop(); /*! * \brief Create a postprocessor that rewrites the cooperative fetch annotation to * actual vectorized cooperative fetching in loop bindings. * \return The postprocessor created. */ TVM_DLL static Postproc RewriteCooperativeFetch(); + /*! + * \brief Creates a postprocessor that applies parallelization, vectorization and auto unrolling + * according to the annotation of each block + * \return The postprocessor created + */ + TVM_DLL static Postproc RewriteParallelVectorizeUnroll(); /*! * \brief Create a postprocessor that rewrites reduction block by moving the init block out. * \return The postprocessor created. @@ -134,6 +145,11 @@ class Postproc : public runtime::ObjectRef { * \return The postprocessor created. */ TVM_DLL static Postproc RewriteUnboundBlock(); + /*! + * \brief Creates a postprocessor that verifies if the GPU code is correct + * \return The postprocessor created + */ + TVM_DLL static Postproc VerifyGPUCode(); TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Postproc, ObjectRef, PostprocNode); }; diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 2a748a4653fa..56c0863cf16f 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1375,6 +1375,23 @@ constexpr const int meta_schedule_cache_type_read = 0; /*! \sa meta_schedule_cache_type */ constexpr const int meta_schedule_cache_type_write = 1; +/*! \brief Mark auto-parallel setting on the block. */ +constexpr const char* meta_schedule_parallel = "meta_schedule.parallel"; + +/*! \brief Mark auto-vectorize setting on the block. */ +constexpr const char* meta_schedule_vectorize = "meta_schedule.vectorize"; + +/*! \brief Mark auto-unroll setting on the block. */ +constexpr const char* meta_schedule_unroll_explicit = "meta_schedule.unroll_explicit"; + +/*! \brief Mark auto-unroll setting on the block. */ +constexpr const char* meta_schedule_unroll_implicit = "meta_schedule.unroll_implicit"; + +/*! \brief Pragma: auto-unroll, max_step */ +constexpr const char* pragma_auto_unroll_max_step = "pragma_auto_unroll_max_step"; + +/*! \brief Pragma: unroll explicit */ +constexpr const char* pragma_unroll_explicit = "pragma_unroll_explicit"; /*! * \brief Check if attr_key is a pragma key extension * \param attr_key The attr key to be compared diff --git a/python/tvm/meta_schedule/postproc/__init__.py b/python/tvm/meta_schedule/postproc/__init__.py index 440312812ec0..96361e739186 100644 --- a/python/tvm/meta_schedule/postproc/__init__.py +++ b/python/tvm/meta_schedule/postproc/__init__.py @@ -16,6 +16,9 @@ # under the License. """The tvm.meta_schedule.postproc package.""" from .postproc import Postproc, PyPostproc +from .disallow_dynamic_loop import DisallowDynamicLoop from .rewrite_cooperative_fetch import RewriteCooperativeFetch +from .rewrite_parallel_vectorize_unroll import RewriteParallelVectorizeUnroll from .rewrite_reduction_block import RewriteReductionBlock from .rewrite_unbound_block import RewriteUnboundBlock +from .verify_gpu_code import VerifyGPUCode diff --git a/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py new file mode 100644 index 000000000000..5515d288e0e7 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/disallow_dynamic_loop.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that checks if the IRModule has any loop with non-constant extent""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.DisallowDynamicLoop") +class DisallowDynamicLoop(Postproc): + """A postprocessor that checks if the IRModule has any loop with non-constant extent""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocDisallowDynamicLoop, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py new file mode 100644 index 000000000000..abe7288acba9 --- /dev/null +++ b/python/tvm/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that applies parallelization, vectorization and auto unrolling +according to the annotation of each block""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.RewriteParallelVectorizeUnroll") +class RewriteParallelVectorizeUnroll(Postproc): + """A postprocessor that applies parallelization, vectorization and auto unrolling + according to the annotation of each block""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocRewriteParallelVectorizeUnroll, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/postproc/verify_gpu_code.py b/python/tvm/meta_schedule/postproc/verify_gpu_code.py new file mode 100644 index 000000000000..501e4423196c --- /dev/null +++ b/python/tvm/meta_schedule/postproc/verify_gpu_code.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""A postprocessor that verifies if the GPU code is correct""" + +from tvm._ffi.registry import register_object +from .. import _ffi_api +from .postproc import Postproc + + +@register_object("meta_schedule.VerifyGPUCode") +class VerifyGPUCode(Postproc): + """A postprocessor that verifies if the GPU code is correct""" + + def __init__(self) -> None: + self.__init_handle_by_constructor__( + _ffi_api.PostprocVerifyGPUCode, # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/meta_schedule/tune_context.py b/python/tvm/meta_schedule/tune_context.py index 5d242290b123..ded55bba4baa 100644 --- a/python/tvm/meta_schedule/tune_context.py +++ b/python/tvm/meta_schedule/tune_context.py @@ -56,9 +56,9 @@ class TuneContext(Object): The search strategy. sch_rules: Optional[List[ScheduleRule]] = None, The schedule rules. - postproc: Optional[List[Postproc"]] = None, + postprocs: Optional[List[Postproc"]] = None, The postprocessors. - mutator: Optional[List[Mutator]] = None, + mutators: Optional[List[Mutator]] = None, The mutators. task_name : Optional[str] = None The name of the tuning task. @@ -81,8 +81,8 @@ class TuneContext(Object): space_generator: Optional["SpaceGenerator"] search_strategy: Optional["SearchStrategy"] sch_rules: Optional[List["ScheduleRule"]] - postproc: Optional[List["Postproc"]] - mutator: Optional[List["Mutator"]] + postprocs: Optional[List["Postproc"]] + mutators: Optional[List["Mutator"]] task_name: Optional[str] rand_state: int num_threads: int @@ -94,8 +94,8 @@ def __init__( space_generator: Optional["SpaceGenerator"] = None, search_strategy: Optional["SearchStrategy"] = None, sch_rules: Optional[List["ScheduleRule"]] = None, - postproc: Optional[List["Postproc"]] = None, - mutator: Optional[List["Mutator"]] = None, + postprocs: Optional[List["Postproc"]] = None, + mutators: Optional[List["Mutator"]] = None, task_name: Optional[str] = None, rand_state: int = -1, num_threads: Optional[int] = None, @@ -114,9 +114,9 @@ def __init__( The search strategy. sch_rules : List[ScheduleRule] = [] The schedule rules. - postproc : List[Postproc] = [] + postprocs : List[Postproc] = [] The postprocessors. - mutator : List[Mutator] = [] + mutators : List[Mutator] = [] The mutators. task_name : Optional[str] = None The name of the tuning task. @@ -138,8 +138,8 @@ def __init__( space_generator, search_strategy, sch_rules, - postproc, - mutator, + postprocs, + mutators, task_name, rand_state, num_threads, diff --git a/src/meta_schedule/postproc/disallow_dynamic_loop.cc b/src/meta_schedule/postproc/disallow_dynamic_loop.cc new file mode 100644 index 000000000000..715815843a84 --- /dev/null +++ b/src/meta_schedule/postproc/disallow_dynamic_loop.cc @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! \brief Check if the loop is dynamic. */ +struct DynamicExtentFinder : private StmtVisitor { + public: + static bool Find(const IRModule& mod) { + DynamicExtentFinder finder; + for (const auto& kv : mod->functions) { + const BaseFunc& func = kv.second; + if (const auto* prim_func = func.as()) { + finder(prim_func->body); + if (finder.found_) { + return true; + } + } + } + return false; + } + + private: + void VisitStmt_(const ForNode* loop) final { + if (!loop->extent->IsInstance()) { + found_ = true; + } else { + StmtVisitor::VisitStmt_(loop); + } + } + + void VisitStmt(const Stmt& stmt) final { + if (!found_) { + StmtVisitor::VisitStmt(stmt); + } + } + + bool found_ = false; +}; + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::Schedule; + +/*! \brief Check if the IRModule has any loop with non-constant extent. */ +class DisallowDynamicLoopNode : public PostprocNode { + public: + // Inherited from PostprocNode + void InitializeWithTuneContext(const TuneContext& context) final {} + // Inherited from PostprocNode + bool Apply(const tir::Schedule& sch) final { return !tir::DynamicExtentFinder::Find(sch->mod()); } + + static constexpr const char* _type_key = "meta_schedule.DisallowDynamicLoop"; + TVM_DECLARE_FINAL_OBJECT_INFO(DisallowDynamicLoopNode, PostprocNode); +}; + +Postproc Postproc::DisallowDynamicLoop() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(DisallowDynamicLoopNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocDisallowDynamicLoop") + .set_body_typed(Postproc::DisallowDynamicLoop); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc new file mode 100644 index 000000000000..447837d36b7a --- /dev/null +++ b/src/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc @@ -0,0 +1,393 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +/*! + * \brief Check whether the block/loop has any annotation + * \param sref The sref of block/loop + * \return Whether the block/loop has any annotation + */ +inline bool HasAnnOrBinding(const ForNode* loop) { + return loop->kind == ForKind::kThreadBinding || !loop->annotations.empty(); +} + +class StrideExtractor : public StmtExprVisitor { + public: + static int64_t Extract(const PrimExpr& expr, const Var& var) { + StrideExtractor extractor(var); + extractor.VisitExpr(expr); + return extractor.strides_[expr.get()]; + } + + private: + explicit StrideExtractor(const Var& var) : var_(var) {} + + void VisitExpr_(const MulNode* node) final { + StmtExprVisitor::VisitExpr_(node); + + if (const auto* a = node->a.as()) { + if (strides_.count(node->b.get())) { + strides_[node] = strides_[node->b.get()] * a->value; + } + } else if (const auto* b = node->b.as()) { + if (strides_.count(node->a.get())) { + strides_[node] = strides_[node->a.get()] * b->value; + } + } + } + + void VisitExpr_(const AddNode* node) final { + StmtExprVisitor::VisitExpr_(node); + int64_t stride_a, stride_b; + if (strides_.count(node->a.get())) { + stride_a = strides_[node->a.get()]; + } else { + stride_a = INT64_MAX; + } + if (strides_.count(node->b.get())) { + stride_b = strides_[node->b.get()]; + } else { + stride_b = INT64_MAX; + } + if (stride_a != INT64_MAX || stride_b != INT64_MAX) { + strides_[node] = std::min(stride_a, stride_b); + } + } + + void VisitExpr_(const VarNode* node) final { + if (node == var_.get()) { + strides_[node] = 1; + } + } + + const Var& var_; + std::unordered_map strides_; +}; + +struct ParsedAnnotation { + int max_parallel_extent; + int max_vectorize_extent; + int unroll_explicit; + int unroll_implicit; + int num_parallel_loops; + int num_vectorize_loops; +}; + +bool ParseAnnotation(const Block& block, ParsedAnnotation* parsed) { + bool found = false; + *parsed = ParsedAnnotation{-1, -1, -1, -1, -1, -1}; + for (const auto& ann : block->annotations) { + if (ann.first == attr::meta_schedule_parallel) { + found = true; + if (const auto* str_imm = ann.second.as()) { + parsed->max_parallel_extent = std::atoi(str_imm->value.c_str()); + } + } else if (ann.first == attr::meta_schedule_vectorize) { + found = true; + if (const auto* str_imm = ann.second.as()) { + parsed->max_vectorize_extent = std::atoi(str_imm->value.c_str()); + } + } else if (ann.first == attr::meta_schedule_unroll_explicit) { + found = true; + if (const auto* str_imm = ann.second.as()) { + parsed->unroll_explicit = std::atoi(str_imm->value.c_str()); + } + } else if (ann.first == attr::meta_schedule_unroll_implicit) { + found = true; + if (const auto* str_imm = ann.second.as()) { + parsed->unroll_implicit = std::atoi(str_imm->value.c_str()); + } + } + } + return found; +} + +void RemoveParsedAnn(const Schedule& sch, const BlockRV& block_rv, const ParsedAnnotation& parsed) { + if (parsed.max_parallel_extent != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_parallel); + } + if (parsed.max_vectorize_extent != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_vectorize); + } + if (parsed.unroll_explicit != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_unroll_explicit); + } + if (parsed.unroll_implicit != -1) { + sch->Unannotate(block_rv, attr::meta_schedule_unroll_implicit); + } +} + +void AdjustParallelVectorize(const Schedule& sch, const BlockRV& block_rv, + const Array& loop_rvs, ParsedAnnotation* parsed) { + StmtSRef block_sref = sch->GetSRef(block_rv); + if (parsed->max_parallel_extent == -1 && parsed->max_vectorize_extent == -1) { + return; + } + int n_loops = loop_rvs.size(); + if (n_loops == 0) { + parsed->max_parallel_extent = -1; + parsed->max_vectorize_extent = -1; + return; + } + // Extract loop_srefs, and calculate the iterator types + Array loop_srefs; + std::vector loop_types; + { + loop_srefs.reserve(n_loops); + loop_types.reserve(n_loops); + for (const LoopRV& loop_rv : loop_rvs) { + loop_srefs.push_back(sch->GetSRef(loop_rv)); + loop_types.push_back(GetLoopIterType(loop_srefs.back())); + } + } + // check the maximal number of axes that are vectorizable (contiguous memory access) + BlockRealize realize = GetBlockRealize(sch->state(), block_sref); + Array buffer_access(realize->block->reads); + buffer_access.insert(buffer_access.end(), realize->block->writes.begin(), + realize->block->writes.end()); + std::unordered_map binding_map; + for (size_t i = 0; i < realize->iter_values.size(); i++) { + binding_map[realize->block->iter_vars[i]->var.get()] = realize->iter_values[i]; + } + int max_fusible = INT32_MAX; + // for each block read/write, get the strides of the loop vars and find the fusible + // (vectorizable) axes + for (const BufferRegion& access : buffer_access) { + int fusible = 0; + std::vector strides; + // get strides for each loop var + for (const StmtSRef& loop_sref : loop_srefs) { + int64_t stride = 0, buffer_stride = 1; + const auto* var = loop_sref->StmtAs(); + arith::Analyzer analyzer; + for (int i = access->region.size() - 1; i >= 0; i--) { + PrimExpr idx = analyzer.Simplify(Substitute(access->region[i]->min, binding_map)); + int64_t coef = StrideExtractor::Extract(idx, var->loop_var); + if (coef != 0) { + stride = coef * buffer_stride; + break; + } + buffer_stride *= access->buffer->shape[i].as()->value; + } + strides.push_back(stride); + } + int prev_used_iter = -1; + // check the number of fusible loops + for (int i = strides.size() - 1; i >= 0; i--) { + if (strides[i] == 0) { + // not used in the buffer access, safe to fuse + fusible++; + continue; + } else if (prev_used_iter == -1) { + // the stride of last axis is not 1 means the memory access is not contiguous + if (strides[i] != 1) { + break; + } + fusible++; + prev_used_iter = i; + } else { + // contiguous memory access + const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs(); + int64_t prev_used_iter_extent = prev_loop->extent.as()->value; + if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) { + fusible++; + prev_used_iter = i; + } else { + break; + } + } + } + max_fusible = std::min(max_fusible, fusible); + } + // Calculate the parallelize extent + if (parsed->max_parallel_extent != -1) { + int max_extent = parsed->max_parallel_extent; + int& num_fusible = parsed->num_parallel_loops = 0; + int64_t prod_extent = 1; + for (int i = 0; i < n_loops && loop_types[i] == IterVarType::kDataPar; ++i) { + const StmtSRef& loop_sref = loop_srefs[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (HasAnnOrBinding(loop)) { + break; + } + // Check if the loop extent is valid + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (extent == nullptr) { + break; + } + // Then we can fuse it in + ++num_fusible; + // Check if we need to break + prod_extent *= *extent; + if (prod_extent > max_extent || !IsSingleStmt(loop->body)) { + break; + } + } + if (prod_extent == 1) { + num_fusible = -1; + } + } + // Calculate the vectorize extent + if (parsed->max_vectorize_extent != -1) { + int max_extent = parsed->max_vectorize_extent; + int& num_fusible = parsed->num_vectorize_loops = 0; + int64_t prod_extent = 1; + for (int i = n_loops - 1; + i >= 0 && loop_types[i] == IterVarType::kDataPar && num_fusible < max_fusible; --i) { + const StmtSRef& loop_sref = loop_srefs[i]; + const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref); + if (HasAnnOrBinding(loop)) { + break; + } + // Cannot vectorize reduce axis + if (GetLoopIterType(loop_sref) != IterVarType::kDataPar) { + break; + } + // Cannot fuse with a loop with multiple children + if (!IsSingleStmt(loop->body)) { + break; + } + // Check if the loop extent is valid + const int64_t* extent = GetLoopIntExtent(loop_sref); + if (extent == nullptr) { + break; + } + // Check if the extent is still in a good range + prod_extent *= *extent; + if (prod_extent > max_extent) { + break; + } + ++num_fusible; + } + if (prod_extent == 1) { + num_fusible = -1; + } + } + // Prefer num_vectorize to num_parallel + if (parsed->num_parallel_loops != -1 && parsed->num_vectorize_loops != -1) { + parsed->num_parallel_loops = std::min(parsed->num_parallel_loops, // + n_loops - parsed->num_vectorize_loops); + } +} + +bool FindAnnotateRootBlock(const Schedule& sch, ParsedAnnotation* parsed, BlockRV* root_rv) { + IRModule mod = sch->mod(); + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + Block block = Downcast(prim_func->body)->block; + if (ParseAnnotation(block, parsed)) { + *root_rv = sch->GetBlock(block->name_hint, g_var->name_hint); + RemoveParsedAnn(sch, *root_rv, *parsed); + return true; + } + } + } + return false; +} + +void RewriteParallel(const Schedule& sch, int n, Array* loop_rvs) { + LoopRV fused = sch->Fuse({loop_rvs->begin(), loop_rvs->begin() + n}); + sch->Parallel(fused); + for (int i = 0; i < n; ++i) { + loop_rvs->Set(i, fused); + } +} + +void RewriteVectorize(const Schedule& sch, int n, Array* loop_rvs) { + int n_loops = loop_rvs->size(); + LoopRV fused = sch->Fuse({loop_rvs->end() - n, loop_rvs->end()}); + sch->Vectorize(fused); + for (int i = n_loops - n; i < n_loops; ++i) { + loop_rvs->Set(i, fused); + } +} + +void RewriteUnroll(const Schedule& sch, int unroll_explicit, int max_step, const LoopRV& loop) { + if (max_step > 0) { + sch->Annotate(loop, attr::pragma_auto_unroll_max_step, IntImm(DataType::Int(32), max_step)); + sch->Annotate(loop, attr::pragma_unroll_explicit, IntImm(DataType::Int(32), unroll_explicit)); + } +} + +} // namespace tir +} // namespace tvm + +namespace tvm { +namespace meta_schedule { + +using tir::Schedule; + +class RewriteParallelVectorizeUnrollNode : public PostprocNode { + public: + void InitializeWithTuneContext(const TuneContext& context) final {} + + bool Apply(const Schedule& sch) final { + using tir::BlockRV; + using tir::LoopRV; + tir::ParsedAnnotation parsed_root; + BlockRV root_rv{nullptr}; + while (tir::FindAnnotateRootBlock(sch, &parsed_root, &root_rv)) { + for (BlockRV block_rv : sch->GetChildBlocks(root_rv)) { + Array loop_rvs = sch->GetLoops(block_rv); + if (loop_rvs.empty()) { + continue; + } + tir::ParsedAnnotation parsed = parsed_root; + tir::AdjustParallelVectorize(sch, block_rv, loop_rvs, &parsed); + // Parallel + if (parsed.num_parallel_loops > 0) { + tir::RewriteParallel(sch, parsed.num_parallel_loops, &loop_rvs); + } + // Vectorize + if (parsed.num_vectorize_loops > 0) { + tir::RewriteVectorize(sch, parsed.num_vectorize_loops, &loop_rvs); + } + // AutoUnroll + if (parsed.unroll_explicit != -1 || parsed.unroll_implicit != -1) { + ICHECK(parsed.unroll_explicit == -1 || parsed.unroll_implicit == -1); + int unroll_explicit = parsed.unroll_explicit != -1; + int max_step = parsed.unroll_explicit + parsed.unroll_implicit + 1; + tir::RewriteUnroll(sch, unroll_explicit, max_step, loop_rvs[0]); + } + } + } + return true; + } + + static constexpr const char* _type_key = "meta_schedule.RewriteParallelVectorizeUnroll"; + TVM_DECLARE_FINAL_OBJECT_INFO(RewriteParallelVectorizeUnrollNode, PostprocNode); +}; + +Postproc Postproc::RewriteParallelVectorizeUnroll() { + ObjectPtr n = + make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(RewriteParallelVectorizeUnrollNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocRewriteParallelVectorizeUnroll") + .set_body_typed(Postproc::RewriteParallelVectorizeUnroll); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/postproc/verify_gpu_code.cc b/src/meta_schedule/postproc/verify_gpu_code.cc new file mode 100644 index 000000000000..71f6961a830a --- /dev/null +++ b/src/meta_schedule/postproc/verify_gpu_code.cc @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace meta_schedule { + +/*! \brief Verify the correctness of the generated GPU code. */ +Integer Extract(const Target& target, const char* name) { + ICHECK(target.defined()); + if (Optional v = target->GetAttr(name)) { + return v.value(); + } + LOG(FATAL) << "AttributedError: \"" << name << "\" is not defined in the target"; + throw; +} + +/*! \brief Verify the correctness of the generated GPU code. */ +class VerifyGPUCodeNode : public PostprocNode { + public: + Map target_constraints_{nullptr}; + + void InitializeWithTuneContext(const TuneContext& context) final { + ICHECK(context->target.defined()); + Target target = context->target.value(); + this->target_constraints_ = Map{ + {"max_shared_memory_per_block", Extract(target, "shared_memory_per_block")}, + {"max_local_memory_per_block", Extract(target, "registers_per_block")}, + {"max_threads_per_block", Extract(target, "max_threads_per_block")}, + {"max_vthread", Integer(8)}, + {"max_vector_bytes", Integer(16)}}; + } + + bool Verify(const IRModule& mod) const { + for (const auto& kv : mod->functions) { + if (const auto* prim_func = kv.second.as()) { + if (!tir::VerifyGPUCode(GetRef(prim_func), this->target_constraints_)) { + return false; + } + } + } + return true; + } + + bool Apply(const tir::Schedule& sch) final { + IRModule mod = sch->mod(); + for (const auto& kv : mod->functions) { + const GlobalVar& g_var = kv.first; + const BaseFunc& base_func = kv.second; + if (const auto* prim_func = base_func.as()) { + IRModule lowered{nullptr}; + try { + lowered = LowerPrimFunc(GetRef(prim_func), g_var->name_hint); + } catch (const dmlc::Error& e) { + return false; + } + if (!Verify(mod)) { + return false; + } + } + } + return true; + } + + static constexpr const char* _type_key = "meta_schedule.VerifyGPUCode"; + TVM_DECLARE_FINAL_OBJECT_INFO(VerifyGPUCodeNode, PostprocNode); +}; + +Postproc Postproc::VerifyGPUCode() { + ObjectPtr n = make_object(); + return Postproc(n); +} + +TVM_REGISTER_NODE_TYPE(VerifyGPUCodeNode); +TVM_REGISTER_GLOBAL("meta_schedule.PostprocVerifyGPUCode").set_body_typed(Postproc::VerifyGPUCode); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/utils.h b/src/meta_schedule/utils.h index df335c4a15f1..15d0c3f9f874 100644 --- a/src/meta_schedule/utils.h +++ b/src/meta_schedule/utils.h @@ -20,6 +20,7 @@ #define TVM_META_SCHEDULE_UTILS_H_ #include +#include #include #include #include @@ -32,11 +33,7 @@ #include #include #include -#include -#include #include -#include -#include #include #include @@ -226,7 +223,7 @@ inline IRModule DeepCopyIRModule(IRModule mod) { * \brief Get the BlockRV from a block StmtSRef * \param sch The schedule * \param block_sref The block StmtSRef - * \param globla_var_name The global variable name + * \param global_var_name The global variable name * \return The BlockRV */ inline tir::BlockRV GetRVFromSRef(const tir::Schedule& sch, const tir::StmtSRef& block_sref, diff --git a/src/tir/analysis/verify_gpu_code.cc b/src/tir/analysis/verify_gpu_code.cc index dc1ed1c193e8..d01788e92c4c 100644 --- a/src/tir/analysis/verify_gpu_code.cc +++ b/src/tir/analysis/verify_gpu_code.cc @@ -92,7 +92,7 @@ class GPUCodeVerifier : public StmtExprVisitor { const auto* extent = op->value.as(); ICHECK(extent); - std::string name = var.get()->name_hint; + std::string name = op->node.as()->thread_tag; // record the number of threads in a block if (name == "threadIdx.x" || name == "threadIdx.y" || name == "threadIdx.z" || name == "vthread") { @@ -151,6 +151,7 @@ class GPUCodeVerifier : public StmtExprVisitor { errors_.push_back(s.str()); } }; + err("threads per block", thread_per_block_, max_threads_per_block_); err("local memory per block", local_memory_per_block_, max_local_memory_per_block_); err("shared memory per block", shared_memory_per_block_, max_shared_memory_per_block_); diff --git a/tests/python/unittest/test_meta_schedule_postproc.py b/tests/python/unittest/test_meta_schedule_postproc.py index 7a448ec09f07..a03cbbdbc3c7 100644 --- a/tests/python/unittest/test_meta_schedule_postproc.py +++ b/tests/python/unittest/test_meta_schedule_postproc.py @@ -19,15 +19,14 @@ import re import tvm -from tvm.script import tir as T - -from tvm.meta_schedule.postproc import PyPostproc from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import PyPostproc from tvm.meta_schedule.utils import _get_hex_address - +from tvm.script import tir as T +from tvm.target.target import Target from tvm.tir.schedule import Schedule -# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant # fmt: off @tvm.script.ir_module @@ -45,8 +44,46 @@ def main(a: T.handle, b: T.handle, c: T.handle) -> None: C[vi, vj] = 0.0 C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +@tvm.script.ir_module +class Conv_cuda0: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + # fmt: on -# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable def _check_correct(schedule: Schedule): @@ -80,7 +117,7 @@ def apply(self, sch: Schedule) -> bool: try: tvm.ir.assert_structural_equal(sch.mod, mod) raise Exception("The postprocessors did not change the schedule.") - except (ValueError): + except ValueError: _check_correct(sch) diff --git a/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py b/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py new file mode 100644 index 000000000000..d27e3e61084f --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_disallow_dynamic_loop.py @@ -0,0 +1,100 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import DisallowDynamicLoop +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("cuda", host="llvm") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + DisallowDynamicLoop(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument +# fmt: off + +@tvm.script.ir_module +class Matmul: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + + +@tvm.script.ir_module +class DynamicLoop: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + for i, j in T.grid(1024, 1024): + for k in T.serial(0, i): + with T.block("matmul"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument + + +def test_postproc_disallow_dynamic_loops(): + mod = Matmul + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert ctx.postprocs[0].apply(sch) + + +def test_postproc_disallow_dynamic_loops_fail(): + mod = DynamicLoop + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +if __name__ == "__main__": + test_postproc_disallow_dynamic_loops() + test_postproc_disallow_dynamic_loops_fail() diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py index 70efb402c372..ec40c592a82d 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_cooperative_fetch.py @@ -34,7 +34,7 @@ def _create_context(mod, target) -> TuneContext: ctx = TuneContext( mod=mod, target=target, - postproc=[ + postprocs=[ RewriteCooperativeFetch(), ], task_name="test", diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py new file mode 100644 index 000000000000..ae60803d08f3 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py @@ -0,0 +1,17 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py index ef654d0874bb..93ea76ec5da4 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_reduction_block.py @@ -32,7 +32,7 @@ def _create_context(mod, target) -> TuneContext: ctx = TuneContext( mod=mod, target=target, - postproc=[ + postprocs=[ RewriteReductionBlock(), ], task_name="test", diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py index efe0a41172bf..8b062a11b538 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_unbound_block.py @@ -20,10 +20,8 @@ from tvm import tir from tvm.meta_schedule import TuneContext from tvm.meta_schedule.postproc import RewriteUnboundBlock -from tvm.meta_schedule.testing import te_workload from tvm.script import tir as T from tvm.target import Target -from tvm.te import create_prim_func def _target() -> Target: @@ -34,7 +32,7 @@ def _create_context(mod, target) -> TuneContext: ctx = TuneContext( mod=mod, target=target, - postproc=[ + postprocs=[ RewriteUnboundBlock(), ], task_name="test", diff --git a/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py new file mode 100644 index 000000000000..cdebcddf5d6d --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_postproc_verify_gpu_code.py @@ -0,0 +1,232 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring + +import tvm +from tvm import tir +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.postproc import VerifyGPUCode +from tvm.script import tir as T +from tvm.target import Target + + +def _target() -> Target: + return Target("nvidia/geforce-rtx-3080") + + +def _create_context(mod, target) -> TuneContext: + ctx = TuneContext( + mod=mod, + target=target, + postprocs=[ + VerifyGPUCode(), + ], + task_name="test", + ) + for rule in ctx.postprocs: + rule.initialize_with_tune_context(ctx) + return ctx + + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant +# fmt: off + +@tvm.script.ir_module +class Conv2dCuda0: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda1: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([6400000], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda2: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512000], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 8) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +@tvm.script.ir_module +class Conv2dCuda3: + @T.prim_func + def main(a: T.handle, b: T.handle) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "T.noalias": True}) + # var definition + threadIdx_x = T.env_thread("threadIdx.x") + threadIdx_y = T.env_thread("threadIdx.y") + blockIdx_x = T.env_thread("blockIdx.x") + blockIdx_y = T.env_thread("blockIdx.y") + blockIdx_z = T.env_thread("blockIdx.z") + A = T.match_buffer(a, [14, 14, 256, 256], dtype="float32") + B = T.match_buffer(b, [14, 14, 512, 256], dtype="float32") + # body + T.launch_thread(blockIdx_z, 196) + B_local = T.allocate([64], "float32", "local") + Apad_shared = T.allocate([512], "float32", "shared") + Apad_shared_local = T.allocate([8], "float32", "local") + T.launch_thread(blockIdx_y, 8) + T.launch_thread(blockIdx_x, 4) + T.launch_thread(threadIdx_y, 8) + T.launch_thread(threadIdx_x, 800000) + for ff_c_init, nn_c_init in T.grid(8, 8): + T.store(B_local, ff_c_init * 8 + nn_c_init, T.float32(0), True) + for rc_outer, ry, rx in T.grid(32, 3, 3): + for ax3_inner_outer in T.serial(0, 2): + T.store(Apad_shared, T.ramp(threadIdx_y * 64 + threadIdx_x * 8 + ax3_inner_outer * 4, 1, 4), T.if_then_else(1 <= blockIdx_z // 14 + ry and blockIdx_z // 14 + ry < 15 and 1 <= rx + blockIdx_z % 14 and rx + blockIdx_z % 14 < 15, T.load("float32x4", A.data, T.ramp(ry * 917504 + blockIdx_z * 65536 + rx * 65536 + rc_outer * 2048 + threadIdx_y * 256 + blockIdx_x * 64 + threadIdx_x * 8 + ax3_inner_outer * 4 - 983040, 1, 4), T.broadcast(True, 4)), T.broadcast(T.float32(0), 4), dtype="float32x4"), T.broadcast(True, 4)) + for rc_inner in T.serial(0, 8): + for ax3 in T.serial(0, 8): + T.store(Apad_shared_local, ax3, T.load("float32", Apad_shared, rc_inner * 64 + threadIdx_x * 8 + ax3), True) + for ff_c, nn_c in T.grid(8, 8): + T.store(B_local, ff_c * 8 + nn_c, T.load("float32", B_local, ff_c * 8 + nn_c) + T.load("float32", Apad_shared_local, nn_c), True) + for ff_inner_inner_inner, nn_inner_inner_inner in T.grid(8, 8): + T.store(B.data, blockIdx_z * 131072 + blockIdx_y * 16384 + threadIdx_y * 2048 + ff_inner_inner_inner * 256 + blockIdx_x * 64 + threadIdx_x * 8 + nn_inner_inner_inner, T.load("float32", B_local, ff_inner_inner_inner * 8 + nn_inner_inner_inner), True)# fmt: on + + +# fmt: on +# pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable,misplaced-comparison-constant + + +def test_postproc_verify_gpu_0(): + mod = Conv2dCuda0 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_1(): + mod = Conv2dCuda1 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_2(): + mod = Conv2dCuda2 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +def test_postproc_verify_gpu_3(): + mod = Conv2dCuda3 + ctx = _create_context(mod, target=_target()) + sch = tir.Schedule(mod, debug_mask="all") + assert not ctx.postprocs[0].apply(sch) + + +if __name__ == "__main__": + test_postproc_verify_gpu_0() + test_postproc_verify_gpu_1() + test_postproc_verify_gpu_2() + test_postproc_verify_gpu_3()